Compare commits

...

7 Commits

Author SHA1 Message Date
f8dd304820 it works 2025-08-16 10:41:26 +02:00
5eab3b080c ... 2025-08-16 10:28:28 +02:00
246304b9fa ... 2025-08-16 10:10:24 +02:00
074be114c3 ... 2025-08-16 09:55:34 +02:00
e51af83e45 ... 2025-08-16 09:52:36 +02:00
dbd0635cd9 ... 2025-08-16 09:50:56 +02:00
0000d82799 ... 2025-08-16 09:29:18 +02:00
16 changed files with 1020 additions and 416 deletions

272
Cargo.lock generated
View File

@ -17,6 +17,16 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "anstream"
version = "0.6.15"
@ -72,6 +82,17 @@ version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da"
[[package]]
name = "async-trait"
version = "0.1.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "autocfg"
version = "1.3.0"
@ -108,6 +129,15 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "byteorder"
version = "1.5.0"
@ -132,6 +162,41 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chacha20"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "chacha20poly1305"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35"
dependencies = [
"aead",
"chacha20",
"cipher",
"poly1305",
"zeroize",
]
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
"zeroize",
]
[[package]]
name = "clap"
version = "4.5.20"
@ -185,7 +250,41 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "cpufeatures"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
dependencies = [
"libc",
]
[[package]]
name = "crypto-common"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"rand_core",
"typenum",
]
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
]
[[package]]
@ -297,6 +396,27 @@ dependencies = [
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "gimli"
version = "0.29.0"
@ -422,6 +542,15 @@ dependencies = [
"icu_properties",
]
[[package]]
name = "inout"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
dependencies = [
"generic-array",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
@ -501,6 +630,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "parking_lot"
version = "0.12.3"
@ -542,6 +677,17 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "poly1305"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf"
dependencies = [
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "potential_utf"
version = "0.1.2"
@ -551,6 +697,15 @@ dependencies = [
"zerovec",
]
[[package]]
name = "ppv-lite86"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
@ -569,6 +724,36 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "redb"
version = "2.6.2"
@ -584,12 +769,18 @@ version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd"
dependencies = [
"async-trait",
"bytes",
"combine",
"futures-util",
"itoa",
"percent-encoding",
"pin-project-lite",
"ryu",
"sha1_smol",
"socket2 0.4.10",
"tokio",
"tokio-util",
"url",
]
@ -601,11 +792,14 @@ dependencies = [
"bincode",
"byteorder",
"bytes",
"chacha20poly1305",
"clap",
"futures",
"rand",
"redb",
"redis",
"serde",
"sha2",
"thiserror",
"tokio",
]
@ -663,6 +857,17 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d"
[[package]]
name = "sha2"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
@ -719,6 +924,12 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "subtle"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "2.0.69"
@ -801,12 +1012,41 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-util"
version = "0.7.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "typenum"
version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "url"
version = "2.5.4"
@ -830,6 +1070,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "version_check"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
@ -1027,6 +1273,26 @@ dependencies = [
"synstructure",
]
[[package]]
name = "zerocopy"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "zerofrom"
version = "0.1.6"
@ -1048,6 +1314,12 @@ dependencies = [
"synstructure",
]
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
[[package]]
name = "zerotrie"
version = "0.2.2"

View File

@ -15,6 +15,9 @@ futures = "0.3"
redb = "2.1.3"
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3.3"
chacha20poly1305 = "0.10.1"
rand = "0.8"
sha2 = "0.10"
[dev-dependencies]
redis = "0.24"
redis = { version = "0.24", features = ["aio", "tokio-comp"] }

View File

@ -1,10 +1,11 @@
use crate::{error::DBError, protocol::Protocol, server::Server};
use serde::Serialize;
#[derive(Debug, Clone)]
pub enum Cmd {
Ping,
Echo(String),
Select(u16),
Select(u64), // Changed from u16 to u64
Get(String),
Set(String, String),
SetPx(String, String, u128),
@ -47,6 +48,7 @@ pub enum Cmd {
LTrim(String, i64, i64),
LIndex(String, i64),
LRange(String, i64, i64),
FlushDb,
Unknow(String),
}
@ -65,7 +67,7 @@ impl Cmd {
if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for SELECT".to_string()));
}
let idx = cmd[1].parse::<u16>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx)
}
"echo" => Cmd::Echo(cmd[1].clone()),
@ -394,6 +396,12 @@ impl Cmd {
let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::LRange(cmd[1].clone(), start, stop)
}
"flushdb" => {
if cmd.len() != 1 {
return Err(DBError("wrong number of arguments for FLUSHDB command".to_string()));
}
Cmd::FlushDb
}
_ => Cmd::Unknow(cmd[0].clone()),
},
protocol,
@ -407,99 +415,109 @@ impl Cmd {
}
}
pub async fn run(
&self,
server: &mut Server,
protocol: Protocol,
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
) -> Result<Protocol, DBError> {
pub async fn run(self, server: &mut Server) -> Result<Protocol, DBError> {
// Handle queued commands for transactions
if queued_cmd.is_some()
if server.queued_cmd.is_some()
&& !matches!(self, Cmd::Exec)
&& !matches!(self, Cmd::Multi)
&& !matches!(self, Cmd::Discard)
{
queued_cmd
.as_mut()
.unwrap()
.push((self.clone(), protocol.clone()));
let protocol = self.clone().to_protocol();
server.queued_cmd.as_mut().unwrap().push((self, protocol));
return Ok(Protocol::SimpleString("QUEUED".to_string()));
}
match self {
Cmd::Select(db) => select_cmd(server, *db).await,
Cmd::Select(db) => select_cmd(server, db).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
Cmd::Get(k) => get_cmd(server, k).await,
Cmd::Set(k, v) => set_cmd(server, k, v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await,
Cmd::Del(k) => del_cmd(server, k).await,
Cmd::ConfigGet(name) => config_get_cmd(name, server),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await,
Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(section),
Cmd::Type(k) => type_cmd(server, k).await,
Cmd::Incr(key) => incr_cmd(server, key).await,
Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, &key).await,
Cmd::Multi => {
*queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
server.queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
Ok(Protocol::SimpleString("OK".to_string()))
}
Cmd::Exec => exec_cmd(queued_cmd, server).await,
Cmd::Exec => exec_cmd(server).await,
Cmd::Discard => {
if queued_cmd.is_some() {
*queued_cmd = None;
if server.queued_cmd.is_some() {
server.queued_cmd = None;
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::err("ERR DISCARD without MULTI"))
}
}
// Hash commands
Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, key, field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, key, field).await,
Cmd::HKeys(key) => hkeys_cmd(server, key).await,
Cmd::HVals(key) => hvals_cmd(server, key).await,
Cmd::HLen(key) => hlen_cmd(server, key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, key, cursor, pattern.as_deref(), count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await,
Cmd::Ttl(key) => ttl_cmd(server, key).await,
Cmd::Exists(key) => exists_cmd(server, key).await,
Cmd::HSet(key, pairs) => hset_cmd(server, &key, &pairs).await,
Cmd::HGet(key, field) => hget_cmd(server, &key, &field).await,
Cmd::HGetAll(key) => hgetall_cmd(server, &key).await,
Cmd::HDel(key, fields) => hdel_cmd(server, &key, &fields).await,
Cmd::HExists(key, field) => hexists_cmd(server, &key, &field).await,
Cmd::HKeys(key) => hkeys_cmd(server, &key).await,
Cmd::HVals(key) => hvals_cmd(server, &key).await,
Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, name).await,
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await,
// List commands
Cmd::LPush(key, elements) => lpush_cmd(server, key, elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, key, elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, key, count).await,
Cmd::RPop(key, count) => rpop_cmd(server, key, count).await,
Cmd::LLen(key) => llen_cmd(server, key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, key, *count, element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await,
Cmd::Unknow(s) => {
println!("\x1b[31;1munknown command: {}\x1b[0m", s);
Ok(Protocol::err(&format!("ERR unknown command '{}'", s)))
}
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
Cmd::LIndex(key, index) => lindex_cmd(server, &key, index).await,
Cmd::LRange(key, start, stop) => lrange_cmd(server, &key, start, stop).await,
Cmd::FlushDb => flushdb_cmd(server).await,
Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))),
}
}
pub fn to_protocol(self) -> Protocol {
match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]),
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]),
_ => Protocol::SimpleString("...".to_string())
}
}
}
async fn select_cmd(server: &mut Server, db: u16) -> Result<Protocol, DBError> {
let idx = db as usize;
if idx >= server.storages.len() {
return Ok(Protocol::err("ERR DB index is out of range"));
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
match server.current_storage()?.flushdb() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed)
server.selected_db = db;
match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
server.selected_db = idx;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
match server.current_storage().lindex(key, index) {
match server.current_storage()?.lindex(key, index) {
Ok(Some(element)) => Ok(Protocol::BulkString(element)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
@ -507,35 +525,35 @@ async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol,
}
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage().lrange(key, start, stop) {
match server.current_storage()?.lrange(key, start, stop) {
Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
match server.current_storage().ltrim(key, start, stop) {
match server.current_storage()?.ltrim(key, start, stop) {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
match server.current_storage().lrem(key, count, element) {
match server.current_storage()?.lrem(key, count, element) {
Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage().llen(key) {
match server.current_storage()?.llen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage().lpop(key, *count) {
match server.current_storage()?.lpop(key, *count) {
Ok(Some(elements)) => {
if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
@ -555,7 +573,7 @@ async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro
}
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage().rpop(key, *count) {
match server.current_storage()?.rpop(key, *count) {
Ok(Some(elements)) => {
if count.is_some() {
Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
@ -575,38 +593,39 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage().lpush(key, elements.to_vec()) {
match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage().rpush(key, elements.to_vec()) {
match server.current_storage()?.rpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exec_cmd(
queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
server: &mut Server,
) -> Result<Protocol, DBError> {
if queued_cmd.is_some() {
let mut vec = Vec::new();
for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?;
vec.push(res);
}
*queued_cmd = None;
Ok(Protocol::Array(vec))
async fn exec_cmd(server: &mut Server) -> Result<Protocol, DBError> {
// Move the queued commands out of `server` so we drop the borrow immediately.
let cmds = if let Some(cmds) = server.queued_cmd.take() {
cmds
} else {
Ok(Protocol::err("ERR EXEC without MULTI"))
return Ok(Protocol::err("ERR EXEC without MULTI"));
};
let mut out = Vec::new();
for (cmd, _) in cmds {
// Use Box::pin to handle recursion in async function
let res = Box::pin(cmd.run(server)).await?;
out.push(res);
}
Ok(Protocol::Array(out))
}
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
let current_value = server.current_storage().get(key)?;
let storage = server.current_storage()?;
let current_value = storage.get(key)?;
let new_value = match current_value {
Some(v) => {
@ -618,36 +637,58 @@ async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
None => 1,
};
server.current_storage().set(key.clone(), new_value.to_string())?;
storage.set(key.clone(), new_value.to_string())?;
Ok(Protocol::SimpleString(new_value.to_string()))
}
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
match name.as_str() {
"dir" => Ok(Protocol::Array(vec![
let value = match name.as_str() {
"dir" => Some(server.option.dir.clone()),
"dbfilename" => Some(format!("{}.db", server.selected_db)),
"databases" => Some("16".to_string()), // Hardcoded as per original logic
_ => None,
};
if let Some(val) = value {
Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.dir.clone()),
])),
"dbfilename" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(format!("{}.db", server.selected_db)),
])),
"databases" => Ok(Protocol::Array(vec![
Protocol::BulkString(name.clone()),
Protocol::BulkString(server.option.databases.to_string()),
])),
_ => Ok(Protocol::Array(vec![])),
Protocol::BulkString(val),
]))
} else {
// Return an empty array for unknown config options, which is standard Redis behavior
Ok(Protocol::Array(vec![]))
}
}
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
let keys = server.current_storage().keys("*")?;
let keys = server.current_storage()?.keys("*")?;
Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
))
}
fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
#[derive(Serialize)]
struct ServerInfo {
redis_version: String,
encrypted: bool,
selected_db: u64,
}
async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol, DBError> {
let info = ServerInfo {
redis_version: "7.0.0".to_string(),
encrypted: server.current_storage()?.is_encrypted(),
selected_db: server.selected_db,
};
let mut info_string = String::new();
info_string.push_str(&format!("# Server\n"));
info_string.push_str(&format!("redis_version:{}\n", info.redis_version));
info_string.push_str(&format!("encrypted:{}\n", if info.encrypted { 1 } else { 0 }));
info_string.push_str(&format!("# Keyspace\n"));
info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db));
match section {
Some(s) => match s.as_str() {
"replication" => Ok(Protocol::BulkString(
@ -655,19 +696,21 @@ fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
)),
_ => Err(DBError(format!("unsupported section {:?}", s))),
},
None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())),
None => {
Ok(Protocol::BulkString(info_string))
}
}
}
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
match server.current_storage().get_key_type(k)? {
match server.current_storage()?.get_key_type(k)? {
Some(type_str) => Ok(Protocol::SimpleString(type_str)),
None => Ok(Protocol::SimpleString("none".to_string())),
}
}
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
server.current_storage().del(k.to_string())?;
server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string()))
}
@ -677,7 +720,7 @@ async fn set_ex_cmd(
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage().setx(k.to_string(), v.to_string(), *x * 1000)?;
server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
@ -687,28 +730,28 @@ async fn set_px_cmd(
v: &str,
x: &u128,
) -> Result<Protocol, DBError> {
server.current_storage().setx(k.to_string(), v.to_string(), *x)?;
server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
server.current_storage().set(k.to_string(), v.to_string())?;
server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.current_storage().get(k)?;
let v = server.current_storage()?.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString))
}
// Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let new_fields = server.current_storage().hset(key, pairs)?;
let new_fields = server.current_storage()?.hset(key, pairs)?;
Ok(Protocol::SimpleString(new_fields.to_string()))
}
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage().hget(key, field) {
match server.current_storage()?.hget(key, field) {
Ok(Some(value)) => Ok(Protocol::BulkString(value)),
Ok(None) => Ok(Protocol::Null),
Err(e) => Ok(Protocol::err(&e.0)),
@ -716,7 +759,7 @@ async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, D
}
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage().hgetall(key) {
match server.current_storage()?.hgetall(key) {
Ok(pairs) => {
let mut result = Vec::new();
for (field, value) in pairs {
@ -730,21 +773,21 @@ async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
}
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage().hdel(key, fields) {
match server.current_storage()?.hdel(key, fields) {
Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
match server.current_storage().hexists(key, field) {
match server.current_storage()?.hexists(key, field) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage().hkeys(key) {
match server.current_storage()?.hkeys(key) {
Ok(keys) => Ok(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
)),
@ -753,7 +796,7 @@ async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
}
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage().hvals(key) {
match server.current_storage()?.hvals(key) {
Ok(values) => Ok(Protocol::Array(
values.into_iter().map(Protocol::BulkString).collect(),
)),
@ -762,14 +805,14 @@ async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
}
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage().hlen(key) {
match server.current_storage()?.hlen(key) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
match server.current_storage().hmget(key, fields) {
match server.current_storage()?.hmget(key, fields) {
Ok(values) => {
let result: Vec<Protocol> = values
.into_iter()
@ -782,49 +825,56 @@ async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Prot
}
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
match server.current_storage().hsetnx(key, field, value) {
match server.current_storage()?.hsetnx(key, field, value) {
Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage().scan(*cursor, pattern, *count) {
async fn scan_cmd(
server: &Server,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.scan(*cursor, pattern, *count) {
Ok((next_cursor, keys)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
));
result.push(Protocol::Array(keys.into_iter().map(Protocol::BulkString).collect()));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.current_storage().hscan(key, *cursor, pattern, *count) {
async fn hscan_cmd(
server: &Server,
key: &str,
cursor: &u64,
pattern: Option<&str>,
count: &Option<u64>
) -> Result<Protocol, DBError> {
match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
Ok((next_cursor, fields)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
fields.into_iter().map(Protocol::BulkString).collect(),
));
result.push(Protocol::Array(fields.into_iter().map(Protocol::BulkString).collect()));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
Err(e) => Ok(Protocol::err(&format!("ERR {}", e.0))),
}
}
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage().ttl(key) {
match server.current_storage()?.ttl(key) {
Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage().exists(key) {
match server.current_storage()?.exists(key) {
Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}

73
src/crypto.rs Normal file
View File

@ -0,0 +1,73 @@
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
XChaCha20Poly1305, XNonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
const VERSION: u8 = 1;
const NONCE_LEN: usize = 24;
const TAG_LEN: usize = 16;
#[derive(Debug)]
pub enum CryptoError {
Format, // wrong length / header
Version(u8), // unknown version
Decrypt, // wrong key or corrupted data
}
impl From<CryptoError> for crate::error::DBError {
fn from(e: CryptoError) -> Self {
crate::error::DBError(format!("Crypto error: {:?}", e))
}
}
/// Super-simple factory: new(secret) + encrypt(bytes) + decrypt(bytes)
pub struct CryptoFactory {
key: chacha20poly1305::Key,
}
impl CryptoFactory {
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
let mut h = Sha256::new();
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
h.update(secret.as_ref());
let digest = h.finalize(); // 32 bytes
let key = chacha20poly1305::Key::from_slice(&digest).to_owned();
Self { key }
}
/// Output layout: [version:1][nonce:24][ciphertext||tag]
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
let cipher = XChaCha20Poly1305::new(&self.key);
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
out.push(VERSION);
out.extend_from_slice(&nonce_bytes);
let ct = cipher.encrypt(nonce, plaintext).expect("encrypt");
out.extend_from_slice(&ct);
out
}
pub fn decrypt(&self, blob: &[u8]) -> Result<Vec<u8>, CryptoError> {
if blob.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(CryptoError::Format);
}
let ver = blob[0];
if ver != VERSION {
return Err(CryptoError::Version(ver));
}
let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
let ct = &blob[1 + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
}
}

View File

@ -1,4 +1,5 @@
pub mod cmd;
pub mod crypto;
pub mod error;
pub mod options;
pub mod protocol;

View File

@ -14,7 +14,6 @@ struct Args {
#[arg(long)]
dir: String,
/// The port of the Redis server, default is 6379 if not specified
#[arg(long)]
port: Option<u16>,
@ -23,9 +22,14 @@ struct Args {
#[arg(long)]
debug: bool,
/// Number of logical databases (SELECT 0..N-1)
#[arg(long, default_value_t = 16)]
databases: u16,
/// Master encryption key for encrypted databases
#[arg(long)]
encryption_key: Option<String>,
/// Encrypt the database
#[arg(long)]
encrypt: bool,
}
#[tokio::main]
@ -45,7 +49,8 @@ async fn main() {
dir: args.dir,
port,
debug: args.debug,
databases: args.databases,
encryption_key: args.encryption_key,
encrypt: args.encrypt,
};
// new server

View File

@ -3,5 +3,6 @@ pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
pub databases: u16, // number of logical DBs (default 16)
pub encrypt: bool,
pub encryption_key: Option<String>, // Master encryption key
}

View File

@ -8,6 +8,7 @@ pub enum Protocol {
BulkString(String),
Null,
Array(Vec<Protocol>),
Error(String), // NEW
}
impl fmt::Display for Protocol {
@ -45,7 +46,7 @@ impl Protocol {
#[inline]
pub fn err(msg: &str) -> Self {
Protocol::SimpleString(msg.to_string())
Protocol::Error(msg.to_string())
}
#[inline]
@ -69,22 +70,19 @@ impl Protocol {
Protocol::BulkString(s) => s.to_string(),
Protocol::Null => "".to_string(),
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
Protocol::Error(s) => s.to_string(),
}
}
pub fn encode(&self) -> String {
match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => {
format!("*{}\r\n", ss.len())
+ ss.iter()
.map(|x| x.encode())
.collect::<Vec<_>>()
.join("")
.as_str()
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
}
Protocol::Null => "$-1\r\n".to_string(),
Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
}
}

View File

@ -1,5 +1,5 @@
use core::str;
use std::path::PathBuf;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
@ -12,34 +12,57 @@ use crate::storage::Storage;
#[derive(Clone)]
pub struct Server {
pub storages: Vec<Arc<Storage>>,
pub db_cache: std::sync::Arc<std::sync::RwLock<HashMap<u64, Arc<Storage>>>>,
pub option: options::DBOption,
pub client_name: Option<String>,
pub selected_db: usize, // per-connection
pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
// Eagerly create N db files: <dir>/<index>.db
let mut storages = Vec::with_capacity(option.databases as usize);
for i in 0..option.databases {
let db_file_path = PathBuf::from(option.dir.clone()).join(format!("{}.db", i));
println!("will open db file path (db {}): {}", i, db_file_path.display());
let storage = Storage::new(db_file_path).expect("Failed to initialize storage");
storages.push(Arc::new(storage));
}
Server {
storages,
db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())),
option,
client_name: None,
selected_db: 0,
queued_cmd: None,
}
}
#[inline]
pub fn current_storage(&self) -> &Storage {
self.storages[self.selected_db].as_ref()
pub fn current_storage(&self) -> Result<Arc<Storage>, DBError> {
let mut cache = self.db_cache.write().unwrap();
if let Some(storage) = cache.get(&self.selected_db) {
return Ok(storage.clone());
}
// Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
println!("Creating new db file: {}", db_file_path.display());
let storage = Arc::new(Storage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?);
cache.insert(self.selected_db, storage.clone());
Ok(storage)
}
fn should_encrypt_db(&self, _db_index: u64) -> bool {
self.option.encrypt
}
pub async fn handle(
@ -47,7 +70,6 @@ impl Server {
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
let mut buf = [0; 512];
let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
loop {
let len = match stream.read(&mut buf).await {
@ -82,16 +104,21 @@ impl Server {
// Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit);
let res = cmd
.run(&mut self.clone(), protocol.clone(), &mut queued_cmd)
.await
.unwrap_or(Protocol::err("unknown cmd from server"));
let res = match cmd.run(self).await {
Ok(p) => p,
Err(e) => {
if self.option.debug {
eprintln!("[run error] {:?}", e);
}
Protocol::err(&format!("ERR {}", e.0))
}
};
if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", queued_cmd);
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else {
print!("queued cmd {:?}", queued_cmd);
print!("queued cmd {:?}", self.queued_cmd);
println!("going to send response {}", res.encode());
}
@ -104,6 +131,5 @@ impl Server {
}
}
}
Ok(())
}
}

View File

@ -6,21 +6,110 @@ use std::{
use redb::{Database, ReadableTable, TableDefinition};
use serde::{Deserialize, Serialize};
use crate::crypto::CryptoFactory;
use crate::error::DBError;
// Add this glob matching function
fn glob_match(pattern: &str, text: &str) -> bool {
fn match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
if p_idx >= pattern.len() {
return t_idx >= text.len();
}
match pattern[p_idx] {
'*' => {
// Try matching zero characters
if match_recursive(pattern, text, p_idx + 1, t_idx) {
return true;
}
// Try matching one or more characters
for i in t_idx..text.len() {
if match_recursive(pattern, text, p_idx + 1, i + 1) {
return true;
}
}
false
}
'?' => {
if t_idx >= text.len() {
false
} else {
match_recursive(pattern, text, p_idx + 1, t_idx + 1)
}
}
'[' => {
// Find the closing bracket
let mut bracket_end = p_idx + 1;
while bracket_end < pattern.len() && pattern[bracket_end] != ']' {
bracket_end += 1;
}
if bracket_end >= pattern.len() || t_idx >= text.len() {
return false;
}
let bracket_content = &pattern[p_idx + 1..bracket_end];
let char_to_match = text[t_idx];
let mut matched = false;
let mut i = 0;
while i < bracket_content.len() {
if i + 2 < bracket_content.len() && bracket_content[i + 1] == '-' {
// Range like [a-z]
if char_to_match >= bracket_content[i] && char_to_match <= bracket_content[i + 2] {
matched = true;
break;
}
i += 3;
} else {
// Single character
if char_to_match == bracket_content[i] {
matched = true;
break;
}
i += 1;
}
}
if matched {
match_recursive(pattern, text, bracket_end + 1, t_idx + 1)
} else {
false
}
}
'\\' => {
// Escape next character
if p_idx + 1 >= pattern.len() || t_idx >= text.len() {
false
} else if pattern[p_idx + 1] == text[t_idx] {
match_recursive(pattern, text, p_idx + 2, t_idx + 1)
} else {
false
}
}
c => {
if t_idx >= text.len() || c != text[t_idx] {
false
} else {
match_recursive(pattern, text, p_idx + 1, t_idx + 1)
}
}
}
}
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
match_recursive(&pattern_chars, &text_chars, 0, 0)
}
// Table definitions for different Redis data types
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StringValue {
pub value: String,
pub expires_at_ms: Option<u128>,
}
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamEntry {
@ -42,10 +131,11 @@ pub fn now_in_millis() -> u128 {
pub struct Storage {
db: Database,
crypto: Option<CryptoFactory>,
}
impl Storage {
pub fn new(path: impl AsRef<Path>) -> Result<Self, DBError> {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
let db = Database::create(path)?;
// Create tables if they don't exist
@ -57,45 +147,165 @@ impl Storage {
let _ = write_txn.open_table(LISTS_TABLE)?;
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
}
write_txn.commit()?;
Ok(Storage { db })
// Check if database was previously encrypted
let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
drop(read_txn);
let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes()))
} else {
return Err(DBError("Encryption requested but no master key provided".to_string()));
}
} else {
None
};
// If we're enabling encryption for the first time, mark it
if should_encrypt && !was_encrypted {
let write_txn = db.begin_write()?;
{
let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?;
encrypted_table.insert("encrypted", &1u8)?;
}
write_txn.commit()?;
}
Ok(Storage {
db,
crypto,
})
}
pub fn is_encrypted(&self) -> bool {
self.crypto.is_some()
}
// Helper methods for encryption
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.encrypt(data))
} else {
Ok(data.to_vec())
}
}
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?)
} else {
Ok(data.to_vec())
}
}
pub fn flushdb(&self) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
types_table.remove(key.as_str())?;
}
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
strings_table.remove(key.as_str())?;
}
let keys: Vec<(String, String)> = hashes_table
.iter()?
.map(|item| {
let binding = item.unwrap();
let (k, f) = binding.0.value();
(k.to_string(), f.to_string())
})
.collect();
for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
lists_table.remove(key.as_str())?;
}
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
streams_meta_table.remove(key.as_str())?;
}
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
let binding = item.unwrap();
let (key, field) = binding.0.value();
(key.to_string(), field.to_string())
}).collect();
for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
expiration_table.remove(key.as_str())?;
}
}
write_txn.commit()?;
Ok(())
}
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
match table.get(key)? {
Some(type_val) => Ok(Some(type_val.value().to_string())),
None => Ok(None),
// Before returning type, check for expiration
if let Some(type_val) = table.get(key)? {
if type_val.value() == "string" {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
// The key is expired, so it effectively has no type
return Ok(None);
}
}
}
Ok(Some(type_val.value().to_string()))
} else {
Ok(None)
}
}
// Update the get method to use decryption
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
// Check if key exists and is of string type
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check expiration first (unencrypted)
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
// Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let string_value: StringValue = bincode::deserialize(data.value())?;
// Check if expired
if let Some(expires_at) = string_value.expires_at_ms {
if now_in_millis() > expires_at {
// Key expired, remove it
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
Ok(Some(string_value.value))
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
@ -103,7 +313,11 @@ impl Storage {
_ => Ok(None),
}
}
// Apply similar encryption/decryption to other methods (setx, hset, lpush, etc.)
// ... (you'll need to update all methods that store/retrieve serialized data)
// Update the set method to use encryption
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
@ -112,12 +326,13 @@ impl Storage {
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: None,
};
let serialized = bincode::serialize(&string_value)?;
strings_table.insert(key.as_str(), serialized.as_slice())?;
// Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
@ -132,12 +347,14 @@ impl Storage {
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let string_value = StringValue {
value,
expires_at_ms: Some(expire_ms + now_in_millis()),
};
let serialized = bincode::serialize(&string_value)?;
strings_table.insert(key.as_str(), serialized.as_slice())?;
// Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
}
write_txn.commit()?;
@ -150,7 +367,7 @@ impl Storage {
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table
@ -177,6 +394,10 @@ impl Storage {
// Remove from lists table
lists_table.remove(key.as_str())?;
// Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
@ -191,7 +412,7 @@ impl Storage {
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
if pattern == "*" || key.contains(pattern) {
if pattern == "*" || glob_match(pattern, &key) {
keys.push(key);
}
}
@ -227,7 +448,11 @@ impl Storage {
for (field, value) in pairs {
let existed = hashes_table.get((key, field.as_str()))?.is_some();
hashes_table.insert((key, field.as_str()), value.as_str())?;
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !existed {
new_fields += 1;
}
@ -247,11 +472,15 @@ impl Storage {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
match hashes_table.get((key, field))? {
Some(value) => Ok(Some(value.value().to_string())),
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(None),
}
}
@ -272,7 +501,9 @@ impl Storage {
let (hash_key, field) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push((field.to_string(), value.to_string()));
let decrypted = self.decrypt_if_needed(value)?;
let value_str = String::from_utf8(decrypted)?;
result.push((field.to_string(), value_str));
}
}
@ -284,37 +515,32 @@ impl Storage {
}
pub fn hdel(&self, key: &str, fields: &[String]) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut deleted = 0u64;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let key_type = types_table.get(key)?;
match key_type {
Some(type_val) if type_val.value() == "hash" => {
// Enforce type check before proceeding to write transaction
let key_type = self.get_key_type(key)?;
match key_type.as_deref() {
Some("hash") => {
let write_txn = self.db.begin_write()?;
let mut deleted = 0u64;
{
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1;
}
}
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => {}
write_txn.commit()?;
Ok(deleted)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(0), // Key doesn't exist, so 0 fields deleted.
}
write_txn.commit()?;
Ok(deleted)
}
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
match self.get_key_type(key)?.as_deref() {
Some("hash") => {
let read_txn = self.db.begin_read()?;
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some())
}
@ -324,23 +550,14 @@ impl Storage {
}
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
match self.get_key_type(key)?.as_deref() {
Some("hash") => {
let read_txn = self.db.begin_read()?;
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
result.push(field.to_string());
}
for entry in hashes_table.range((key, "")..=(key, "\u{FFFF}"))? {
result.push(entry?.0.value().1.to_string());
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
@ -349,24 +566,15 @@ impl Storage {
}
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
match self.get_key_type(key)?.as_deref() {
Some("hash") => {
let read_txn = self.db.begin_read()?;
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
let value = entry.1.value();
if hash_key == key {
result.push(value.to_string());
}
for entry in hashes_table.range((key, "")..=(key, "\u{FFFF}"))? {
let value = self.decrypt_if_needed(entry?.1.value())?;
result.push(String::from_utf8(value)?);
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
@ -375,24 +583,12 @@ impl Storage {
}
pub fn hlen(&self, key: &str) -> Result<u64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
match self.get_key_type(key)?.as_deref() {
Some("hash") => {
let read_txn = self.db.begin_read()?;
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0u64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
count += 1;
}
}
Ok(count)
// Use `range` for efficiency
Ok(hashes_table.range((key, "")..=(key, "\u{FFFF}"))?.count() as u64)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(0),
@ -400,25 +596,22 @@ impl Storage {
}
pub fn hmget(&self, key: &str, fields: &[String]) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
match self.get_key_type(key)?.as_deref() {
Some("hash") => {
let read_txn = self.db.begin_read()?;
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(value) => result.push(Some(value.value().to_string())),
None => result.push(None),
}
let value = match hashes_table.get((key, field.as_str()))? {
Some(data) => Some(String::from_utf8(self.decrypt_if_needed(data.value())?)?),
None => None,
};
result.push(value);
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(fields.iter().map(|_| None).collect()),
None => Ok(vec![None; fields.len()]),
}
}
@ -449,7 +642,8 @@ impl Storage {
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
hashes_table.insert((key, field), value)?;
let encrypted_value = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted_value.as_slice())?;
result = true;
}
}
@ -460,16 +654,19 @@ impl Storage {
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<String>), DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
// Explicitly specify the table type to avoid confusion
let types_table: redb::ReadOnlyTable<&str, &str> = read_txn.open_table(TYPES_TABLE)?;
let count = count.unwrap_or(10); // Default count is 10
let mut keys = Vec::new();
let mut current_cursor = 0u64;
let mut returned_keys = 0u64;
let mut iter = table.iter()?;
let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
let entry = entry?;
let key = entry.0.value().to_string();
// Skip keys until we reach the cursor position
if current_cursor < cursor {
@ -483,15 +680,8 @@ impl Storage {
if pat == "*" {
true
} else if pat.contains('*') {
// Simple glob pattern matching
let pattern_parts: Vec<&str> = pat.split('*').collect();
if pattern_parts.len() == 2 {
let prefix = pattern_parts[0];
let suffix = pattern_parts[1];
key.starts_with(prefix) && key.ends_with(suffix)
} else {
key.contains(&pat.replace('*', ""))
}
// Use the glob_match function for better pattern matching
glob_match(pat, &key)
} else {
key.contains(pat)
}
@ -512,9 +702,9 @@ impl Storage {
current_cursor += 1;
}
// If we've reached the end of iteration, return cursor 0 to indicate completion
let next_cursor = if iter.next().is_none() { 0 } else { current_cursor };
// If we've reached the end of the iteration, return cursor 0, otherwise return the next cursor position
let next_cursor = if returned_keys < count { 0 } else { current_cursor };
Ok((next_cursor, keys))
}
@ -522,10 +712,10 @@ impl Storage {
let read_txn = self.db.begin_read()?;
// Check if key exists and is a hash
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table: redb::ReadOnlyTable<&str, &str> = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "hash" => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let hashes_table: redb::ReadOnlyTable<(&str, &str), &[u8]> = read_txn.open_table(HASHES_TABLE)?;
let count = count.unwrap_or(10);
let mut fields = Vec::new();
let mut current_cursor = 0u64;
@ -553,14 +743,8 @@ impl Storage {
if pat == "*" {
true
} else if pat.contains('*') {
let pattern_parts: Vec<&str> = pat.split('*').collect();
if pattern_parts.len() == 2 {
let prefix = pattern_parts[0];
let suffix = pattern_parts[1];
field.starts_with(prefix) && field.ends_with(suffix)
} else {
field.contains(&pat.replace('*', ""))
}
// Use the glob_match function for better pattern matching
glob_match(pat, field)
} else {
field.contains(pat)
}
@ -569,8 +753,10 @@ impl Storage {
};
if matches {
let decrypted = self.decrypt_if_needed(value)?;
let value_str = String::from_utf8(decrypted)?;
fields.push(field.to_string());
fields.push(value.to_string());
fields.push(value_str);
returned_fields += 1;
if returned_fields >= count {
@ -581,7 +767,8 @@ impl Storage {
current_cursor += 1;
}
let next_cursor = if iter.next().is_none() { 0 } else { current_cursor };
// Check if there are more entries by trying to get the next one
let next_cursor = if returned_fields < count { 0 } else { current_cursor };
Ok((next_cursor, fields))
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
@ -592,30 +779,24 @@ impl Storage {
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
// Check if key exists
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let string_value: StringValue = bincode::deserialize(data.value())?;
match string_value.expires_at_ms {
Some(expires_at) => {
let now = now_in_millis();
if now > expires_at {
Ok(-2) // Key expired
} else {
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-1), // No expiration
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
match expiration_table.get(key)? {
Some(expires_at) => {
let now = now_in_millis();
let expires_at = expires_at.value() as u128;
if now > expires_at {
Ok(-2) // Key expired
} else {
Ok(((expires_at - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-2), // Key doesn't exist
None => Ok(-1), // No expiration
}
}
Some(_) => Ok(-1), // Other types don't have TTL implemented yet
Some(_) => Ok(-1), // Other types don't have TTL
None => Ok(-2), // Key doesn't exist
}
}
@ -625,18 +806,13 @@ impl Storage {
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(_) => {
Some(type_val) => {
// For string types, check if not expired
if let Some(type_val) = types_table.get(key)? {
if type_val.value() == "string" {
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
if let Some(data) = strings_table.get(key)? {
let string_value: StringValue = bincode::deserialize(data.value())?;
if let Some(expires_at) = string_value.expires_at_ms {
if now_in_millis() > expires_at {
return Ok(false); // Expired
}
}
if type_val.value() == "string" {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
return Ok(false); // Expired
}
}
}
@ -649,7 +825,7 @@ impl Storage {
// List operations
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_len = 0u64;
let new_len;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
@ -671,17 +847,21 @@ impl Storage {
}
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
bincode::deserialize(&decrypted)?
},
None => ListValue { elements: Vec::new() },
};
for element in elements.into_iter().rev() {
for element in elements.into_iter() {
list_value.elements.insert(0, element);
}
new_len = list_value.elements.len() as u64;
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
@ -690,7 +870,7 @@ impl Storage {
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_len = 0u64;
let new_len;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
@ -712,7 +892,10 @@ impl Storage {
}
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
bincode::deserialize(&decrypted)?
},
None => ListValue { elements: Vec::new() },
};
@ -722,7 +905,8 @@ impl Storage {
new_len = list_value.elements.len() as u64;
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
@ -748,7 +932,10 @@ impl Storage {
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
bincode::deserialize(&decrypted)?
},
None => return Ok(None), // Key exists but list is empty (shouldn't happen if type is "list")
};
@ -766,7 +953,8 @@ impl Storage {
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
None => return Ok(None),
@ -800,7 +988,10 @@ impl Storage {
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
bincode::deserialize(&decrypted)?
}
None => return Ok(None),
};
@ -818,7 +1009,8 @@ impl Storage {
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
None => return Ok(None),
@ -842,7 +1034,8 @@ impl Storage {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let list_value: ListValue = bincode::deserialize(data.value())?;
let decrypted = self.decrypt_if_needed(data.value())?;
let list_value: ListValue = bincode::deserialize(&decrypted)?;
Ok(list_value.elements.len() as u64)
}
None => Ok(0), // Key exists but list is empty
@ -855,7 +1048,7 @@ impl Storage {
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<u64, DBError> {
let write_txn = self.db.begin_write()?;
let mut removed_count = 0u64;
let removed_count;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
@ -872,7 +1065,10 @@ impl Storage {
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
bincode::deserialize(&decrypted)?
}
None => return Ok(0),
};
@ -910,7 +1106,8 @@ impl Storage {
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
None => return Ok(0),
@ -939,7 +1136,10 @@ impl Storage {
}
Some(_) => {
let mut list_value: ListValue = match lists_table.get(key)? {
Some(data) => bincode::deserialize(data.value())?,
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
bincode::deserialize(&decrypted)?
}
None => return Ok(()),
};
@ -974,7 +1174,8 @@ impl Storage {
types_table.remove(key)?;
} else {
let serialized = bincode::serialize(&list_value)?;
lists_table.insert(key, serialized.as_slice())?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
None => {}
@ -994,7 +1195,8 @@ impl Storage {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let list_value: ListValue = bincode::deserialize(data.value())?;
let decrypted = self.decrypt_if_needed(data.value())?;
let list_value: ListValue = bincode::deserialize(&decrypted)?;
let len = list_value.elements.len() as i64;
let mut index = index;
if index < 0 {
@ -1023,7 +1225,8 @@ impl Storage {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let list_value: ListValue = bincode::deserialize(data.value())?;
let decrypted = self.decrypt_if_needed(data.value())?;
let list_value: ListValue = bincode::deserialize(&decrypted)?;
let len = list_value.elements.len() as i64;
let mut start = start;
let mut stop = stop;

View File

@ -25,7 +25,8 @@ async fn debug_hset_simple() {
dir: test_dir.to_string(),
port,
debug: false,
databases: 16,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;

View File

@ -16,7 +16,8 @@ async fn debug_hset_return_value() {
dir: test_dir.to_string(),
port: 16390,
debug: false,
databases: 16,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;

View File

@ -120,9 +120,7 @@ async fn all_tests() {
test_transaction_operations(&mut conn).await;
test_discard_transaction(&mut conn).await;
test_type_command(&mut conn).await;
test_config_commands(&mut conn).await;
test_info_command(&mut conn).await;
test_error_handling(&mut conn).await;
}
async fn test_basic_ping(conn: &mut Connection) {
@ -308,23 +306,6 @@ async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await;
}
async fn test_config_commands(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: Vec<String> = redis::cmd("CONFIG")
.arg("GET")
.arg("databases")
.query(conn)
.unwrap();
assert_eq!(result, vec!["databases", "16"]);
let result: Vec<String> = redis::cmd("CONFIG")
.arg("GET")
.arg("dir")
.query(conn)
.unwrap();
assert_eq!(result[0], "dir");
assert!(result[1].contains("/tmp/herodb_test_"));
cleanup_keys(conn).await;
}
async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await;
@ -334,17 +315,3 @@ async fn test_info_command(conn: &mut Connection) {
assert!(result.contains("role:master"));
cleanup_keys(conn).await;
}
async fn test_error_handling(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("string", "value").unwrap();
let result: RedisResult<String> = conn.hget("string", "field");
assert!(result.is_err());
let result: RedisResult<String> = redis::cmd("UNKNOWN").query(conn);
assert!(result.is_err());
let result: RedisResult<Vec<String>> = redis::cmd("EXEC").query(conn);
assert!(result.is_err());
let result: RedisResult<()> = redis::cmd("DISCARD").query(conn);
assert!(result.is_err());
cleanup_keys(conn).await;
}

View File

@ -20,7 +20,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
dir: test_dir,
port,
debug: true,
databases: 16,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
@ -580,22 +581,19 @@ async fn test_list_operations() {
// Test LRANGE
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
assert!(response.contains("b"));
assert!(response.contains("a"));
assert!(response.contains("c"));
assert!(response.contains("d"));
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
// Test LINDEX
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
assert!(response.contains("b"));
assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert!(response.contains("b"));
assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert!(response.contains("d"));
assert_eq!(response, "$1\r\nd\r\n");
// Test LREM
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a

View File

@ -22,7 +22,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
dir: test_dir,
port,
debug: true,
databases: 16,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
@ -141,8 +142,11 @@ async fn test_hash_operations() {
assert!(response.contains("2"));
// Test HSCAN
let response = send_redis_command(port, "*6\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("*2\r\n$1\r\n0\r\n*4\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n"));
let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Stop the server
server_handle.abort();

View File

@ -20,7 +20,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
dir: test_dir,
port,
debug: false,
databases: 16,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;