Format rust files

Signed-off-by: Lee Smet <lee.smet@hotmail.com>
This commit is contained in:
Lee Smet
2025-08-25 11:16:25 +02:00
parent e9675aafed
commit ff0659b933
25 changed files with 2267 additions and 1318 deletions

View File

@@ -14,25 +14,31 @@ fn read_reply(s: &mut TcpStream) -> String {
let n = s.read(&mut buf).unwrap(); let n = s.read(&mut buf).unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
} }
fn parse_two_bulk(reply: &str) -> Option<(String,String)> { fn parse_two_bulk(reply: &str) -> Option<(String, String)> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
if lines.next()? != "*2" { return None; } if lines.next()? != "*2" {
return None;
}
let _n = lines.next()?; let _n = lines.next()?;
let a = lines.next()?.to_string(); let a = lines.next()?.to_string();
let _m = lines.next()?; let _m = lines.next()?;
let b = lines.next()?.to_string(); let b = lines.next()?.to_string();
Some((a,b)) Some((a, b))
} }
fn parse_bulk(reply: &str) -> Option<String> { fn parse_bulk(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
let hdr = lines.next()?; let hdr = lines.next()?;
if !hdr.starts_with('$') { return None; } if !hdr.starts_with('$') {
return None;
}
Some(lines.next()?.to_string()) Some(lines.next()?.to_string())
} }
fn parse_simple(reply: &str) -> Option<String> { fn parse_simple(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
let hdr = lines.next()?; let hdr = lines.next()?;
if !hdr.starts_with('+') { return None; } if !hdr.starts_with('+') {
return None;
}
Some(hdr[1..].to_string()) Some(hdr[1..].to_string())
} }
@@ -45,39 +51,45 @@ fn main() {
let mut s = TcpStream::connect(addr).expect("connect"); let mut s = TcpStream::connect(addr).expect("connect");
// Generate & persist X25519 enc keys under name "alice" // Generate & persist X25519 enc keys under name "alice"
s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap(); s.write_all(arr(&["age", "keygen", "alice"]).as_bytes())
.unwrap();
let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc"); let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc");
// Generate & persist Ed25519 signing key under name "signer" // Generate & persist Ed25519 signing key under name "signer"
s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap(); s.write_all(arr(&["age", "signkeygen", "signer"]).as_bytes())
.unwrap();
let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign"); let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign");
// Encrypt by name // Encrypt by name
let msg = "hello from persistent keys"; let msg = "hello from persistent keys";
s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap(); s.write_all(arr(&["age", "encryptname", "alice", msg]).as_bytes())
.unwrap();
let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64"); let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64");
println!("ciphertext b64: {}", ct_b64); println!("ciphertext b64: {}", ct_b64);
// Decrypt by name // Decrypt by name
s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap(); s.write_all(arr(&["age", "decryptname", "alice", &ct_b64]).as_bytes())
.unwrap();
let pt = parse_bulk(&read_reply(&mut s)).expect("pt"); let pt = parse_bulk(&read_reply(&mut s)).expect("pt");
assert_eq!(pt, msg); assert_eq!(pt, msg);
println!("decrypted ok"); println!("decrypted ok");
// Sign by name // Sign by name
s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap(); s.write_all(arr(&["age", "signname", "signer", msg]).as_bytes())
.unwrap();
let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64"); let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64");
// Verify by name // Verify by name
s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap(); s.write_all(arr(&["age", "verifyname", "signer", msg, &sig_b64]).as_bytes())
.unwrap();
let ok = parse_simple(&read_reply(&mut s)).expect("verify"); let ok = parse_simple(&read_reply(&mut s)).expect("verify");
assert_eq!(ok, "1"); assert_eq!(ok, "1");
println!("signature verified"); println!("signature verified");
// List names // List names
s.write_all(arr(&["age","list"]).as_bytes()).unwrap(); s.write_all(arr(&["age", "list"]).as_bytes()).unwrap();
let list = read_reply(&mut s); let list = read_reply(&mut s);
println!("LIST -> {list}"); println!("LIST -> {list}");
println!("✔ persistent AGE workflow complete."); println!("✔ persistent AGE workflow complete.");
} }

View File

@@ -12,17 +12,17 @@
use std::str::FromStr; use std::str::FromStr;
use secrecy::ExposeSecret;
use age::{Decryptor, Encryptor};
use age::x25519; use age::x25519;
use age::{Decryptor, Encryptor};
use secrecy::ExposeSecret;
use ed25519_dalek::{Signature, Signer, Verifier, SigningKey, VerifyingKey}; use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _}; use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use crate::error::DBError;
use crate::protocol::Protocol; use crate::protocol::Protocol;
use crate::server::Server; use crate::server::Server;
use crate::error::DBError;
// ---------- Internal helpers ---------- // ---------- Internal helpers ----------
@@ -32,7 +32,7 @@ pub enum AgeWireError {
Crypto(String), Crypto(String),
Utf8, Utf8,
SignatureLen, SignatureLen,
NotFound(&'static str), // which kind of key was missing NotFound(&'static str), // which kind of key was missing
Storage(String), Storage(String),
} }
@@ -83,34 +83,38 @@ pub fn gen_enc_keypair() -> (String, String) {
} }
pub fn gen_sign_keypair() -> (String, String) { pub fn gen_sign_keypair() -> (String, String) {
use rand::RngCore;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rand::RngCore;
// Generate random 32 bytes for the signing key // Generate random 32 bytes for the signing key
let mut secret_bytes = [0u8; 32]; let mut secret_bytes = [0u8; 32];
OsRng.fill_bytes(&mut secret_bytes); OsRng.fill_bytes(&mut secret_bytes);
let signing_key = SigningKey::from_bytes(&secret_bytes); let signing_key = SigningKey::from_bytes(&secret_bytes);
let verifying_key = signing_key.verifying_key(); let verifying_key = signing_key.verifying_key();
// Encode as base64 for storage // Encode as base64 for storage
let signing_key_b64 = B64.encode(signing_key.to_bytes()); let signing_key_b64 = B64.encode(signing_key.to_bytes());
let verifying_key_b64 = B64.encode(verifying_key.to_bytes()); let verifying_key_b64 = B64.encode(verifying_key.to_bytes());
(verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret) (verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret)
} }
/// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext). /// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext).
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> { pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> {
let recipient = parse_recipient(recipient_str)?; let recipient = parse_recipient(recipient_str)?;
let enc = Encryptor::with_recipients(vec![Box::new(recipient)]) let enc =
.expect("failed to create encryptor"); // Handle Option<Encryptor> Encryptor::with_recipients(vec![Box::new(recipient)]).expect("failed to create encryptor"); // Handle Option<Encryptor>
let mut out = Vec::new(); let mut out = Vec::new();
{ {
use std::io::Write; use std::io::Write;
let mut w = enc.wrap_output(&mut out).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let mut w = enc
w.write_all(msg.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; .wrap_output(&mut out)
w.finish().map_err(|e| AgeWireError::Crypto(e.to_string()))?; .map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.write_all(msg.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.finish()
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
} }
Ok(B64.encode(out)) Ok(B64.encode(out))
} }
@@ -118,19 +122,27 @@ pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireErro
/// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String. /// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String.
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> { pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> {
let id = parse_identity(identity_str)?; let id = parse_identity(identity_str)?;
let ct = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let ct = B64
.decode(ct_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
// The decrypt method returns a Result<StreamReader, DecryptError> // The decrypt method returns a Result<StreamReader, DecryptError>
let mut r = match dec { let mut r = match dec {
Decryptor::Recipients(d) => d.decrypt(std::iter::once(&id as &dyn age::Identity)) Decryptor::Recipients(d) => d
.decrypt(std::iter::once(&id as &dyn age::Identity))
.map_err(|e| AgeWireError::Crypto(e.to_string()))?, .map_err(|e| AgeWireError::Crypto(e.to_string()))?,
Decryptor::Passphrase(_) => return Err(AgeWireError::Crypto("Expected recipients, got passphrase".to_string())), Decryptor::Passphrase(_) => {
return Err(AgeWireError::Crypto(
"Expected recipients, got passphrase".to_string(),
))
}
}; };
let mut pt = Vec::new(); let mut pt = Vec::new();
use std::io::Read; use std::io::Read;
r.read_to_end(&mut pt).map_err(|e| AgeWireError::Crypto(e.to_string()))?; r.read_to_end(&mut pt)
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
String::from_utf8(pt).map_err(|_| AgeWireError::Utf8) String::from_utf8(pt).map_err(|_| AgeWireError::Utf8)
} }
@@ -144,7 +156,9 @@ pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AgeWireEr
/// Verify detached signature (base64) for `msg` with pubkey. /// Verify detached signature (base64) for `msg` with pubkey.
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> { pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> {
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?; let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
let sig_bytes = B64.decode(sig_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let sig_bytes = B64
.decode(sig_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
if sig_bytes.len() != 64 { if sig_bytes.len() != 64 {
return Err(AgeWireError::SignatureLen); return Err(AgeWireError::SignatureLen);
} }
@@ -155,30 +169,49 @@ pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool
// ---------- Storage helpers ---------- // ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> { fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?; let st = server
.current_storage()
.map_err(|e| AgeWireError::Storage(e.0))?;
st.get(key).map_err(|e| AgeWireError::Storage(e.0)) st.get(key).map_err(|e| AgeWireError::Storage(e.0))
} }
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> { fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?; let st = server
st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0)) .current_storage()
.map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string())
.map_err(|e| AgeWireError::Storage(e.0))
} }
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") } fn enc_pub_key_key(name: &str) -> String {
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") } format!("age:key:{name}")
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") } }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") } fn enc_priv_key_key(name: &str) -> String {
format!("age:privkey:{name}")
}
fn sign_pub_key_key(name: &str) -> String {
format!("age:signpub:{name}")
}
fn sign_priv_key_key(name: &str) -> String {
format!("age:signpriv:{name}")
}
// ---------- Command handlers (RESP Protocol) ---------- // ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness // Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol { pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = gen_enc_keypair(); let (recip, ident) = gen_enc_keypair();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)]) Protocol::Array(vec![
Protocol::BulkString(recip),
Protocol::BulkString(ident),
])
} }
pub async fn cmd_age_gensign() -> Protocol { pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = gen_sign_keypair(); let (verify, secret) = gen_sign_keypair();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)]) Protocol::Array(vec![
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
} }
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol { pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
@@ -214,16 +247,30 @@ pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> P
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol { pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = gen_enc_keypair(); let (recip, ident) = gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); } if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) {
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); } return e.to_protocol();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)]) }
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) {
return e.to_protocol();
}
Protocol::Array(vec![
Protocol::BulkString(recip),
Protocol::BulkString(ident),
])
} }
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol { pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = gen_sign_keypair(); let (verify, secret) = gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); } if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) {
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); } return e.to_protocol();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)]) }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) {
return e.to_protocol();
}
Protocol::Array(vec![
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
} }
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol { pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
@@ -253,7 +300,9 @@ pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) ->
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol { pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) { let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v, Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(), Ok(None) => {
return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol()
}
Err(e) => return e.to_protocol(), Err(e) => return e.to_protocol(),
}; };
match sign_b64(&sec, message) { match sign_b64(&sec, message) {
@@ -262,10 +311,17 @@ pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Pr
} }
} }
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol { pub async fn cmd_age_verify_name(
server: &Server,
name: &str,
message: &str,
sig_b64: &str,
) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) { let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v, Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(), Ok(None) => {
return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol()
}
Err(e) => return e.to_protocol(), Err(e) => return e.to_protocol(),
}; };
match verify_b64(&pubk, message, sig_b64) { match verify_b64(&pubk, message, sig_b64) {
@@ -277,25 +333,43 @@ pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig
pub async fn cmd_age_list(server: &Server) -> Protocol { pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...] // Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) }; let st = match server.current_storage() {
Ok(s) => s,
Err(e) => return Protocol::err(&e.0),
};
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> { let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?; let keys = st.keys(pat)?;
let mut names: Vec<String> = keys.into_iter() let mut names: Vec<String> = keys
.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string())) .filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect(); .collect();
names.sort(); names.sort();
Ok(names) Ok(names)
}; };
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; let encpub = match pull("age:key:*", "age:key:") {
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; Ok(v) => v,
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; Err(e) => return Protocol::err(&e.0),
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; };
let encpriv = match pull("age:privkey:*", "age:privkey:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpub = match pull("age:signpub:*", "age:signpub:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpriv = match pull("age:signpriv:*", "age:signpriv:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let to_arr = |label: &str, v: Vec<String>| { let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())]; let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect())); out.push(Protocol::Array(
v.into_iter().map(Protocol::BulkString).collect(),
));
Protocol::Array(out) Protocol::Array(out)
}; };
@@ -305,4 +379,4 @@ pub async fn cmd_age_list(server: &Server) -> Protocol {
to_arr("signpub", signpub), to_arr("signpub", signpub),
to_arr("signpriv", signpriv), to_arr("signpriv", signpriv),
]) ])
} }

1116
src/cmd.rs

File diff suppressed because it is too large Load Diff

View File

