From ff0659b933a0ead81ac919e89f2207d4f24ff7a9 Mon Sep 17 00:00:00 2001 From: Lee Smet Date: Mon, 25 Aug 2025 11:16:25 +0200 Subject: [PATCH] Format rust files Signed-off-by: Lee Smet --- examples/age_persist_demo.rs | 38 +- src/age.rs | 170 +++-- src/cmd.rs | 1116 ++++++++++++++++++++---------- src/crypto.rs | 8 +- src/error.rs | 5 +- src/lib.rs | 8 +- src/main.rs | 1 - src/protocol.rs | 13 +- src/search_cmd.rs | 131 ++-- src/server.rs | 68 +- src/storage/mod.rs | 58 +- src/storage/storage_basic.rs | 87 ++- src/storage/storage_extra.rs | 59 +- src/storage/storage_hset.rs | 144 ++-- src/storage/storage_lists.rs | 120 ++-- src/storage_sled/mod.rs | 469 +++++++------ src/storage_trait.rs | 29 +- tests/debug_hset.rs | 44 +- tests/debug_hset_simple.rs | 32 +- tests/debug_protocol.rs | 20 +- tests/redis_integration_tests.rs | 31 +- tests/redis_tests.rs | 398 +++++++---- tests/simple_integration_test.rs | 153 ++-- tests/simple_redis_test.rs | 122 ++-- tests/usage_suite.rs | 261 ++++--- 25 files changed, 2267 insertions(+), 1318 deletions(-) diff --git a/examples/age_persist_demo.rs b/examples/age_persist_demo.rs index 9caf3bd..c363158 100644 --- a/examples/age_persist_demo.rs +++ b/examples/age_persist_demo.rs @@ -14,25 +14,31 @@ fn read_reply(s: &mut TcpStream) -> String { let n = s.read(&mut buf).unwrap(); String::from_utf8_lossy(&buf[..n]).to_string() } -fn parse_two_bulk(reply: &str) -> Option<(String,String)> { +fn parse_two_bulk(reply: &str) -> Option<(String, String)> { let mut lines = reply.split("\r\n"); - if lines.next()? != "*2" { return None; } + if lines.next()? != "*2" { + return None; + } let _n = lines.next()?; let a = lines.next()?.to_string(); let _m = lines.next()?; let b = lines.next()?.to_string(); - Some((a,b)) + Some((a, b)) } fn parse_bulk(reply: &str) -> Option { let mut lines = reply.split("\r\n"); let hdr = lines.next()?; - if !hdr.starts_with('$') { return None; } + if !hdr.starts_with('$') { + return None; + } Some(lines.next()?.to_string()) } fn parse_simple(reply: &str) -> Option { let mut lines = reply.split("\r\n"); let hdr = lines.next()?; - if !hdr.starts_with('+') { return None; } + if !hdr.starts_with('+') { + return None; + } Some(hdr[1..].to_string()) } @@ -45,39 +51,45 @@ fn main() { let mut s = TcpStream::connect(addr).expect("connect"); // Generate & persist X25519 enc keys under name "alice" - s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap(); + s.write_all(arr(&["age", "keygen", "alice"]).as_bytes()) + .unwrap(); let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc"); // Generate & persist Ed25519 signing key under name "signer" - s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap(); + s.write_all(arr(&["age", "signkeygen", "signer"]).as_bytes()) + .unwrap(); let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign"); // Encrypt by name let msg = "hello from persistent keys"; - s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap(); + s.write_all(arr(&["age", "encryptname", "alice", msg]).as_bytes()) + .unwrap(); let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64"); println!("ciphertext b64: {}", ct_b64); // Decrypt by name - s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap(); + s.write_all(arr(&["age", "decryptname", "alice", &ct_b64]).as_bytes()) + .unwrap(); let pt = parse_bulk(&read_reply(&mut s)).expect("pt"); assert_eq!(pt, msg); println!("decrypted ok"); // Sign by name - s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap(); + s.write_all(arr(&["age", "signname", "signer", msg]).as_bytes()) + .unwrap(); let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64"); // Verify by name - s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap(); + s.write_all(arr(&["age", "verifyname", "signer", msg, &sig_b64]).as_bytes()) + .unwrap(); let ok = parse_simple(&read_reply(&mut s)).expect("verify"); assert_eq!(ok, "1"); println!("signature verified"); // List names - s.write_all(arr(&["age","list"]).as_bytes()).unwrap(); + s.write_all(arr(&["age", "list"]).as_bytes()).unwrap(); let list = read_reply(&mut s); println!("LIST -> {list}"); println!("✔ persistent AGE workflow complete."); -} \ No newline at end of file +} diff --git a/src/age.rs b/src/age.rs index 77501da..3f334e5 100644 --- a/src/age.rs +++ b/src/age.rs @@ -12,17 +12,17 @@ use std::str::FromStr; -use secrecy::ExposeSecret; -use age::{Decryptor, Encryptor}; use age::x25519; +use age::{Decryptor, Encryptor}; +use secrecy::ExposeSecret; -use ed25519_dalek::{Signature, Signer, Verifier, SigningKey, VerifyingKey}; +use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey}; use base64::{engine::general_purpose::STANDARD as B64, Engine as _}; +use crate::error::DBError; use crate::protocol::Protocol; use crate::server::Server; -use crate::error::DBError; // ---------- Internal helpers ---------- @@ -32,7 +32,7 @@ pub enum AgeWireError { Crypto(String), Utf8, SignatureLen, - NotFound(&'static str), // which kind of key was missing + NotFound(&'static str), // which kind of key was missing Storage(String), } @@ -83,34 +83,38 @@ pub fn gen_enc_keypair() -> (String, String) { } pub fn gen_sign_keypair() -> (String, String) { - use rand::RngCore; use rand::rngs::OsRng; - + use rand::RngCore; + // Generate random 32 bytes for the signing key let mut secret_bytes = [0u8; 32]; OsRng.fill_bytes(&mut secret_bytes); - + let signing_key = SigningKey::from_bytes(&secret_bytes); let verifying_key = signing_key.verifying_key(); - + // Encode as base64 for storage let signing_key_b64 = B64.encode(signing_key.to_bytes()); let verifying_key_b64 = B64.encode(verifying_key.to_bytes()); - + (verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret) } /// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext). pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result { let recipient = parse_recipient(recipient_str)?; - let enc = Encryptor::with_recipients(vec![Box::new(recipient)]) - .expect("failed to create encryptor"); // Handle Option + let enc = + Encryptor::with_recipients(vec![Box::new(recipient)]).expect("failed to create encryptor"); // Handle Option let mut out = Vec::new(); { use std::io::Write; - let mut w = enc.wrap_output(&mut out).map_err(|e| AgeWireError::Crypto(e.to_string()))?; - w.write_all(msg.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; - w.finish().map_err(|e| AgeWireError::Crypto(e.to_string()))?; + let mut w = enc + .wrap_output(&mut out) + .map_err(|e| AgeWireError::Crypto(e.to_string()))?; + w.write_all(msg.as_bytes()) + .map_err(|e| AgeWireError::Crypto(e.to_string()))?; + w.finish() + .map_err(|e| AgeWireError::Crypto(e.to_string()))?; } Ok(B64.encode(out)) } @@ -118,19 +122,27 @@ pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result Result { let id = parse_identity(identity_str)?; - let ct = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; + let ct = B64 + .decode(ct_b64.as_bytes()) + .map_err(|e| AgeWireError::Crypto(e.to_string()))?; let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?; - + // The decrypt method returns a Result let mut r = match dec { - Decryptor::Recipients(d) => d.decrypt(std::iter::once(&id as &dyn age::Identity)) + Decryptor::Recipients(d) => d + .decrypt(std::iter::once(&id as &dyn age::Identity)) .map_err(|e| AgeWireError::Crypto(e.to_string()))?, - Decryptor::Passphrase(_) => return Err(AgeWireError::Crypto("Expected recipients, got passphrase".to_string())), + Decryptor::Passphrase(_) => { + return Err(AgeWireError::Crypto( + "Expected recipients, got passphrase".to_string(), + )) + } }; - + let mut pt = Vec::new(); use std::io::Read; - r.read_to_end(&mut pt).map_err(|e| AgeWireError::Crypto(e.to_string()))?; + r.read_to_end(&mut pt) + .map_err(|e| AgeWireError::Crypto(e.to_string()))?; String::from_utf8(pt).map_err(|_| AgeWireError::Utf8) } @@ -144,7 +156,9 @@ pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result Result { let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?; - let sig_bytes = B64.decode(sig_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; + let sig_bytes = B64 + .decode(sig_b64.as_bytes()) + .map_err(|e| AgeWireError::Crypto(e.to_string()))?; if sig_bytes.len() != 64 { return Err(AgeWireError::SignatureLen); } @@ -155,30 +169,49 @@ pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result Result, AgeWireError> { - let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?; + let st = server + .current_storage() + .map_err(|e| AgeWireError::Storage(e.0))?; st.get(key).map_err(|e| AgeWireError::Storage(e.0)) } fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> { - let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?; - st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0)) + let st = server + .current_storage() + .map_err(|e| AgeWireError::Storage(e.0))?; + st.set(key.to_string(), val.to_string()) + .map_err(|e| AgeWireError::Storage(e.0)) } -fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") } -fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") } -fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") } -fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") } +fn enc_pub_key_key(name: &str) -> String { + format!("age:key:{name}") +} +fn enc_priv_key_key(name: &str) -> String { + format!("age:privkey:{name}") +} +fn sign_pub_key_key(name: &str) -> String { + format!("age:signpub:{name}") +} +fn sign_priv_key_key(name: &str) -> String { + format!("age:signpriv:{name}") +} // ---------- Command handlers (RESP Protocol) ---------- // Basic (stateless) ones kept for completeness pub async fn cmd_age_genenc() -> Protocol { let (recip, ident) = gen_enc_keypair(); - Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)]) + Protocol::Array(vec![ + Protocol::BulkString(recip), + Protocol::BulkString(ident), + ]) } pub async fn cmd_age_gensign() -> Protocol { let (verify, secret) = gen_sign_keypair(); - Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)]) + Protocol::Array(vec![ + Protocol::BulkString(verify), + Protocol::BulkString(secret), + ]) } pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol { @@ -214,16 +247,30 @@ pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> P pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol { let (recip, ident) = gen_enc_keypair(); - if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); } - if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); } - Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)]) + if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { + return e.to_protocol(); + } + if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { + return e.to_protocol(); + } + Protocol::Array(vec![ + Protocol::BulkString(recip), + Protocol::BulkString(ident), + ]) } pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol { let (verify, secret) = gen_sign_keypair(); - if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); } - if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); } - Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)]) + if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { + return e.to_protocol(); + } + if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { + return e.to_protocol(); + } + Protocol::Array(vec![ + Protocol::BulkString(verify), + Protocol::BulkString(secret), + ]) } pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol { @@ -253,7 +300,9 @@ pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol { let sec = match sget(server, &sign_priv_key_key(name)) { Ok(Some(v)) => v, - Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(), + Ok(None) => { + return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol() + } Err(e) => return e.to_protocol(), }; match sign_b64(&sec, message) { @@ -262,10 +311,17 @@ pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Pr } } -pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol { +pub async fn cmd_age_verify_name( + server: &Server, + name: &str, + message: &str, + sig_b64: &str, +) -> Protocol { let pubk = match sget(server, &sign_pub_key_key(name)) { Ok(Some(v)) => v, - Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(), + Ok(None) => { + return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol() + } Err(e) => return e.to_protocol(), }; match verify_b64(&pubk, message, sig_b64) { @@ -277,25 +333,43 @@ pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig pub async fn cmd_age_list(server: &Server) -> Protocol { // Returns 4 arrays: ["encpub", ], ["encpriv", ...], ["signpub", ...], ["signpriv", ...] - let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) }; + let st = match server.current_storage() { + Ok(s) => s, + Err(e) => return Protocol::err(&e.0), + }; let pull = |pat: &str, prefix: &str| -> Result, DBError> { let keys = st.keys(pat)?; - let mut names: Vec = keys.into_iter() + let mut names: Vec = keys + .into_iter() .filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string())) .collect(); names.sort(); Ok(names) }; - let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; - let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; - let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; - let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; + let encpub = match pull("age:key:*", "age:key:") { + Ok(v) => v, + Err(e) => return Protocol::err(&e.0), + }; + let encpriv = match pull("age:privkey:*", "age:privkey:") { + Ok(v) => v, + Err(e) => return Protocol::err(&e.0), + }; + let signpub = match pull("age:signpub:*", "age:signpub:") { + Ok(v) => v, + Err(e) => return Protocol::err(&e.0), + }; + let signpriv = match pull("age:signpriv:*", "age:signpriv:") { + Ok(v) => v, + Err(e) => return Protocol::err(&e.0), + }; let to_arr = |label: &str, v: Vec| { let mut out = vec![Protocol::BulkString(label.to_string())]; - out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect())); + out.push(Protocol::Array( + v.into_iter().map(Protocol::BulkString).collect(), + )); Protocol::Array(out) }; @@ -305,4 +379,4 @@ pub async fn cmd_age_list(server: &Server) -> Protocol { to_arr("signpub", signpub), to_arr("signpriv", signpriv), ]) -} \ No newline at end of file +} diff --git a/src/cmd.rs b/src/cmd.rs index f3df94b..1f1ed8c 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -1,7 +1,7 @@ -use crate::{error::DBError, protocol::Protocol, server::Server, search_cmd}; -use tokio::time::{timeout, Duration}; +use crate::{error::DBError, protocol::Protocol, search_cmd, server::Server}; use futures::future::select_all; use std::collections::HashMap; +use tokio::time::{timeout, Duration}; #[derive(Debug, Clone)] pub enum Cmd { @@ -40,7 +40,7 @@ pub enum Cmd { HIncrBy(String, String, i64), HIncrByFloat(String, String, f64), HScan(String, u64, Option, Option), // key, cursor, pattern, count - Scan(u64, Option, Option), // cursor, pattern, count + Scan(u64, Option, Option), // cursor, pattern, count Ttl(String), Expire(String, i64), PExpire(String, i64), @@ -72,10 +72,10 @@ pub enum Cmd { // AGE (rage) commands — stateless AgeGenEnc, AgeGenSign, - AgeEncrypt(String, String), // recipient, message - AgeDecrypt(String, String), // identity, ciphertext_b64 - AgeSign(String, String), // signing_secret, message - AgeVerify(String, String, String), // verify_pub, message, signature_b64 + AgeEncrypt(String, String), // recipient, message + AgeDecrypt(String, String), // identity, ciphertext_b64 + AgeSign(String, String), // signing_secret, message + AgeVerify(String, String, String), // verify_pub, message, signature_b64 // NEW: persistent named-key commands AgeKeygen(String), // name @@ -100,14 +100,14 @@ pub enum Cmd { FtSearch { index_name: String, query: String, - filters: Vec<(String, String)>, // field, value pairs + filters: Vec<(String, String)>, // field, value pairs limit: Option, offset: Option, return_fields: Option>, }, - FtDel(String, String), // index_name, doc_id - FtInfo(String), // index_name - FtDrop(String), // index_name + FtDel(String, String), // index_name, doc_id + FtInfo(String), // index_name + FtDrop(String), // index_name FtAlter { index_name: String, field_name: String, @@ -135,9 +135,13 @@ impl Cmd { match cmd[0].to_lowercase().as_str() { "select" => { if cmd.len() != 2 { - return Err(DBError("wrong number of arguments for SELECT".to_string())); + return Err(DBError( + "wrong number of arguments for SELECT".to_string(), + )); } - let idx = cmd[1].parse::().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?; + let idx = cmd[1].parse::().map_err(|_| { + DBError("ERR DB index is not an integer".to_string()) + })?; Cmd::Select(idx) } "echo" => Cmd::Echo(cmd[1].clone()), @@ -145,7 +149,9 @@ impl Cmd { "get" => Cmd::Get(cmd[1].clone()), "set" => { if cmd.len() < 3 { - return Err(DBError("wrong number of arguments for SET".to_string())); + return Err(DBError( + "wrong number of arguments for SET".to_string(), + )); } let key = cmd[1].clone(); let val = cmd[2].clone(); @@ -163,7 +169,12 @@ impl Cmd { if i + 1 >= cmd.len() { return Err(DBError("ERR syntax error".to_string())); } - let secs: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let secs: u128 = cmd[i + 1].parse().map_err(|_| { + DBError( + "ERR value is not an integer or out of range" + .to_string(), + ) + })?; ex_ms = Some(secs * 1000); i += 2; } @@ -171,13 +182,27 @@ impl Cmd { if i + 1 >= cmd.len() { return Err(DBError("ERR syntax error".to_string())); } - let ms: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let ms: u128 = cmd[i + 1].parse().map_err(|_| { + DBError( + "ERR value is not an integer or out of range" + .to_string(), + ) + })?; ex_ms = Some(ms); i += 2; } - "nx" => { nx = true; i += 1; } - "xx" => { xx = true; i += 1; } - "get" => { getflag = true; i += 1; } + "nx" => { + nx = true; + i += 1; + } + "xx" => { + xx = true; + i += 1; + } + "get" => { + getflag = true; + i += 1; + } _ => { return Err(DBError(format!("unsupported cmd {:?}", cmd))); } @@ -193,19 +218,25 @@ impl Cmd { } "setex" => { if cmd.len() != 4 { - return Err(DBError(format!("wrong number of arguments for SETEX command"))); + return Err(DBError(format!( + "wrong number of arguments for SETEX command" + ))); } Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap()) } "mget" => { if cmd.len() < 2 { - return Err(DBError("wrong number of arguments for MGET command".to_string())); + return Err(DBError( + "wrong number of arguments for MGET command".to_string(), + )); } Cmd::MGet(cmd[1..].to_vec()) } "mset" => { if cmd.len() < 3 || ((cmd.len() - 1) % 2 != 0) { - return Err(DBError("wrong number of arguments for MSET command".to_string())); + return Err(DBError( + "wrong number of arguments for MSET command".to_string(), + )); } let mut pairs = Vec::new(); let mut i = 1; @@ -231,7 +262,9 @@ impl Cmd { } "dbsize" => { if cmd.len() != 1 { - return Err(DBError(format!("wrong number of arguments for DBSIZE command"))); + return Err(DBError(format!( + "wrong number of arguments for DBSIZE command" + ))); } Cmd::DbSize } @@ -245,7 +278,9 @@ impl Cmd { } "del" => { if cmd.len() < 2 { - return Err(DBError(format!("wrong number of arguments for DEL command"))); + return Err(DBError(format!( + "wrong number of arguments for DEL command" + ))); } if cmd.len() == 2 { Cmd::Del(cmd[1].clone()) @@ -281,7 +316,9 @@ impl Cmd { // Hash commands "hset" => { if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 { - return Err(DBError(format!("wrong number of arguments for HSET command"))); + return Err(DBError(format!( + "wrong number of arguments for HSET command" + ))); } let mut pairs = Vec::new(); let mut i = 2; @@ -293,85 +330,114 @@ impl Cmd { } "hget" => { if cmd.len() != 3 { - return Err(DBError(format!("wrong number of arguments for HGET command"))); + return Err(DBError(format!( + "wrong number of arguments for HGET command" + ))); } Cmd::HGet(cmd[1].clone(), cmd[2].clone()) } "hgetall" => { if cmd.len() != 2 { - return Err(DBError(format!("wrong number of arguments for HGETALL command"))); + return Err(DBError(format!( + "wrong number of arguments for HGETALL command" + ))); } Cmd::HGetAll(cmd[1].clone()) } "hdel" => { if cmd.len() < 3 { - return Err(DBError(format!("wrong number of arguments for HDEL command"))); + return Err(DBError(format!( + "wrong number of arguments for HDEL command" + ))); } Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec()) } "hexists" => { if cmd.len() != 3 { - return Err(DBError(format!("wrong number of arguments for HEXISTS command"))); + return Err(DBError(format!( + "wrong number of arguments for HEXISTS command" + ))); } Cmd::HExists(cmd[1].clone(), cmd[2].clone()) } "hkeys" => { if cmd.len() != 2 { - return Err(DBError(format!("wrong number of arguments for HKEYS command"))); + return Err(DBError(format!( + "wrong number of arguments for HKEYS command" + ))); } Cmd::HKeys(cmd[1].clone()) } "hvals" => { if cmd.len() != 2 { - return Err(DBError(format!("wrong number of arguments for HVALS command"))); + return Err(DBError(format!( + "wrong number of arguments for HVALS command" + ))); } Cmd::HVals(cmd[1].clone()) } "hlen" => { if cmd.len() != 2 { - return Err(DBError(format!("wrong number of arguments for HLEN command"))); + return Err(DBError(format!( + "wrong number of arguments for HLEN command" + ))); } Cmd::HLen(cmd[1].clone()) } "hmget" => { if cmd.len() < 3 { - return Err(DBError(format!("wrong number of arguments for HMGET command"))); + return Err(DBError(format!( + "wrong number of arguments for HMGET command" + ))); } Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec()) } "hsetnx" => { if cmd.len() != 4 { - return Err(DBError(format!("wrong number of arguments for HSETNX command"))); + return Err(DBError(format!( + "wrong number of arguments for HSETNX command" + ))); } Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) } "hincrby" => { if cmd.len() != 4 { - return Err(DBError(format!("wrong number of arguments for HINCRBY command"))); + return Err(DBError(format!( + "wrong number of arguments for HINCRBY command" + ))); } - let delta = cmd[3].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let delta = cmd[3].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::HIncrBy(cmd[1].clone(), cmd[2].clone(), delta) } "hincrbyfloat" => { if cmd.len() != 4 { - return Err(DBError(format!("wrong number of arguments for HINCRBYFLOAT command"))); + return Err(DBError(format!( + "wrong number of arguments for HINCRBYFLOAT command" + ))); } - let delta = cmd[3].parse::().map_err(|_| DBError("ERR value is not a valid float".to_string()))?; + let delta = cmd[3].parse::().map_err(|_| { + DBError("ERR value is not a valid float".to_string()) + })?; Cmd::HIncrByFloat(cmd[1].clone(), cmd[2].clone(), delta) } "hscan" => { if cmd.len() < 3 { - return Err(DBError(format!("wrong number of arguments for HSCAN command"))); + return Err(DBError(format!( + "wrong number of arguments for HSCAN command" + ))); } - + let key = cmd[1].clone(); - let cursor = cmd[2].parse::().map_err(|_| - DBError("ERR invalid cursor".to_string()))?; - + let cursor = cmd[2] + .parse::() + .map_err(|_| DBError("ERR invalid cursor".to_string()))?; + let mut pattern = None; let mut count = None; let mut i = 3; - + while i < cmd.len() { match cmd[i].to_lowercase().as_str() { "match" => { @@ -385,8 +451,12 @@ impl Cmd { if i + 1 >= cmd.len() { return Err(DBError("ERR syntax error".to_string())); } - count = Some(cmd[i + 1].parse::().map_err(|_| - DBError("ERR value is not an integer or out of range".to_string()))?); + count = Some(cmd[i + 1].parse::().map_err(|_| { + DBError( + "ERR value is not an integer or out of range" + .to_string(), + ) + })?); i += 2; } _ => { @@ -394,21 +464,24 @@ impl Cmd { } } } - + Cmd::HScan(key, cursor, pattern, count) } "scan" => { if cmd.len() < 2 { - return Err(DBError(format!("wrong number of arguments for SCAN command"))); + return Err(DBError(format!( + "wrong number of arguments for SCAN command" + ))); } - - let cursor = cmd[1].parse::().map_err(|_| - DBError("ERR invalid cursor".to_string()))?; - + + let cursor = cmd[1] + .parse::() + .map_err(|_| DBError("ERR invalid cursor".to_string()))?; + let mut pattern = None; let mut count = None; let mut i = 2; - + while i < cmd.len() { match cmd[i].to_lowercase().as_str() { "match" => { @@ -422,8 +495,12 @@ impl Cmd { if i + 1 >= cmd.len() { return Err(DBError("ERR syntax error".to_string())); } - count = Some(cmd[i + 1].parse::().map_err(|_| - DBError("ERR value is not an integer or out of range".to_string()))?); + count = Some(cmd[i + 1].parse::().map_err(|_| { + DBError( + "ERR value is not an integer or out of range" + .to_string(), + ) + })?); i += 2; } _ => { @@ -431,52 +508,74 @@ impl Cmd { } } } - + Cmd::Scan(cursor, pattern, count) } "ttl" => { if cmd.len() != 2 { - return Err(DBError(format!("wrong number of arguments for TTL command"))); + return Err(DBError(format!( + "wrong number of arguments for TTL command" + ))); } Cmd::Ttl(cmd[1].clone()) } "expire" => { if cmd.len() != 3 { - return Err(DBError("wrong number of arguments for EXPIRE command".to_string())); + return Err(DBError( + "wrong number of arguments for EXPIRE command".to_string(), + )); } - let secs = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let secs = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::Expire(cmd[1].clone(), secs) } "pexpire" => { if cmd.len() != 3 { - return Err(DBError("wrong number of arguments for PEXPIRE command".to_string())); + return Err(DBError( + "wrong number of arguments for PEXPIRE command".to_string(), + )); } - let ms = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let ms = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::PExpire(cmd[1].clone(), ms) } "expireat" => { if cmd.len() != 3 { - return Err(DBError("wrong number of arguments for EXPIREAT command".to_string())); + return Err(DBError( + "wrong number of arguments for EXPIREAT command".to_string(), + )); } - let ts = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let ts = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::ExpireAt(cmd[1].clone(), ts) } "pexpireat" => { if cmd.len() != 3 { - return Err(DBError("wrong number of arguments for PEXPIREAT command".to_string())); + return Err(DBError( + "wrong number of arguments for PEXPIREAT command".to_string(), + )); } - let ts_ms = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let ts_ms = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::PExpireAt(cmd[1].clone(), ts_ms) } "persist" => { if cmd.len() != 2 { - return Err(DBError("wrong number of arguments for PERSIST command".to_string())); + return Err(DBError( + "wrong number of arguments for PERSIST command".to_string(), + )); } Cmd::Persist(cmd[1].clone()) } "exists" => { if cmd.len() < 2 { - return Err(DBError(format!("wrong number of arguments for EXISTS command"))); + return Err(DBError(format!( + "wrong number of arguments for EXISTS command" + ))); } if cmd.len() == 2 { Cmd::Exists(cmd[1].clone()) @@ -486,7 +585,9 @@ impl Cmd { } "quit" => { if cmd.len() != 1 { - return Err(DBError(format!("wrong number of arguments for QUIT command"))); + return Err(DBError(format!( + "wrong number of arguments for QUIT command" + ))); } Cmd::Quit } @@ -514,27 +615,41 @@ impl Cmd { } } "command" => { - let args = if cmd.len() > 1 { cmd[1..].to_vec() } else { vec![] }; + let args = if cmd.len() > 1 { + cmd[1..].to_vec() + } else { + vec![] + }; Cmd::Command(args) } "lpush" => { if cmd.len() < 3 { - return Err(DBError(format!("wrong number of arguments for LPUSH command"))); + return Err(DBError(format!( + "wrong number of arguments for LPUSH command" + ))); } Cmd::LPush(cmd[1].clone(), cmd[2..].to_vec()) } "rpush" => { if cmd.len() < 3 { - return Err(DBError(format!("wrong number of arguments for RPUSH command"))); + return Err(DBError(format!( + "wrong number of arguments for RPUSH command" + ))); } Cmd::RPush(cmd[1].clone(), cmd[2..].to_vec()) } "lpop" => { if cmd.len() < 2 || cmd.len() > 3 { - return Err(DBError(format!("wrong number of arguments for LPOP command"))); + return Err(DBError(format!( + "wrong number of arguments for LPOP command" + ))); } let count = if cmd.len() == 3 { - Some(cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?) + Some(cmd[2].parse::().map_err(|_| { + DBError( + "ERR value is not an integer or out of range".to_string(), + ) + })?) } else { None }; @@ -542,10 +657,16 @@ impl Cmd { } "rpop" => { if cmd.len() < 2 || cmd.len() > 3 { - return Err(DBError(format!("wrong number of arguments for RPOP command"))); + return Err(DBError(format!( + "wrong number of arguments for RPOP command" + ))); } let count = if cmd.len() == 3 { - Some(cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?) + Some(cmd[2].parse::().map_err(|_| { + DBError( + "ERR value is not an integer or out of range".to_string(), + ) + })?) } else { None }; @@ -553,257 +674,398 @@ impl Cmd { } "blpop" => { if cmd.len() < 3 { - return Err(DBError(format!("wrong number of arguments for BLPOP command"))); + return Err(DBError(format!( + "wrong number of arguments for BLPOP command" + ))); } // keys are all but the last argument - let keys = cmd[1..cmd.len()-1].to_vec(); - let timeout_f = cmd[cmd.len()-1] + let keys = cmd[1..cmd.len() - 1].to_vec(); + let timeout_f = cmd[cmd.len() - 1] .parse::() .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; Cmd::BLPop(keys, timeout_f) } "brpop" => { if cmd.len() < 3 { - return Err(DBError(format!("wrong number of arguments for BRPOP command"))); + return Err(DBError(format!( + "wrong number of arguments for BRPOP command" + ))); } // keys are all but the last argument - let keys = cmd[1..cmd.len()-1].to_vec(); - let timeout_f = cmd[cmd.len()-1] + let keys = cmd[1..cmd.len() - 1].to_vec(); + let timeout_f = cmd[cmd.len() - 1] .parse::() .map_err(|_| DBError("ERR timeout is not a number".to_string()))?; Cmd::BRPop(keys, timeout_f) } "llen" => { if cmd.len() != 2 { - return Err(DBError(format!("wrong number of arguments for LLEN command"))); + return Err(DBError(format!( + "wrong number of arguments for LLEN command" + ))); } Cmd::LLen(cmd[1].clone()) } "lrem" => { if cmd.len() != 4 { - return Err(DBError(format!("wrong number of arguments for LREM command"))); + return Err(DBError(format!( + "wrong number of arguments for LREM command" + ))); } - let count = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let count = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::LRem(cmd[1].clone(), count, cmd[3].clone()) } "ltrim" => { if cmd.len() != 4 { - return Err(DBError(format!("wrong number of arguments for LTRIM command"))); + return Err(DBError(format!( + "wrong number of arguments for LTRIM command" + ))); } - let start = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; - let stop = cmd[3].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let start = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; + let stop = cmd[3].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::LTrim(cmd[1].clone(), start, stop) } "lindex" => { if cmd.len() != 3 { - return Err(DBError(format!("wrong number of arguments for LINDEX command"))); + return Err(DBError(format!( + "wrong number of arguments for LINDEX command" + ))); } - let index = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let index = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::LIndex(cmd[1].clone(), index) } "lrange" => { if cmd.len() != 4 { - return Err(DBError(format!("wrong number of arguments for LRANGE command"))); + return Err(DBError(format!( + "wrong number of arguments for LRANGE command" + ))); } - let start = cmd[2].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; - let stop = cmd[3].parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?; + let start = cmd[2].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; + let stop = cmd[3].parse::().map_err(|_| { + DBError("ERR value is not an integer or out of range".to_string()) + })?; Cmd::LRange(cmd[1].clone(), start, stop) } "flushdb" => { if cmd.len() != 1 { - return Err(DBError("wrong number of arguments for FLUSHDB command".to_string())); + return Err(DBError( + "wrong number of arguments for FLUSHDB command".to_string(), + )); } Cmd::FlushDb } "age" => { if cmd.len() < 2 { - return Err(DBError("wrong number of arguments for AGE".to_string())); + return Err(DBError( + "wrong number of arguments for AGE".to_string(), + )); } match cmd[1].to_lowercase().as_str() { // stateless - "genenc" => { if cmd.len() != 2 { return Err(DBError("AGE GENENC takes no args".to_string())); } - Cmd::AgeGenEnc } - "gensign" => { if cmd.len() != 2 { return Err(DBError("AGE GENSIGN takes no args".to_string())); } - Cmd::AgeGenSign } - "encrypt" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPT ".to_string())); } - Cmd::AgeEncrypt(cmd[2].clone(), cmd[3].clone()) } - "decrypt" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPT ".to_string())); } - Cmd::AgeDecrypt(cmd[2].clone(), cmd[3].clone()) } - "sign" => { if cmd.len() != 4 { return Err(DBError("AGE SIGN ".to_string())); } - Cmd::AgeSign(cmd[2].clone(), cmd[3].clone()) } - "verify" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFY ".to_string())); } - Cmd::AgeVerify(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) } + "genenc" => { + if cmd.len() != 2 { + return Err(DBError( + "AGE GENENC takes no args".to_string(), + )); + } + Cmd::AgeGenEnc + } + "gensign" => { + if cmd.len() != 2 { + return Err(DBError( + "AGE GENSIGN takes no args".to_string(), + )); + } + Cmd::AgeGenSign + } + "encrypt" => { + if cmd.len() != 4 { + return Err(DBError( + "AGE ENCRYPT ".to_string(), + )); + } + Cmd::AgeEncrypt(cmd[2].clone(), cmd[3].clone()) + } + "decrypt" => { + if cmd.len() != 4 { + return Err(DBError( + "AGE DECRYPT ".to_string(), + )); + } + Cmd::AgeDecrypt(cmd[2].clone(), cmd[3].clone()) + } + "sign" => { + if cmd.len() != 4 { + return Err(DBError( + "AGE SIGN ".to_string(), + )); + } + Cmd::AgeSign(cmd[2].clone(), cmd[3].clone()) + } + "verify" => { + if cmd.len() != 5 { + return Err(DBError( + "AGE VERIFY " + .to_string(), + )); + } + Cmd::AgeVerify(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) + } // persistent names - "keygen" => { if cmd.len() != 3 { return Err(DBError("AGE KEYGEN ".to_string())); } - Cmd::AgeKeygen(cmd[2].clone()) } - "signkeygen" => { if cmd.len() != 3 { return Err(DBError("AGE SIGNKEYGEN ".to_string())); } - Cmd::AgeSignKeygen(cmd[2].clone()) } - "encryptname" => { if cmd.len() != 4 { return Err(DBError("AGE ENCRYPTNAME ".to_string())); } - Cmd::AgeEncryptName(cmd[2].clone(), cmd[3].clone()) } - "decryptname" => { if cmd.len() != 4 { return Err(DBError("AGE DECRYPTNAME ".to_string())); } - Cmd::AgeDecryptName(cmd[2].clone(), cmd[3].clone()) } - "signname" => { if cmd.len() != 4 { return Err(DBError("AGE SIGNNAME ".to_string())); } - Cmd::AgeSignName(cmd[2].clone(), cmd[3].clone()) } - "verifyname" => { if cmd.len() != 5 { return Err(DBError("AGE VERIFYNAME ".to_string())); } - Cmd::AgeVerifyName(cmd[2].clone(), cmd[3].clone(), cmd[4].clone()) } - "list" => { if cmd.len() != 2 { return Err(DBError("AGE LIST".to_string())); } - Cmd::AgeList } - _ => return Err(DBError(format!("unsupported AGE subcommand {:?}", cmd))), + "keygen" => { + if cmd.len() != 3 { + return Err(DBError("AGE KEYGEN ".to_string())); + } + Cmd::AgeKeygen(cmd[2].clone()) + } + "signkeygen" => { + if cmd.len() != 3 { + return Err(DBError("AGE SIGNKEYGEN ".to_string())); + } + Cmd::AgeSignKeygen(cmd[2].clone()) + } + "encryptname" => { + if cmd.len() != 4 { + return Err(DBError( + "AGE ENCRYPTNAME ".to_string(), + )); + } + Cmd::AgeEncryptName(cmd[2].clone(), cmd[3].clone()) + } + "decryptname" => { + if cmd.len() != 4 { + return Err(DBError( + "AGE DECRYPTNAME ".to_string(), + )); + } + Cmd::AgeDecryptName(cmd[2].clone(), cmd[3].clone()) + } + "signname" => { + if cmd.len() != 4 { + return Err(DBError( + "AGE SIGNNAME ".to_string(), + )); + } + Cmd::AgeSignName(cmd[2].clone(), cmd[3].clone()) + } + "verifyname" => { + if cmd.len() != 5 { + return Err(DBError( + "AGE VERIFYNAME " + .to_string(), + )); + } + Cmd::AgeVerifyName( + cmd[2].clone(), + cmd[3].clone(), + cmd[4].clone(), + ) + } + "list" => { + if cmd.len() != 2 { + return Err(DBError("AGE LIST".to_string())); + } + Cmd::AgeList + } + _ => { + return Err(DBError(format!( + "unsupported AGE subcommand {:?}", + cmd + ))) + } } } - "ft.create" => { - if cmd.len() < 4 || cmd[2].to_uppercase() != "SCHEMA" { - return Err(DBError("ERR FT.CREATE requires: indexname SCHEMA field1 type1 [options] ...".to_string())); - } - - let index_name = cmd[1].clone(); - let mut schema = Vec::new(); - let mut i = 3; - - while i < cmd.len() { - if i + 1 >= cmd.len() { - return Err(DBError("ERR incomplete field definition".to_string())); - } - - let field_name = cmd[i].clone(); - let field_type = cmd[i + 1].to_uppercase(); - let mut options = Vec::new(); - i += 2; - - // Parse field options until we hit another field name or end - while i < cmd.len() && ["WEIGHT", "SORTABLE", "NOINDEX", "SEPARATOR", "CASESENSITIVE"].contains(&cmd[i].to_uppercase().as_str()) { - options.push(cmd[i].to_uppercase()); - i += 1; - - // If this option takes a value, consume it too - if i > 0 && ["SEPARATOR", "WEIGHT"].contains(&cmd[i-1].to_uppercase().as_str()) && i < cmd.len() { - options.push(cmd[i].clone()); - i += 1; - } - } - - schema.push((field_name, field_type, options)); - } + "ft.create" => { + if cmd.len() < 4 || cmd[2].to_uppercase() != "SCHEMA" { + return Err(DBError("ERR FT.CREATE requires: indexname SCHEMA field1 type1 [options] ...".to_string())); + } - Cmd::FtCreate { - index_name, - schema, - } - } - "ft.add" => { - if cmd.len() < 5 { - return Err(DBError("ERR FT.ADD requires: index_name doc_id score field value ...".to_string())); - } - - let index_name = cmd[1].clone(); - let doc_id = cmd[2].clone(); - let score = cmd[3].parse::() - .map_err(|_| DBError("ERR score must be a number".to_string()))?; - - let mut fields = HashMap::new(); - let mut i = 4; - - while i + 1 < cmd.len() { - fields.insert(cmd[i].clone(), cmd[i + 1].clone()); - i += 2; - } - - Cmd::FtAdd { - index_name, - doc_id, - score, - fields, - } - } - "ft.search" => { - if cmd.len() < 3 { - return Err(DBError("ERR FT.SEARCH requires: index_name query [options]".to_string())); - } - - let index_name = cmd[1].clone(); - let query = cmd[2].clone(); - - let mut filters = Vec::new(); - let mut limit = None; - let mut offset = None; - let mut return_fields = None; - - let mut i = 3; - while i < cmd.len() { - match cmd[i].to_uppercase().as_str() { - "FILTER" => { - if i + 3 >= cmd.len() { - return Err(DBError("ERR FILTER requires field and value".to_string())); - } - filters.push((cmd[i + 1].clone(), cmd[i + 2].clone())); - i += 3; - } - "LIMIT" => { - if i + 2 >= cmd.len() { - return Err(DBError("ERR LIMIT requires offset and num".to_string())); - } - offset = Some(cmd[i + 1].parse().unwrap_or(0)); - limit = Some(cmd[i + 2].parse().unwrap_or(10)); - i += 3; - } - "RETURN" => { - if i + 1 >= cmd.len() { - return Err(DBError("ERR RETURN requires field count".to_string())); - } - let count: usize = cmd[i + 1].parse().unwrap_or(0); - i += 2; - - let mut fields = Vec::new(); - for _ in 0..count { - if i < cmd.len() { - fields.push(cmd[i].clone()); - i += 1; - } - } - return_fields = Some(fields); - } - _ => i += 1, - } - } - - Cmd::FtSearch { - index_name, - query, - filters, - limit, - offset, - return_fields, - } - } - "ft.del" => { - if cmd.len() != 3 { - return Err(DBError("ERR FT.DEL requires: index_name doc_id".to_string())); - } - Cmd::FtDel(cmd[1].clone(), cmd[2].clone()) - } - "ft.info" => { - if cmd.len() != 2 { - return Err(DBError("ERR FT.INFO requires: index_name".to_string())); - } - Cmd::FtInfo(cmd[1].clone()) - } - "ft.drop" => { - if cmd.len() != 2 { - return Err(DBError("ERR FT.DROP requires: index_name".to_string())); - } - Cmd::FtDrop(cmd[1].clone()) - } + let index_name = cmd[1].clone(); + let mut schema = Vec::new(); + let mut i = 3; + + while i < cmd.len() { + if i + 1 >= cmd.len() { + return Err(DBError( + "ERR incomplete field definition".to_string(), + )); + } + + let field_name = cmd[i].clone(); + let field_type = cmd[i + 1].to_uppercase(); + let mut options = Vec::new(); + i += 2; + + // Parse field options until we hit another field name or end + while i < cmd.len() + && [ + "WEIGHT", + "SORTABLE", + "NOINDEX", + "SEPARATOR", + "CASESENSITIVE", + ] + .contains(&cmd[i].to_uppercase().as_str()) + { + options.push(cmd[i].to_uppercase()); + i += 1; + + // If this option takes a value, consume it too + if i > 0 + && ["SEPARATOR", "WEIGHT"] + .contains(&cmd[i - 1].to_uppercase().as_str()) + && i < cmd.len() + { + options.push(cmd[i].clone()); + i += 1; + } + } + + schema.push((field_name, field_type, options)); + } + + Cmd::FtCreate { index_name, schema } + } + "ft.add" => { + if cmd.len() < 5 { + return Err(DBError( + "ERR FT.ADD requires: index_name doc_id score field value ..." + .to_string(), + )); + } + + let index_name = cmd[1].clone(); + let doc_id = cmd[2].clone(); + let score = cmd[3] + .parse::() + .map_err(|_| DBError("ERR score must be a number".to_string()))?; + + let mut fields = HashMap::new(); + let mut i = 4; + + while i + 1 < cmd.len() { + fields.insert(cmd[i].clone(), cmd[i + 1].clone()); + i += 2; + } + + Cmd::FtAdd { + index_name, + doc_id, + score, + fields, + } + } + "ft.search" => { + if cmd.len() < 3 { + return Err(DBError( + "ERR FT.SEARCH requires: index_name query [options]" + .to_string(), + )); + } + + let index_name = cmd[1].clone(); + let query = cmd[2].clone(); + + let mut filters = Vec::new(); + let mut limit = None; + let mut offset = None; + let mut return_fields = None; + + let mut i = 3; + while i < cmd.len() { + match cmd[i].to_uppercase().as_str() { + "FILTER" => { + if i + 3 >= cmd.len() { + return Err(DBError( + "ERR FILTER requires field and value".to_string(), + )); + } + filters.push((cmd[i + 1].clone(), cmd[i + 2].clone())); + i += 3; + } + "LIMIT" => { + if i + 2 >= cmd.len() { + return Err(DBError( + "ERR LIMIT requires offset and num".to_string(), + )); + } + offset = Some(cmd[i + 1].parse().unwrap_or(0)); + limit = Some(cmd[i + 2].parse().unwrap_or(10)); + i += 3; + } + "RETURN" => { + if i + 1 >= cmd.len() { + return Err(DBError( + "ERR RETURN requires field count".to_string(), + )); + } + let count: usize = cmd[i + 1].parse().unwrap_or(0); + i += 2; + + let mut fields = Vec::new(); + for _ in 0..count { + if i < cmd.len() { + fields.push(cmd[i].clone()); + i += 1; + } + } + return_fields = Some(fields); + } + _ => i += 1, + } + } + + Cmd::FtSearch { + index_name, + query, + filters, + limit, + offset, + return_fields, + } + } + "ft.del" => { + if cmd.len() != 3 { + return Err(DBError( + "ERR FT.DEL requires: index_name doc_id".to_string(), + )); + } + Cmd::FtDel(cmd[1].clone(), cmd[2].clone()) + } + "ft.info" => { + if cmd.len() != 2 { + return Err(DBError( + "ERR FT.INFO requires: index_name".to_string(), + )); + } + Cmd::FtInfo(cmd[1].clone()) + } + "ft.drop" => { + if cmd.len() != 2 { + return Err(DBError( + "ERR FT.DROP requires: index_name".to_string(), + )); + } + Cmd::FtDrop(cmd[1].clone()) + } _ => Cmd::Unknow(cmd[0].clone()), }, protocol, - remaining + remaining, )) } - _ => Err(DBError(format!( - "fail to parse as cmd for {:?}", - protocol - ))), + _ => Err(DBError(format!("fail to parse as cmd for {:?}", protocol))), } } @@ -827,7 +1089,9 @@ impl Cmd { Cmd::Set(k, v) => set_cmd(server, &k, &v).await, Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, - Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await, + Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => { + set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await + } Cmd::MGet(keys) => mget_cmd(server, &keys).await, Cmd::MSet(pairs) => mset_cmd(server, &pairs).await, Cmd::Del(k) => del_cmd(server, &k).await, @@ -863,9 +1127,15 @@ impl Cmd { Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await, Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await, Cmd::HIncrBy(key, field, delta) => hincrby_cmd(server, &key, &field, delta).await, - Cmd::HIncrByFloat(key, field, delta) => hincrbyfloat_cmd(server, &key, &field, delta).await, - Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, - Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, + Cmd::HIncrByFloat(key, field, delta) => { + hincrbyfloat_cmd(server, &key, &field, delta).await + } + Cmd::HScan(key, cursor, pattern, count) => { + hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await + } + Cmd::Scan(cursor, pattern, count) => { + scan_cmd(server, &cursor, pattern.as_deref(), &count).await + } Cmd::Ttl(key) => ttl_cmd(server, &key).await, Cmd::Expire(key, secs) => expire_cmd(server, &key, secs).await, Cmd::PExpire(key, ms) => pexpire_cmd(server, &key, ms).await, @@ -895,39 +1165,68 @@ impl Cmd { // AGE (rage): stateless Cmd::AgeGenEnc => Ok(crate::age::cmd_age_genenc().await), Cmd::AgeGenSign => Ok(crate::age::cmd_age_gensign().await), - Cmd::AgeEncrypt(recipient, message) => Ok(crate::age::cmd_age_encrypt(&recipient, &message).await), - Cmd::AgeDecrypt(identity, ct_b64) => Ok(crate::age::cmd_age_decrypt(&identity, &ct_b64).await), + Cmd::AgeEncrypt(recipient, message) => { + Ok(crate::age::cmd_age_encrypt(&recipient, &message).await) + } + Cmd::AgeDecrypt(identity, ct_b64) => { + Ok(crate::age::cmd_age_decrypt(&identity, &ct_b64).await) + } Cmd::AgeSign(secret, message) => Ok(crate::age::cmd_age_sign(&secret, &message).await), - Cmd::AgeVerify(vpub, msg, sig_b64) => Ok(crate::age::cmd_age_verify(&vpub, &msg, &sig_b64).await), + Cmd::AgeVerify(vpub, msg, sig_b64) => { + Ok(crate::age::cmd_age_verify(&vpub, &msg, &sig_b64).await) + } // AGE (rage): persistent named keys Cmd::AgeKeygen(name) => Ok(crate::age::cmd_age_keygen(server, &name).await), Cmd::AgeSignKeygen(name) => Ok(crate::age::cmd_age_signkeygen(server, &name).await), - Cmd::AgeEncryptName(name, message) => Ok(crate::age::cmd_age_encrypt_name(server, &name, &message).await), - Cmd::AgeDecryptName(name, ct_b64) => Ok(crate::age::cmd_age_decrypt_name(server, &name, &ct_b64).await), - Cmd::AgeSignName(name, message) => Ok(crate::age::cmd_age_sign_name(server, &name, &message).await), - Cmd::AgeVerifyName(name, message, sig_b64) => Ok(crate::age::cmd_age_verify_name(server, &name, &message, &sig_b64).await), + Cmd::AgeEncryptName(name, message) => { + Ok(crate::age::cmd_age_encrypt_name(server, &name, &message).await) + } + Cmd::AgeDecryptName(name, ct_b64) => { + Ok(crate::age::cmd_age_decrypt_name(server, &name, &ct_b64).await) + } + Cmd::AgeSignName(name, message) => { + Ok(crate::age::cmd_age_sign_name(server, &name, &message).await) + } + Cmd::AgeVerifyName(name, message, sig_b64) => { + Ok(crate::age::cmd_age_verify_name(server, &name, &message, &sig_b64).await) + } Cmd::AgeList => Ok(crate::age::cmd_age_list(server).await), // Full-text search commands Cmd::FtCreate { index_name, schema } => { search_cmd::ft_create_cmd(server, index_name, schema).await } - Cmd::FtAdd { index_name, doc_id, score, fields } => { - search_cmd::ft_add_cmd(server, index_name, doc_id, score, fields).await - } - Cmd::FtSearch { index_name, query, filters, limit, offset, return_fields } => { - search_cmd::ft_search_cmd(server, index_name, query, filters, limit, offset, return_fields).await + Cmd::FtAdd { + index_name, + doc_id, + score, + fields, + } => search_cmd::ft_add_cmd(server, index_name, doc_id, score, fields).await, + Cmd::FtSearch { + index_name, + query, + filters, + limit, + offset, + return_fields, + } => { + search_cmd::ft_search_cmd( + server, + index_name, + query, + filters, + limit, + offset, + return_fields, + ) + .await } Cmd::FtDel(index_name, doc_id) => { search_cmd::ft_del_cmd(server, index_name, doc_id).await } - Cmd::FtInfo(index_name) => { - search_cmd::ft_info_cmd(server, index_name).await - } - Cmd::FtDrop(index_name) => { - search_cmd::ft_drop_cmd(server, index_name).await - } + Cmd::FtInfo(index_name) => search_cmd::ft_info_cmd(server, index_name).await, + Cmd::FtDrop(index_name) => search_cmd::ft_drop_cmd(server, index_name).await, Cmd::FtAlter { .. } => { // Not implemented yet Ok(Protocol::err("FT.ALTER not implemented yet")) @@ -939,15 +1238,28 @@ impl Cmd { Cmd::Unknow(s) => Ok(Protocol::err(&format!("ERR unknown command `{}`", s))), } } - + pub fn to_protocol(self) -> Protocol { match self { - Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]), + Cmd::Select(db) => Protocol::Array(vec![ + Protocol::BulkString("select".to_string()), + Protocol::BulkString(db.to_string()), + ]), Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]), - Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]), - Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]), - Cmd::Set(k, v) => Protocol::Array(vec![Protocol::BulkString("set".to_string()), Protocol::BulkString(k), Protocol::BulkString(v)]), - _ => Protocol::SimpleString("...".to_string()) + Cmd::Echo(s) => Protocol::Array(vec![ + Protocol::BulkString("echo".to_string()), + Protocol::BulkString(s), + ]), + Cmd::Get(k) => Protocol::Array(vec![ + Protocol::BulkString("get".to_string()), + Protocol::BulkString(k), + ]), + Cmd::Set(k, v) => Protocol::Array(vec![ + Protocol::BulkString("set".to_string()), + Protocol::BulkString(k), + Protocol::BulkString(v), + ]), + _ => Protocol::SimpleString("...".to_string()), } } } @@ -976,9 +1288,16 @@ async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result Result { +async fn lrange_cmd( + server: &Server, + key: &str, + start: i64, + stop: i64, +) -> Result { match server.current_storage()?.lrange(key, start, stop) { - Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())), + Ok(elements) => Ok(Protocol::Array( + elements.into_iter().map(Protocol::BulkString).collect(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } @@ -990,7 +1309,12 @@ async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result< } } -async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result { +async fn lrem_cmd( + server: &Server, + key: &str, + count: i64, + element: &str, +) -> Result { match server.current_storage()?.lrem(key, count, element) { Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())), Err(e) => Ok(Protocol::err(&e.0)), @@ -1015,11 +1339,13 @@ async fn lpop_cmd(server: &Server, key: &str, count: &Option) -> Result Ok(Protocol::err(&e.0)), } } @@ -1035,17 +1361,23 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option) -> Result Ok(Protocol::err(&e.0)), } } // BLPOP implementation -async fn blpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result { +async fn blpop_cmd( + server: &Server, + keys: &[String], + timeout_secs: f64, +) -> Result { // Immediate, non-blocking attempt in key order for k in keys { let elems = server.current_storage()?.lpop(k, 1)?; @@ -1066,10 +1398,13 @@ async fn blpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul let db_index = server.selected_db; let mut ids: Vec = Vec::with_capacity(keys.len()); let mut names: Vec = Vec::with_capacity(keys.len()); - let mut rxs: Vec> = Vec::with_capacity(keys.len()); + let mut rxs: Vec> = + Vec::with_capacity(keys.len()); for k in keys { - let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Left).await; + let (id, rx) = server + .register_waiter(db_index, k, crate::server::PopSide::Left) + .await; ids.push(id); names.push(k.clone()); rxs.push(rx); @@ -1127,7 +1462,11 @@ async fn blpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul } // BRPOP implementation (mirror of BLPOP, popping from the right) -async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result { +async fn brpop_cmd( + server: &Server, + keys: &[String], + timeout_secs: f64, +) -> Result { // Immediate, non-blocking attempt in key order using RPOP for k in keys { let elems = server.current_storage()?.rpop(k, 1)?; @@ -1148,10 +1487,13 @@ async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul let db_index = server.selected_db; let mut ids: Vec = Vec::with_capacity(keys.len()); let mut names: Vec = Vec::with_capacity(keys.len()); - let mut rxs: Vec> = Vec::with_capacity(keys.len()); + let mut rxs: Vec> = + Vec::with_capacity(keys.len()); for k in keys { - let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Right).await; + let (id, rx) = server + .register_waiter(db_index, k, crate::server::PopSide::Right) + .await; ids.push(id); names.push(k.clone()); rxs.push(rx); @@ -1250,17 +1592,15 @@ async fn exec_cmd(server: &mut Server) -> Result { async fn incr_cmd(server: &Server, key: &String) -> Result { let storage = server.current_storage()?; let current_value = storage.get(key)?; - + let new_value = match current_value { - Some(v) => { - match v.parse::() { - Ok(num) => num + 1, - Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")), - } - } + Some(v) => match v.parse::() { + Ok(num) => num + 1, + Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")), + }, None => 1, }; - + storage.set(key.clone(), new_value.to_string())?; Ok(Protocol::SimpleString(new_value.to_string())) } @@ -1300,21 +1640,34 @@ async fn dbsize_cmd(server: &Server) -> Result { async fn info_cmd(server: &Server, section: &Option) -> Result { let storage_info = server.current_storage()?.info()?; - let mut info_map: std::collections::HashMap = storage_info.into_iter().collect(); + let mut info_map: std::collections::HashMap = + storage_info.into_iter().collect(); info_map.insert("redis_version".to_string(), "7.0.0".to_string()); info_map.insert("selected_db".to_string(), server.selected_db.to_string()); - info_map.insert("backend".to_string(), format!("{:?}", server.option.backend)); - + info_map.insert( + "backend".to_string(), + format!("{:?}", server.option.backend), + ); let mut info_string = String::new(); info_string.push_str("# Server\n"); - info_string.push_str(&format!("redis_version:{}\n", info_map.get("redis_version").unwrap())); + info_string.push_str(&format!( + "redis_version:{}\n", + info_map.get("redis_version").unwrap() + )); info_string.push_str(&format!("backend:{}\n", info_map.get("backend").unwrap())); - info_string.push_str(&format!("encrypted:{}\n", info_map.get("is_encrypted").unwrap())); - + info_string.push_str(&format!( + "encrypted:{}\n", + info_map.get("is_encrypted").unwrap() + )); + info_string.push_str("# Keyspace\n"); - info_string.push_str(&format!("db{}:keys={},expires=0,avg_ttl=0\n", info_map.get("selected_db").unwrap(), info_map.get("db_size").unwrap())); + info_string.push_str(&format!( + "db{}:keys={},expires=0,avg_ttl=0\n", + info_map.get("selected_db").unwrap(), + info_map.get("db_size").unwrap() + )); match section { Some(s) => { @@ -1344,28 +1697,24 @@ async fn del_cmd(server: &Server, k: &str) -> Result { Ok(Protocol::SimpleString("1".to_string())) } -async fn set_ex_cmd( - server: &Server, - k: &str, - v: &str, - x: &u128, -) -> Result { - server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?; +async fn set_ex_cmd(server: &Server, k: &str, v: &str, x: &u128) -> Result { + server + .current_storage()? + .setx(k.to_string(), v.to_string(), *x * 1000)?; Ok(Protocol::SimpleString("OK".to_string())) } -async fn set_px_cmd( - server: &Server, - k: &str, - v: &str, - x: &u128, -) -> Result { - server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?; +async fn set_px_cmd(server: &Server, k: &str, v: &str, x: &u128) -> Result { + server + .current_storage()? + .setx(k.to_string(), v.to_string(), *x)?; Ok(Protocol::SimpleString("OK".to_string())) } async fn set_cmd(server: &Server, k: &str, v: &str) -> Result { - server.current_storage()?.set(k.to_string(), v.to_string())?; + server + .current_storage()? + .set(k.to_string(), v.to_string())?; Ok(Protocol::SimpleString("OK".to_string())) } @@ -1394,11 +1743,7 @@ async fn set_with_opts_cmd( } // Fetch old value if needed for GET - let old_val = if get_old { - storage.get(key)? - } else { - None - }; + let old_val = if get_old { storage.get(key)? } else { None }; if should_set { if let Some(ms) = ex_ms { @@ -1478,7 +1823,11 @@ async fn get_cmd(server: &Server, k: &str) -> Result { } // Hash command implementations -async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result { +async fn hset_cmd( + server: &Server, + key: &str, + pairs: &[(String, String)], +) -> Result { let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?; Ok(Protocol::SimpleString(new_fields.to_string())) } @@ -1514,7 +1863,9 @@ async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result Result { match server.current_storage()?.hexists(key, field) { - Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())), + Ok(exists) => Ok(Protocol::SimpleString( + if exists { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } @@ -1557,31 +1908,54 @@ async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result Result { +async fn hsetnx_cmd( + server: &Server, + key: &str, + field: &str, + value: &str, +) -> Result { match server.current_storage()?.hsetnx(key, field, value) { - Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())), + Ok(was_set) => Ok(Protocol::SimpleString( + if was_set { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } -async fn hincrby_cmd(server: &Server, key: &str, field: &str, delta: i64) -> Result { +async fn hincrby_cmd( + server: &Server, + key: &str, + field: &str, + delta: i64, +) -> Result { let storage = server.current_storage()?; let current = storage.hget(key, field)?; let base: i64 = match current { - Some(v) => v.parse::().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?, + Some(v) => v + .parse::() + .map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?, None => 0, }; - let new_val = base.checked_add(delta).ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?; + let new_val = base + .checked_add(delta) + .ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?; // Update the field storage.hset(key, vec![(field.to_string(), new_val.to_string())])?; Ok(Protocol::SimpleString(new_val.to_string())) } -async fn hincrbyfloat_cmd(server: &Server, key: &str, field: &str, delta: f64) -> Result { +async fn hincrbyfloat_cmd( + server: &Server, + key: &str, + field: &str, + delta: f64, +) -> Result { let storage = server.current_storage()?; let current = storage.hget(key, field)?; let base: f64 = match current { - Some(v) => v.parse::().map_err(|_| DBError("ERR value is not a valid float".to_string()))?, + Some(v) => v + .parse::() + .map_err(|_| DBError("ERR value is not a valid float".to_string()))?, None => 0.0, }; let new_val = base + delta; @@ -1594,14 +1968,17 @@ async fn scan_cmd( server: &Server, cursor: &u64, pattern: Option<&str>, - count: &Option + count: &Option, ) -> Result { match server.current_storage()?.scan(*cursor, pattern, *count) { Ok((next_cursor, key_value_pairs)) => { let mut result = Vec::new(); result.push(Protocol::BulkString(next_cursor.to_string())); // For SCAN, we only return the keys, not the values - let keys: Vec = key_value_pairs.into_iter().map(|(key, _)| Protocol::BulkString(key)).collect(); + let keys: Vec = key_value_pairs + .into_iter() + .map(|(key, _)| Protocol::BulkString(key)) + .collect(); result.push(Protocol::Array(keys)); Ok(Protocol::Array(result)) } @@ -1614,9 +1991,12 @@ async fn hscan_cmd( key: &str, cursor: &u64, pattern: Option<&str>, - count: &Option + count: &Option, ) -> Result { - match server.current_storage()?.hscan(key, *cursor, pattern, *count) { + match server + .current_storage()? + .hscan(key, *cursor, pattern, *count) + { Ok((next_cursor, field_value_pairs)) => { let mut result = Vec::new(); result.push(Protocol::BulkString(next_cursor.to_string())); @@ -1642,7 +2022,9 @@ async fn ttl_cmd(server: &Server, key: &str) -> Result { async fn exists_cmd(server: &Server, key: &str) -> Result { match server.current_storage()?.exists(key) { - Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())), + Ok(exists) => Ok(Protocol::SimpleString( + if exists { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } @@ -1653,7 +2035,9 @@ async fn expire_cmd(server: &Server, key: &str, secs: i64) -> Result Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Ok(applied) => Ok(Protocol::SimpleString( + if applied { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } @@ -1664,7 +2048,9 @@ async fn pexpire_cmd(server: &Server, key: &str, ms: i64) -> Result Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Ok(applied) => Ok(Protocol::SimpleString( + if applied { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } @@ -1672,14 +2058,18 @@ async fn pexpire_cmd(server: &Server, key: &str, ms: i64) -> Result 1 if timeout removed, 0 otherwise async fn persist_cmd(server: &Server, key: &str) -> Result { match server.current_storage()?.persist(key) { - Ok(removed) => Ok(Protocol::SimpleString(if removed { "1" } else { "0" }.to_string())), + Ok(removed) => Ok(Protocol::SimpleString( + if removed { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } // EXPIREAT key timestamp-seconds -> 1 if timeout set, 0 otherwise async fn expireat_cmd(server: &Server, key: &str, ts_secs: i64) -> Result { match server.current_storage()?.expire_at_seconds(key, ts_secs) { - Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Ok(applied) => Ok(Protocol::SimpleString( + if applied { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } @@ -1687,7 +2077,9 @@ async fn expireat_cmd(server: &Server, key: &str, ts_secs: i64) -> Result 1 if timeout set, 0 otherwise async fn pexpireat_cmd(server: &Server, key: &str, ts_ms: i64) -> Result { match server.current_storage()?.pexpire_at_millis(key, ts_ms) { - Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())), + Ok(applied) => Ok(Protocol::SimpleString( + if applied { "1" } else { "0" }.to_string(), + )), Err(e) => Ok(Protocol::err(&e.0)), } } diff --git a/src/crypto.rs b/src/crypto.rs index 48a9f8c..db7a3ec 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -11,9 +11,9 @@ const TAG_LEN: usize = 16; #[derive(Debug)] pub enum CryptoError { - Format, // wrong length / header - Version(u8), // unknown version - Decrypt, // wrong key or corrupted data + Format, // wrong length / header + Version(u8), // unknown version + Decrypt, // wrong key or corrupted data } impl From for crate::error::DBError { @@ -71,4 +71,4 @@ impl CryptoFactory { let cipher = XChaCha20Poly1305::new(&self.key); cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt) } -} \ No newline at end of file +} diff --git a/src/error.rs b/src/error.rs index 3037c70..25314ff 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,8 @@ use std::num::ParseIntError; -use tokio::sync::mpsc; -use redb; use bincode; - +use redb; +use tokio::sync::mpsc; // todo: more error types #[derive(Debug)] diff --git a/src/lib.rs b/src/lib.rs index 2108082..b2e5bfd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,12 @@ -pub mod age; // NEW +pub mod age; // NEW pub mod cmd; pub mod crypto; pub mod error; pub mod options; pub mod protocol; -pub mod search_cmd; // Add this +pub mod search_cmd; // Add this pub mod server; pub mod storage; -pub mod storage_trait; // Add this -pub mod storage_sled; // Add this +pub mod storage_sled; // Add this +pub mod storage_trait; // Add this pub mod tantivy_search; diff --git a/src/main.rs b/src/main.rs index dce569b..c2c9ed6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,6 @@ struct Args { #[arg(long)] debug: bool, - /// Master encryption key for encrypted databases #[arg(long)] encryption_key: Option, diff --git a/src/protocol.rs b/src/protocol.rs index 6025074..22587eb 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -81,18 +81,21 @@ impl Protocol { pub fn encode(&self) -> String { match self { Protocol::SimpleString(s) => format!("+{}\r\n", s), - Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), - Protocol::Array(ss) => { + Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), + Protocol::Array(ss) => { format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::() } - Protocol::Null => "$-1\r\n".to_string(), - Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error + Protocol::Null => "$-1\r\n".to_string(), + Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error } } fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { match protocol.find("\r\n") { - Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])), + Some(x) => Ok(( + Self::SimpleString(protocol[..x].to_string()), + &protocol[x + 2..], + )), _ => Err(DBError(format!( "[new simple string] unsupported protocol: {:?}", protocol diff --git a/src/search_cmd.rs b/src/search_cmd.rs index d37df2a..c227cfe 100644 --- a/src/search_cmd.rs +++ b/src/search_cmd.rs @@ -3,8 +3,7 @@ use crate::{ protocol::Protocol, server::Server, tantivy_search::{ - TantivySearch, FieldDef, NumericType, IndexConfig, - SearchOptions, Filter, FilterType + FieldDef, Filter, FilterType, IndexConfig, NumericType, SearchOptions, TantivySearch, }, }; use std::collections::HashMap; @@ -17,14 +16,14 @@ pub async fn ft_create_cmd( ) -> Result { // Parse schema into field definitions let mut field_definitions = Vec::new(); - + for (field_name, field_type, options) in schema { let field_def = match field_type.to_uppercase().as_str() { "TEXT" => { let mut weight = 1.0; let mut sortable = false; let mut no_index = false; - + for opt in &options { match opt.to_uppercase().as_str() { "WEIGHT" => { @@ -40,7 +39,7 @@ pub async fn ft_create_cmd( _ => {} } } - + FieldDef::Text { stored: true, indexed: !no_index, @@ -50,13 +49,13 @@ pub async fn ft_create_cmd( } "NUMERIC" => { let mut sortable = false; - + for opt in &options { if opt.to_uppercase() == "SORTABLE" { sortable = true; } } - + FieldDef::Numeric { stored: true, indexed: true, @@ -67,7 +66,7 @@ pub async fn ft_create_cmd( "TAG" => { let mut separator = ",".to_string(); let mut case_sensitive = false; - + for i in 0..options.len() { match options[i].to_uppercase().as_str() { "SEPARATOR" => { @@ -79,44 +78,45 @@ pub async fn ft_create_cmd( _ => {} } } - + FieldDef::Tag { stored: true, separator, case_sensitive, } } - "GEO" => { - FieldDef::Geo { stored: true } - } + "GEO" => FieldDef::Geo { stored: true }, _ => { return Err(DBError(format!("Unknown field type: {}", field_type))); } }; - + field_definitions.push((field_name, field_def)); } - + // Create the search index let search_path = server.search_index_path(); let config = IndexConfig::default(); - - println!("Creating search index '{}' at path: {:?}", index_name, search_path); + + println!( + "Creating search index '{}' at path: {:?}", + index_name, search_path + ); println!("Field definitions: {:?}", field_definitions); - + let search_index = TantivySearch::new_with_schema( search_path, index_name.clone(), field_definitions, Some(config), )?; - + println!("Search index '{}' created successfully", index_name); - + // Store in registry let mut indexes = server.search_indexes.write().unwrap(); indexes.insert(index_name, Arc::new(search_index)); - + Ok(Protocol::SimpleString("OK".to_string())) } @@ -128,12 +128,13 @@ pub async fn ft_add_cmd( fields: HashMap, ) -> Result { let indexes = server.search_indexes.read().unwrap(); - - let search_index = indexes.get(&index_name) + + let search_index = indexes + .get(&index_name) .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; - + search_index.add_document_with_fields(&doc_id, fields)?; - + Ok(Protocol::SimpleString("OK".to_string())) } @@ -147,18 +148,20 @@ pub async fn ft_search_cmd( return_fields: Option>, ) -> Result { let indexes = server.search_indexes.read().unwrap(); - - let search_index = indexes.get(&index_name) + + let search_index = indexes + .get(&index_name) .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; - + // Convert filters to search filters - let search_filters = filters.into_iter().map(|(field, value)| { - Filter { + let search_filters = filters + .into_iter() + .map(|(field, value)| Filter { field, filter_type: FilterType::Equals(value), - } - }).collect(); - + }) + .collect(); + let options = SearchOptions { limit: limit.unwrap_or(10), offset: offset.unwrap_or(0), @@ -167,27 +170,27 @@ pub async fn ft_search_cmd( return_fields, highlight: false, }; - + let results = search_index.search_with_options(&query, options)?; - + // Format results as Redis protocol let mut response = Vec::new(); - + // First element is the total count response.push(Protocol::SimpleString(results.total.to_string())); - + // Then each document for doc in results.documents { let mut doc_array = Vec::new(); - + // Add document ID if it exists if let Some(id) = doc.fields.get("_id") { doc_array.push(Protocol::BulkString(id.clone())); } - + // Add score doc_array.push(Protocol::BulkString(doc.score.to_string())); - + // Add fields as key-value pairs for (field_name, field_value) in doc.fields { if field_name != "_id" { @@ -195,10 +198,10 @@ pub async fn ft_search_cmd( doc_array.push(Protocol::BulkString(field_value)); } } - + response.push(Protocol::Array(doc_array)); } - + Ok(Protocol::Array(response)) } @@ -208,56 +211,54 @@ pub async fn ft_del_cmd( doc_id: String, ) -> Result { let indexes = server.search_indexes.read().unwrap(); - - let _search_index = indexes.get(&index_name) + + let _search_index = indexes + .get(&index_name) .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; - + // For now, return success // In a full implementation, we'd need to add a delete method to TantivySearch println!("Deleting document '{}' from index '{}'", doc_id, index_name); - + Ok(Protocol::SimpleString("1".to_string())) } -pub async fn ft_info_cmd( - server: &Server, - index_name: String, -) -> Result { +pub async fn ft_info_cmd(server: &Server, index_name: String) -> Result { let indexes = server.search_indexes.read().unwrap(); - - let search_index = indexes.get(&index_name) + + let search_index = indexes + .get(&index_name) .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; - + let info = search_index.get_info()?; - + // Format info as Redis protocol let mut response = Vec::new(); - + response.push(Protocol::BulkString("index_name".to_string())); response.push(Protocol::BulkString(info.name)); - + response.push(Protocol::BulkString("num_docs".to_string())); response.push(Protocol::BulkString(info.num_docs.to_string())); - + response.push(Protocol::BulkString("num_fields".to_string())); response.push(Protocol::BulkString(info.fields.len().to_string())); - + response.push(Protocol::BulkString("fields".to_string())); - let fields_str = info.fields.iter() + let fields_str = info + .fields + .iter() .map(|f| format!("{}:{}", f.name, f.field_type)) .collect::>() .join(", "); response.push(Protocol::BulkString(fields_str)); - + Ok(Protocol::Array(response)) } -pub async fn ft_drop_cmd( - server: &Server, - index_name: String, -) -> Result { +pub async fn ft_drop_cmd(server: &Server, index_name: String) -> Result { let mut indexes = server.search_indexes.write().unwrap(); - + if indexes.remove(&index_name).is_some() { // Also remove the index files from disk let index_path = server.search_index_path().join(&index_name); @@ -269,4 +270,4 @@ pub async fn ft_drop_cmd( } else { Err(DBError(format!("Index '{}' not found", index_name))) } -} \ No newline at end of file +} diff --git a/src/server.rs b/src/server.rs index f9333e8..af631f8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,10 +1,10 @@ use core::str; use std::collections::HashMap; use std::sync::Arc; +use std::sync::RwLock; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; -use tokio::sync::{Mutex, oneshot}; -use std::sync::RwLock; +use tokio::sync::{oneshot, Mutex}; use std::sync::atomic::{AtomicU64, Ordering}; @@ -60,51 +60,50 @@ impl Server { pub fn current_storage(&self) -> Result, DBError> { let mut cache = self.db_cache.write().unwrap(); - + if let Some(storage) = cache.get(&self.selected_db) { return Ok(storage.clone()); } - - + // Create new database file let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) .join(format!("{}.db", self.selected_db)); - + // Ensure the directory exists before creating the database file if let Some(parent_dir) = db_file_path.parent() { std::fs::create_dir_all(parent_dir).map_err(|e| { - DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e)) + DBError(format!( + "Failed to create directory {}: {}", + parent_dir.display(), + e + )) })?; } - + println!("Creating new db file: {}", db_file_path.display()); - + let storage: Arc = match self.option.backend { - options::BackendType::Redb => { - Arc::new(Storage::new( - db_file_path, - self.should_encrypt_db(self.selected_db), - self.option.encryption_key.as_deref() - )?) - } - options::BackendType::Sled => { - Arc::new(SledStorage::new( - db_file_path, - self.should_encrypt_db(self.selected_db), - self.option.encryption_key.as_deref() - )?) - } + options::BackendType::Redb => Arc::new(Storage::new( + db_file_path, + self.should_encrypt_db(self.selected_db), + self.option.encryption_key.as_deref(), + )?), + options::BackendType::Sled => Arc::new(SledStorage::new( + db_file_path, + self.should_encrypt_db(self.selected_db), + self.option.encryption_key.as_deref(), + )?), }; - + cache.insert(self.selected_db, storage.clone()); Ok(storage) } - + fn should_encrypt_db(&self, db_index: u64) -> bool { // DB 0-9 are non-encrypted, DB 10+ are encrypted self.option.encrypt && db_index >= 10 } - + // Add method to get search index path pub fn search_index_path(&self) -> std::path::PathBuf { std::path::PathBuf::from(&self.option.dir).join("search_indexes") @@ -112,7 +111,12 @@ impl Server { // ----- BLPOP waiter helpers ----- - pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) { + pub async fn register_waiter( + &self, + db_index: u64, + key: &str, + side: PopSide, + ) -> (u64, oneshot::Receiver<(String, String)>) { let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); let (tx, rx) = oneshot::channel::<(String, String)>(); @@ -188,10 +192,7 @@ impl Server { Ok(()) } - pub async fn handle( - &mut self, - mut stream: tokio::net::TcpStream, - ) -> Result<(), DBError> { + pub async fn handle(&mut self, mut stream: tokio::net::TcpStream) -> Result<(), DBError> { // Accumulate incoming bytes to handle partial RESP frames let mut acc = String::new(); let mut buf = vec![0u8; 8192]; @@ -228,7 +229,10 @@ impl Server { acc = remaining.to_string(); if self.option.debug { - println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); + println!( + "\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", + cmd, protocol + ); } else { println!("got command: {:?}, protocol: {:?}", cmd, protocol); } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index abc2cd5..c52d456 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -12,9 +12,9 @@ use crate::error::DBError; // Re-export modules mod storage_basic; +mod storage_extra; mod storage_hset; mod storage_lists; -mod storage_extra; // Re-export implementations // Note: These imports are used by the impl blocks in the submodules @@ -28,7 +28,8 @@ const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("string const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); -const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); +const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = + TableDefinition::new("streams_data"); const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); @@ -55,9 +56,13 @@ pub struct Storage { } impl Storage { - pub fn new(path: impl AsRef, should_encrypt: bool, master_key: Option<&str>) -> Result { + pub fn new( + path: impl AsRef, + should_encrypt: bool, + master_key: Option<&str>, + ) -> Result { let db = Database::create(path)?; - + // Create tables if they don't exist let write_txn = db.begin_write()?; { @@ -71,23 +76,28 @@ impl Storage { let _ = write_txn.open_table(EXPIRATION_TABLE)?; } write_txn.commit()?; - + // Check if database was previously encrypted let read_txn = db.begin_read()?; let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?; - let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false); + let was_encrypted = encrypted_table + .get("encrypted")? + .map(|v| v.value() == 1) + .unwrap_or(false); drop(read_txn); - + let crypto = if should_encrypt || was_encrypted { if let Some(key) = master_key { Some(CryptoFactory::new(key.as_bytes())) } else { - return Err(DBError("Encryption requested but no master key provided".to_string())); + return Err(DBError( + "Encryption requested but no master key provided".to_string(), + )); } } else { None }; - + // If we're enabling encryption for the first time, mark it if should_encrypt && !was_encrypted { let write_txn = db.begin_write()?; @@ -97,13 +107,10 @@ impl Storage { } write_txn.commit()?; } - - Ok(Storage { - db, - crypto, - }) + + Ok(Storage { db, crypto }) } - + pub fn is_encrypted(&self) -> bool { self.crypto.is_some() } @@ -116,7 +123,7 @@ impl Storage { Ok(data.to_vec()) } } - + fn decrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { if let Some(crypto) = &self.crypto { Ok(crypto.decrypt(data)?) @@ -165,11 +172,22 @@ impl StorageBackend for Storage { self.get_key_type(key) } - fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + fn scan( + &self, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError> { self.scan(cursor, pattern, count) } - fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + fn hscan( + &self, + key: &str, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError> { self.hscan(key, cursor, pattern, count) } @@ -276,7 +294,7 @@ impl StorageBackend for Storage { fn is_encrypted(&self) -> bool { self.is_encrypted() } - + fn info(&self) -> Result, DBError> { self.info() } @@ -284,4 +302,4 @@ impl StorageBackend for Storage { fn clone_arc(&self) -> Arc { unimplemented!("Storage cloning not yet implemented for redb backend") } -} \ No newline at end of file +} diff --git a/src/storage/storage_basic.rs b/src/storage/storage_basic.rs index 1594b87..fbc7f15 100644 --- a/src/storage/storage_basic.rs +++ b/src/storage/storage_basic.rs @@ -1,6 +1,6 @@ -use redb::{ReadableTable}; -use crate::error::DBError; use super::*; +use crate::error::DBError; +use redb::ReadableTable; impl Storage { pub fn flushdb(&self) -> Result<(), DBError> { @@ -15,11 +15,17 @@ impl Storage { let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; // inefficient, but there is no other way - let keys: Vec = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + let keys: Vec = types_table + .iter()? + .map(|item| item.unwrap().0.value().to_string()) + .collect(); for key in keys { types_table.remove(key.as_str())?; } - let keys: Vec = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + let keys: Vec = strings_table + .iter()? + .map(|item| item.unwrap().0.value().to_string()) + .collect(); for key in keys { strings_table.remove(key.as_str())?; } @@ -34,23 +40,35 @@ impl Storage { for (key, field) in keys { hashes_table.remove((key.as_str(), field.as_str()))?; } - let keys: Vec = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + let keys: Vec = lists_table + .iter()? + .map(|item| item.unwrap().0.value().to_string()) + .collect(); for key in keys { lists_table.remove(key.as_str())?; } - let keys: Vec = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + let keys: Vec = streams_meta_table + .iter()? + .map(|item| item.unwrap().0.value().to_string()) + .collect(); for key in keys { streams_meta_table.remove(key.as_str())?; } - let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| { - let binding = item.unwrap(); - let (key, field) = binding.0.value(); - (key.to_string(), field.to_string()) - }).collect(); + let keys: Vec<(String, String)> = streams_data_table + .iter()? + .map(|item| { + let binding = item.unwrap(); + let (key, field) = binding.0.value(); + (key.to_string(), field.to_string()) + }) + .collect(); for (key, field) in keys { streams_data_table.remove((key.as_str(), field.as_str()))?; } - let keys: Vec = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); + let keys: Vec = expiration_table + .iter()? + .map(|item| item.unwrap().0.value().to_string()) + .collect(); for key in keys { expiration_table.remove(key.as_str())?; } @@ -62,7 +80,7 @@ impl Storage { pub fn get_key_type(&self, key: &str) -> Result, DBError> { let read_txn = self.db.begin_read()?; let table = read_txn.open_table(TYPES_TABLE)?; - + // Before returning type, check for expiration if let Some(type_val) = table.get(key)? { if type_val.value() == "string" { @@ -83,7 +101,7 @@ impl Storage { // ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted pub fn get(&self, key: &str) -> Result, DBError> { let read_txn = self.db.begin_read()?; - + let types_table = read_txn.open_table(TYPES_TABLE)?; match types_table.get(key)? { Some(type_val) if type_val.value() == "string" => { @@ -96,7 +114,7 @@ impl Storage { return Ok(None); } } - + // Get and decrypt value let strings_table = read_txn.open_table(STRINGS_TABLE)?; match strings_table.get(key)? { @@ -115,21 +133,21 @@ impl Storage { // ✅ ENCRYPTION APPLIED: Value is encrypted before storage pub fn set(&self, key: String, value: String) -> Result<(), DBError> { let write_txn = self.db.begin_write()?; - + { let mut types_table = write_txn.open_table(TYPES_TABLE)?; types_table.insert(key.as_str(), "string")?; - + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; // Only encrypt the value, not expiration let encrypted = self.encrypt_if_needed(value.as_bytes())?; strings_table.insert(key.as_str(), encrypted.as_slice())?; - + // Remove any existing expiration since this is a regular SET let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; expiration_table.remove(key.as_str())?; } - + write_txn.commit()?; Ok(()) } @@ -137,41 +155,42 @@ impl Storage { // ✅ ENCRYPTION APPLIED: Value is encrypted before storage pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { let write_txn = self.db.begin_write()?; - + { let mut types_table = write_txn.open_table(TYPES_TABLE)?; types_table.insert(key.as_str(), "string")?; - + let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; // Only encrypt the value let encrypted = self.encrypt_if_needed(value.as_bytes())?; strings_table.insert(key.as_str(), encrypted.as_slice())?; - + // Store expiration separately (unencrypted) let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let expires_at = expire_ms + now_in_millis(); expiration_table.insert(key.as_str(), &(expires_at as u64))?; } - + write_txn.commit()?; Ok(()) } pub fn del(&self, key: String) -> Result<(), DBError> { let write_txn = self.db.begin_write()?; - + { let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; - let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?; + let mut hashes_table: redb::Table<(&str, &str), &[u8]> = + write_txn.open_table(HASHES_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - + // Remove from type table types_table.remove(key.as_str())?; - + // Remove from strings table strings_table.remove(key.as_str())?; - + // Remove all hash fields for this key let mut to_remove = Vec::new(); let mut iter = hashes_table.iter()?; @@ -183,19 +202,19 @@ impl Storage { } } drop(iter); - + for (hash_key, field) in to_remove { hashes_table.remove((hash_key.as_str(), field.as_str()))?; } // Remove from lists table lists_table.remove(key.as_str())?; - + // Also remove expiration let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; expiration_table.remove(key.as_str())?; } - + write_txn.commit()?; Ok(()) } @@ -203,7 +222,7 @@ impl Storage { pub fn keys(&self, pattern: &str) -> Result, DBError> { let read_txn = self.db.begin_read()?; let table = read_txn.open_table(TYPES_TABLE)?; - + let mut keys = Vec::new(); let mut iter = table.iter()?; while let Some(entry) = iter.next() { @@ -212,7 +231,7 @@ impl Storage { keys.push(key); } } - + Ok(keys) } } @@ -242,4 +261,4 @@ impl Storage { } Ok(count) } -} \ No newline at end of file +} diff --git a/src/storage/storage_extra.rs b/src/storage/storage_extra.rs index d918b58..bc73641 100644 --- a/src/storage/storage_extra.rs +++ b/src/storage/storage_extra.rs @@ -1,24 +1,29 @@ -use redb::{ReadableTable}; -use crate::error::DBError; use super::*; +use crate::error::DBError; +use redb::ReadableTable; impl Storage { // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval - pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + pub fn scan( + &self, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError> { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; let strings_table = read_txn.open_table(STRINGS_TABLE)?; - + let mut result = Vec::new(); let mut current_cursor = 0u64; let limit = count.unwrap_or(10) as usize; - + let mut iter = types_table.iter()?; while let Some(entry) = iter.next() { let entry = entry?; let key = entry.0.value().to_string(); let key_type = entry.1.value().to_string(); - + if current_cursor >= cursor { // Apply pattern matching if specified let matches = if let Some(pat) = pattern { @@ -26,7 +31,7 @@ impl Storage { } else { true }; - + if matches { // For scan, we return key-value pairs for string types if key_type == "string" { @@ -41,7 +46,7 @@ impl Storage { // For non-string types, just return the key with type as value result.push((key, key_type)); } - + if result.len() >= limit { break; } @@ -49,15 +54,19 @@ impl Storage { } current_cursor += 1; } - - let next_cursor = if result.len() < limit { 0 } else { current_cursor }; + + let next_cursor = if result.len() < limit { + 0 + } else { + current_cursor + }; Ok((next_cursor, result)) } pub fn ttl(&self, key: &str) -> Result { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; - + match types_table.get(key)? { Some(type_val) if type_val.value() == "string" => { let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; @@ -75,14 +84,14 @@ impl Storage { } } Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types) - None => Ok(-2), // Key does not exist + None => Ok(-2), // Key does not exist } } pub fn exists(&self, key: &str) -> Result { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; - + match types_table.get(key)? { Some(type_val) if type_val.value() == "string" => { // Check if string key has expired @@ -95,7 +104,7 @@ impl Storage { Ok(true) } Some(_) => Ok(true), // Key exists and is not a string - None => Ok(false), // Key does not exist + None => Ok(false), // Key does not exist } } @@ -178,8 +187,12 @@ impl Storage { .unwrap_or(false); if is_string { let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; - let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; - expiration_table.insert(key, &((expires_at_ms as u64)))?; + let expires_at_ms: u128 = if ts_secs <= 0 { + 0 + } else { + (ts_secs as u128) * 1000 + }; + expiration_table.insert(key, &(expires_at_ms as u64))?; applied = true; } } @@ -201,7 +214,7 @@ impl Storage { if is_string { let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; - expiration_table.insert(key, &((expires_at_ms as u64)))?; + expiration_table.insert(key, &(expires_at_ms as u64))?; applied = true; } } @@ -223,21 +236,21 @@ pub fn glob_match(pattern: &str, text: &str) -> bool { if pattern == "*" { return true; } - + // Simple glob matching - supports * and ? wildcards let pattern_chars: Vec = pattern.chars().collect(); let text_chars: Vec = text.chars().collect(); - + fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { if pi >= pattern.len() { return ti >= text.len(); } - + if ti >= text.len() { // Check if remaining pattern is all '*' return pattern[pi..].iter().all(|&c| c == '*'); } - + match pattern[pi] { '*' => { // Try matching zero or more characters @@ -262,7 +275,7 @@ pub fn glob_match(pattern: &str, text: &str) -> bool { } } } - + match_recursive(&pattern_chars, &text_chars, 0, 0) } @@ -283,4 +296,4 @@ mod tests { assert!(glob_match("*test*", "this_is_a_test_string")); assert!(!glob_match("*test*", "this_is_a_string")); } -} \ No newline at end of file +} diff --git a/src/storage/storage_hset.rs b/src/storage/storage_hset.rs index dfe9394..9c6d230 100644 --- a/src/storage/storage_hset.rs +++ b/src/storage/storage_hset.rs @@ -1,44 +1,50 @@ -use redb::{ReadableTable}; -use crate::error::DBError; use super::*; +use crate::error::DBError; +use redb::ReadableTable; impl Storage { // ✅ ENCRYPTION APPLIED: Values are encrypted before storage pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result { let write_txn = self.db.begin_write()?; let mut new_fields = 0i64; - + { let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; - + let key_type = { let access_guard = types_table.get(key)?; access_guard.map(|v| v.value().to_string()) }; match key_type.as_deref() { - Some("hash") | None => { // Proceed if hash or new key + Some("hash") | None => { + // Proceed if hash or new key // Set the type to hash (only if new key or existing hash) types_table.insert(key, "hash")?; - + for (field, value) in pairs { // Check if field already exists let exists = hashes_table.get((key, field.as_str()))?.is_some(); - + // Encrypt the value before storing let encrypted = self.encrypt_if_needed(value.as_bytes())?; hashes_table.insert((key, field.as_str()), encrypted.as_slice())?; - + if !exists { new_fields += 1; } } } - Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => { + return Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value" + .to_string(), + )) + } } } - + write_txn.commit()?; Ok(new_fields) } @@ -47,7 +53,7 @@ impl Storage { pub fn hget(&self, key: &str, field: &str) -> Result, DBError> { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; - + let key_type = types_table.get(key)?.map(|v| v.value().to_string()); match key_type.as_deref() { @@ -62,7 +68,9 @@ impl Storage { None => Ok(None), } } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok(None), } } @@ -80,7 +88,7 @@ impl Storage { Some("hash") => { let hashes_table = read_txn.open_table(HASHES_TABLE)?; let mut result = Vec::new(); - + let mut iter = hashes_table.iter()?; while let Some(entry) = iter.next() { let entry = entry?; @@ -91,10 +99,12 @@ impl Storage { result.push((field.to_string(), value)); } } - + Ok(result) } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok(Vec::new()), } } @@ -102,24 +112,24 @@ impl Storage { pub fn hdel(&self, key: &str, fields: Vec) -> Result { let write_txn = self.db.begin_write()?; let mut deleted = 0i64; - + // First check if key exists and is a hash let key_type = { let types_table = write_txn.open_table(TYPES_TABLE)?; let access_guard = types_table.get(key)?; access_guard.map(|v| v.value().to_string()) }; - + match key_type.as_deref() { Some("hash") => { let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; - + for field in fields { if hashes_table.remove((key, field.as_str()))?.is_some() { deleted += 1; } } - + // Check if hash is now empty and remove type if so let mut has_fields = false; let mut iter = hashes_table.iter()?; @@ -132,16 +142,20 @@ impl Storage { } } drop(iter); - + if !has_fields { let mut types_table = write_txn.open_table(TYPES_TABLE)?; types_table.remove(key)?; } } - Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => { + return Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )) + } None => {} // Key does not exist, nothing to delete, return 0 deleted } - + write_txn.commit()?; Ok(deleted) } @@ -159,7 +173,9 @@ impl Storage { let hashes_table = read_txn.open_table(HASHES_TABLE)?; Ok(hashes_table.get((key, field))?.is_some()) } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok(false), } } @@ -176,7 +192,7 @@ impl Storage { Some("hash") => { let hashes_table = read_txn.open_table(HASHES_TABLE)?; let mut result = Vec::new(); - + let mut iter = hashes_table.iter()?; while let Some(entry) = iter.next() { let entry = entry?; @@ -185,10 +201,12 @@ impl Storage { result.push(field.to_string()); } } - + Ok(result) } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok(Vec::new()), } } @@ -206,7 +224,7 @@ impl Storage { Some("hash") => { let hashes_table = read_txn.open_table(HASHES_TABLE)?; let mut result = Vec::new(); - + let mut iter = hashes_table.iter()?; while let Some(entry) = iter.next() { let entry = entry?; @@ -217,10 +235,12 @@ impl Storage { result.push(value); } } - + Ok(result) } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok(Vec::new()), } } @@ -237,7 +257,7 @@ impl Storage { Some("hash") => { let hashes_table = read_txn.open_table(HASHES_TABLE)?; let mut count = 0i64; - + let mut iter = hashes_table.iter()?; while let Some(entry) = iter.next() { let entry = entry?; @@ -246,10 +266,12 @@ impl Storage { count += 1; } } - + Ok(count) } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok(0), } } @@ -267,7 +289,7 @@ impl Storage { Some("hash") => { let hashes_table = read_txn.open_table(HASHES_TABLE)?; let mut result = Vec::new(); - + for field in fields { match hashes_table.get((key, field.as_str()))? { Some(data) => { @@ -278,10 +300,12 @@ impl Storage { None => result.push(None), } } - + Ok(result) } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok(fields.into_iter().map(|_| None).collect()), } } @@ -290,39 +314,51 @@ impl Storage { pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result { let write_txn = self.db.begin_write()?; let mut result = false; - + { let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; - + let key_type = { let access_guard = types_table.get(key)?; access_guard.map(|v| v.value().to_string()) }; match key_type.as_deref() { - Some("hash") | None => { // Proceed if hash or new key + Some("hash") | None => { + // Proceed if hash or new key // Check if field already exists if hashes_table.get((key, field))?.is_none() { // Set the type to hash (only if new key or existing hash) types_table.insert(key, "hash")?; - + // Encrypt the value before storing let encrypted = self.encrypt_if_needed(value.as_bytes())?; hashes_table.insert((key, field), encrypted.as_slice())?; result = true; } } - Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => { + return Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value" + .to_string(), + )) + } } } - + write_txn.commit()?; Ok(result) } // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval - pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + pub fn hscan( + &self, + key: &str, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError> { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; let key_type = { @@ -336,28 +372,28 @@ impl Storage { let mut result = Vec::new(); let mut current_cursor = 0u64; let limit = count.unwrap_or(10) as usize; - + let mut iter = hashes_table.iter()?; while let Some(entry) = iter.next() { let entry = entry?; let (hash_key, field) = entry.0.value(); - + if hash_key == key { if current_cursor >= cursor { let field_str = field.to_string(); - + // Apply pattern matching if specified let matches = if let Some(pat) = pattern { super::storage_extra::glob_match(pat, &field_str) } else { true }; - + if matches { let decrypted = self.decrypt_if_needed(entry.1.value())?; let value = String::from_utf8(decrypted)?; result.push((field_str, value)); - + if result.len() >= limit { break; } @@ -366,12 +402,18 @@ impl Storage { current_cursor += 1; } } - - let next_cursor = if result.len() < limit { 0 } else { current_cursor }; + + let next_cursor = if result.len() < limit { + 0 + } else { + current_cursor + }; Ok((next_cursor, result)) } - Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + Some(_) => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), None => Ok((0, Vec::new())), } } -} \ No newline at end of file +} diff --git a/src/storage/storage_lists.rs b/src/storage/storage_lists.rs index 93a2ef6..7bfb3e0 100644 --- a/src/storage/storage_lists.rs +++ b/src/storage/storage_lists.rs @@ -1,20 +1,20 @@ -use redb::{ReadableTable}; -use crate::error::DBError; use super::*; +use crate::error::DBError; +use redb::ReadableTable; impl Storage { // ✅ ENCRYPTION APPLIED: Elements are encrypted before storage pub fn lpush(&self, key: &str, elements: Vec) -> Result { let write_txn = self.db.begin_write()?; let mut _length = 0i64; - + { let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - + // Set the type to list types_table.insert(key, "list")?; - + // Get current list or create empty one let mut list: Vec = match lists_table.get(key)? { Some(data) => { @@ -23,20 +23,20 @@ impl Storage { } None => Vec::new(), }; - + // Add elements to the front (left) for element in elements.into_iter() { list.insert(0, element); } - + _length = list.len() as i64; - + // Encrypt and store the updated list let serialized = serde_json::to_vec(&list)?; let encrypted = self.encrypt_if_needed(&serialized)?; lists_table.insert(key, encrypted.as_slice())?; } - + write_txn.commit()?; Ok(_length) } @@ -45,14 +45,14 @@ impl Storage { pub fn rpush(&self, key: &str, elements: Vec) -> Result { let write_txn = self.db.begin_write()?; let mut _length = 0i64; - + { let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?; - + // Set the type to list types_table.insert(key, "list")?; - + // Get current list or create empty one let mut list: Vec = match lists_table.get(key)? { Some(data) => { @@ -61,17 +61,17 @@ impl Storage { } None => Vec::new(), }; - + // Add elements to the end (right) list.extend(elements); _length = list.len() as i64; - + // Encrypt and store the updated list let serialized = serde_json::to_vec(&list)?; let encrypted = self.encrypt_if_needed(&serialized)?; lists_table.insert(key, encrypted.as_slice())?; } - + write_txn.commit()?; Ok(_length) } @@ -80,12 +80,12 @@ impl Storage { pub fn lpop(&self, key: &str, count: u64) -> Result, DBError> { let write_txn = self.db.begin_write()?; let mut result = Vec::new(); - + // First check if key exists and is a list, and get the data let list_data = { let types_table = write_txn.open_table(TYPES_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?; - + let result = match types_table.get(key)? { Some(type_val) if type_val.value() == "list" => { if let Some(data) = lists_table.get(key)? { @@ -100,7 +100,7 @@ impl Storage { }; result }; - + if let Some(mut list) = list_data { let pop_count = std::cmp::min(count as usize, list.len()); for _ in 0..pop_count { @@ -108,7 +108,7 @@ impl Storage { result.push(list.remove(0)); } } - + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; if list.is_empty() { // Remove the key if list is empty @@ -122,7 +122,7 @@ impl Storage { lists_table.insert(key, encrypted.as_slice())?; } } - + write_txn.commit()?; Ok(result) } @@ -131,12 +131,12 @@ impl Storage { pub fn rpop(&self, key: &str, count: u64) -> Result, DBError> { let write_txn = self.db.begin_write()?; let mut result = Vec::new(); - + // First check if key exists and is a list, and get the data let list_data = { let types_table = write_txn.open_table(TYPES_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?; - + let result = match types_table.get(key)? { Some(type_val) if type_val.value() == "list" => { if let Some(data) = lists_table.get(key)? { @@ -151,7 +151,7 @@ impl Storage { }; result }; - + if let Some(mut list) = list_data { let pop_count = std::cmp::min(count as usize, list.len()); for _ in 0..pop_count { @@ -159,7 +159,7 @@ impl Storage { result.push(list.pop().unwrap()); } } - + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; if list.is_empty() { // Remove the key if list is empty @@ -173,7 +173,7 @@ impl Storage { lists_table.insert(key, encrypted.as_slice())?; } } - + write_txn.commit()?; Ok(result) } @@ -181,7 +181,7 @@ impl Storage { pub fn llen(&self, key: &str) -> Result { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; - + match types_table.get(key)? { Some(type_val) if type_val.value() == "list" => { let lists_table = read_txn.open_table(LISTS_TABLE)?; @@ -202,7 +202,7 @@ impl Storage { pub fn lindex(&self, key: &str, index: i64) -> Result, DBError> { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; - + match types_table.get(key)? { Some(type_val) if type_val.value() == "list" => { let lists_table = read_txn.open_table(LISTS_TABLE)?; @@ -210,13 +210,13 @@ impl Storage { Some(data) => { let decrypted = self.decrypt_if_needed(data.value())?; let list: Vec = serde_json::from_slice(&decrypted)?; - + let actual_index = if index < 0 { list.len() as i64 + index } else { index }; - + if actual_index >= 0 && (actual_index as usize) < list.len() { Ok(Some(list[actual_index as usize].clone())) } else { @@ -234,7 +234,7 @@ impl Storage { pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError> { let read_txn = self.db.begin_read()?; let types_table = read_txn.open_table(TYPES_TABLE)?; - + match types_table.get(key)? { Some(type_val) if type_val.value() == "list" => { let lists_table = read_txn.open_table(LISTS_TABLE)?; @@ -242,22 +242,30 @@ impl Storage { Some(data) => { let decrypted = self.decrypt_if_needed(data.value())?; let list: Vec = serde_json::from_slice(&decrypted)?; - + if list.is_empty() { return Ok(Vec::new()); } - + let len = list.len() as i64; - let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; - let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; - + let start_idx = if start < 0 { + std::cmp::max(0, len + start) + } else { + std::cmp::min(start, len) + }; + let stop_idx = if stop < 0 { + std::cmp::max(-1, len + stop) + } else { + std::cmp::min(stop, len - 1) + }; + if start_idx > stop_idx || start_idx >= len { return Ok(Vec::new()); } - + let start_usize = start_idx as usize; let stop_usize = (stop_idx + 1) as usize; - + Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) } None => Ok(Vec::new()), @@ -270,12 +278,12 @@ impl Storage { // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { let write_txn = self.db.begin_write()?; - + // First check if key exists and is a list, and get the data let list_data = { let types_table = write_txn.open_table(TYPES_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?; - + let result = match types_table.get(key)? { Some(type_val) if type_val.value() == "list" => { if let Some(data) = lists_table.get(key)? { @@ -290,17 +298,25 @@ impl Storage { }; result }; - + if let Some(list) = list_data { if list.is_empty() { write_txn.commit()?; return Ok(()); } - + let len = list.len() as i64; - let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; - let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; - + let start_idx = if start < 0 { + std::cmp::max(0, len + start) + } else { + std::cmp::min(start, len) + }; + let stop_idx = if stop < 0 { + std::cmp::max(-1, len + stop) + } else { + std::cmp::min(stop, len - 1) + }; + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; if start_idx > stop_idx || start_idx >= len { // Remove the entire list @@ -311,7 +327,7 @@ impl Storage { let start_usize = start_idx as usize; let stop_usize = (stop_idx + 1) as usize; let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); - + if trimmed.is_empty() { lists_table.remove(key)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?; @@ -324,7 +340,7 @@ impl Storage { } } } - + write_txn.commit()?; Ok(()) } @@ -333,12 +349,12 @@ impl Storage { pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result { let write_txn = self.db.begin_write()?; let mut removed = 0i64; - + // First check if key exists and is a list, and get the data let list_data = { let types_table = write_txn.open_table(TYPES_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?; - + let result = match types_table.get(key)? { Some(type_val) if type_val.value() == "list" => { if let Some(data) = lists_table.get(key)? { @@ -353,7 +369,7 @@ impl Storage { }; result }; - + if let Some(mut list) = list_data { if count == 0 { // Remove all occurrences @@ -383,7 +399,7 @@ impl Storage { } } } - + let mut lists_table = write_txn.open_table(LISTS_TABLE)?; if list.is_empty() { lists_table.remove(key)?; @@ -396,8 +412,8 @@ impl Storage { lists_table.insert(key, encrypted.as_slice())?; } } - + write_txn.commit()?; Ok(removed) } -} \ No newline at end of file +} diff --git a/src/storage_sled/mod.rs b/src/storage_sled/mod.rs index ec22b88..d5514d7 100644 --- a/src/storage_sled/mod.rs +++ b/src/storage_sled/mod.rs @@ -1,12 +1,12 @@ // src/storage_sled/mod.rs -use std::path::Path; -use std::sync::Arc; -use std::collections::HashMap; -use std::time::{SystemTime, UNIX_EPOCH}; -use serde::{Deserialize, Serialize}; +use crate::crypto::CryptoFactory; use crate::error::DBError; use crate::storage_trait::StorageBackend; -use crate::crypto::CryptoFactory; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; #[derive(Serialize, Deserialize, Debug, Clone)] enum ValueType { @@ -28,44 +28,56 @@ pub struct SledStorage { } impl SledStorage { - pub fn new(path: impl AsRef, should_encrypt: bool, master_key: Option<&str>) -> Result { + pub fn new( + path: impl AsRef, + should_encrypt: bool, + master_key: Option<&str>, + ) -> Result { let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?; - let types = db.open_tree("types").map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?; - + let types = db + .open_tree("types") + .map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?; + // Check if database was previously encrypted - let encrypted_tree = db.open_tree("encrypted").map_err(|e| DBError(e.to_string()))?; - let was_encrypted = encrypted_tree.get("encrypted") + let encrypted_tree = db + .open_tree("encrypted") + .map_err(|e| DBError(e.to_string()))?; + let was_encrypted = encrypted_tree + .get("encrypted") .map_err(|e| DBError(e.to_string()))? .map(|v| v[0] == 1) .unwrap_or(false); - + let crypto = if should_encrypt || was_encrypted { if let Some(key) = master_key { Some(CryptoFactory::new(key.as_bytes())) } else { - return Err(DBError("Encryption requested but no master key provided".to_string())); + return Err(DBError( + "Encryption requested but no master key provided".to_string(), + )); } } else { None }; - + // Mark database as encrypted if enabling encryption if should_encrypt && !was_encrypted { - encrypted_tree.insert("encrypted", &[1u8]) + encrypted_tree + .insert("encrypted", &[1u8]) .map_err(|e| DBError(e.to_string()))?; encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?; } - + Ok(SledStorage { db, types, crypto }) } - + fn now_millis() -> u128 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_millis() } - + fn encrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { if let Some(crypto) = &self.crypto { Ok(crypto.encrypt(data)) @@ -73,7 +85,7 @@ impl SledStorage { Ok(data.to_vec()) } } - + fn decrypt_if_needed(&self, data: &[u8]) -> Result, DBError> { if let Some(crypto) = &self.crypto { Ok(crypto.decrypt(data)?) @@ -81,14 +93,14 @@ impl SledStorage { Ok(data.to_vec()) } } - + fn get_storage_value(&self, key: &str) -> Result, DBError> { match self.db.get(key).map_err(|e| DBError(e.to_string()))? { Some(encrypted_data) => { let decrypted = self.decrypt_if_needed(&encrypted_data)?; let storage_val: StorageValue = bincode::deserialize(&decrypted) .map_err(|e| DBError(format!("Deserialization error: {}", e)))?; - + // Check expiration if let Some(expires_at) = storage_val.expires_at { if Self::now_millis() > expires_at { @@ -98,47 +110,51 @@ impl SledStorage { return Ok(None); } } - + Ok(Some(storage_val)) } - None => Ok(None) + None => Ok(None), } } - + fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> { let data = bincode::serialize(&storage_val) .map_err(|e| DBError(format!("Serialization error: {}", e)))?; let encrypted = self.encrypt_if_needed(&data)?; - self.db.insert(key, encrypted).map_err(|e| DBError(e.to_string()))?; - + self.db + .insert(key, encrypted) + .map_err(|e| DBError(e.to_string()))?; + // Store type info (unencrypted for efficiency) let type_str = match &storage_val.value { ValueType::String(_) => "string", ValueType::Hash(_) => "hash", ValueType::List(_) => "list", }; - self.types.insert(key, type_str.as_bytes()).map_err(|e| DBError(e.to_string()))?; - + self.types + .insert(key, type_str.as_bytes()) + .map_err(|e| DBError(e.to_string()))?; + Ok(()) } - + fn glob_match(pattern: &str, text: &str) -> bool { if pattern == "*" { return true; } - + let pattern_chars: Vec = pattern.chars().collect(); let text_chars: Vec = text.chars().collect(); - + fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { if pi >= pattern.len() { return ti >= text.len(); } - + if ti >= text.len() { return pattern[pi..].iter().all(|&c| c == '*'); } - + match pattern[pi] { '*' => { for i in ti..=text.len() { @@ -158,7 +174,7 @@ impl SledStorage { } } } - + match_recursive(&pattern_chars, &text_chars, 0, 0) } } @@ -168,12 +184,12 @@ impl StorageBackend for SledStorage { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::String(s) => Ok(Some(s)), - _ => Ok(None) - } - None => Ok(None) + _ => Ok(None), + }, + None => Ok(None), } } - + fn set(&self, key: String, value: String) -> Result<(), DBError> { let storage_val = StorageValue { value: ValueType::String(value), @@ -183,7 +199,7 @@ impl StorageBackend for SledStorage { self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } - + fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { let storage_val = StorageValue { value: ValueType::String(value), @@ -193,25 +209,27 @@ impl StorageBackend for SledStorage { self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } - + fn del(&self, key: String) -> Result<(), DBError> { self.db.remove(&key).map_err(|e| DBError(e.to_string()))?; - self.types.remove(&key).map_err(|e| DBError(e.to_string()))?; + self.types + .remove(&key) + .map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } - + fn exists(&self, key: &str) -> Result { // Check with expiration Ok(self.get_storage_value(key)?.is_some()) } - + fn keys(&self, pattern: &str) -> Result, DBError> { let mut keys = Vec::new(); for item in self.types.iter() { let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?; let key = String::from_utf8_lossy(&key_bytes).to_string(); - + // Check if key is expired if self.get_storage_value(&key)?.is_some() { if Self::glob_match(pattern, &key) { @@ -221,24 +239,29 @@ impl StorageBackend for SledStorage { } Ok(keys) } - - fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + + fn scan( + &self, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError> { let mut result = Vec::new(); let mut current_cursor = 0u64; let limit = count.unwrap_or(10) as usize; - + for item in self.types.iter() { if current_cursor >= cursor { let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?; let key = String::from_utf8_lossy(&key_bytes).to_string(); - + // Check pattern match let matches = if let Some(pat) = pattern { Self::glob_match(pat, &key) } else { true }; - + if matches { // Check if key is expired and get value if let Some(storage_val) = self.get_storage_value(&key)? { @@ -247,7 +270,7 @@ impl StorageBackend for SledStorage { _ => String::from_utf8_lossy(&type_bytes).to_string(), }; result.push((key, value)); - + if result.len() >= limit { current_cursor += 1; break; @@ -257,11 +280,15 @@ impl StorageBackend for SledStorage { } current_cursor += 1; } - - let next_cursor = if result.len() < limit { 0 } else { current_cursor }; + + let next_cursor = if result.len() < limit { + 0 + } else { + current_cursor + }; Ok((next_cursor, result)) } - + fn dbsize(&self) -> Result { let mut count = 0i64; for item in self.types.iter() { @@ -273,38 +300,42 @@ impl StorageBackend for SledStorage { } Ok(count) } - + fn flushdb(&self) -> Result<(), DBError> { self.db.clear().map_err(|e| DBError(e.to_string()))?; self.types.clear().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(()) } - + fn get_key_type(&self, key: &str) -> Result, DBError> { // First check if key exists (handles expiration) if self.get_storage_value(key)?.is_some() { match self.types.get(key).map_err(|e| DBError(e.to_string()))? { Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())), - None => Ok(None) + None => Ok(None), } } else { Ok(None) } } - + // Hash operations fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::Hash(HashMap::new()), expires_at: None, }); - + let hash = match &mut storage_val.value { ValueType::Hash(h) => h, - _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + _ => { + return Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )) + } }; - + let mut new_fields = 0i64; for (field, value) in pairs { if !hash.contains_key(&field) { @@ -312,40 +343,46 @@ impl StorageBackend for SledStorage { } hash.insert(field, value); } - + self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(new_fields) } - + fn hget(&self, key: &str, field: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.get(field).cloned()), - _ => Ok(None) - } - None => Ok(None) + _ => Ok(None), + }, + None => Ok(None), } } - + fn hgetall(&self, key: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.into_iter().collect()), - _ => Ok(Vec::new()) - } - None => Ok(Vec::new()) + _ => Ok(Vec::new()), + }, + None => Ok(Vec::new()), } } - - fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError> { + + fn hscan( + &self, + key: &str, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => { let mut result = Vec::new(); let mut current_cursor = 0u64; let limit = count.unwrap_or(10) as usize; - + for (field, value) in h.iter() { if current_cursor >= cursor { let matches = if let Some(pat) = pattern { @@ -353,7 +390,7 @@ impl StorageBackend for SledStorage { } else { true }; - + if matches { result.push((field.clone(), value.clone())); if result.len() >= limit { @@ -364,107 +401,115 @@ impl StorageBackend for SledStorage { } current_cursor += 1; } - - let next_cursor = if result.len() < limit { 0 } else { current_cursor }; + + let next_cursor = if result.len() < limit { + 0 + } else { + current_cursor + }; Ok((next_cursor, result)) } - _ => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())) - } - None => Ok((0, Vec::new())) + _ => Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )), + }, + None => Ok((0, Vec::new())), } } - + fn hdel(&self, key: &str, fields: Vec) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(0) + None => return Ok(0), }; - + let hash = match &mut storage_val.value { ValueType::Hash(h) => h, - _ => return Ok(0) + _ => return Ok(0), }; - + let mut deleted = 0i64; for field in fields { if hash.remove(&field).is_some() { deleted += 1; } } - + if hash.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } - + Ok(deleted) } - + fn hexists(&self, key: &str, field: &str) -> Result { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.contains_key(field)), - _ => Ok(false) - } - None => Ok(false) + _ => Ok(false), + }, + None => Ok(false), } } - + fn hkeys(&self, key: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.keys().cloned().collect()), - _ => Ok(Vec::new()) - } - None => Ok(Vec::new()) + _ => Ok(Vec::new()), + }, + None => Ok(Vec::new()), } } - + fn hvals(&self, key: &str) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.values().cloned().collect()), - _ => Ok(Vec::new()) - } - None => Ok(Vec::new()) + _ => Ok(Vec::new()), + }, + None => Ok(Vec::new()), } } - + fn hlen(&self, key: &str) -> Result { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::Hash(h) => Ok(h.len() as i64), - _ => Ok(0) - } - None => Ok(0) + _ => Ok(0), + }, + None => Ok(0), } } - + fn hmget(&self, key: &str, fields: Vec) -> Result>, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { - ValueType::Hash(h) => { - Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()) - } - _ => Ok(fields.into_iter().map(|_| None).collect()) - } - None => Ok(fields.into_iter().map(|_| None).collect()) + ValueType::Hash(h) => Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()), + _ => Ok(fields.into_iter().map(|_| None).collect()), + }, + None => Ok(fields.into_iter().map(|_| None).collect()), } } - + fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::Hash(HashMap::new()), expires_at: None, }); - + let hash = match &mut storage_val.value { ValueType::Hash(h) => h, - _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + _ => { + return Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )) + } }; - + if hash.contains_key(field) { Ok(false) } else { @@ -474,58 +519,66 @@ impl StorageBackend for SledStorage { Ok(true) } } - + // List operations fn lpush(&self, key: &str, elements: Vec) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::List(Vec::new()), expires_at: None, }); - + let list = match &mut storage_val.value { ValueType::List(l) => l, - _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + _ => { + return Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )) + } }; - + for element in elements.into_iter().rev() { list.insert(0, element); } - + let len = list.len() as i64; self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(len) } - + fn rpush(&self, key: &str, elements: Vec) -> Result { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { value: ValueType::List(Vec::new()), expires_at: None, }); - + let list = match &mut storage_val.value { ValueType::List(l) => l, - _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), + _ => { + return Err(DBError( + "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), + )) + } }; - + list.extend(elements); let len = list.len() as i64; self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(len) } - + fn lpop(&self, key: &str, count: u64) -> Result, DBError> { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(Vec::new()) + None => return Ok(Vec::new()), }; - + let list = match &mut storage_val.value { ValueType::List(l) => l, - _ => return Ok(Vec::new()) + _ => return Ok(Vec::new()), }; - + let mut result = Vec::new(); for _ in 0..count.min(list.len() as u64) { if let Some(elem) = list.first() { @@ -533,55 +586,55 @@ impl StorageBackend for SledStorage { list.remove(0); } } - + if list.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } - + Ok(result) } - + fn rpop(&self, key: &str, count: u64) -> Result, DBError> { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(Vec::new()) + None => return Ok(Vec::new()), }; - + let list = match &mut storage_val.value { ValueType::List(l) => l, - _ => return Ok(Vec::new()) + _ => return Ok(Vec::new()), }; - + let mut result = Vec::new(); for _ in 0..count.min(list.len() as u64) { if let Some(elem) = list.pop() { result.push(elem); } } - + if list.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } - + Ok(result) } - + fn llen(&self, key: &str) -> Result { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { ValueType::List(l) => Ok(l.len() as i64), - _ => Ok(0) - } - None => Ok(0) + _ => Ok(0), + }, + None => Ok(0), } } - + fn lindex(&self, key: &str, index: i64) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { @@ -591,19 +644,19 @@ impl StorageBackend for SledStorage { } else { index }; - + if actual_index >= 0 && (actual_index as usize) < list.len() { Ok(Some(list[actual_index as usize].clone())) } else { Ok(None) } } - _ => Ok(None) - } - None => Ok(None) + _ => Ok(None), + }, + None => Ok(None), } } - + fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError> { match self.get_storage_value(key)? { Some(storage_val) => match storage_val.value { @@ -611,68 +664,68 @@ impl StorageBackend for SledStorage { if list.is_empty() { return Ok(Vec::new()); } - + let len = list.len() as i64; - let start_idx = if start < 0 { - std::cmp::max(0, len + start) - } else { - std::cmp::min(start, len) + let start_idx = if start < 0 { + std::cmp::max(0, len + start) + } else { + std::cmp::min(start, len) }; - let stop_idx = if stop < 0 { - std::cmp::max(-1, len + stop) - } else { - std::cmp::min(stop, len - 1) + let stop_idx = if stop < 0 { + std::cmp::max(-1, len + stop) + } else { + std::cmp::min(stop, len - 1) }; - + if start_idx > stop_idx || start_idx >= len { return Ok(Vec::new()); } - + let start_usize = start_idx as usize; let stop_usize = (stop_idx + 1) as usize; - + Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) } - _ => Ok(Vec::new()) - } - None => Ok(Vec::new()) + _ => Ok(Vec::new()), + }, + None => Ok(Vec::new()), } } - + fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(()) + None => return Ok(()), }; - + let list = match &mut storage_val.value { ValueType::List(l) => l, - _ => return Ok(()) + _ => return Ok(()), }; - + if list.is_empty() { return Ok(()); } - + let len = list.len() as i64; - let start_idx = if start < 0 { - std::cmp::max(0, len + start) - } else { - std::cmp::min(start, len) + let start_idx = if start < 0 { + std::cmp::max(0, len + start) + } else { + std::cmp::min(start, len) }; - let stop_idx = if stop < 0 { - std::cmp::max(-1, len + stop) - } else { - std::cmp::min(stop, len - 1) + let stop_idx = if stop < 0 { + std::cmp::max(-1, len + stop) + } else { + std::cmp::min(stop, len - 1) }; - + if start_idx > stop_idx || start_idx >= len { self.del(key.to_string())?; } else { let start_usize = start_idx as usize; let stop_usize = (stop_idx + 1) as usize; *list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); - + if list.is_empty() { self.del(key.to_string())?; } else { @@ -680,23 +733,23 @@ impl StorageBackend for SledStorage { self.db.flush().map_err(|e| DBError(e.to_string()))?; } } - + Ok(()) } - + fn lrem(&self, key: &str, count: i64, element: &str) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(0) + None => return Ok(0), }; - + let list = match &mut storage_val.value { ValueType::List(l) => l, - _ => return Ok(0) + _ => return Ok(0), }; - + let mut removed = 0i64; - + if count == 0 { // Remove all occurrences let original_len = list.len(); @@ -725,17 +778,17 @@ impl StorageBackend for SledStorage { } } } - + if list.is_empty() { self.del(key.to_string())?; } else { self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; } - + Ok(removed) } - + // Expiration fn ttl(&self, key: &str) -> Result { match self.get_storage_value(key)? { @@ -751,40 +804,40 @@ impl StorageBackend for SledStorage { Ok(-1) // Key exists but has no expiration } } - None => Ok(-2) // Key does not exist + None => Ok(-2), // Key does not exist } } - + fn expire_seconds(&self, key: &str, secs: u64) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(false) + None => return Ok(false), }; - + storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } - + fn pexpire_millis(&self, key: &str, ms: u128) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(false) + None => return Ok(false), }; - + storage_val.expires_at = Some(Self::now_millis() + ms); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } - + fn persist(&self, key: &str) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(false) + None => return Ok(false), }; - + if storage_val.expires_at.is_some() { storage_val.expires_at = None; self.set_storage_value(key, storage_val)?; @@ -794,37 +847,41 @@ impl StorageBackend for SledStorage { Ok(false) } } - + fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(false) + None => return Ok(false), + }; + + let expires_at_ms: u128 = if ts_secs <= 0 { + 0 + } else { + (ts_secs as u128) * 1000 }; - - let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; storage_val.expires_at = Some(expires_at_ms); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } - + fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result { let mut storage_val = match self.get_storage_value(key)? { Some(sv) => sv, - None => return Ok(false) + None => return Ok(false), }; - + let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; storage_val.expires_at = Some(expires_at_ms); self.set_storage_value(key, storage_val)?; self.db.flush().map_err(|e| DBError(e.to_string()))?; Ok(true) } - + fn is_encrypted(&self) -> bool { self.crypto.is_some() } - + fn info(&self) -> Result, DBError> { let dbsize = self.dbsize()?; Ok(vec![ @@ -842,4 +899,4 @@ impl StorageBackend for SledStorage { crypto: self.crypto.clone(), }) } -} \ No newline at end of file +} diff --git a/src/storage_trait.rs b/src/storage_trait.rs index 13fe11e..4e4ef1e 100644 --- a/src/storage_trait.rs +++ b/src/storage_trait.rs @@ -13,11 +13,22 @@ pub trait StorageBackend: Send + Sync { fn dbsize(&self) -> Result; fn flushdb(&self) -> Result<(), DBError>; fn get_key_type(&self, key: &str) -> Result, DBError>; - + // Scanning - fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError>; - fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option) -> Result<(u64, Vec<(String, String)>), DBError>; - + fn scan( + &self, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError>; + fn hscan( + &self, + key: &str, + cursor: u64, + pattern: Option<&str>, + count: Option, + ) -> Result<(u64, Vec<(String, String)>), DBError>; + // Hash operations fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result; fn hget(&self, key: &str, field: &str) -> Result, DBError>; @@ -29,7 +40,7 @@ pub trait StorageBackend: Send + Sync { fn hlen(&self, key: &str) -> Result; fn hmget(&self, key: &str, fields: Vec) -> Result>, DBError>; fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result; - + // List operations fn lpush(&self, key: &str, elements: Vec) -> Result; fn rpush(&self, key: &str, elements: Vec) -> Result; @@ -40,7 +51,7 @@ pub trait StorageBackend: Send + Sync { fn lrange(&self, key: &str, start: i64, stop: i64) -> Result, DBError>; fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError>; fn lrem(&self, key: &str, count: i64, element: &str) -> Result; - + // Expiration fn ttl(&self, key: &str) -> Result; fn expire_seconds(&self, key: &str, secs: u64) -> Result; @@ -48,11 +59,11 @@ pub trait StorageBackend: Send + Sync { fn persist(&self, key: &str) -> Result; fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result; fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result; - + // Metadata fn is_encrypted(&self) -> bool; fn info(&self) -> Result, DBError>; - + // Clone to Arc for sharing fn clone_arc(&self) -> Arc; -} \ No newline at end of file +} diff --git a/tests/debug_hset.rs b/tests/debug_hset.rs index 7930be8..b921d09 100644 --- a/tests/debug_hset.rs +++ b/tests/debug_hset.rs @@ -1,4 +1,4 @@ -use herodb::{server::Server, options::DBOption}; +use herodb::{options::DBOption, server::Server}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -7,7 +7,7 @@ use tokio::time::sleep; // Helper function to send command and get response async fn send_command(stream: &mut TcpStream, command: &str) -> String { stream.write_all(command.as_bytes()).await.unwrap(); - + let mut buffer = [0; 1024]; let n = stream.read(&mut buffer).await.unwrap(); String::from_utf8_lossy(&buffer[..n]).to_string() @@ -19,7 +19,7 @@ async fn debug_hset_simple() { let test_dir = "/tmp/herodb_debug_hset"; let _ = std::fs::remove_dir_all(test_dir); std::fs::create_dir_all(test_dir).unwrap(); - + let port = 16500; let option = DBOption { dir: test_dir.to_string(), @@ -29,35 +29,49 @@ async fn debug_hset_simple() { encryption_key: None, backend: herodb::options::BackendType::Redb, }; - + let mut server = Server::new(option).await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(200)).await; - - let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); - + + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); + // Test simple HSET println!("Testing HSET..."); - let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; + let response = send_command( + &mut stream, + "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n", + ) + .await; println!("HSET response: {}", response); assert!(response.contains("1"), "Expected '1' but got: {}", response); - + // Test HGET println!("Testing HGET..."); - let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n", + ) + .await; println!("HGET response: {}", response); - assert!(response.contains("value1"), "Expected 'value1' but got: {}", response); -} \ No newline at end of file + assert!( + response.contains("value1"), + "Expected 'value1' but got: {}", + response + ); +} diff --git a/tests/debug_hset_simple.rs b/tests/debug_hset_simple.rs index 356e704..621b962 100644 --- a/tests/debug_hset_simple.rs +++ b/tests/debug_hset_simple.rs @@ -1,4 +1,4 @@ -use herodb::{server::Server, options::DBOption}; +use herodb::{options::DBOption, server::Server}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -7,11 +7,11 @@ use tokio::time::sleep; #[tokio::test] async fn debug_hset_return_value() { let test_dir = "/tmp/herodb_debug_hset_return"; - + // Clean up any existing test data let _ = std::fs::remove_dir_all(&test_dir); std::fs::create_dir_all(&test_dir).unwrap(); - + let option = DBOption { dir: test_dir.to_string(), port: 16390, @@ -20,38 +20,42 @@ async fn debug_hset_return_value() { encryption_key: None, backend: herodb::options::BackendType::Redb, }; - + let mut server = Server::new(option).await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind("127.0.0.1:16390") .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(200)).await; - + // Connect and test HSET let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap(); - + // Send HSET command let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n"; stream.write_all(cmd.as_bytes()).await.unwrap(); - + let mut buffer = [0; 1024]; let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); - + println!("HSET response: {}", response); println!("Response bytes: {:?}", &buffer[..n]); - + // Check if response contains "1" - assert!(response.contains("1"), "Expected response to contain '1', got: {}", response); -} \ No newline at end of file + assert!( + response.contains("1"), + "Expected response to contain '1', got: {}", + response + ); +} diff --git a/tests/debug_protocol.rs b/tests/debug_protocol.rs index 8df61e7..0e9e305 100644 --- a/tests/debug_protocol.rs +++ b/tests/debug_protocol.rs @@ -1,12 +1,15 @@ -use herodb::protocol::Protocol; use herodb::cmd::Cmd; +use herodb::protocol::Protocol; #[test] fn test_protocol_parsing() { // Test TYPE command parsing let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n"; - println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n")); - + println!( + "Parsing TYPE command: {}", + type_cmd.replace("\r\n", "\\r\\n") + ); + match Protocol::from(type_cmd) { Ok((protocol, _)) => { println!("Protocol parsed successfully: {:?}", protocol); @@ -17,11 +20,14 @@ fn test_protocol_parsing() { } Err(e) => println!("Protocol parsing failed: {:?}", e), } - + // Test HEXISTS command parsing let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n"; - println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n")); - + println!( + "\nParsing HEXISTS command: {}", + hexists_cmd.replace("\r\n", "\\r\\n") + ); + match Protocol::from(hexists_cmd) { Ok((protocol, _)) => { println!("Protocol parsed successfully: {:?}", protocol); @@ -32,4 +38,4 @@ fn test_protocol_parsing() { } Err(e) => println!("Protocol parsing failed: {:?}", e), } -} \ No newline at end of file +} diff --git a/tests/redis_integration_tests.rs b/tests/redis_integration_tests.rs index 47033e1..a647551 100644 --- a/tests/redis_integration_tests.rs +++ b/tests/redis_integration_tests.rs @@ -81,13 +81,13 @@ fn setup_server() -> (ServerProcessGuard, u16) { ]) .spawn() .expect("Failed to start server process"); - + // Create a new guard that also owns the test directory path let guard = ServerProcessGuard { process: child, test_dir, }; - + // Give the server time to build and start (cargo run may compile first) std::thread::sleep(Duration::from_millis(2500)); @@ -206,7 +206,9 @@ async fn test_expiration(conn: &mut Connection) { async fn test_scan_operations(conn: &mut Connection) { cleanup_keys(conn).await; for i in 0..5 { - let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap(); + let _: () = conn + .set(format!("key{}", i), format!("value{}", i)) + .unwrap(); } let result: (u64, Vec) = redis::cmd("SCAN") .arg(0) @@ -253,7 +255,9 @@ async fn test_scan_with_count(conn: &mut Connection) { async fn test_hscan_operations(conn: &mut Connection) { cleanup_keys(conn).await; for i in 0..3 { - let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap(); + let _: () = conn + .hset("testhash", format!("field{}", i), format!("value{}", i)) + .unwrap(); } let result: (u64, Vec) = redis::cmd("HSCAN") .arg("testhash") @@ -273,8 +277,16 @@ async fn test_hscan_operations(conn: &mut Connection) { async fn test_transaction_operations(conn: &mut Connection) { cleanup_keys(conn).await; let _: () = redis::cmd("MULTI").query(conn).unwrap(); - let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap(); - let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap(); + let _: () = redis::cmd("SET") + .arg("key1") + .arg("value1") + .query(conn) + .unwrap(); + let _: () = redis::cmd("SET") + .arg("key2") + .arg("value2") + .query(conn) + .unwrap(); let _: Vec = redis::cmd("EXEC").query(conn).unwrap(); let result: String = conn.get("key1").unwrap(); assert_eq!(result, "value1"); @@ -286,7 +298,11 @@ async fn test_transaction_operations(conn: &mut Connection) { async fn test_discard_transaction(conn: &mut Connection) { cleanup_keys(conn).await; let _: () = redis::cmd("MULTI").query(conn).unwrap(); - let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap(); + let _: () = redis::cmd("SET") + .arg("discard") + .arg("value") + .query(conn) + .unwrap(); let _: () = redis::cmd("DISCARD").query(conn).unwrap(); let result: Option = conn.get("discard").unwrap(); assert_eq!(result, None); @@ -306,7 +322,6 @@ async fn test_type_command(conn: &mut Connection) { cleanup_keys(conn).await; } - async fn test_info_command(conn: &mut Connection) { cleanup_keys(conn).await; let result: String = redis::cmd("INFO").query(conn).unwrap(); diff --git a/tests/redis_tests.rs b/tests/redis_tests.rs index f6e8a13..589577c 100644 --- a/tests/redis_tests.rs +++ b/tests/redis_tests.rs @@ -1,4 +1,4 @@ -use herodb::{server::Server, options::DBOption}; +use herodb::{options::DBOption, server::Server}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -8,14 +8,14 @@ use tokio::time::sleep; async fn start_test_server(test_name: &str) -> (Server, u16) { use std::sync::atomic::{AtomicU16, Ordering}; static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379); - + let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let test_dir = format!("/tmp/herodb_test_{}", test_name); - + // Clean up and create test directory let _ = std::fs::remove_dir_all(&test_dir); std::fs::create_dir_all(&test_dir).unwrap(); - + let option = DBOption { dir: test_dir, port, @@ -24,7 +24,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { encryption_key: None, backend: herodb::options::BackendType::Redb, }; - + let server = Server::new(option).await; (server, port) } @@ -47,7 +47,7 @@ async fn connect_to_server(port: u16) -> TcpStream { // Helper function to send command and get response async fn send_command(stream: &mut TcpStream, command: &str) -> String { stream.write_all(command.as_bytes()).await.unwrap(); - + let mut buffer = [0; 1024]; let n = stream.read(&mut buffer).await.unwrap(); String::from_utf8_lossy(&buffer[..n]).to_string() @@ -56,22 +56,22 @@ async fn send_command(stream: &mut TcpStream, command: &str) -> String { #[tokio::test] async fn test_basic_ping() { let (mut server, port) = start_test_server("ping").await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await; assert!(response.contains("PONG")); @@ -80,40 +80,44 @@ async fn test_basic_ping() { #[tokio::test] async fn test_string_operations() { let (mut server, port) = start_test_server("string").await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test SET - let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n", + ) + .await; assert!(response.contains("OK")); - + // Test GET let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; assert!(response.contains("value")); - + // Test GET non-existent key let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await; assert!(response.contains("$-1")); // NULL response - + // Test DEL let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await; assert!(response.contains("1")); - + // Test GET after DEL let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; assert!(response.contains("$-1")); // NULL response @@ -122,33 +126,37 @@ async fn test_string_operations() { #[tokio::test] async fn test_incr_operations() { let (mut server, port) = start_test_server("incr").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test INCR on non-existent key let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await; assert!(response.contains("1")); - + // Test INCR on existing key let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await; assert!(response.contains("2")); - + // Test INCR on string value (should fail) - send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await; + send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n", + ) + .await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await; assert!(response.contains("ERR")); } @@ -156,63 +164,83 @@ async fn test_incr_operations() { #[tokio::test] async fn test_hash_operations() { let (mut server, port) = start_test_server("hash").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test HSET - let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; + let response = send_command( + &mut stream, + "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n", + ) + .await; assert!(response.contains("1")); // 1 new field - + // Test HGET - let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n", + ) + .await; assert!(response.contains("value1")); - + // Test HSET multiple fields let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await; assert!(response.contains("2")); // 2 new fields - + // Test HGETALL let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await; assert!(response.contains("field1")); assert!(response.contains("value1")); assert!(response.contains("field2")); assert!(response.contains("value2")); - + // Test HLEN let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await; assert!(response.contains("3")); - + // Test HEXISTS - let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n", + ) + .await; assert!(response.contains("1")); - - let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; + + let response = send_command( + &mut stream, + "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n", + ) + .await; assert!(response.contains("0")); - + // Test HDEL - let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n", + ) + .await; assert!(response.contains("1")); - + // Test HKEYS let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await; assert!(response.contains("field2")); assert!(response.contains("field3")); assert!(!response.contains("field1")); // Should be deleted - + // Test HVALS let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await; assert!(response.contains("value2")); @@ -222,46 +250,50 @@ async fn test_hash_operations() { #[tokio::test] async fn test_expiration() { let (mut server, port) = start_test_server("expiration").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test SETEX (expire in 1 second) - let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await; + let response = send_command( + &mut stream, + "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n", + ) + .await; assert!(response.contains("OK")); - + // Test TTL let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await; assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds - + // Test EXISTS let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await; assert!(response.contains("1")); - + // Wait for expiration sleep(Duration::from_millis(1100)).await; - + // Test GET after expiration let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await; assert!(response.contains("$-1")); // Should be NULL - + // Test TTL after expiration let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await; assert!(response.contains("-2")); // Key doesn't exist - + // Test EXISTS after expiration let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await; assert!(response.contains("0")); @@ -270,33 +302,37 @@ async fn test_expiration() { #[tokio::test] async fn test_scan_operations() { let (mut server, port) = start_test_server("scan").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Set up test data for i in 0..5 { let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i); send_command(&mut stream, &cmd).await; } - + // Test SCAN - let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; + let response = send_command( + &mut stream, + "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n", + ) + .await; assert!(response.contains("key")); - + // Test KEYS let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await; assert!(response.contains("key0")); @@ -306,29 +342,32 @@ async fn test_scan_operations() { #[tokio::test] async fn test_hscan_operations() { let (mut server, port) = start_test_server("hscan").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Set up hash data for i in 0..3 { - let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i); + let cmd = format!( + "*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", + i, i + ); send_command(&mut stream, &cmd).await; } - + // Test HSCAN let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; assert!(response.contains("field")); @@ -338,42 +377,50 @@ async fn test_hscan_operations() { #[tokio::test] async fn test_transaction_operations() { let (mut server, port) = start_test_server("transaction").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test MULTI let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await; assert!(response.contains("OK")); - + // Test queued commands - let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n", + ) + .await; assert!(response.contains("QUEUED")); - - let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await; + + let response = send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n", + ) + .await; assert!(response.contains("QUEUED")); - + // Test EXEC let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await; assert!(response.contains("OK")); // Should contain results of executed commands - + // Verify commands were executed let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await; assert!(response.contains("value1")); - + let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await; assert!(response.contains("value2")); } @@ -381,35 +428,39 @@ async fn test_transaction_operations() { #[tokio::test] async fn test_discard_transaction() { let (mut server, port) = start_test_server("discard").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test MULTI let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await; assert!(response.contains("OK")); - + // Test queued command - let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n", + ) + .await; assert!(response.contains("QUEUED")); - + // Test DISCARD let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await; assert!(response.contains("OK")); - + // Verify command was not executed let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await; assert!(response.contains("$-1")); // Should be NULL @@ -418,33 +469,41 @@ async fn test_discard_transaction() { #[tokio::test] async fn test_type_command() { let (mut server, port) = start_test_server("type").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test string type - send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; + send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n", + ) + .await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; assert!(response.contains("string")); - + // Test hash type - send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; + send_command( + &mut stream, + "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n", + ) + .await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; assert!(response.contains("hash")); - + // Test non-existent key let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; assert!(response.contains("none")); @@ -453,30 +512,38 @@ async fn test_type_command() { #[tokio::test] async fn test_config_commands() { let (mut server, port) = start_test_server("config").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test CONFIG GET databases - let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n", + ) + .await; assert!(response.contains("databases")); assert!(response.contains("16")); - + // Test CONFIG GET dir - let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n", + ) + .await; assert!(response.contains("dir")); assert!(response.contains("/tmp/herodb_test_config")); } @@ -484,27 +551,27 @@ async fn test_config_commands() { #[tokio::test] async fn test_info_command() { let (mut server, port) = start_test_server("info").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test INFO let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await; assert!(response.contains("redis_version")); - + // Test INFO replication let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await; assert!(response.contains("role:master")); @@ -513,36 +580,44 @@ async fn test_info_command() { #[tokio::test] async fn test_error_handling() { let (mut server, port) = start_test_server("error").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test WRONGTYPE error - try to use hash command on string - send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; - let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await; + send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n", + ) + .await; + let response = send_command( + &mut stream, + "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n", + ) + .await; assert!(response.contains("WRONGTYPE")); - + // Test unknown command let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await; assert!(response.contains("unknown cmd") || response.contains("ERR")); - + // Test EXEC without MULTI let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await; assert!(response.contains("ERR")); - + // Test DISCARD without MULTI let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await; assert!(response.contains("ERR")); @@ -551,29 +626,37 @@ async fn test_error_handling() { #[tokio::test] async fn test_list_operations() { let (mut server, port) = start_test_server("list").await; - + tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(100)).await; - + let mut stream = connect_to_server(port).await; - + // Test LPUSH - let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await; + let response = send_command( + &mut stream, + "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n", + ) + .await; assert!(response.contains("2")); // 2 elements - + // Test RPUSH - let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await; + let response = send_command( + &mut stream, + "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n", + ) + .await; assert!(response.contains("4")); // 4 elements // Test LLEN @@ -581,29 +664,52 @@ async fn test_list_operations() { assert!(response.contains("4")); // Test LRANGE - let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await; - assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"); - + let response = send_command( + &mut stream, + "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n", + ) + .await; + assert_eq!( + response, + "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n" + ); + // Test LINDEX - let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n", + ) + .await; assert_eq!(response, "$1\r\nb\r\n"); - + // Test LPOP let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await; assert_eq!(response, "$1\r\nb\r\n"); - + // Test RPOP let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await; assert_eq!(response, "$1\r\nd\r\n"); // Test LREM - send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a - let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await; + send_command( + &mut stream, + "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n", + ) + .await; // list is now a, c, a + let response = send_command( + &mut stream, + "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n", + ) + .await; assert!(response.contains("1")); // Test LTRIM - let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await; + let response = send_command( + &mut stream, + "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n", + ) + .await; assert!(response.contains("OK")); let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await; assert!(response.contains("1")); -} \ No newline at end of file +} diff --git a/tests/simple_integration_test.rs b/tests/simple_integration_test.rs index 42269df..d1704e3 100644 --- a/tests/simple_integration_test.rs +++ b/tests/simple_integration_test.rs @@ -1,23 +1,23 @@ -use herodb::{server::Server, options::DBOption}; +use herodb::{options::DBOption, server::Server}; use std::time::Duration; -use tokio::time::sleep; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; +use tokio::time::sleep; // Helper function to start a test server with clean data directory async fn start_test_server(test_name: &str) -> (Server, u16) { use std::sync::atomic::{AtomicU16, Ordering}; static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000); - + // Get a unique port for this test let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); - + let test_dir = format!("/tmp/herodb_test_{}", test_name); - + // Clean up any existing test data let _ = std::fs::remove_dir_all(&test_dir); std::fs::create_dir_all(&test_dir).unwrap(); - + let option = DBOption { dir: test_dir, port, @@ -26,16 +26,18 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { encryption_key: None, backend: herodb::options::BackendType::Redb, }; - + let server = Server::new(option).await; (server, port) } // Helper function to send Redis command and get response async fn send_redis_command(port: u16, command: &str) -> String { - let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); stream.write_all(command.as_bytes()).await.unwrap(); - + let mut buffer = [0; 1024]; let n = stream.read(&mut buffer).await.unwrap(); String::from_utf8_lossy(&buffer[..n]).to_string() @@ -44,13 +46,13 @@ async fn send_redis_command(port: u16, command: &str) -> String { #[tokio::test] async fn test_basic_redis_functionality() { let (mut server, port) = start_test_server("basic").await; - + // Start server in background with timeout let server_handle = tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + // Accept only a few connections for testing for _ in 0..10 { if let Ok((stream, _)) = listener.accept().await { @@ -58,68 +60,79 @@ async fn test_basic_redis_functionality() { } } }); - + sleep(Duration::from_millis(100)).await; - + // Test PING let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await; assert!(response.contains("PONG")); - + // Test SET - let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; + let response = + send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; assert!(response.contains("OK")); - + // Test GET let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; assert!(response.contains("value")); - + // Test HSET - let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; + let response = send_redis_command( + port, + "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n", + ) + .await; assert!(response.contains("1")); - + // Test HGET - let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await; + let response = + send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await; assert!(response.contains("value")); - + // Test EXISTS let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await; assert!(response.contains("1")); - + // Test TTL let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await; assert!(response.contains("-1")); // No expiration - + // Test TYPE let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await; assert!(response.contains("string")); - + // Test QUIT to close connection gracefully - let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); - stream.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes()).await.unwrap(); + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); + stream + .write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes()) + .await + .unwrap(); let mut buffer = [0; 1024]; let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); assert!(response.contains("OK")); - + // Ensure the stream is closed stream.shutdown().await.unwrap(); // Stop the server server_handle.abort(); - + println!("✅ All basic Redis functionality tests passed!"); } #[tokio::test] async fn test_hash_operations() { let (mut server, port) = start_test_server("hash_ops").await; - + // Start server in background with timeout let server_handle = tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + // Accept only a few connections for testing for _ in 0..5 { if let Ok((stream, _)) = listener.accept().await { @@ -127,53 +140,57 @@ async fn test_hash_operations() { } } }); - + sleep(Duration::from_millis(100)).await; - + // Test HSET multiple fields let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await; assert!(response.contains("2")); // 2 new fields - + // Test HGETALL let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await; assert!(response.contains("field1")); assert!(response.contains("value1")); assert!(response.contains("field2")); assert!(response.contains("value2")); - + // Test HEXISTS - let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; + let response = send_redis_command( + port, + "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n", + ) + .await; assert!(response.contains("1")); - + // Test HLEN let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await; assert!(response.contains("2")); - + // Test HSCAN let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; assert!(response.contains("field1")); assert!(response.contains("value1")); assert!(response.contains("field2")); assert!(response.contains("value2")); - + // Stop the server // For hash operations, we don't have a persistent stream, so we'll just abort the server. // The server should handle closing its connections. server_handle.abort(); - + println!("✅ All hash operations tests passed!"); } #[tokio::test] async fn test_transaction_operations() { let (mut server, port) = start_test_server("transactions").await; - + // Start server in background with timeout let server_handle = tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + // Accept only a few connections for testing for _ in 0..5 { if let Ok((stream, _)) = listener.accept().await { @@ -181,49 +198,69 @@ async fn test_transaction_operations() { } } }); - + sleep(Duration::from_millis(100)).await; - + // Use a single connection for the transaction - let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); - + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); + // Test MULTI - stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap(); + stream + .write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()) + .await + .unwrap(); let mut buffer = [0; 1024]; let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); assert!(response.contains("OK")); - + // Test queued commands - stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap(); + stream + .write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()) + .await + .unwrap(); let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); assert!(response.contains("QUEUED")); - - stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap(); + + stream + .write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()) + .await + .unwrap(); let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); assert!(response.contains("QUEUED")); - + // Test EXEC - stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap(); + stream + .write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()) + .await + .unwrap(); let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); assert!(response.contains("OK")); // Should contain array of OK responses - + // Verify commands were executed - stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes()).await.unwrap(); + stream + .write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes()) + .await + .unwrap(); let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); assert!(response.contains("value1")); - - stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes()).await.unwrap(); + + stream + .write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes()) + .await + .unwrap(); let n = stream.read(&mut buffer).await.unwrap(); let response = String::from_utf8_lossy(&buffer[..n]); assert!(response.contains("value2")); // Stop the server server_handle.abort(); - + println!("✅ All transaction operations tests passed!"); -} \ No newline at end of file +} diff --git a/tests/simple_redis_test.rs b/tests/simple_redis_test.rs index 8afb304..bf9aee7 100644 --- a/tests/simple_redis_test.rs +++ b/tests/simple_redis_test.rs @@ -1,4 +1,4 @@ -use herodb::{server::Server, options::DBOption}; +use herodb::{options::DBOption, server::Server}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -8,14 +8,14 @@ use tokio::time::sleep; async fn start_test_server(test_name: &str) -> (Server, u16) { use std::sync::atomic::{AtomicU16, Ordering}; static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500); - + let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let test_dir = format!("/tmp/herodb_simple_test_{}", test_name); - + // Clean up any existing test data let _ = std::fs::remove_dir_all(&test_dir); std::fs::create_dir_all(&test_dir).unwrap(); - + let option = DBOption { dir: test_dir, port, @@ -24,7 +24,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { encryption_key: None, backend: herodb::options::BackendType::Redb, }; - + let server = Server::new(option).await; (server, port) } @@ -32,7 +32,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) { // Helper function to send command and get response async fn send_command(stream: &mut TcpStream, command: &str) -> String { stream.write_all(command.as_bytes()).await.unwrap(); - + let mut buffer = [0; 1024]; let n = stream.read(&mut buffer).await.unwrap(); String::from_utf8_lossy(&buffer[..n]).to_string() @@ -56,22 +56,22 @@ async fn connect_to_server(port: u16) -> TcpStream { #[tokio::test] async fn test_basic_ping_simple() { let (mut server, port) = start_test_server("ping").await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(200)).await; - + let mut stream = connect_to_server(port).await; let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await; assert!(response.contains("PONG")); @@ -80,31 +80,43 @@ async fn test_basic_ping_simple() { #[tokio::test] async fn test_hset_clean_db() { let (mut server, port) = start_test_server("hset_clean").await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(200)).await; - + let mut stream = connect_to_server(port).await; - + // Test HSET - should return 1 for new field - let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; + let response = send_command( + &mut stream, + "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n", + ) + .await; println!("HSET response: {}", response); - assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response); - + assert!( + response.contains("1"), + "Expected HSET to return 1, got: {}", + response + ); + // Test HGET - let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n", + ) + .await; println!("HGET response: {}", response); assert!(response.contains("value1")); } @@ -112,73 +124,101 @@ async fn test_hset_clean_db() { #[tokio::test] async fn test_type_command_simple() { let (mut server, port) = start_test_server("type").await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(200)).await; - + let mut stream = connect_to_server(port).await; - + // Test string type - send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; + send_command( + &mut stream, + "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n", + ) + .await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; println!("TYPE string response: {}", response); assert!(response.contains("string")); - + // Test hash type - send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; + send_command( + &mut stream, + "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n", + ) + .await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; println!("TYPE hash response: {}", response); assert!(response.contains("hash")); - + // Test non-existent key let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; println!("TYPE noexist response: {}", response); - assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response); + assert!( + response.contains("none"), + "Expected 'none' for non-existent key, got: {}", + response + ); } #[tokio::test] async fn test_hexists_simple() { let (mut server, port) = start_test_server("hexists").await; - + // Start server in background tokio::spawn(async move { let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .unwrap(); - + loop { if let Ok((stream, _)) = listener.accept().await { let _ = server.handle(stream).await; } } }); - + sleep(Duration::from_millis(200)).await; - + let mut stream = connect_to_server(port).await; - + // Set up hash - send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; - + send_command( + &mut stream, + "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n", + ) + .await; + // Test HEXISTS for existing field - let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n", + ) + .await; println!("HEXISTS existing field response: {}", response); assert!(response.contains("1")); - + // Test HEXISTS for non-existent field - let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; + let response = send_command( + &mut stream, + "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n", + ) + .await; println!("HEXISTS non-existent field response: {}", response); - assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response); -} \ No newline at end of file + assert!( + response.contains("0"), + "Expected HEXISTS to return 0 for non-existent field, got: {}", + response + ); +} diff --git a/tests/usage_suite.rs b/tests/usage_suite.rs index d7298cc..6874edb 100644 --- a/tests/usage_suite.rs +++ b/tests/usage_suite.rs @@ -325,7 +325,11 @@ async fn test_03_scan_and_keys() { let mut s = connect(port).await; for i in 0..5 { - let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await; + let _ = send_cmd( + &mut s, + &["SET", &format!("key{}", i), &format!("value{}", i)], + ) + .await; } let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await; @@ -358,7 +362,11 @@ async fn test_04_hashes_suite() { assert_contains(&h2, "2", "HSET added 2 new fields"); // HMGET - let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await; + let hmg = send_cmd( + &mut s, + &["HMGET", "profile:1", "name", "age", "city", "nope"], + ) + .await; assert_contains(&hmg, "alice", "HMGET name"); assert_contains(&hmg, "30", "HMGET age"); assert_contains(&hmg, "paris", "HMGET city"); @@ -392,7 +400,11 @@ async fn test_04_hashes_suite() { assert_contains(&hnx1, "1", "HSETNX new field -> 1"); // HSCAN - let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await; + let hscan = send_cmd( + &mut s, + &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"], + ) + .await; assert_contains(&hscan, "name", "HSCAN matches fields starting with n"); assert_contains(&hscan, "nickname", "HSCAN nickname present"); @@ -424,13 +436,21 @@ async fn test_05_lists_suite_including_blpop() { assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b"); let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; - assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]"); + assert_eq_resp( + &lr, + "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", + "LRANGE q:jobs 0 -1 should be [b,a,c]", + ); // LTRIM let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await; assert_contains(<rim, "OK", "LTRIM OK"); let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; - assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]"); + assert_eq_resp( + &lr_post, + "*2\r\n$1\r\nb\r\n$1\r\na\r\n", + "After LTRIM, list [b,a]", + ); // LREM remove first occurrence of b let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await; @@ -444,7 +464,11 @@ async fn test_05_lists_suite_including_blpop() { // LPOP with count on empty -> [] let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await; - assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array"); + assert_eq_resp( + &lpop0, + "*0\r\n", + "LPOP with count on empty returns empty array", + ); // BLPOP: block on one client, push from another let c1 = connect(port).await; @@ -513,7 +537,7 @@ async fn test_07_age_stateless_suite() { // naive parse for tests let mut lines = resp.lines(); let _ = lines.next(); // *2 - // $len + // $len let _ = lines.next(); let recip = lines.next().unwrap_or("").to_string(); let _ = lines.next(); @@ -548,8 +572,16 @@ async fn test_07_age_stateless_suite() { let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await; assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature"); - let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await; - assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature"); + let v_bad = send_cmd( + &mut s, + &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64], + ) + .await; + assert_contains( + &v_bad, + "0", + "VERIFY should be 0 for invalid message/signature", + ); } #[tokio::test] @@ -581,7 +613,7 @@ async fn test_08_age_persistent_named_suite() { skg ); - let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await; + let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"]).await; let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME"); let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await; assert_contains(&v1, "1", "VERIFYNAME valid => 1"); @@ -597,60 +629,69 @@ async fn test_08_age_persistent_named_suite() { #[tokio::test] async fn test_10_expire_pexpire_persist() { - let (server, port) = start_test_server("expire_suite").await; - spawn_listener(server, port).await; - sleep(Duration::from_millis(150)).await; + let (server, port) = start_test_server("expire_suite").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; - let mut s = connect(port).await; + let mut s = connect(port).await; - // EXPIRE: seconds - let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await; - let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await; - assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)"); - let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await; - assert!( - ttl1.contains("1") || ttl1.contains("0"), - "TTL exp:s should be 1 or 0, got: {}", - ttl1 - ); - sleep(Duration::from_millis(1100)).await; - let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await; - assert_contains(&get_after, "$-1", "GET after expiry should be Null"); - let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await; - assert_contains(&ttl_after, "-2", "TTL after expiry -> -2"); - let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await; - assert_contains(&exists_after, "0", "EXISTS after expiry -> 0"); + // EXPIRE: seconds + let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await; + let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await; + assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)"); + let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await; + assert!( + ttl1.contains("1") || ttl1.contains("0"), + "TTL exp:s should be 1 or 0, got: {}", + ttl1 + ); + sleep(Duration::from_millis(1100)).await; + let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await; + assert_contains(&get_after, "$-1", "GET after expiry should be Null"); + let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await; + assert_contains(&ttl_after, "-2", "TTL after expiry -> -2"); + let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await; + assert_contains(&exists_after, "0", "EXISTS after expiry -> 0"); - // PEXPIRE: milliseconds - let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await; - let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await; - assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)"); - let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await; - assert!( - ttl_ms1.contains("1") || ttl_ms1.contains("0"), - "TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}", - ttl_ms1 - ); - sleep(Duration::from_millis(1600)).await; - let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await; - assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0"); + // PEXPIRE: milliseconds + let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await; + let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await; + assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)"); + let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await; + assert!( + ttl_ms1.contains("1") || ttl_ms1.contains("0"), + "TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}", + ttl_ms1 + ); + sleep(Duration::from_millis(1600)).await; + let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await; + assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0"); - // PERSIST: remove expiration - let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await; - let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await; - let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await; - assert!( - ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"), - "TTL exp:persist should be >=0 before persist, got: {}", - ttl_pre - ); - let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; - assert_contains(&persist1, "1", "PERSIST should remove expiration"); - let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await; - assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)"); - // Second persist should return 0 (nothing to remove) - let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; - assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); + // PERSIST: remove expiration + let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await; + let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await; + let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await; + assert!( + ttl_pre.contains("5") + || ttl_pre.contains("4") + || ttl_pre.contains("3") + || ttl_pre.contains("2") + || ttl_pre.contains("1") + || ttl_pre.contains("0"), + "TTL exp:persist should be >=0 before persist, got: {}", + ttl_pre + ); + let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; + assert_contains(&persist1, "1", "PERSIST should remove expiration"); + let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await; + assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)"); + // Second persist should return 0 (nothing to remove) + let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; + assert_contains( + &persist2, + "0", + "PERSIST again -> 0 (no expiration to remove)", + ); } #[tokio::test] @@ -663,7 +704,11 @@ async fn test_11_set_with_options() { // SET with GET on non-existing key -> returns Null, sets value let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; - assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist"); + assert_contains( + &set_get1, + "$-1", + "SET s1 v1 GET returns Null when key didn't exist", + ); let g1 = send_cmd(&mut s, &["GET", "s1"]).await; assert_contains(&g1, "v1", "GET s1 after first SET"); @@ -707,42 +752,42 @@ async fn test_11_set_with_options() { #[tokio::test] async fn test_09_mget_mset_and_variadic_exists_del() { - let (server, port) = start_test_server("mget_mset_variadic").await; - spawn_listener(server, port).await; - sleep(Duration::from_millis(150)).await; + let (server, port) = start_test_server("mget_mset_variadic").await; + spawn_listener(server, port).await; + sleep(Duration::from_millis(150)).await; - let mut s = connect(port).await; + let mut s = connect(port).await; - // MSET multiple keys - let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await; - assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK"); + // MSET multiple keys + let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await; + assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK"); - // MGET should return values and Null for missing - let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await; - // Expect an array with 4 entries; verify payloads - assert_contains(&mget, "v1", "MGET k1"); - assert_contains(&mget, "v2", "MGET k2"); - assert_contains(&mget, "v3", "MGET k3"); - assert_contains(&mget, "$-1", "MGET missing returns Null"); + // MGET should return values and Null for missing + let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await; + // Expect an array with 4 entries; verify payloads + assert_contains(&mget, "v1", "MGET k1"); + assert_contains(&mget, "v2", "MGET k2"); + assert_contains(&mget, "v3", "MGET k3"); + assert_contains(&mget, "$-1", "MGET missing returns Null"); - // EXISTS variadic: count how many exist - let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await; - // Server returns SimpleString numeric, e.g. +2 - assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2"); + // EXISTS variadic: count how many exist + let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await; + // Server returns SimpleString numeric, e.g. +2 + assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2"); - // DEL variadic: delete multiple keys, return count deleted - let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await; - assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2"); + // DEL variadic: delete multiple keys, return count deleted + let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await; + assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2"); - // Verify deletion - let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await; - assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0"); + // Verify deletion + let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await; + assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0"); - // MGET after deletion should include Nulls for deleted keys - let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await; - assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); - assert_contains(&mget_after, "v2", "MGET k2 remains"); - assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); + // MGET after deletion should include Nulls for deleted keys + let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await; + assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); + assert_contains(&mget_after, "v2", "MGET k2 remains"); + assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); } #[tokio::test] async fn test_12_hash_incr() { @@ -862,9 +907,16 @@ async fn test_14_expireat_pexpireat() { let mut s = connect(port).await; // EXPIREAT: seconds since epoch - let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; + let now_secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await; - let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await; + let exat = send_cmd( + &mut s, + &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)], + ) + .await; assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)"); let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await; assert!( @@ -874,12 +926,23 @@ async fn test_14_expireat_pexpireat() { ); sleep(Duration::from_millis(1200)).await; let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await; - assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0"); + assert_contains( + &exists_after_exat, + "0", + "EXISTS exp:at:s after EXPIREAT expiry -> 0", + ); // PEXPIREAT: milliseconds since epoch - let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64; + let now_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64; let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await; - let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await; + let pexat = send_cmd( + &mut s, + &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)], + ) + .await; assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)"); let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await; assert!( @@ -889,5 +952,9 @@ async fn test_14_expireat_pexpireat() { ); sleep(Duration::from_millis(600)).await; let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await; - assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0"); -} \ No newline at end of file + assert_contains( + &exists_after_pexat, + "0", + "EXISTS exp:at:ms after PEXPIREAT expiry -> 0", + ); +}