@@ -11,9 +11,9 @@ const TAG_LEN: usize = 16;
#[derive(Debug)] #[derive(Debug)]
pub enum CryptoError { pub enum CryptoError {
Format, // wrong length / header Format, // wrong length / header
Version(u8), // unknown version Version(u8), // unknown version
Decrypt, // wrong key or corrupted data Decrypt, // wrong key or corrupted data
} }
impl From<CryptoError> for crate::error::DBError { impl From<CryptoError> for crate::error::DBError {
@@ -71,4 +71,4 @@ impl CryptoFactory {
let cipher = XChaCha20Poly1305::new(&self.key); let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt) cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
} }
} }

View File

@@ -1,9 +1,8 @@
use std::num::ParseIntError; use std::num::ParseIntError;
use tokio::sync::mpsc;
use redb;
use bincode; use bincode;
use redb;
use tokio::sync::mpsc;
// todo: more error types // todo: more error types
#[derive(Debug)] #[derive(Debug)]

View File

@@ -1,12 +1,12 @@
pub mod age; // NEW pub mod age; // NEW
pub mod cmd; pub mod cmd;
pub mod crypto; pub mod crypto;
pub mod error; pub mod error;
pub mod options; pub mod options;
pub mod protocol; pub mod protocol;
pub mod search_cmd; // Add this pub mod search_cmd; // Add this
pub mod server; pub mod server;
pub mod storage; pub mod storage;
pub mod storage_trait; // Add this pub mod storage_sled; // Add this
pub mod storage_sled; // Add this pub mod storage_trait; // Add this
pub mod tantivy_search; pub mod tantivy_search;

View File

@@ -22,7 +22,6 @@ struct Args {
#[arg(long)] #[arg(long)]
debug: bool, debug: bool,
/// Master encryption key for encrypted databases /// Master encryption key for encrypted databases
#[arg(long)] #[arg(long)]
encryption_key: Option<String>, encryption_key: Option<String>,

View File

@@ -81,18 +81,21 @@ impl Protocol {
pub fn encode(&self) -> String { pub fn encode(&self) -> String {
match self { match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s), Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => { Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>() format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
} }
Protocol::Null => "$-1\r\n".to_string(), Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
} }
} }
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") { match protocol.find("\r\n") {
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])), Some(x) => Ok((
Self::SimpleString(protocol[..x].to_string()),
&protocol[x + 2..],
)),
_ => Err(DBError(format!( _ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}", "[new simple string] unsupported protocol: {:?}",
protocol protocol

View File

@@ -3,8 +3,7 @@ use crate::{
protocol::Protocol, protocol::Protocol,
server::Server, server::Server,
tantivy_search::{ tantivy_search::{
TantivySearch, FieldDef, NumericType, IndexConfig, FieldDef, Filter, FilterType, IndexConfig, NumericType, SearchOptions, TantivySearch,
SearchOptions, Filter, FilterType
}, },
}; };
use std::collections::HashMap; use std::collections::HashMap;
@@ -17,14 +16,14 @@ pub async fn ft_create_cmd(
) -> Result<Protocol, DBError> { ) -> Result<Protocol, DBError> {
// Parse schema into field definitions // Parse schema into field definitions
let mut field_definitions = Vec::new(); let mut field_definitions = Vec::new();
for (field_name, field_type, options) in schema { for (field_name, field_type, options) in schema {
let field_def = match field_type.to_uppercase().as_str() { let field_def = match field_type.to_uppercase().as_str() {
"TEXT" => { "TEXT" => {
let mut weight = 1.0; let mut weight = 1.0;
let mut sortable = false; let mut sortable = false;
let mut no_index = false; let mut no_index = false;
for opt in &options { for opt in &options {
match opt.to_uppercase().as_str() { match opt.to_uppercase().as_str() {
"WEIGHT" => { "WEIGHT" => {
@@ -40,7 +39,7 @@ pub async fn ft_create_cmd(
_ => {} _ => {}
} }
} }
FieldDef::Text { FieldDef::Text {
stored: true, stored: true,
indexed: !no_index, indexed: !no_index,
@@ -50,13 +49,13 @@ pub async fn ft_create_cmd(
} }
"NUMERIC" => { "NUMERIC" => {
let mut sortable = false; let mut sortable = false;
for opt in &options { for opt in &options {
if opt.to_uppercase() == "SORTABLE" { if opt.to_uppercase() == "SORTABLE" {
sortable = true; sortable = true;
} }
} }
FieldDef::Numeric { FieldDef::Numeric {
stored: true, stored: true,
indexed: true, indexed: true,
@@ -67,7 +66,7 @@ pub async fn ft_create_cmd(
"TAG" => { "TAG" => {
let mut separator = ",".to_string(); let mut separator = ",".to_string();
let mut case_sensitive = false; let mut case_sensitive = false;
for i in 0..options.len() { for i in 0..options.len() {
match options[i].to_uppercase().as_str() { match options[i].to_uppercase().as_str() {
"SEPARATOR" => { "SEPARATOR" => {
@@ -79,44 +78,45 @@ pub async fn ft_create_cmd(
_ => {} _ => {}
} }
} }
FieldDef::Tag { FieldDef::Tag {
stored: true, stored: true,
separator, separator,
case_sensitive, case_sensitive,
} }
} }
"GEO" => { "GEO" => FieldDef::Geo { stored: true },
FieldDef::Geo { stored: true }
}
_ => { _ => {
return Err(DBError(format!("Unknown field type: {}", field_type))); return Err(DBError(format!("Unknown field type: {}", field_type)));
} }
}; };
field_definitions.push((field_name, field_def)); field_definitions.push((field_name, field_def));
} }
// Create the search index // Create the search index
let search_path = server.search_index_path(); let search_path = server.search_index_path();
let config = IndexConfig::default(); let config = IndexConfig::default();
println!("Creating search index '{}' at path: {:?}", index_name, search_path); println!(
"Creating search index '{}' at path: {:?}",
index_name, search_path
);
println!("Field definitions: {:?}", field_definitions); println!("Field definitions: {:?}", field_definitions);
let search_index = TantivySearch::new_with_schema( let search_index = TantivySearch::new_with_schema(
search_path, search_path,
index_name.clone(), index_name.clone(),
field_definitions, field_definitions,
Some(config), Some(config),
)?; )?;
println!("Search index '{}' created successfully", index_name); println!("Search index '{}' created successfully", index_name);
// Store in registry // Store in registry
let mut indexes = server.search_indexes.write().unwrap(); let mut indexes = server.search_indexes.write().unwrap();
indexes.insert(index_name, Arc::new(search_index)); indexes.insert(index_name, Arc::new(search_index));
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
} }
@@ -128,12 +128,13 @@ pub async fn ft_add_cmd(
fields: HashMap<String, String>, fields: HashMap<String, String>,
) -> Result<Protocol, DBError> { ) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let search_index = indexes.get(&index_name) let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
search_index.add_document_with_fields(&doc_id, fields)?; search_index.add_document_with_fields(&doc_id, fields)?;
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
} }
@@ -147,18 +148,20 @@ pub async fn ft_search_cmd(
return_fields: Option<Vec<String>>, return_fields: Option<Vec<String>>,
) -> Result<Protocol, DBError> { ) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let search_index = indexes.get(&index_name) let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// Convert filters to search filters // Convert filters to search filters
let search_filters = filters.into_iter().map(|(field, value)| { let search_filters = filters
Filter { .into_iter()
.map(|(field, value)| Filter {
field, field,
filter_type: FilterType::Equals(value), filter_type: FilterType::Equals(value),
} })
}).collect(); .collect();
let options = SearchOptions { let options = SearchOptions {
limit: limit.unwrap_or(10), limit: limit.unwrap_or(10),
offset: offset.unwrap_or(0), offset: offset.unwrap_or(0),
@@ -167,27 +170,27 @@ pub async fn ft_search_cmd(
return_fields, return_fields,
highlight: false, highlight: false,
}; };
let results = search_index.search_with_options(&query, options)?; let results = search_index.search_with_options(&query, options)?;
// Format results as Redis protocol // Format results as Redis protocol
let mut response = Vec::new(); let mut response = Vec::new();
// First element is the total count // First element is the total count
response.push(Protocol::SimpleString(results.total.to_string())); response.push(Protocol::SimpleString(results.total.to_string()));
// Then each document // Then each document
for doc in results.documents { for doc in results.documents {
let mut doc_array = Vec::new(); let mut doc_array = Vec::new();
// Add document ID if it exists // Add document ID if it exists
if let Some(id) = doc.fields.get("_id") { if let Some(id) = doc.fields.get("_id") {
doc_array.push(Protocol::BulkString(id.clone())); doc_array.push(Protocol::BulkString(id.clone()));
} }
// Add score // Add score
doc_array.push(Protocol::BulkString(doc.score.to_string())); doc_array.push(Protocol::BulkString(doc.score.to_string()));
// Add fields as key-value pairs // Add fields as key-value pairs
for (field_name, field_value) in doc.fields { for (field_name, field_value) in doc.fields {
if field_name != "_id" { if field_name != "_id" {
@@ -195,10 +198,10 @@ pub async fn ft_search_cmd(
doc_array.push(Protocol::BulkString(field_value)); doc_array.push(Protocol::BulkString(field_value));
} }
} }
response.push(Protocol::Array(doc_array)); response.push(Protocol::Array(doc_array));
} }
Ok(Protocol::Array(response)) Ok(Protocol::Array(response))
} }
@@ -208,56 +211,54 @@ pub async fn ft_del_cmd(
doc_id: String, doc_id: String,
) -> Result<Protocol, DBError> { ) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let _search_index = indexes.get(&index_name) let _search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// For now, return success // For now, return success
// In a full implementation, we'd need to add a delete method to TantivySearch // In a full implementation, we'd need to add a delete method to TantivySearch
println!("Deleting document '{}' from index '{}'", doc_id, index_name); println!("Deleting document '{}' from index '{}'", doc_id, index_name);
Ok(Protocol::SimpleString("1".to_string())) Ok(Protocol::SimpleString("1".to_string()))
} }
pub async fn ft_info_cmd( pub async fn ft_info_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
server: &Server,
index_name: String,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let search_index = indexes.get(&index_name) let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
let info = search_index.get_info()?; let info = search_index.get_info()?;
// Format info as Redis protocol // Format info as Redis protocol
let mut response = Vec::new(); let mut response = Vec::new();
response.push(Protocol::BulkString("index_name".to_string())); response.push(Protocol::BulkString("index_name".to_string()));
response.push(Protocol::BulkString(info.name)); response.push(Protocol::BulkString(info.name));
response.push(Protocol::BulkString("num_docs".to_string())); response.push(Protocol::BulkString("num_docs".to_string()));
response.push(Protocol::BulkString(info.num_docs.to_string())); response.push(Protocol::BulkString(info.num_docs.to_string()));
response.push(Protocol::BulkString("num_fields".to_string())); response.push(Protocol::BulkString("num_fields".to_string()));
response.push(Protocol::BulkString(info.fields.len().to_string())); response.push(Protocol::BulkString(info.fields.len().to_string()));
response.push(Protocol::BulkString("fields".to_string())); response.push(Protocol::BulkString("fields".to_string()));
let fields_str = info.fields.iter() let fields_str = info
.fields
.iter()
.map(|f| format!("{}:{}", f.name, f.field_type)) .map(|f| format!("{}:{}", f.name, f.field_type))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "); .join(", ");
response.push(Protocol::BulkString(fields_str)); response.push(Protocol::BulkString(fields_str));
Ok(Protocol::Array(response)) Ok(Protocol::Array(response))
} }
pub async fn ft_drop_cmd( pub async fn ft_drop_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
server: &Server,
index_name: String,
) -> Result<Protocol, DBError> {
let mut indexes = server.search_indexes.write().unwrap(); let mut indexes = server.search_indexes.write().unwrap();
if indexes.remove(&index_name).is_some() { if indexes.remove(&index_name).is_some() {
// Also remove the index files from disk // Also remove the index files from disk
let index_path = server.search_index_path().join(&index_name); let index_path = server.search_index_path().join(&index_name);
@@ -269,4 +270,4 @@ pub async fn ft_drop_cmd(
} else { } else {
Err(DBError(format!("Index '{}' not found", index_name))) Err(DBError(format!("Index '{}' not found", index_name)))
} }
} }

View File

@@ -1,10 +1,10 @@
use core::str; use core::str;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::RwLock;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::sync::{Mutex, oneshot}; use tokio::sync::{oneshot, Mutex};
use std::sync::RwLock;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
@@ -60,51 +60,50 @@ impl Server {
pub fn current_storage(&self) -> Result<Arc<dyn StorageBackend>, DBError> { pub fn current_storage(&self) -> Result<Arc<dyn StorageBackend>, DBError> {
let mut cache = self.db_cache.write().unwrap(); let mut cache = self.db_cache.write().unwrap();
if let Some(storage) = cache.get(&self.selected_db) { if let Some(storage) = cache.get(&self.selected_db) {
return Ok(storage.clone()); return Ok(storage.clone());
} }
// Create new database file // Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db)); .join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file // Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() { if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| { std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e)) DBError(format!(
"Failed to create directory {}: {}",
parent_dir.display(),
e
))
})?; })?;
} }
println!("Creating new db file: {}", db_file_path.display()); println!("Creating new db file: {}", db_file_path.display());
let storage: Arc<dyn StorageBackend> = match self.option.backend { let storage: Arc<dyn StorageBackend> = match self.option.backend {
options::BackendType::Redb => { options::BackendType::Redb => Arc::new(Storage::new(
Arc::new(Storage::new( db_file_path,
db_file_path, self.should_encrypt_db(self.selected_db),
self.should_encrypt_db(self.selected_db), self.option.encryption_key.as_deref(),
self.option.encryption_key.as_deref() )?),
)?) options::BackendType::Sled => Arc::new(SledStorage::new(
} db_file_path,
options::BackendType::Sled => { self.should_encrypt_db(self.selected_db),
Arc::new(SledStorage::new( self.option.encryption_key.as_deref(),
db_file_path, )?),
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?)
}
}; };
cache.insert(self.selected_db, storage.clone()); cache.insert(self.selected_db, storage.clone());
Ok(storage) Ok(storage)
} }
fn should_encrypt_db(&self, db_index: u64) -> bool { fn should_encrypt_db(&self, db_index: u64) -> bool {
// DB 0-9 are non-encrypted, DB 10+ are encrypted // DB 0-9 are non-encrypted, DB 10+ are encrypted
self.option.encrypt && db_index >= 10 self.option.encrypt && db_index >= 10
} }
// Add method to get search index path // Add method to get search index path
pub fn search_index_path(&self) -> std::path::PathBuf { pub fn search_index_path(&self) -> std::path::PathBuf {
std::path::PathBuf::from(&self.option.dir).join("search_indexes") std::path::PathBuf::from(&self.option.dir).join("search_indexes")
@@ -112,7 +111,12 @@ impl Server {
// ----- BLPOP waiter helpers ----- // ----- BLPOP waiter helpers -----
pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) { pub async fn register_waiter(
&self,
db_index: u64,
key: &str,
side: PopSide,
) -> (u64, oneshot::Receiver<(String, String)>) {
let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel::<(String, String)>(); let (tx, rx) = oneshot::channel::<(String, String)>();
@@ -188,10 +192,7 @@ impl Server {
Ok(()) Ok(())
} }
pub async fn handle( pub async fn handle(&mut self, mut stream: tokio::net::TcpStream) -> Result<(), DBError> {
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
// Accumulate incoming bytes to handle partial RESP frames // Accumulate incoming bytes to handle partial RESP frames
let mut acc = String::new(); let mut acc = String::new();
let mut buf = vec![0u8; 8192]; let mut buf = vec![0u8; 8192];
@@ -228,7 +229,10 @@ impl Server {
acc = remaining.to_string(); acc = remaining.to_string();
if self.option.debug { if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); println!(
"\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m",
cmd, protocol
);
} else { } else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol); println!("got command: {:?}, protocol: {:?}", cmd, protocol);
} }

View File

@@ -12,9 +12,9 @@ use crate::error::DBError;
// Re-export modules // Re-export modules
mod storage_basic; mod storage_basic;
mod storage_extra;
mod storage_hset; mod storage_hset;
mod storage_lists; mod storage_lists;
mod storage_extra;
// Re-export implementations // Re-export implementations
// Note: These imports are used by the impl blocks in the submodules // Note: These imports are used by the impl blocks in the submodules
@@ -28,7 +28,8 @@ const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("string
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> =
TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
@@ -55,9 +56,13 @@ pub struct Storage {
} }
impl Storage { impl Storage {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> { pub fn new(
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
let db = Database::create(path)?; let db = Database::create(path)?;
// Create tables if they don't exist // Create tables if they don't exist
let write_txn = db.begin_write()?; let write_txn = db.begin_write()?;
{ {
@@ -71,23 +76,28 @@ impl Storage {
let _ = write_txn.open_table(EXPIRATION_TABLE)?; let _ = write_txn.open_table(EXPIRATION_TABLE)?;
} }
write_txn.commit()?; write_txn.commit()?;
// Check if database was previously encrypted // Check if database was previously encrypted
let read_txn = db.begin_read()?; let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?; let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false); let was_encrypted = encrypted_table
.get("encrypted")?
.map(|v| v.value() == 1)
.unwrap_or(false);
drop(read_txn); drop(read_txn);
let crypto = if should_encrypt || was_encrypted { let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key { if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes())) Some(CryptoFactory::new(key.as_bytes()))
} else { } else {
return Err(DBError("Encryption requested but no master key provided".to_string())); return Err(DBError(
"Encryption requested but no master key provided".to_string(),
));
} }
} else { } else {
None None
}; };
// If we're enabling encryption for the first time, mark it // If we're enabling encryption for the first time, mark it
if should_encrypt && !was_encrypted { if should_encrypt && !was_encrypted {
let write_txn = db.begin_write()?; let write_txn = db.begin_write()?;
@@ -97,13 +107,10 @@ impl Storage {
} }
write_txn.commit()?; write_txn.commit()?;
} }
Ok(Storage { Ok(Storage { db, crypto })
db,
crypto,
})
} }
pub fn is_encrypted(&self) -> bool { pub fn is_encrypted(&self) -> bool {
self.crypto.is_some() self.crypto.is_some()
} }
@@ -116,7 +123,7 @@ impl Storage {
Ok(data.to_vec()) Ok(data.to_vec())
} }
} }
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> { fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto { if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?) Ok(crypto.decrypt(data)?)
@@ -165,11 +172,22 @@ impl StorageBackend for Storage {
self.get_key_type(key) self.get_key_type(key)
} }
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
self.scan(cursor, pattern, count) self.scan(cursor, pattern, count)
} }
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
self.hscan(key, cursor, pattern, count) self.hscan(key, cursor, pattern, count)
} }
@@ -276,7 +294,7 @@ impl StorageBackend for Storage {
fn is_encrypted(&self) -> bool { fn is_encrypted(&self) -> bool {
self.is_encrypted() self.is_encrypted()
} }
fn info(&self) -> Result<Vec<(String, String)>, DBError> { fn info(&self) -> Result<Vec<(String, String)>, DBError> {
self.info() self.info()
} }
@@ -284,4 +302,4 @@ impl StorageBackend for Storage {
fn clone_arc(&self) -> Arc<dyn StorageBackend> { fn clone_arc(&self) -> Arc<dyn StorageBackend> {
unimplemented!("Storage cloning not yet implemented for redb backend") unimplemented!("Storage cloning not yet implemented for redb backend")
} }
} }

View File

@@ -1,6 +1,6 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
pub fn flushdb(&self) -> Result<(), DBError> { pub fn flushdb(&self) -> Result<(), DBError> {
@@ -15,11 +15,17 @@ impl Storage {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way // inefficient, but there is no other way
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = types_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
types_table.remove(key.as_str())?; types_table.remove(key.as_str())?;
} }
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = strings_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
strings_table.remove(key.as_str())?; strings_table.remove(key.as_str())?;
} }
@@ -34,23 +40,35 @@ impl Storage {
for (key, field) in keys { for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?; hashes_table.remove((key.as_str(), field.as_str()))?;
} }
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = lists_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
lists_table.remove(key.as_str())?; lists_table.remove(key.as_str())?;
} }
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = streams_meta_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
streams_meta_table.remove(key.as_str())?; streams_meta_table.remove(key.as_str())?;
} }
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| { let keys: Vec<(String, String)> = streams_data_table
let binding = item.unwrap(); .iter()?
let (key, field) = binding.0.value(); .map(|item| {
(key.to_string(), field.to_string()) let binding = item.unwrap();
}).collect(); let (key, field) = binding.0.value();
(key.to_string(), field.to_string())
})
.collect();
for (key, field) in keys { for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?; streams_data_table.remove((key.as_str(), field.as_str()))?;
} }
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = expiration_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
expiration_table.remove(key.as_str())?; expiration_table.remove(key.as_str())?;
} }
@@ -62,7 +80,7 @@ impl Storage {
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> { pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?; let table = read_txn.open_table(TYPES_TABLE)?;
// Before returning type, check for expiration // Before returning type, check for expiration
if let Some(type_val) = table.get(key)? { if let Some(type_val) = table.get(key)? {
if type_val.value() == "string" { if type_val.value() == "string" {
@@ -83,7 +101,7 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted // ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> { pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => { Some(type_val) if type_val.value() == "string" => {
@@ -96,7 +114,7 @@ impl Storage {
return Ok(None); return Ok(None);
} }
} }
// Get and decrypt value // Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?; let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? { match strings_table.get(key)? {
@@ -115,21 +133,21 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage // ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn set(&self, key: String, value: String) -> Result<(), DBError> { pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?; types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value, not expiration // Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?; strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET // Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?; expiration_table.remove(key.as_str())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
@@ -137,41 +155,42 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage // ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?; types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value // Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?; strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted) // Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis(); let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?; expiration_table.insert(key.as_str(), &(expires_at as u64))?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
pub fn del(&self, key: String) -> Result<(), DBError> { pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table: redb::Table<(&str, &str), &[u8]> =
write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table // Remove from type table
types_table.remove(key.as_str())?; types_table.remove(key.as_str())?;
// Remove from strings table // Remove from strings table
strings_table.remove(key.as_str())?; strings_table.remove(key.as_str())?;
// Remove all hash fields for this key // Remove all hash fields for this key
let mut to_remove = Vec::new(); let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
@@ -183,19 +202,19 @@ impl Storage {
} }
} }
drop(iter); drop(iter);
for (hash_key, field) in to_remove { for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?; hashes_table.remove((hash_key.as_str(), field.as_str()))?;
} }
// Remove from lists table // Remove from lists table
lists_table.remove(key.as_str())?; lists_table.remove(key.as_str())?;
// Also remove expiration // Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?; expiration_table.remove(key.as_str())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
@@ -203,7 +222,7 @@ impl Storage {
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> { pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?; let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new(); let mut keys = Vec::new();
let mut iter = table.iter()?; let mut iter = table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
@@ -212,7 +231,7 @@ impl Storage {
keys.push(key); keys.push(key);
} }
} }
Ok(keys) Ok(keys)
} }
} }
@@ -242,4 +261,4 @@ impl Storage {
} }
Ok(count) Ok(count)
} }
} }

View File

@@ -1,24 +1,29 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { pub fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let strings_table = read_txn.open_table(STRINGS_TABLE)?; let strings_table = read_txn.open_table(STRINGS_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
let mut iter = types_table.iter()?; let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
let key = entry.0.value().to_string(); let key = entry.0.value().to_string();
let key_type = entry.1.value().to_string(); let key_type = entry.1.value().to_string();
if current_cursor >= cursor { if current_cursor >= cursor {
// Apply pattern matching if specified // Apply pattern matching if specified
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
@@ -26,7 +31,7 @@ impl Storage {
} else { } else {
true true
}; };
if matches { if matches {
// For scan, we return key-value pairs for string types // For scan, we return key-value pairs for string types
if key_type == "string" { if key_type == "string" {
@@ -41,7 +46,7 @@ impl Storage {
// For non-string types, just return the key with type as value // For non-string types, just return the key with type as value
result.push((key, key_type)); result.push((key, key_type));
} }
if result.len() >= limit { if result.len() >= limit {
break; break;
} }
@@ -49,15 +54,19 @@ impl Storage {
} }
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
pub fn ttl(&self, key: &str) -> Result<i64, DBError> { pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => { Some(type_val) if type_val.value() == "string" => {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
@@ -75,14 +84,14 @@ impl Storage {
} }
} }
Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types) Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types)
None => Ok(-2), // Key does not exist None => Ok(-2), // Key does not exist
} }
} }
pub fn exists(&self, key: &str) -> Result<bool, DBError> { pub fn exists(&self, key: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => { Some(type_val) if type_val.value() == "string" => {
// Check if string key has expired // Check if string key has expired
@@ -95,7 +104,7 @@ impl Storage {
Ok(true) Ok(true)
} }
Some(_) => Ok(true), // Key exists and is not a string Some(_) => Ok(true), // Key exists and is not a string
None => Ok(false), // Key does not exist None => Ok(false), // Key does not exist
} }
} }
@@ -178,8 +187,12 @@ impl Storage {
.unwrap_or(false); .unwrap_or(false);
if is_string { if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; let expires_at_ms: u128 = if ts_secs <= 0 {
expiration_table.insert(key, &((expires_at_ms as u64)))?; 0
} else {
(ts_secs as u128) * 1000
};
expiration_table.insert(key, &(expires_at_ms as u64))?;
applied = true; applied = true;
} }
} }
@@ -201,7 +214,7 @@ impl Storage {
if is_string { if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
expiration_table.insert(key, &((expires_at_ms as u64)))?; expiration_table.insert(key, &(expires_at_ms as u64))?;
applied = true; applied = true;
} }
} }
@@ -223,21 +236,21 @@ pub fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" { if pattern == "*" {
return true; return true;
} }
// Simple glob matching - supports * and ? wildcards // Simple glob matching - supports * and ? wildcards
let pattern_chars: Vec<char> = pattern.chars().collect(); let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect(); let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() { if pi >= pattern.len() {
return ti >= text.len(); return ti >= text.len();
} }
if ti >= text.len() { if ti >= text.len() {
// Check if remaining pattern is all '*' // Check if remaining pattern is all '*'
return pattern[pi..].iter().all(|&c| c == '*'); return pattern[pi..].iter().all(|&c| c == '*');
} }
match pattern[pi] { match pattern[pi] {
'*' => { '*' => {
// Try matching zero or more characters // Try matching zero or more characters
@@ -262,7 +275,7 @@ pub fn glob_match(pattern: &str, text: &str) -> bool {
} }
} }
} }
match_recursive(&pattern_chars, &text_chars, 0, 0) match_recursive(&pattern_chars, &text_chars, 0, 0)
} }
@@ -283,4 +296,4 @@ mod tests {
assert!(glob_match("*test*", "this_is_a_test_string")); assert!(glob_match("*test*", "this_is_a_test_string"));
assert!(!glob_match("*test*", "this_is_a_string")); assert!(!glob_match("*test*", "this_is_a_string"));
} }
} }

View File

@@ -1,44 +1,50 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage // ✅ ENCRYPTION APPLIED: Values are encrypted before storage
pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> { pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut new_fields = 0i64; let mut new_fields = 0i64;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string()) access_guard.map(|v| v.value().to_string())
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") | None => { // Proceed if hash or new key Some("hash") | None => {
// Proceed if hash or new key
// Set the type to hash (only if new key or existing hash) // Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?; types_table.insert(key, "hash")?;
for (field, value) in pairs { for (field, value) in pairs {
// Check if field already exists // Check if field already exists
let exists = hashes_table.get((key, field.as_str()))?.is_some(); let exists = hashes_table.get((key, field.as_str()))?.is_some();
// Encrypt the value before storing // Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?; hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !exists { if !exists {
new_fields += 1; new_fields += 1;
} }
} }
} }
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(new_fields) Ok(new_fields)
} }
@@ -47,7 +53,7 @@ impl Storage {
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> { pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = types_table.get(key)?.map(|v| v.value().to_string()); let key_type = types_table.get(key)?.map(|v| v.value().to_string());
match key_type.as_deref() { match key_type.as_deref() {
@@ -62,7 +68,9 @@ impl Storage {
None => Ok(None), None => Ok(None),
} }
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(None), None => Ok(None),
} }
} }
@@ -80,7 +88,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -91,10 +99,12 @@ impl Storage {
result.push((field.to_string(), value)); result.push((field.to_string(), value));
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -102,24 +112,24 @@ impl Storage {
pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> { pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut deleted = 0i64; let mut deleted = 0i64;
// First check if key exists and is a hash // First check if key exists and is a hash
let key_type = { let key_type = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string()) access_guard.map(|v| v.value().to_string())
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") => { Some("hash") => {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields { for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() { if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1; deleted += 1;
} }
} }
// Check if hash is now empty and remove type if so // Check if hash is now empty and remove type if so
let mut has_fields = false; let mut has_fields = false;
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
@@ -132,16 +142,20 @@ impl Storage {
} }
} }
drop(iter); drop(iter);
if !has_fields { if !has_fields {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?; types_table.remove(key)?;
} }
} }
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
None => {} // Key does not exist, nothing to delete, return 0 deleted None => {} // Key does not exist, nothing to delete, return 0 deleted
} }
write_txn.commit()?; write_txn.commit()?;
Ok(deleted) Ok(deleted)
} }
@@ -159,7 +173,9 @@ impl Storage {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some()) Ok(hashes_table.get((key, field))?.is_some())
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(false), None => Ok(false),
} }
} }
@@ -176,7 +192,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -185,10 +201,12 @@ impl Storage {
result.push(field.to_string()); result.push(field.to_string());
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -206,7 +224,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -217,10 +235,12 @@ impl Storage {
result.push(value); result.push(value);
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -237,7 +257,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0i64; let mut count = 0i64;
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -246,10 +266,12 @@ impl Storage {
count += 1; count += 1;
} }
} }
Ok(count) Ok(count)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(0), None => Ok(0),
} }
} }
@@ -267,7 +289,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
for field in fields { for field in fields {
match hashes_table.get((key, field.as_str()))? { match hashes_table.get((key, field.as_str()))? {
Some(data) => { Some(data) => {
@@ -278,10 +300,12 @@ impl Storage {
None => result.push(None), None => result.push(None),
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(fields.into_iter().map(|_| None).collect()), None => Ok(fields.into_iter().map(|_| None).collect()),
} }
} }
@@ -290,39 +314,51 @@ impl Storage {
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> { pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut result = false; let mut result = false;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string()) access_guard.map(|v| v.value().to_string())
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") | None => { // Proceed if hash or new key Some("hash") | None => {
// Proceed if hash or new key
// Check if field already exists // Check if field already exists
if hashes_table.get((key, field))?.is_none() { if hashes_table.get((key, field))?.is_none() {
// Set the type to hash (only if new key or existing hash) // Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?; types_table.insert(key, "hash")?;
// Encrypt the value before storing // Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted.as_slice())?; hashes_table.insert((key, field), encrypted.as_slice())?;
result = true; result = true;
} }
} }
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(result) Ok(result)
} }
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { pub fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
@@ -336,28 +372,28 @@ impl Storage {
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
let (hash_key, field) = entry.0.value(); let (hash_key, field) = entry.0.value();
if hash_key == key { if hash_key == key {
if current_cursor >= cursor { if current_cursor >= cursor {
let field_str = field.to_string(); let field_str = field.to_string();
// Apply pattern matching if specified // Apply pattern matching if specified
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
super::storage_extra::glob_match(pat, &field_str) super::storage_extra::glob_match(pat, &field_str)
} else { } else {
true true
}; };
if matches { if matches {
let decrypted = self.decrypt_if_needed(entry.1.value())?; let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?; let value = String::from_utf8(decrypted)?;
result.push((field_str, value)); result.push((field_str, value));
if result.len() >= limit { if result.len() >= limit {
break; break;
} }
@@ -366,12 +402,18 @@ impl Storage {
current_cursor += 1; current_cursor += 1;
} }
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok((0, Vec::new())), None => Ok((0, Vec::new())),
} }
} }
} }

View File

@@ -1,20 +1,20 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage // ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut _length = 0i64; let mut _length = 0i64;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list // Set the type to list
types_table.insert(key, "list")?; types_table.insert(key, "list")?;
// Get current list or create empty one // Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? { let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => { Some(data) => {
@@ -23,20 +23,20 @@ impl Storage {
} }
None => Vec::new(), None => Vec::new(),
}; };
// Add elements to the front (left) // Add elements to the front (left)
for element in elements.into_iter() { for element in elements.into_iter() {
list.insert(0, element); list.insert(0, element);
} }
_length = list.len() as i64; _length = list.len() as i64;
// Encrypt and store the updated list // Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?; let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?; let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(_length) Ok(_length)
} }
@@ -45,14 +45,14 @@ impl Storage {
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut _length = 0i64; let mut _length = 0i64;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list // Set the type to list
types_table.insert(key, "list")?; types_table.insert(key, "list")?;
// Get current list or create empty one // Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? { let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => { Some(data) => {
@@ -61,17 +61,17 @@ impl Storage {
} }
None => Vec::new(), None => Vec::new(),
}; };
// Add elements to the end (right) // Add elements to the end (right)
list.extend(elements); list.extend(elements);
_length = list.len() as i64; _length = list.len() as i64;
// Encrypt and store the updated list // Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?; let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?; let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(_length) Ok(_length)
} }
@@ -80,12 +80,12 @@ impl Storage {
pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut result = Vec::new(); let mut result = Vec::new();
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -100,7 +100,7 @@ impl Storage {
}; };
result result
}; };
if let Some(mut list) = list_data { if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len()); let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count { for _ in 0..pop_count {
@@ -108,7 +108,7 @@ impl Storage {
result.push(list.remove(0)); result.push(list.remove(0));
} }
} }
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() { if list.is_empty() {
// Remove the key if list is empty // Remove the key if list is empty
@@ -122,7 +122,7 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(result) Ok(result)
} }
@@ -131,12 +131,12 @@ impl Storage {
pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut result = Vec::new(); let mut result = Vec::new();
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -151,7 +151,7 @@ impl Storage {
}; };
result result
}; };
if let Some(mut list) = list_data { if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len()); let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count { for _ in 0..pop_count {
@@ -159,7 +159,7 @@ impl Storage {
result.push(list.pop().unwrap()); result.push(list.pop().unwrap());
} }
} }
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() { if list.is_empty() {
// Remove the key if list is empty // Remove the key if list is empty
@@ -173,7 +173,7 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(result) Ok(result)
} }
@@ -181,7 +181,7 @@ impl Storage {
pub fn llen(&self, key: &str) -> Result<i64, DBError> { pub fn llen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?; let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -202,7 +202,7 @@ impl Storage {
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> { pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?; let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -210,13 +210,13 @@ impl Storage {
Some(data) => { Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?; let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?; let list: Vec<String> = serde_json::from_slice(&decrypted)?;
let actual_index = if index < 0 { let actual_index = if index < 0 {
list.len() as i64 + index list.len() as i64 + index
} else { } else {
index index
}; };
if actual_index >= 0 && (actual_index as usize) < list.len() { if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone())) Ok(Some(list[actual_index as usize].clone()))
} else { } else {
@@ -234,7 +234,7 @@ impl Storage {
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> { pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?; let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -242,22 +242,30 @@ impl Storage {
Some(data) => { Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?; let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?; let list: Vec<String> = serde_json::from_slice(&decrypted)?;
if list.is_empty() { if list.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; let start_idx = if start < 0 {
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
} }
None => Ok(Vec::new()), None => Ok(Vec::new()),
@@ -270,12 +278,12 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -290,17 +298,25 @@ impl Storage {
}; };
result result
}; };
if let Some(list) = list_data { if let Some(list) = list_data {
if list.is_empty() { if list.is_empty() {
write_txn.commit()?; write_txn.commit()?;
return Ok(()); return Ok(());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; let start_idx = if start < 0 {
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
// Remove the entire list // Remove the entire list
@@ -311,7 +327,7 @@ impl Storage {
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if trimmed.is_empty() { if trimmed.is_empty() {
lists_table.remove(key)?; lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
@@ -324,7 +340,7 @@ impl Storage {
} }
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
@@ -333,12 +349,12 @@ impl Storage {
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> { pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut removed = 0i64; let mut removed = 0i64;
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -353,7 +369,7 @@ impl Storage {
}; };
result result
}; };
if let Some(mut list) = list_data { if let Some(mut list) = list_data {
if count == 0 { if count == 0 {
// Remove all occurrences // Remove all occurrences
@@ -383,7 +399,7 @@ impl Storage {
} }
} }
} }
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() { if list.is_empty() {
lists_table.remove(key)?; lists_table.remove(key)?;
@@ -396,8 +412,8 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(removed) Ok(removed)
} }
} }

View File

@@ -1,12 +1,12 @@
// src/storage_sled/mod.rs // src/storage_sled/mod.rs
use std::path::Path; use crate::crypto::CryptoFactory;
use std::sync::Arc;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::error::DBError; use crate::error::DBError;
use crate::storage_trait::StorageBackend; use crate::storage_trait::StorageBackend;
use crate::crypto::CryptoFactory; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
enum ValueType { enum ValueType {
@@ -28,44 +28,56 @@ pub struct SledStorage {
} }
impl SledStorage { impl SledStorage {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> { pub fn new(
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?; let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?;
let types = db.open_tree("types").map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?; let types = db
.open_tree("types")
.map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?;
// Check if database was previously encrypted // Check if database was previously encrypted
let encrypted_tree = db.open_tree("encrypted").map_err(|e| DBError(e.to_string()))?; let encrypted_tree = db
let was_encrypted = encrypted_tree.get("encrypted") .open_tree("encrypted")
.map_err(|e| DBError(e.to_string()))?;
let was_encrypted = encrypted_tree
.get("encrypted")
.map_err(|e| DBError(e.to_string()))? .map_err(|e| DBError(e.to_string()))?
.map(|v| v[0] == 1) .map(|v| v[0] == 1)
.unwrap_or(false); .unwrap_or(false);
let crypto = if should_encrypt || was_encrypted { let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key { if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes())) Some(CryptoFactory::new(key.as_bytes()))
} else { } else {
return Err(DBError("Encryption requested but no master key provided".to_string())); return Err(DBError(
"Encryption requested but no master key provided".to_string(),
));
} }
} else { } else {
None None
}; };
// Mark database as encrypted if enabling encryption // Mark database as encrypted if enabling encryption
if should_encrypt && !was_encrypted { if should_encrypt && !was_encrypted {
encrypted_tree.insert("encrypted", &[1u8]) encrypted_tree
.insert("encrypted", &[1u8])
.map_err(|e| DBError(e.to_string()))?; .map_err(|e| DBError(e.to_string()))?;
encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?; encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(SledStorage { db, types, crypto }) Ok(SledStorage { db, types, crypto })
} }
fn now_millis() -> u128 { fn now_millis() -> u128 {
SystemTime::now() SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
.as_millis() .as_millis()
} }
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> { fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto { if let Some(crypto) = &self.crypto {
Ok(crypto.encrypt(data)) Ok(crypto.encrypt(data))
@@ -73,7 +85,7 @@ impl SledStorage {
Ok(data.to_vec()) Ok(data.to_vec())
} }
} }
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> { fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto { if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?) Ok(crypto.decrypt(data)?)
@@ -81,14 +93,14 @@ impl SledStorage {
Ok(data.to_vec()) Ok(data.to_vec())
} }
} }
fn get_storage_value(&self, key: &str) -> Result<Option<StorageValue>, DBError> { fn get_storage_value(&self, key: &str) -> Result<Option<StorageValue>, DBError> {
match self.db.get(key).map_err(|e| DBError(e.to_string()))? { match self.db.get(key).map_err(|e| DBError(e.to_string()))? {
Some(encrypted_data) => { Some(encrypted_data) => {
let decrypted = self.decrypt_if_needed(&encrypted_data)?; let decrypted = self.decrypt_if_needed(&encrypted_data)?;
let storage_val: StorageValue = bincode::deserialize(&decrypted) let storage_val: StorageValue = bincode::deserialize(&decrypted)
.map_err(|e| DBError(format!("Deserialization error: {}", e)))?; .map_err(|e| DBError(format!("Deserialization error: {}", e)))?;
// Check expiration // Check expiration
if let Some(expires_at) = storage_val.expires_at { if let Some(expires_at) = storage_val.expires_at {
if Self::now_millis() > expires_at { if Self::now_millis() > expires_at {
@@ -98,47 +110,51 @@ impl SledStorage {
return Ok(None); return Ok(None);
} }
} }
Ok(Some(storage_val)) Ok(Some(storage_val))
} }
None => Ok(None) None => Ok(None),
} }
} }
fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> { fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> {
let data = bincode::serialize(&storage_val) let data = bincode::serialize(&storage_val)
.map_err(|e| DBError(format!("Serialization error: {}", e)))?; .map_err(|e| DBError(format!("Serialization error: {}", e)))?;
let encrypted = self.encrypt_if_needed(&data)?; let encrypted = self.encrypt_if_needed(&data)?;
self.db.insert(key, encrypted).map_err(|e| DBError(e.to_string()))?; self.db
.insert(key, encrypted)
.map_err(|e| DBError(e.to_string()))?;
// Store type info (unencrypted for efficiency) // Store type info (unencrypted for efficiency)
let type_str = match &storage_val.value { let type_str = match &storage_val.value {
ValueType::String(_) => "string", ValueType::String(_) => "string",
ValueType::Hash(_) => "hash", ValueType::Hash(_) => "hash",
ValueType::List(_) => "list", ValueType::List(_) => "list",
}; };
self.types.insert(key, type_str.as_bytes()).map_err(|e| DBError(e.to_string()))?; self.types
.insert(key, type_str.as_bytes())
.map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn glob_match(pattern: &str, text: &str) -> bool { fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" { if pattern == "*" {
return true; return true;
} }
let pattern_chars: Vec<char> = pattern.chars().collect(); let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect(); let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() { if pi >= pattern.len() {
return ti >= text.len(); return ti >= text.len();
} }
if ti >= text.len() { if ti >= text.len() {
return pattern[pi..].iter().all(|&c| c == '*'); return pattern[pi..].iter().all(|&c| c == '*');
} }
match pattern[pi] { match pattern[pi] {
'*' => { '*' => {
for i in ti..=text.len() { for i in ti..=text.len() {
@@ -158,7 +174,7 @@ impl SledStorage {
} }
} }
} }
match_recursive(&pattern_chars, &text_chars, 0, 0) match_recursive(&pattern_chars, &text_chars, 0, 0)
} }
} }
@@ -168,12 +184,12 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::String(s) => Ok(Some(s)), ValueType::String(s) => Ok(Some(s)),
_ => Ok(None) _ => Ok(None),
} },
None => Ok(None) None => Ok(None),
} }
} }
fn set(&self, key: String, value: String) -> Result<(), DBError> { fn set(&self, key: String, value: String) -> Result<(), DBError> {
let storage_val = StorageValue { let storage_val = StorageValue {
value: ValueType::String(value), value: ValueType::String(value),
@@ -183,7 +199,7 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let storage_val = StorageValue { let storage_val = StorageValue {
value: ValueType::String(value), value: ValueType::String(value),
@@ -193,25 +209,27 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn del(&self, key: String) -> Result<(), DBError> { fn del(&self, key: String) -> Result<(), DBError> {
self.db.remove(&key).map_err(|e| DBError(e.to_string()))?; self.db.remove(&key).map_err(|e| DBError(e.to_string()))?;
self.types.remove(&key).map_err(|e| DBError(e.to_string()))?; self.types
.remove(&key)
.map_err(|e| DBError(e.to_string()))?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn exists(&self, key: &str) -> Result<bool, DBError> { fn exists(&self, key: &str) -> Result<bool, DBError> {
// Check with expiration // Check with expiration
Ok(self.get_storage_value(key)?.is_some()) Ok(self.get_storage_value(key)?.is_some())
} }
fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> { fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let mut keys = Vec::new(); let mut keys = Vec::new();
for item in self.types.iter() { for item in self.types.iter() {
let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?; let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?;
let key = String::from_utf8_lossy(&key_bytes).to_string(); let key = String::from_utf8_lossy(&key_bytes).to_string();
// Check if key is expired // Check if key is expired
if self.get_storage_value(&key)?.is_some() { if self.get_storage_value(&key)?.is_some() {
if Self::glob_match(pattern, &key) { if Self::glob_match(pattern, &key) {
@@ -221,24 +239,29 @@ impl StorageBackend for SledStorage {
} }
Ok(keys) Ok(keys)
} }
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
for item in self.types.iter() { for item in self.types.iter() {
if current_cursor >= cursor { if current_cursor >= cursor {
let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?; let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?;
let key = String::from_utf8_lossy(&key_bytes).to_string(); let key = String::from_utf8_lossy(&key_bytes).to_string();
// Check pattern match // Check pattern match
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
Self::glob_match(pat, &key) Self::glob_match(pat, &key)
} else { } else {
true true
}; };
if matches { if matches {
// Check if key is expired and get value // Check if key is expired and get value
if let Some(storage_val) = self.get_storage_value(&key)? { if let Some(storage_val) = self.get_storage_value(&key)? {
@@ -247,7 +270,7 @@ impl StorageBackend for SledStorage {
_ => String::from_utf8_lossy(&type_bytes).to_string(), _ => String::from_utf8_lossy(&type_bytes).to_string(),
}; };
result.push((key, value)); result.push((key, value));
if result.len() >= limit { if result.len() >= limit {
current_cursor += 1; current_cursor += 1;
break; break;
@@ -257,11 +280,15 @@ impl StorageBackend for SledStorage {
} }
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
fn dbsize(&self) -> Result<i64, DBError> { fn dbsize(&self) -> Result<i64, DBError> {
let mut count = 0i64; let mut count = 0i64;
for item in self.types.iter() { for item in self.types.iter() {
@@ -273,38 +300,42 @@ impl StorageBackend for SledStorage {
} }
Ok(count) Ok(count)
} }
fn flushdb(&self) -> Result<(), DBError> { fn flushdb(&self) -> Result<(), DBError> {
self.db.clear().map_err(|e| DBError(e.to_string()))?; self.db.clear().map_err(|e| DBError(e.to_string()))?;
self.types.clear().map_err(|e| DBError(e.to_string()))?; self.types.clear().map_err(|e| DBError(e.to_string()))?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> { fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
// First check if key exists (handles expiration) // First check if key exists (handles expiration)
if self.get_storage_value(key)?.is_some() { if self.get_storage_value(key)?.is_some() {
match self.types.get(key).map_err(|e| DBError(e.to_string()))? { match self.types.get(key).map_err(|e| DBError(e.to_string()))? {
Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())), Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())),
None => Ok(None) None => Ok(None),
} }
} else { } else {
Ok(None) Ok(None)
} }
} }
// Hash operations // Hash operations
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> { fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::Hash(HashMap::new()), value: ValueType::Hash(HashMap::new()),
expires_at: None, expires_at: None,
}); });
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
let mut new_fields = 0i64; let mut new_fields = 0i64;
for (field, value) in pairs { for (field, value) in pairs {
if !hash.contains_key(&field) { if !hash.contains_key(&field) {
@@ -312,40 +343,46 @@ impl StorageBackend for SledStorage {
} }
hash.insert(field, value); hash.insert(field, value);
} }
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(new_fields) Ok(new_fields)
} }
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> { fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.get(field).cloned()), ValueType::Hash(h) => Ok(h.get(field).cloned()),
_ => Ok(None) _ => Ok(None),
} },
None => Ok(None) None => Ok(None),
} }
} }
fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> { fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.into_iter().collect()), ValueType::Hash(h) => Ok(h.into_iter().collect()),
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => { ValueType::Hash(h) => {
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
for (field, value) in h.iter() { for (field, value) in h.iter() {
if current_cursor >= cursor { if current_cursor >= cursor {
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
@@ -353,7 +390,7 @@ impl StorageBackend for SledStorage {
} else { } else {
true true
}; };
if matches { if matches {
result.push((field.clone(), value.clone())); result.push((field.clone(), value.clone()));
if result.len() >= limit { if result.len() >= limit {
@@ -364,107 +401,115 @@ impl StorageBackend for SledStorage {
} }
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
_ => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())) _ => Err(DBError(
} "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
None => Ok((0, Vec::new())) )),
},
None => Ok((0, Vec::new())),
} }
} }
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> { fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(0) None => return Ok(0),
}; };
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => return Ok(0) _ => return Ok(0),
}; };
let mut deleted = 0i64; let mut deleted = 0i64;
for field in fields { for field in fields {
if hash.remove(&field).is_some() { if hash.remove(&field).is_some() {
deleted += 1; deleted += 1;
} }
} }
if hash.is_empty() { if hash.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(deleted) Ok(deleted)
} }
fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> { fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.contains_key(field)), ValueType::Hash(h) => Ok(h.contains_key(field)),
_ => Ok(false) _ => Ok(false),
} },
None => Ok(false) None => Ok(false),
} }
} }
fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> { fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.keys().cloned().collect()), ValueType::Hash(h) => Ok(h.keys().cloned().collect()),
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> { fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.values().cloned().collect()), ValueType::Hash(h) => Ok(h.values().cloned().collect()),
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
fn hlen(&self, key: &str) -> Result<i64, DBError> { fn hlen(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.len() as i64), ValueType::Hash(h) => Ok(h.len() as i64),
_ => Ok(0) _ => Ok(0),
} },
None => Ok(0) None => Ok(0),
} }
} }
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> { fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => { ValueType::Hash(h) => Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()),
Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()) _ => Ok(fields.into_iter().map(|_| None).collect()),
} },
_ => Ok(fields.into_iter().map(|_| None).collect()) None => Ok(fields.into_iter().map(|_| None).collect()),
}
None => Ok(fields.into_iter().map(|_| None).collect())
} }
} }
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> { fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::Hash(HashMap::new()), value: ValueType::Hash(HashMap::new()),
expires_at: None, expires_at: None,
}); });
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
if hash.contains_key(field) { if hash.contains_key(field) {
Ok(false) Ok(false)
} else { } else {
@@ -474,58 +519,66 @@ impl StorageBackend for SledStorage {
Ok(true) Ok(true)
} }
} }
// List operations // List operations
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::List(Vec::new()), value: ValueType::List(Vec::new()),
expires_at: None, expires_at: None,
}); });
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
for element in elements.into_iter().rev() { for element in elements.into_iter().rev() {
list.insert(0, element); list.insert(0, element);
} }
let len = list.len() as i64; let len = list.len() as i64;
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(len) Ok(len)
} }
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::List(Vec::new()), value: ValueType::List(Vec::new()),
expires_at: None, expires_at: None,
}); });
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
list.extend(elements); list.extend(elements);
let len = list.len() as i64; let len = list.len() as i64;
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(len) Ok(len)
} }
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(Vec::new()) None => return Ok(Vec::new()),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(Vec::new()) _ => return Ok(Vec::new()),
}; };
let mut result = Vec::new(); let mut result = Vec::new();
for _ in 0..count.min(list.len() as u64) { for _ in 0..count.min(list.len() as u64) {
if let Some(elem) = list.first() { if let Some(elem) = list.first() {
@@ -533,55 +586,55 @@ impl StorageBackend for SledStorage {
list.remove(0); list.remove(0);
} }
} }
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(result) Ok(result)
} }
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(Vec::new()) None => return Ok(Vec::new()),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(Vec::new()) _ => return Ok(Vec::new()),
}; };
let mut result = Vec::new(); let mut result = Vec::new();
for _ in 0..count.min(list.len() as u64) { for _ in 0..count.min(list.len() as u64) {
if let Some(elem) = list.pop() { if let Some(elem) = list.pop() {
result.push(elem); result.push(elem);
} }
} }
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(result) Ok(result)
} }
fn llen(&self, key: &str) -> Result<i64, DBError> { fn llen(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::List(l) => Ok(l.len() as i64), ValueType::List(l) => Ok(l.len() as i64),
_ => Ok(0) _ => Ok(0),
} },
None => Ok(0) None => Ok(0),
} }
} }
fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> { fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
@@ -591,19 +644,19 @@ impl StorageBackend for SledStorage {
} else { } else {
index index
}; };
if actual_index >= 0 && (actual_index as usize) < list.len() { if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone())) Ok(Some(list[actual_index as usize].clone()))
} else { } else {
Ok(None) Ok(None)
} }
} }
_ => Ok(None) _ => Ok(None),
} },
None => Ok(None) None => Ok(None),
} }
} }
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> { fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
@@ -611,68 +664,68 @@ impl StorageBackend for SledStorage {
if list.is_empty() { if list.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { let start_idx = if start < 0 {
std::cmp::max(0, len + start) std::cmp::max(0, len + start)
} else { } else {
std::cmp::min(start, len) std::cmp::min(start, len)
}; };
let stop_idx = if stop < 0 { let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop) std::cmp::max(-1, len + stop)
} else { } else {
std::cmp::min(stop, len - 1) std::cmp::min(stop, len - 1)
}; };
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
} }
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(()) None => return Ok(()),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(()) _ => return Ok(()),
}; };
if list.is_empty() { if list.is_empty() {
return Ok(()); return Ok(());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { let start_idx = if start < 0 {
std::cmp::max(0, len + start) std::cmp::max(0, len + start)
} else { } else {
std::cmp::min(start, len) std::cmp::min(start, len)
}; };
let stop_idx = if stop < 0 { let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop) std::cmp::max(-1, len + stop)
} else { } else {
std::cmp::min(stop, len - 1) std::cmp::min(stop, len - 1)
}; };
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
*list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); *list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
@@ -680,23 +733,23 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
} }
Ok(()) Ok(())
} }
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> { fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(0) None => return Ok(0),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(0) _ => return Ok(0),
}; };
let mut removed = 0i64; let mut removed = 0i64;
if count == 0 { if count == 0 {
// Remove all occurrences // Remove all occurrences
let original_len = list.len(); let original_len = list.len();
@@ -725,17 +778,17 @@ impl StorageBackend for SledStorage {
} }
} }
} }
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(removed) Ok(removed)
} }
// Expiration // Expiration
fn ttl(&self, key: &str) -> Result<i64, DBError> { fn ttl(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
@@ -751,40 +804,40 @@ impl StorageBackend for SledStorage {
Ok(-1) // Key exists but has no expiration Ok(-1) // Key exists but has no expiration
} }
} }
None => Ok(-2) // Key does not exist None => Ok(-2), // Key does not exist
} }
} }
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> { fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000); storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> { fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
storage_val.expires_at = Some(Self::now_millis() + ms); storage_val.expires_at = Some(Self::now_millis() + ms);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn persist(&self, key: &str) -> Result<bool, DBError> { fn persist(&self, key: &str) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
if storage_val.expires_at.is_some() { if storage_val.expires_at.is_some() {
storage_val.expires_at = None; storage_val.expires_at = None;
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
@@ -794,37 +847,41 @@ impl StorageBackend for SledStorage {
Ok(false) Ok(false)
} }
} }
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> { fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
};
let expires_at_ms: u128 = if ts_secs <= 0 {
0
} else {
(ts_secs as u128) * 1000
}; };
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
storage_val.expires_at = Some(expires_at_ms); storage_val.expires_at = Some(expires_at_ms);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> { fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
storage_val.expires_at = Some(expires_at_ms); storage_val.expires_at = Some(expires_at_ms);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn is_encrypted(&self) -> bool { fn is_encrypted(&self) -> bool {
self.crypto.is_some() self.crypto.is_some()
} }
fn info(&self) -> Result<Vec<(String, String)>, DBError> { fn info(&self) -> Result<Vec<(String, String)>, DBError> {
let dbsize = self.dbsize()?; let dbsize = self.dbsize()?;
Ok(vec![ Ok(vec![
@@ -842,4 +899,4 @@ impl StorageBackend for SledStorage {
crypto: self.crypto.clone(), crypto: self.crypto.clone(),
}) })
} }
} }

View File

@@ -13,11 +13,22 @@ pub trait StorageBackend: Send + Sync {
fn dbsize(&self) -> Result<i64, DBError>; fn dbsize(&self) -> Result<i64, DBError>;
fn flushdb(&self) -> Result<(), DBError>; fn flushdb(&self) -> Result<(), DBError>;
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>; fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>;
// Scanning // Scanning
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>; fn scan(
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>; &self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
// Hash operations // Hash operations
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>; fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>;
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError>; fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError>;
@@ -29,7 +40,7 @@ pub trait StorageBackend: Send + Sync {
fn hlen(&self, key: &str) -> Result<i64, DBError>; fn hlen(&self, key: &str) -> Result<i64, DBError>;
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError>; fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError>;
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError>; fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError>;
// List operations // List operations
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>; fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>; fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
@@ -40,7 +51,7 @@ pub trait StorageBackend: Send + Sync {
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError>; fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError>;
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError>; fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError>;
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError>; fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError>;
// Expiration // Expiration
fn ttl(&self, key: &str) -> Result<i64, DBError>; fn ttl(&self, key: &str) -> Result<i64, DBError>;
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError>; fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError>;
@@ -48,11 +59,11 @@ pub trait StorageBackend: Send + Sync {
fn persist(&self, key: &str) -> Result<bool, DBError>; fn persist(&self, key: &str) -> Result<bool, DBError>;
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError>; fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError>;
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError>; fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError>;
// Metadata // Metadata
fn is_encrypted(&self) -> bool; fn is_encrypted(&self) -> bool;
fn info(&self) -> Result<Vec<(String, String)>, DBError>; fn info(&self) -> Result<Vec<(String, String)>, DBError>;
// Clone to Arc for sharing // Clone to Arc for sharing
fn clone_arc(&self) -> Arc<dyn StorageBackend>; fn clone_arc(&self) -> Arc<dyn StorageBackend>;
} }

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -7,7 +7,7 @@ use tokio::time::sleep;
// Helper function to send command and get response // Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String { async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -19,7 +19,7 @@ async fn debug_hset_simple() {
let test_dir = "/tmp/herodb_debug_hset"; let test_dir = "/tmp/herodb_debug_hset";
let _ = std::fs::remove_dir_all(test_dir); let _ = std::fs::remove_dir_all(test_dir);
std::fs::create_dir_all(test_dir).unwrap(); std::fs::create_dir_all(test_dir).unwrap();
let port = 16500; let port = 16500;
let option = DBOption { let option = DBOption {
dir: test_dir.to_string(), dir: test_dir.to_string(),
@@ -29,35 +29,49 @@ async fn debug_hset_simple() {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let mut server = Server::new(option).await; let mut server = Server::new(option).await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Test simple HSET // Test simple HSET
println!("Testing HSET..."); println!("Testing HSET...");
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
println!("HSET response: {}", response); println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected '1' but got: {}", response); assert!(response.contains("1"), "Expected '1' but got: {}", response);
// Test HGET // Test HGET
println!("Testing HGET..."); println!("Testing HGET...");
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HGET response: {}", response); println!("HGET response: {}", response);
assert!(response.contains("value1"), "Expected 'value1' but got: {}", response); assert!(
} response.contains("value1"),
"Expected 'value1' but got: {}",
response
);
}

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -7,11 +7,11 @@ use tokio::time::sleep;
#[tokio::test] #[tokio::test]
async fn debug_hset_return_value() { async fn debug_hset_return_value() {
let test_dir = "/tmp/herodb_debug_hset_return"; let test_dir = "/tmp/herodb_debug_hset_return";
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir.to_string(), dir: test_dir.to_string(),
port: 16390, port: 16390,
@@ -20,38 +20,42 @@ async fn debug_hset_return_value() {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let mut server = Server::new(option).await; let mut server = Server::new(option).await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:16390") let listener = tokio::net::TcpListener::bind("127.0.0.1:16390")
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
// Connect and test HSET // Connect and test HSET
let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap(); let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap();
// Send HSET command // Send HSET command
let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n"; let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n";
stream.write_all(cmd.as_bytes()).await.unwrap(); stream.write_all(cmd.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
println!("HSET response: {}", response); println!("HSET response: {}", response);
println!("Response bytes: {:?}", &buffer[..n]); println!("Response bytes: {:?}", &buffer[..n]);
// Check if response contains "1" // Check if response contains "1"
assert!(response.contains("1"), "Expected response to contain '1', got: {}", response); assert!(
} response.contains("1"),
"Expected response to contain '1', got: {}",
response
);
}

View File

@@ -1,12 +1,15 @@
use herodb::protocol::Protocol;
use herodb::cmd::Cmd; use herodb::cmd::Cmd;
use herodb::protocol::Protocol;
#[test] #[test]
fn test_protocol_parsing() { fn test_protocol_parsing() {
// Test TYPE command parsing // Test TYPE command parsing
let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n"; let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n";
println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n")); println!(
"Parsing TYPE command: {}",
type_cmd.replace("\r\n", "\\r\\n")
);
match Protocol::from(type_cmd) { match Protocol::from(type_cmd) {
Ok((protocol, _)) => { Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol); println!("Protocol parsed successfully: {:?}", protocol);
@@ -17,11 +20,14 @@ fn test_protocol_parsing() {
} }
Err(e) => println!("Protocol parsing failed: {:?}", e), Err(e) => println!("Protocol parsing failed: {:?}", e),
} }
// Test HEXISTS command parsing // Test HEXISTS command parsing
let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n"; let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n";
println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n")); println!(
"\nParsing HEXISTS command: {}",
hexists_cmd.replace("\r\n", "\\r\\n")
);
match Protocol::from(hexists_cmd) { match Protocol::from(hexists_cmd) {
Ok((protocol, _)) => { Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol); println!("Protocol parsed successfully: {:?}", protocol);
@@ -32,4 +38,4 @@ fn test_protocol_parsing() {
} }
Err(e) => println!("Protocol parsing failed: {:?}", e), Err(e) => println!("Protocol parsing failed: {:?}", e),
} }
} }

View File

@@ -81,13 +81,13 @@ fn setup_server() -> (ServerProcessGuard, u16) {
]) ])
.spawn() .spawn()
.expect("Failed to start server process"); .expect("Failed to start server process");
// Create a new guard that also owns the test directory path // Create a new guard that also owns the test directory path
let guard = ServerProcessGuard { let guard = ServerProcessGuard {
process: child, process: child,
test_dir, test_dir,
}; };
// Give the server time to build and start (cargo run may compile first) // Give the server time to build and start (cargo run may compile first)
std::thread::sleep(Duration::from_millis(2500)); std::thread::sleep(Duration::from_millis(2500));
@@ -206,7 +206,9 @@ async fn test_expiration(conn: &mut Connection) {
async fn test_scan_operations(conn: &mut Connection) { async fn test_scan_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
for i in 0..5 { for i in 0..5 {
let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap(); let _: () = conn
.set(format!("key{}", i), format!("value{}", i))
.unwrap();
} }
let result: (u64, Vec<String>) = redis::cmd("SCAN") let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(0) .arg(0)
@@ -253,7 +255,9 @@ async fn test_scan_with_count(conn: &mut Connection) {
async fn test_hscan_operations(conn: &mut Connection) { async fn test_hscan_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
for i in 0..3 { for i in 0..3 {
let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap(); let _: () = conn
.hset("testhash", format!("field{}", i), format!("value{}", i))
.unwrap();
} }
let result: (u64, Vec<String>) = redis::cmd("HSCAN") let result: (u64, Vec<String>) = redis::cmd("HSCAN")
.arg("testhash") .arg("testhash")
@@ -273,8 +277,16 @@ async fn test_hscan_operations(conn: &mut Connection) {
async fn test_transaction_operations(conn: &mut Connection) { async fn test_transaction_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap(); let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap(); let _: () = redis::cmd("SET")
let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap(); .arg("key1")
.arg("value1")
.query(conn)
.unwrap();
let _: () = redis::cmd("SET")
.arg("key2")
.arg("value2")
.query(conn)
.unwrap();
let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap(); let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap();
let result: String = conn.get("key1").unwrap(); let result: String = conn.get("key1").unwrap();
assert_eq!(result, "value1"); assert_eq!(result, "value1");
@@ -286,7 +298,11 @@ async fn test_transaction_operations(conn: &mut Connection) {
async fn test_discard_transaction(conn: &mut Connection) { async fn test_discard_transaction(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap(); let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap(); let _: () = redis::cmd("SET")
.arg("discard")
.arg("value")
.query(conn)
.unwrap();
let _: () = redis::cmd("DISCARD").query(conn).unwrap(); let _: () = redis::cmd("DISCARD").query(conn).unwrap();
let result: Option<String> = conn.get("discard").unwrap(); let result: Option<String> = conn.get("discard").unwrap();
assert_eq!(result, None); assert_eq!(result, None);
@@ -306,7 +322,6 @@ async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
} }
async fn test_info_command(conn: &mut Connection) { async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let result: String = redis::cmd("INFO").query(conn).unwrap(); let result: String = redis::cmd("INFO").query(conn).unwrap();

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -8,14 +8,14 @@ use tokio::time::sleep;
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379); static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name); let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up and create test directory // Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -24,7 +24,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let server = Server::new(option).await; let server = Server::new(option).await;
(server, port) (server, port)
} }
@@ -47,7 +47,7 @@ async fn connect_to_server(port: u16) -> TcpStream {
// Helper function to send command and get response // Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String { async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -56,22 +56,22 @@ async fn send_command(stream: &mut TcpStream, command: &str) -> String {
#[tokio::test] #[tokio::test]
async fn test_basic_ping() { async fn test_basic_ping() {
let (mut server, port) = start_test_server("ping").await; let (mut server, port) = start_test_server("ping").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG")); assert!(response.contains("PONG"));
@@ -80,40 +80,44 @@ async fn test_basic_ping() {
#[tokio::test] #[tokio::test]
async fn test_string_operations() { async fn test_string_operations() {
let (mut server, port) = start_test_server("string").await; let (mut server, port) = start_test_server("string").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test SET // Test SET
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test GET // Test GET
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value")); assert!(response.contains("value"));
// Test GET non-existent key // Test GET non-existent key
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("$-1")); // NULL response assert!(response.contains("$-1")); // NULL response
// Test DEL // Test DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test GET after DEL // Test GET after DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("$-1")); // NULL response assert!(response.contains("$-1")); // NULL response
@@ -122,33 +126,37 @@ async fn test_string_operations() {
#[tokio::test] #[tokio::test]
async fn test_incr_operations() { async fn test_incr_operations() {
let (mut server, port) = start_test_server("incr").await; let (mut server, port) = start_test_server("incr").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test INCR on non-existent key // Test INCR on non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test INCR on existing key // Test INCR on existing key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("2")); assert!(response.contains("2"));
// Test INCR on string value (should fail) // Test INCR on string value (should fail)
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await; send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await;
assert!(response.contains("ERR")); assert!(response.contains("ERR"));
} }
@@ -156,63 +164,83 @@ async fn test_incr_operations() {
#[tokio::test] #[tokio::test]
async fn test_hash_operations() { async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash").await; let (mut server, port) = start_test_server("hash").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test HSET // Test HSET
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
assert!(response.contains("1")); // 1 new field assert!(response.contains("1")); // 1 new field
// Test HGET // Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("value1")); assert!(response.contains("value1"));
// Test HSET multiple fields // Test HSET multiple fields
let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await; let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await;
assert!(response.contains("2")); // 2 new fields assert!(response.contains("2")); // 2 new fields
// Test HGETALL // Test HGETALL
let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1")); assert!(response.contains("field1"));
assert!(response.contains("value1")); assert!(response.contains("value1"));
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Test HLEN // Test HLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("3")); assert!(response.contains("3"));
// Test HEXISTS // Test HEXISTS
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
assert!(response.contains("0")); assert!(response.contains("0"));
// Test HDEL // Test HDEL
let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HKEYS // Test HKEYS
let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("field3")); assert!(response.contains("field3"));
assert!(!response.contains("field1")); // Should be deleted assert!(!response.contains("field1")); // Should be deleted
// Test HVALS // Test HVALS
let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("value2")); assert!(response.contains("value2"));
@@ -222,46 +250,50 @@ async fn test_hash_operations() {
#[tokio::test] #[tokio::test]
async fn test_expiration() { async fn test_expiration() {
let (mut server, port) = start_test_server("expiration").await; let (mut server, port) = start_test_server("expiration").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test SETEX (expire in 1 second) // Test SETEX (expire in 1 second)
let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await; let response = send_command(
&mut stream,
"*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test TTL // Test TTL
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds
// Test EXISTS // Test EXISTS
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Wait for expiration // Wait for expiration
sleep(Duration::from_millis(1100)).await; sleep(Duration::from_millis(1100)).await;
// Test GET after expiration // Test GET after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("$-1")); // Should be NULL assert!(response.contains("$-1")); // Should be NULL
// Test TTL after expiration // Test TTL after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("-2")); // Key doesn't exist assert!(response.contains("-2")); // Key doesn't exist
// Test EXISTS after expiration // Test EXISTS after expiration
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("0")); assert!(response.contains("0"));
@@ -270,33 +302,37 @@ async fn test_expiration() {
#[tokio::test] #[tokio::test]
async fn test_scan_operations() { async fn test_scan_operations() {
let (mut server, port) = start_test_server("scan").await; let (mut server, port) = start_test_server("scan").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Set up test data // Set up test data
for i in 0..5 { for i in 0..5 {
let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i); let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await; send_command(&mut stream, &cmd).await;
} }
// Test SCAN // Test SCAN
let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; let response = send_command(
&mut stream,
"*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n",
)
.await;
assert!(response.contains("key")); assert!(response.contains("key"));
// Test KEYS // Test KEYS
let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await;
assert!(response.contains("key0")); assert!(response.contains("key0"));
@@ -306,29 +342,32 @@ async fn test_scan_operations() {
#[tokio::test] #[tokio::test]
async fn test_hscan_operations() { async fn test_hscan_operations() {
let (mut server, port) = start_test_server("hscan").await; let (mut server, port) = start_test_server("hscan").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Set up hash data // Set up hash data
for i in 0..3 { for i in 0..3 {
let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i); let cmd = format!(
"*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n",
i, i
);
send_command(&mut stream, &cmd).await; send_command(&mut stream, &cmd).await;
} }
// Test HSCAN // Test HSCAN
let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field")); assert!(response.contains("field"));
@@ -338,42 +377,50 @@ async fn test_hscan_operations() {
#[tokio::test] #[tokio::test]
async fn test_transaction_operations() { async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transaction").await; let (mut server, port) = start_test_server("transaction").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test MULTI // Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await; let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued commands // Test queued commands
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test EXEC // Test EXEC
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("OK")); // Should contain results of executed commands assert!(response.contains("OK")); // Should contain results of executed commands
// Verify commands were executed // Verify commands were executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1")); assert!(response.contains("value1"));
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await;
assert!(response.contains("value2")); assert!(response.contains("value2"));
} }
@@ -381,35 +428,39 @@ async fn test_transaction_operations() {
#[tokio::test] #[tokio::test]
async fn test_discard_transaction() { async fn test_discard_transaction() {
let (mut server, port) = start_test_server("discard").await; let (mut server, port) = start_test_server("discard").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test MULTI // Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await; let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued command // Test queued command
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test DISCARD // Test DISCARD
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await; let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Verify command was not executed // Verify command was not executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await;
assert!(response.contains("$-1")); // Should be NULL assert!(response.contains("$-1")); // Should be NULL
@@ -418,33 +469,41 @@ async fn test_discard_transaction() {
#[tokio::test] #[tokio::test]
async fn test_type_command() { async fn test_type_command() {
let (mut server, port) = start_test_server("type").await; let (mut server, port) = start_test_server("type").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test string type // Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
assert!(response.contains("string")); assert!(response.contains("string"));
// Test hash type // Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
assert!(response.contains("hash")); assert!(response.contains("hash"));
// Test non-existent key // Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("none")); assert!(response.contains("none"));
@@ -453,30 +512,38 @@ async fn test_type_command() {
#[tokio::test] #[tokio::test]
async fn test_config_commands() { async fn test_config_commands() {
let (mut server, port) = start_test_server("config").await; let (mut server, port) = start_test_server("config").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test CONFIG GET databases // Test CONFIG GET databases
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n",
)
.await;
assert!(response.contains("databases")); assert!(response.contains("databases"));
assert!(response.contains("16")); assert!(response.contains("16"));
// Test CONFIG GET dir // Test CONFIG GET dir
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n",
)
.await;
assert!(response.contains("dir")); assert!(response.contains("dir"));
assert!(response.contains("/tmp/herodb_test_config")); assert!(response.contains("/tmp/herodb_test_config"));
} }
@@ -484,27 +551,27 @@ async fn test_config_commands() {
#[tokio::test] #[tokio::test]
async fn test_info_command() { async fn test_info_command() {
let (mut server, port) = start_test_server("info").await; let (mut server, port) = start_test_server("info").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test INFO // Test INFO
let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await;
assert!(response.contains("redis_version")); assert!(response.contains("redis_version"));
// Test INFO replication // Test INFO replication
let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await;
assert!(response.contains("role:master")); assert!(response.contains("role:master"));
@@ -513,36 +580,44 @@ async fn test_info_command() {
#[tokio::test] #[tokio::test]
async fn test_error_handling() { async fn test_error_handling() {
let (mut server, port) = start_test_server("error").await; let (mut server, port) = start_test_server("error").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test WRONGTYPE error - try to use hash command on string // Test WRONGTYPE error - try to use hash command on string
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; send_command(
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await; &mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n",
)
.await;
assert!(response.contains("WRONGTYPE")); assert!(response.contains("WRONGTYPE"));
// Test unknown command // Test unknown command
let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await; let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await;
assert!(response.contains("unknown cmd") || response.contains("ERR")); assert!(response.contains("unknown cmd") || response.contains("ERR"));
// Test EXEC without MULTI // Test EXEC without MULTI
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("ERR")); assert!(response.contains("ERR"));
// Test DISCARD without MULTI // Test DISCARD without MULTI
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await; let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("ERR")); assert!(response.contains("ERR"));
@@ -551,29 +626,37 @@ async fn test_error_handling() {
#[tokio::test] #[tokio::test]
async fn test_list_operations() { async fn test_list_operations() {
let (mut server, port) = start_test_server("list").await; let (mut server, port) = start_test_server("list").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test LPUSH // Test LPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n",
)
.await;
assert!(response.contains("2")); // 2 elements assert!(response.contains("2")); // 2 elements
// Test RPUSH // Test RPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n",
)
.await;
assert!(response.contains("4")); // 4 elements assert!(response.contains("4")); // 4 elements
// Test LLEN // Test LLEN
@@ -581,29 +664,52 @@ async fn test_list_operations() {
assert!(response.contains("4")); assert!(response.contains("4"));
// Test LRANGE // Test LRANGE
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await; let response = send_command(
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"); &mut stream,
"*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n",
)
.await;
assert_eq!(
response,
"*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"
);
// Test LINDEX // Test LINDEX
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n",
)
.await;
assert_eq!(response, "$1\r\nb\r\n"); assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP // Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nb\r\n"); assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP // Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nd\r\n"); assert_eq!(response, "$1\r\nd\r\n");
// Test LREM // Test LREM
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a send_command(
let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await; &mut stream,
"*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n",
)
.await; // list is now a, c, a
let response = send_command(
&mut stream,
"*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test LTRIM // Test LTRIM
let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
} }

View File

@@ -1,23 +1,23 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server with clean data directory // Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000); static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000);
// Get a unique port for this test // Get a unique port for this test
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name); let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -26,16 +26,18 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let server = Server::new(option).await; let server = Server::new(option).await;
(server, port) (server, port)
} }
// Helper function to send Redis command and get response // Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String { async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -44,13 +46,13 @@ async fn send_redis_command(port: u16, command: &str) -> String {
#[tokio::test] #[tokio::test]
async fn test_basic_redis_functionality() { async fn test_basic_redis_functionality() {
let (mut server, port) = start_test_server("basic").await; let (mut server, port) = start_test_server("basic").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..10 { for _ in 0..10 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
@@ -58,68 +60,79 @@ async fn test_basic_redis_functionality() {
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Test PING // Test PING
let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await; let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG")); assert!(response.contains("PONG"));
// Test SET // Test SET
let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; let response =
send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test GET // Test GET
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value")); assert!(response.contains("value"));
// Test HSET // Test HSET
let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; let response = send_redis_command(
port,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HGET // Test HGET
let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await; let response =
send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
assert!(response.contains("value")); assert!(response.contains("value"));
// Test EXISTS // Test EXISTS
let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test TTL // Test TTL
let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("-1")); // No expiration assert!(response.contains("-1")); // No expiration
// Test TYPE // Test TYPE
let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await;
assert!(response.contains("string")); assert!(response.contains("string"));
// Test QUIT to close connection gracefully // Test QUIT to close connection gracefully
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
stream.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes()).await.unwrap(); .await
.unwrap();
stream
.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes())
.await
.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Ensure the stream is closed // Ensure the stream is closed
stream.shutdown().await.unwrap(); stream.shutdown().await.unwrap();
// Stop the server // Stop the server
server_handle.abort(); server_handle.abort();
println!("✅ All basic Redis functionality tests passed!"); println!("✅ All basic Redis functionality tests passed!");
} }
#[tokio::test] #[tokio::test]
async fn test_hash_operations() { async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash_ops").await; let (mut server, port) = start_test_server("hash_ops").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..5 { for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
@@ -127,53 +140,57 @@ async fn test_hash_operations() {
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Test HSET multiple fields // Test HSET multiple fields
let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await; let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("2")); // 2 new fields assert!(response.contains("2")); // 2 new fields
// Test HGETALL // Test HGETALL
let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await; let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1")); assert!(response.contains("field1"));
assert!(response.contains("value1")); assert!(response.contains("value1"));
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Test HEXISTS // Test HEXISTS
let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_redis_command(
port,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HLEN // Test HLEN
let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await; let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("2")); assert!(response.contains("2"));
// Test HSCAN // Test HSCAN
let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field1")); assert!(response.contains("field1"));
assert!(response.contains("value1")); assert!(response.contains("value1"));
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Stop the server // Stop the server
// For hash operations, we don't have a persistent stream, so we'll just abort the server. // For hash operations, we don't have a persistent stream, so we'll just abort the server.
// The server should handle closing its connections. // The server should handle closing its connections.
server_handle.abort(); server_handle.abort();
println!("✅ All hash operations tests passed!"); println!("✅ All hash operations tests passed!");
} }
#[tokio::test] #[tokio::test]
async fn test_transaction_operations() { async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transactions").await; let (mut server, port) = start_test_server("transactions").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..5 { for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
@@ -181,49 +198,69 @@ async fn test_transaction_operations() {
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction // Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Test MULTI // Test MULTI
stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap(); stream
.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes())
.await
.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued commands // Test queued commands
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap(); stream
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap(); stream
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test EXEC // Test EXEC
stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap(); stream
.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); // Should contain array of OK responses assert!(response.contains("OK")); // Should contain array of OK responses
// Verify commands were executed // Verify commands were executed
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes()).await.unwrap(); stream
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value1")); assert!(response.contains("value1"));
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes()).await.unwrap(); stream
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Stop the server // Stop the server
server_handle.abort(); server_handle.abort();
println!("✅ All transaction operations tests passed!"); println!("✅ All transaction operations tests passed!");
} }

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -8,14 +8,14 @@ use tokio::time::sleep;
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500); static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_simple_test_{}", test_name); let test_dir = format!("/tmp/herodb_simple_test_{}", test_name);
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -24,7 +24,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let server = Server::new(option).await; let server = Server::new(option).await;
(server, port) (server, port)
} }
@@ -32,7 +32,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
// Helper function to send command and get response // Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String { async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -56,22 +56,22 @@ async fn connect_to_server(port: u16) -> TcpStream {
#[tokio::test] #[tokio::test]
async fn test_basic_ping_simple() { async fn test_basic_ping_simple() {
let (mut server, port) = start_test_server("ping").await; let (mut server, port) = start_test_server("ping").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG")); assert!(response.contains("PONG"));
@@ -80,31 +80,43 @@ async fn test_basic_ping_simple() {
#[tokio::test] #[tokio::test]
async fn test_hset_clean_db() { async fn test_hset_clean_db() {
let (mut server, port) = start_test_server("hset_clean").await; let (mut server, port) = start_test_server("hset_clean").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field // Test HSET - should return 1 for new field
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
println!("HSET response: {}", response); println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response); assert!(
response.contains("1"),
"Expected HSET to return 1, got: {}",
response
);
// Test HGET // Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HGET response: {}", response); println!("HGET response: {}", response);
assert!(response.contains("value1")); assert!(response.contains("value1"));
} }
@@ -112,73 +124,101 @@ async fn test_hset_clean_db() {
#[tokio::test] #[tokio::test]
async fn test_type_command_simple() { async fn test_type_command_simple() {
let (mut server, port) = start_test_server("type").await; let (mut server, port) = start_test_server("type").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test string type // Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
println!("TYPE string response: {}", response); println!("TYPE string response: {}", response);
assert!(response.contains("string")); assert!(response.contains("string"));
// Test hash type // Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
println!("TYPE hash response: {}", response); println!("TYPE hash response: {}", response);
assert!(response.contains("hash")); assert!(response.contains("hash"));
// Test non-existent key // Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
println!("TYPE noexist response: {}", response); println!("TYPE noexist response: {}", response);
assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response); assert!(
response.contains("none"),
"Expected 'none' for non-existent key, got: {}",
response
);
} }
#[tokio::test] #[tokio::test]
async fn test_hexists_simple() { async fn test_hexists_simple() {
let (mut server, port) = start_test_server("hexists").await; let (mut server, port) = start_test_server("hexists").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Set up hash // Set up hash
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
// Test HEXISTS for existing field // Test HEXISTS for existing field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HEXISTS existing field response: {}", response); println!("HEXISTS existing field response: {}", response);
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HEXISTS for non-existent field // Test HEXISTS for non-existent field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
println!("HEXISTS non-existent field response: {}", response); println!("HEXISTS non-existent field response: {}", response);
assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response); assert!(
} response.contains("0"),
"Expected HEXISTS to return 0 for non-existent field, got: {}",
response
);
}

View File

@@ -325,7 +325,11 @@ async fn test_03_scan_and_keys() {
let mut s = connect(port).await; let mut s = connect(port).await;
for i in 0..5 { for i in 0..5 {
let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await; let _ = send_cmd(
&mut s,
&["SET", &format!("key{}", i), &format!("value{}", i)],
)
.await;
} }
let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await; let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await;
@@ -358,7 +362,11 @@ async fn test_04_hashes_suite() {
assert_contains(&h2, "2", "HSET added 2 new fields"); assert_contains(&h2, "2", "HSET added 2 new fields");
// HMGET // HMGET
let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await; let hmg = send_cmd(
&mut s,
&["HMGET", "profile:1", "name", "age", "city", "nope"],
)
.await;
assert_contains(&hmg, "alice", "HMGET name"); assert_contains(&hmg, "alice", "HMGET name");
assert_contains(&hmg, "30", "HMGET age"); assert_contains(&hmg, "30", "HMGET age");
assert_contains(&hmg, "paris", "HMGET city"); assert_contains(&hmg, "paris", "HMGET city");
@@ -392,7 +400,11 @@ async fn test_04_hashes_suite() {
assert_contains(&hnx1, "1", "HSETNX new field -> 1"); assert_contains(&hnx1, "1", "HSETNX new field -> 1");
// HSCAN // HSCAN
let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await; let hscan = send_cmd(
&mut s,
&["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"],
)
.await;
assert_contains(&hscan, "name", "HSCAN matches fields starting with n"); assert_contains(&hscan, "name", "HSCAN matches fields starting with n");
assert_contains(&hscan, "nickname", "HSCAN nickname present"); assert_contains(&hscan, "nickname", "HSCAN nickname present");
@@ -424,13 +436,21 @@ async fn test_05_lists_suite_including_blpop() {
assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b"); assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b");
let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]"); assert_eq_resp(
&lr,
"*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n",
"LRANGE q:jobs 0 -1 should be [b,a,c]",
);
// LTRIM // LTRIM
let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await; let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await;
assert_contains(&ltrim, "OK", "LTRIM OK"); assert_contains(&ltrim, "OK", "LTRIM OK");
let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]"); assert_eq_resp(
&lr_post,
"*2\r\n$1\r\nb\r\n$1\r\na\r\n",
"After LTRIM, list [b,a]",
);
// LREM remove first occurrence of b // LREM remove first occurrence of b
let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await; let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await;
@@ -444,7 +464,11 @@ async fn test_05_lists_suite_including_blpop() {
// LPOP with count on empty -> [] // LPOP with count on empty -> []
let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await; let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await;
assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array"); assert_eq_resp(
&lpop0,
"*0\r\n",
"LPOP with count on empty returns empty array",
);
// BLPOP: block on one client, push from another // BLPOP: block on one client, push from another
let c1 = connect(port).await; let c1 = connect(port).await;
@@ -513,7 +537,7 @@ async fn test_07_age_stateless_suite() {
// naive parse for tests // naive parse for tests
let mut lines = resp.lines(); let mut lines = resp.lines();
let _ = lines.next(); // *2 let _ = lines.next(); // *2
// $len // $len
let _ = lines.next(); let _ = lines.next();
let recip = lines.next().unwrap_or("").to_string(); let recip = lines.next().unwrap_or("").to_string();
let _ = lines.next(); let _ = lines.next();
@@ -548,8 +572,16 @@ async fn test_07_age_stateless_suite() {
let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await; let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await;
assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature"); assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature");
let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await; let v_bad = send_cmd(
assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature"); &mut s,
&["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64],
)
.await;
assert_contains(
&v_bad,
"0",
"VERIFY should be 0 for invalid message/signature",
);
} }
#[tokio::test] #[tokio::test]
@@ -581,7 +613,7 @@ async fn test_08_age_persistent_named_suite() {
skg skg
); );
let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await; let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"]).await;
let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME"); let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME");
let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await; let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await;
assert_contains(&v1, "1", "VERIFYNAME valid => 1"); assert_contains(&v1, "1", "VERIFYNAME valid => 1");
@@ -597,60 +629,69 @@ async fn test_08_age_persistent_named_suite() {
#[tokio::test] #[tokio::test]
async fn test_10_expire_pexpire_persist() { async fn test_10_expire_pexpire_persist() {
let (server, port) = start_test_server("expire_suite").await; let (server, port) = start_test_server("expire_suite").await;
spawn_listener(server, port).await; spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await; sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await; let mut s = connect(port).await;
// EXPIRE: seconds // EXPIRE: seconds
let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await;
let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await; let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await;
assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)"); assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await; let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert!( assert!(
ttl1.contains("1") || ttl1.contains("0"), ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:s should be 1 or 0, got: {}", "TTL exp:s should be 1 or 0, got: {}",
ttl1 ttl1
); );
sleep(Duration::from_millis(1100)).await; sleep(Duration::from_millis(1100)).await;
let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await; let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await;
assert_contains(&get_after, "$-1", "GET after expiry should be Null"); assert_contains(&get_after, "$-1", "GET after expiry should be Null");
let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await; let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert_contains(&ttl_after, "-2", "TTL after expiry -> -2"); assert_contains(&ttl_after, "-2", "TTL after expiry -> -2");
let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await; let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await;
assert_contains(&exists_after, "0", "EXISTS after expiry -> 0"); assert_contains(&exists_after, "0", "EXISTS after expiry -> 0");
// PEXPIRE: milliseconds // PEXPIRE: milliseconds
let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await;
let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await; let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await;
assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)"); assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)");
let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await; let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await;
assert!( assert!(
ttl_ms1.contains("1") || ttl_ms1.contains("0"), ttl_ms1.contains("1") || ttl_ms1.contains("0"),
"TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}", "TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}",
ttl_ms1 ttl_ms1
); );
sleep(Duration::from_millis(1600)).await; sleep(Duration::from_millis(1600)).await;
let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await; let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await;
assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0"); assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0");
// PERSIST: remove expiration // PERSIST: remove expiration
let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await;
let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await; let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await;
let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await; let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert!( assert!(
ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"), ttl_pre.contains("5")
"TTL exp:persist should be >=0 before persist, got: {}", || ttl_pre.contains("4")
ttl_pre || ttl_pre.contains("3")
); || ttl_pre.contains("2")
let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; || ttl_pre.contains("1")
assert_contains(&persist1, "1", "PERSIST should remove expiration"); || ttl_pre.contains("0"),
let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await; "TTL exp:persist should be >=0 before persist, got: {}",
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)"); ttl_pre
// Second persist should return 0 (nothing to remove) );
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); assert_contains(&persist1, "1", "PERSIST should remove expiration");
let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
// Second persist should return 0 (nothing to remove)
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(
&persist2,
"0",
"PERSIST again -> 0 (no expiration to remove)",
);
} }
#[tokio::test] #[tokio::test]
@@ -663,7 +704,11 @@ async fn test_11_set_with_options() {
// SET with GET on non-existing key -> returns Null, sets value // SET with GET on non-existing key -> returns Null, sets value
let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await;
assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist"); assert_contains(
&set_get1,
"$-1",
"SET s1 v1 GET returns Null when key didn't exist",
);
let g1 = send_cmd(&mut s, &["GET", "s1"]).await; let g1 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g1, "v1", "GET s1 after first SET"); assert_contains(&g1, "v1", "GET s1 after first SET");
@@ -707,42 +752,42 @@ async fn test_11_set_with_options() {
#[tokio::test] #[tokio::test]
async fn test_09_mget_mset_and_variadic_exists_del() { async fn test_09_mget_mset_and_variadic_exists_del() {
let (server, port) = start_test_server("mget_mset_variadic").await; let (server, port) = start_test_server("mget_mset_variadic").await;
spawn_listener(server, port).await; spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await; sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await; let mut s = connect(port).await;
// MSET multiple keys // MSET multiple keys
let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await; let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await;
assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK"); assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK");
// MGET should return values and Null for missing // MGET should return values and Null for missing
let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await; let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await;
// Expect an array with 4 entries; verify payloads // Expect an array with 4 entries; verify payloads
assert_contains(&mget, "v1", "MGET k1"); assert_contains(&mget, "v1", "MGET k1");
assert_contains(&mget, "v2", "MGET k2"); assert_contains(&mget, "v2", "MGET k2");
assert_contains(&mget, "v3", "MGET k3"); assert_contains(&mget, "v3", "MGET k3");
assert_contains(&mget, "$-1", "MGET missing returns Null"); assert_contains(&mget, "$-1", "MGET missing returns Null");
// EXISTS variadic: count how many exist // EXISTS variadic: count how many exist
let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await; let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await;
// Server returns SimpleString numeric, e.g. +2 // Server returns SimpleString numeric, e.g. +2
assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2"); assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2");
// DEL variadic: delete multiple keys, return count deleted // DEL variadic: delete multiple keys, return count deleted
let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await; let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await;
assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2"); assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2");
// Verify deletion // Verify deletion
let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await; let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await;
assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0"); assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0");
// MGET after deletion should include Nulls for deleted keys // MGET after deletion should include Nulls for deleted keys
let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await; let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await;
assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null");
assert_contains(&mget_after, "v2", "MGET k2 remains"); assert_contains(&mget_after, "v2", "MGET k2 remains");
assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null");
} }
#[tokio::test] #[tokio::test]
async fn test_12_hash_incr() { async fn test_12_hash_incr() {
@@ -862,9 +907,16 @@ async fn test_14_expireat_pexpireat() {
let mut s = connect(port).await; let mut s = connect(port).await;
// EXPIREAT: seconds since epoch // EXPIREAT: seconds since epoch
let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await;
let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await; let exat = send_cmd(
&mut s,
&["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)],
)
.await;
assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)"); assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await; let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await;
assert!( assert!(
@@ -874,12 +926,23 @@ async fn test_14_expireat_pexpireat() {
); );
sleep(Duration::from_millis(1200)).await; sleep(Duration::from_millis(1200)).await;
let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await; let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await;
assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0"); assert_contains(
&exists_after_exat,
"0",
"EXISTS exp:at:s after EXPIREAT expiry -> 0",
);
// PEXPIREAT: milliseconds since epoch // PEXPIREAT: milliseconds since epoch
let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64; let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await;
let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await; let pexat = send_cmd(
&mut s,
&["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)],
)
.await;
assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)"); assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)");
let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await; let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await;
assert!( assert!(
@@ -889,5 +952,9 @@ async fn test_14_expireat_pexpireat() {
); );
sleep(Duration::from_millis(600)).await; sleep(Duration::from_millis(600)).await;
let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await; let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await;
assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0"); assert_contains(
} &exists_after_pexat,
"0",
"EXISTS exp:at:ms after PEXPIREAT expiry -> 0",
);
}