Format rust files

Signed-off-by: Lee Smet <lee.smet@hotmail.com>
This commit is contained in:
Lee Smet
2025-08-25 11:16:25 +02:00
parent e9675aafed
commit ff0659b933
25 changed files with 2267 additions and 1318 deletions

View File

@@ -16,7 +16,9 @@ fn read_reply(s: &mut TcpStream) -> String {
} }
fn parse_two_bulk(reply: &str) -> Option<(String, String)> { fn parse_two_bulk(reply: &str) -> Option<(String, String)> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
if lines.next()? != "*2" { return None; } if lines.next()? != "*2" {
return None;
}
let _n = lines.next()?; let _n = lines.next()?;
let a = lines.next()?.to_string(); let a = lines.next()?.to_string();
let _m = lines.next()?; let _m = lines.next()?;
@@ -26,13 +28,17 @@ fn parse_two_bulk(reply: &str) -> Option<(String,String)> {
fn parse_bulk(reply: &str) -> Option<String> { fn parse_bulk(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
let hdr = lines.next()?; let hdr = lines.next()?;
if !hdr.starts_with('$') { return None; } if !hdr.starts_with('$') {
return None;
}
Some(lines.next()?.to_string()) Some(lines.next()?.to_string())
} }
fn parse_simple(reply: &str) -> Option<String> { fn parse_simple(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
let hdr = lines.next()?; let hdr = lines.next()?;
if !hdr.starts_with('+') { return None; } if !hdr.starts_with('+') {
return None;
}
Some(hdr[1..].to_string()) Some(hdr[1..].to_string())
} }
@@ -45,31 +51,37 @@ fn main() {
let mut s = TcpStream::connect(addr).expect("connect"); let mut s = TcpStream::connect(addr).expect("connect");
// Generate & persist X25519 enc keys under name "alice" // Generate & persist X25519 enc keys under name "alice"
s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap(); s.write_all(arr(&["age", "keygen", "alice"]).as_bytes())
.unwrap();
let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc"); let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc");
// Generate & persist Ed25519 signing key under name "signer" // Generate & persist Ed25519 signing key under name "signer"
s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap(); s.write_all(arr(&["age", "signkeygen", "signer"]).as_bytes())
.unwrap();
let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign"); let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign");
// Encrypt by name // Encrypt by name
let msg = "hello from persistent keys"; let msg = "hello from persistent keys";
s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap(); s.write_all(arr(&["age", "encryptname", "alice", msg]).as_bytes())
.unwrap();
let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64"); let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64");
println!("ciphertext b64: {}", ct_b64); println!("ciphertext b64: {}", ct_b64);
// Decrypt by name // Decrypt by name
s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap(); s.write_all(arr(&["age", "decryptname", "alice", &ct_b64]).as_bytes())
.unwrap();
let pt = parse_bulk(&read_reply(&mut s)).expect("pt"); let pt = parse_bulk(&read_reply(&mut s)).expect("pt");
assert_eq!(pt, msg); assert_eq!(pt, msg);
println!("decrypted ok"); println!("decrypted ok");
// Sign by name // Sign by name
s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap(); s.write_all(arr(&["age", "signname", "signer", msg]).as_bytes())
.unwrap();
let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64"); let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64");
// Verify by name // Verify by name
s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap(); s.write_all(arr(&["age", "verifyname", "signer", msg, &sig_b64]).as_bytes())
.unwrap();
let ok = parse_simple(&read_reply(&mut s)).expect("verify"); let ok = parse_simple(&read_reply(&mut s)).expect("verify");
assert_eq!(ok, "1"); assert_eq!(ok, "1");
println!("signature verified"); println!("signature verified");

View File

@@ -12,17 +12,17 @@
use std::str::FromStr; use std::str::FromStr;
use secrecy::ExposeSecret;
use age::{Decryptor, Encryptor};
use age::x25519; use age::x25519;
use age::{Decryptor, Encryptor};
use secrecy::ExposeSecret;
use ed25519_dalek::{Signature, Signer, Verifier, SigningKey, VerifyingKey}; use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _}; use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use crate::error::DBError;
use crate::protocol::Protocol; use crate::protocol::Protocol;
use crate::server::Server; use crate::server::Server;
use crate::error::DBError;
// ---------- Internal helpers ---------- // ---------- Internal helpers ----------
@@ -83,8 +83,8 @@ pub fn gen_enc_keypair() -> (String, String) {
} }
pub fn gen_sign_keypair() -> (String, String) { pub fn gen_sign_keypair() -> (String, String) {
use rand::RngCore;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rand::RngCore;
// Generate random 32 bytes for the signing key // Generate random 32 bytes for the signing key
let mut secret_bytes = [0u8; 32]; let mut secret_bytes = [0u8; 32];
@@ -103,14 +103,18 @@ pub fn gen_sign_keypair() -> (String, String) {
/// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext). /// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext).
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> { pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> {
let recipient = parse_recipient(recipient_str)?; let recipient = parse_recipient(recipient_str)?;
let enc = Encryptor::with_recipients(vec![Box::new(recipient)]) let enc =
.expect("failed to create encryptor"); // Handle Option<Encryptor> Encryptor::with_recipients(vec![Box::new(recipient)]).expect("failed to create encryptor"); // Handle Option<Encryptor>
let mut out = Vec::new(); let mut out = Vec::new();
{ {
use std::io::Write; use std::io::Write;
let mut w = enc.wrap_output(&mut out).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let mut w = enc
w.write_all(msg.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; .wrap_output(&mut out)
w.finish().map_err(|e| AgeWireError::Crypto(e.to_string()))?; .map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.write_all(msg.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.finish()
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
} }
Ok(B64.encode(out)) Ok(B64.encode(out))
} }
@@ -118,19 +122,27 @@ pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireErro
/// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String. /// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String.
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> { pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> {
let id = parse_identity(identity_str)?; let id = parse_identity(identity_str)?;
let ct = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let ct = B64
.decode(ct_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
// The decrypt method returns a Result<StreamReader, DecryptError> // The decrypt method returns a Result<StreamReader, DecryptError>
let mut r = match dec { let mut r = match dec {
Decryptor::Recipients(d) => d.decrypt(std::iter::once(&id as &dyn age::Identity)) Decryptor::Recipients(d) => d
.decrypt(std::iter::once(&id as &dyn age::Identity))
.map_err(|e| AgeWireError::Crypto(e.to_string()))?, .map_err(|e| AgeWireError::Crypto(e.to_string()))?,
Decryptor::Passphrase(_) => return Err(AgeWireError::Crypto("Expected recipients, got passphrase".to_string())), Decryptor::Passphrase(_) => {
return Err(AgeWireError::Crypto(
"Expected recipients, got passphrase".to_string(),
))
}
}; };
let mut pt = Vec::new(); let mut pt = Vec::new();
use std::io::Read; use std::io::Read;
r.read_to_end(&mut pt).map_err(|e| AgeWireError::Crypto(e.to_string()))?; r.read_to_end(&mut pt)
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
String::from_utf8(pt).map_err(|_| AgeWireError::Utf8) String::from_utf8(pt).map_err(|_| AgeWireError::Utf8)
} }
@@ -144,7 +156,9 @@ pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AgeWireEr
/// Verify detached signature (base64) for `msg` with pubkey. /// Verify detached signature (base64) for `msg` with pubkey.
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> { pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> {
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?; let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
let sig_bytes = B64.decode(sig_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let sig_bytes = B64
.decode(sig_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
if sig_bytes.len() != 64 { if sig_bytes.len() != 64 {
return Err(AgeWireError::SignatureLen); return Err(AgeWireError::SignatureLen);
} }
@@ -155,30 +169,49 @@ pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool
// ---------- Storage helpers ---------- // ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> { fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?; let st = server
.current_storage()
.map_err(|e| AgeWireError::Storage(e.0))?;
st.get(key).map_err(|e| AgeWireError::Storage(e.0)) st.get(key).map_err(|e| AgeWireError::Storage(e.0))
} }
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> { fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?; let st = server
st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0)) .current_storage()
.map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string())
.map_err(|e| AgeWireError::Storage(e.0))
} }
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") } fn enc_pub_key_key(name: &str) -> String {
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") } format!("age:key:{name}")
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") } }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") } fn enc_priv_key_key(name: &str) -> String {
format!("age:privkey:{name}")
}
fn sign_pub_key_key(name: &str) -> String {
format!("age:signpub:{name}")
}
fn sign_priv_key_key(name: &str) -> String {
format!("age:signpriv:{name}")
}
// ---------- Command handlers (RESP Protocol) ---------- // ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness // Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol { pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = gen_enc_keypair(); let (recip, ident) = gen_enc_keypair();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)]) Protocol::Array(vec![
Protocol::BulkString(recip),
Protocol::BulkString(ident),
])
} }
pub async fn cmd_age_gensign() -> Protocol { pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = gen_sign_keypair(); let (verify, secret) = gen_sign_keypair();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)]) Protocol::Array(vec![
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
} }
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol { pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
@@ -214,16 +247,30 @@ pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> P
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol { pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = gen_enc_keypair(); let (recip, ident) = gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); } if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) {
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); } return e.to_protocol();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)]) }
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) {
return e.to_protocol();
}
Protocol::Array(vec![
Protocol::BulkString(recip),
Protocol::BulkString(ident),
])
} }
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol { pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = gen_sign_keypair(); let (verify, secret) = gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); } if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) {
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); } return e.to_protocol();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)]) }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) {
return e.to_protocol();
}
Protocol::Array(vec![
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
} }
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol { pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
@@ -253,7 +300,9 @@ pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) ->
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol { pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) { let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v, Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(), Ok(None) => {
return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol()
}
Err(e) => return e.to_protocol(), Err(e) => return e.to_protocol(),
}; };
match sign_b64(&sec, message) { match sign_b64(&sec, message) {
@@ -262,10 +311,17 @@ pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Pr
} }
} }
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol { pub async fn cmd_age_verify_name(
server: &Server,
name: &str,
message: &str,
sig_b64: &str,
) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) { let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v, Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(), Ok(None) => {
return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol()
}
Err(e) => return e.to_protocol(), Err(e) => return e.to_protocol(),
}; };
match verify_b64(&pubk, message, sig_b64) { match verify_b64(&pubk, message, sig_b64) {
@@ -277,25 +333,43 @@ pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig
pub async fn cmd_age_list(server: &Server) -> Protocol { pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...] // Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) }; let st = match server.current_storage() {
Ok(s) => s,
Err(e) => return Protocol::err(&e.0),
};
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> { let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?; let keys = st.keys(pat)?;
let mut names: Vec<String> = keys.into_iter() let mut names: Vec<String> = keys
.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string())) .filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect(); .collect();
names.sort(); names.sort();
Ok(names) Ok(names)
}; };
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; let encpub = match pull("age:key:*", "age:key:") {
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; Ok(v) => v,
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; Err(e) => return Protocol::err(&e.0),
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) }; };
let encpriv = match pull("age:privkey:*", "age:privkey:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpub = match pull("age:signpub:*", "age:signpub:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpriv = match pull("age:signpriv:*", "age:signpriv:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let to_arr = |label: &str, v: Vec<String>| { let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())]; let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect())); out.push(Protocol::Array(
v.into_iter().map(Protocol::BulkString).collect(),
));
Protocol::Array(out) Protocol::Array(out)
}; };

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,8 @@
use std::num::ParseIntError; use std::num::ParseIntError;
use tokio::sync::mpsc;
use redb;
use bincode; use bincode;
use redb;
use tokio::sync::mpsc;
// todo: more error types // todo: more error types
#[derive(Debug)] #[derive(Debug)]

View File

@@ -7,6 +7,6 @@ pub mod protocol;
pub mod search_cmd; // Add this pub mod search_cmd; // Add this
pub mod server; pub mod server;
pub mod storage; pub mod storage;
pub mod storage_trait; // Add this
pub mod storage_sled; // Add this pub mod storage_sled; // Add this
pub mod storage_trait; // Add this
pub mod tantivy_search; pub mod tantivy_search;

View File

@@ -22,7 +22,6 @@ struct Args {
#[arg(long)] #[arg(long)]
debug: bool, debug: bool,
/// Master encryption key for encrypted databases /// Master encryption key for encrypted databases
#[arg(long)] #[arg(long)]
encryption_key: Option<String>, encryption_key: Option<String>,

View File

@@ -92,7 +92,10 @@ impl Protocol {
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") { match protocol.find("\r\n") {
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])), Some(x) => Ok((
Self::SimpleString(protocol[..x].to_string()),
&protocol[x + 2..],
)),
_ => Err(DBError(format!( _ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}", "[new simple string] unsupported protocol: {:?}",
protocol protocol

View File

@@ -3,8 +3,7 @@ use crate::{
protocol::Protocol, protocol::Protocol,
server::Server, server::Server,
tantivy_search::{ tantivy_search::{
TantivySearch, FieldDef, NumericType, IndexConfig, FieldDef, Filter, FilterType, IndexConfig, NumericType, SearchOptions, TantivySearch,
SearchOptions, Filter, FilterType
}, },
}; };
use std::collections::HashMap; use std::collections::HashMap;
@@ -86,9 +85,7 @@ pub async fn ft_create_cmd(
case_sensitive, case_sensitive,
} }
} }
"GEO" => { "GEO" => FieldDef::Geo { stored: true },
FieldDef::Geo { stored: true }
}
_ => { _ => {
return Err(DBError(format!("Unknown field type: {}", field_type))); return Err(DBError(format!("Unknown field type: {}", field_type)));
} }
@@ -101,7 +98,10 @@ pub async fn ft_create_cmd(
let search_path = server.search_index_path(); let search_path = server.search_index_path();
let config = IndexConfig::default(); let config = IndexConfig::default();
println!("Creating search index '{}' at path: {:?}", index_name, search_path); println!(
"Creating search index '{}' at path: {:?}",
index_name, search_path
);
println!("Field definitions: {:?}", field_definitions); println!("Field definitions: {:?}", field_definitions);
let search_index = TantivySearch::new_with_schema( let search_index = TantivySearch::new_with_schema(
@@ -129,7 +129,8 @@ pub async fn ft_add_cmd(
) -> Result<Protocol, DBError> { ) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let search_index = indexes.get(&index_name) let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
search_index.add_document_with_fields(&doc_id, fields)?; search_index.add_document_with_fields(&doc_id, fields)?;
@@ -148,16 +149,18 @@ pub async fn ft_search_cmd(
) -> Result<Protocol, DBError> { ) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let search_index = indexes.get(&index_name) let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// Convert filters to search filters // Convert filters to search filters
let search_filters = filters.into_iter().map(|(field, value)| { let search_filters = filters
Filter { .into_iter()
.map(|(field, value)| Filter {
field, field,
filter_type: FilterType::Equals(value), filter_type: FilterType::Equals(value),
} })
}).collect(); .collect();
let options = SearchOptions { let options = SearchOptions {
limit: limit.unwrap_or(10), limit: limit.unwrap_or(10),
@@ -209,7 +212,8 @@ pub async fn ft_del_cmd(
) -> Result<Protocol, DBError> { ) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let _search_index = indexes.get(&index_name) let _search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// For now, return success // For now, return success
@@ -219,13 +223,11 @@ pub async fn ft_del_cmd(
Ok(Protocol::SimpleString("1".to_string())) Ok(Protocol::SimpleString("1".to_string()))
} }
pub async fn ft_info_cmd( pub async fn ft_info_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
server: &Server,
index_name: String,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap(); let indexes = server.search_indexes.read().unwrap();
let search_index = indexes.get(&index_name) let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?; .ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
let info = search_index.get_info()?; let info = search_index.get_info()?;
@@ -243,7 +245,9 @@ pub async fn ft_info_cmd(
response.push(Protocol::BulkString(info.fields.len().to_string())); response.push(Protocol::BulkString(info.fields.len().to_string()));
response.push(Protocol::BulkString("fields".to_string())); response.push(Protocol::BulkString("fields".to_string()));
let fields_str = info.fields.iter() let fields_str = info
.fields
.iter()
.map(|f| format!("{}:{}", f.name, f.field_type)) .map(|f| format!("{}:{}", f.name, f.field_type))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "); .join(", ");
@@ -252,10 +256,7 @@ pub async fn ft_info_cmd(
Ok(Protocol::Array(response)) Ok(Protocol::Array(response))
} }
pub async fn ft_drop_cmd( pub async fn ft_drop_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
server: &Server,
index_name: String,
) -> Result<Protocol, DBError> {
let mut indexes = server.search_indexes.write().unwrap(); let mut indexes = server.search_indexes.write().unwrap();
if indexes.remove(&index_name).is_some() { if indexes.remove(&index_name).is_some() {

View File

@@ -1,10 +1,10 @@
use core::str; use core::str;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::RwLock;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::sync::{Mutex, oneshot}; use tokio::sync::{oneshot, Mutex};
use std::sync::RwLock;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
@@ -65,7 +65,6 @@ impl Server {
return Ok(storage.clone()); return Ok(storage.clone());
} }
// Create new database file // Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db)); .join(format!("{}.db", self.selected_db));
@@ -73,27 +72,27 @@ impl Server {
// Ensure the directory exists before creating the database file // Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() { if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| { std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e)) DBError(format!(
"Failed to create directory {}: {}",
parent_dir.display(),
e
))
})?; })?;
} }
println!("Creating new db file: {}", db_file_path.display()); println!("Creating new db file: {}", db_file_path.display());
let storage: Arc<dyn StorageBackend> = match self.option.backend { let storage: Arc<dyn StorageBackend> = match self.option.backend {
options::BackendType::Redb => { options::BackendType::Redb => Arc::new(Storage::new(
Arc::new(Storage::new(
db_file_path, db_file_path,
self.should_encrypt_db(self.selected_db), self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref() self.option.encryption_key.as_deref(),
)?) )?),
} options::BackendType::Sled => Arc::new(SledStorage::new(
options::BackendType::Sled => {
Arc::new(SledStorage::new(
db_file_path, db_file_path,
self.should_encrypt_db(self.selected_db), self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref() self.option.encryption_key.as_deref(),
)?) )?),
}
}; };
cache.insert(self.selected_db, storage.clone()); cache.insert(self.selected_db, storage.clone());
@@ -112,7 +111,12 @@ impl Server {
// ----- BLPOP waiter helpers ----- // ----- BLPOP waiter helpers -----
pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) { pub async fn register_waiter(
&self,
db_index: u64,
key: &str,
side: PopSide,
) -> (u64, oneshot::Receiver<(String, String)>) {
let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel::<(String, String)>(); let (tx, rx) = oneshot::channel::<(String, String)>();
@@ -188,10 +192,7 @@ impl Server {
Ok(()) Ok(())
} }
pub async fn handle( pub async fn handle(&mut self, mut stream: tokio::net::TcpStream) -> Result<(), DBError> {
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
// Accumulate incoming bytes to handle partial RESP frames // Accumulate incoming bytes to handle partial RESP frames
let mut acc = String::new(); let mut acc = String::new();
let mut buf = vec![0u8; 8192]; let mut buf = vec![0u8; 8192];
@@ -228,7 +229,10 @@ impl Server {
acc = remaining.to_string(); acc = remaining.to_string();
if self.option.debug { if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); println!(
"\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m",
cmd, protocol
);
} else { } else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol); println!("got command: {:?}, protocol: {:?}", cmd, protocol);
} }

View File

@@ -12,9 +12,9 @@ use crate::error::DBError;
// Re-export modules // Re-export modules
mod storage_basic; mod storage_basic;
mod storage_extra;
mod storage_hset; mod storage_hset;
mod storage_lists; mod storage_lists;
mod storage_extra;
// Re-export implementations // Re-export implementations
// Note: These imports are used by the impl blocks in the submodules // Note: These imports are used by the impl blocks in the submodules
@@ -28,7 +28,8 @@ const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("string
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data"); const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> =
TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
@@ -55,7 +56,11 @@ pub struct Storage {
} }
impl Storage { impl Storage {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> { pub fn new(
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
let db = Database::create(path)?; let db = Database::create(path)?;
// Create tables if they don't exist // Create tables if they don't exist
@@ -75,14 +80,19 @@ impl Storage {
// Check if database was previously encrypted // Check if database was previously encrypted
let read_txn = db.begin_read()?; let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?; let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false); let was_encrypted = encrypted_table
.get("encrypted")?
.map(|v| v.value() == 1)
.unwrap_or(false);
drop(read_txn); drop(read_txn);
let crypto = if should_encrypt || was_encrypted { let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key { if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes())) Some(CryptoFactory::new(key.as_bytes()))
} else { } else {
return Err(DBError("Encryption requested but no master key provided".to_string())); return Err(DBError(
"Encryption requested but no master key provided".to_string(),
));
} }
} else { } else {
None None
@@ -98,10 +108,7 @@ impl Storage {
write_txn.commit()?; write_txn.commit()?;
} }
Ok(Storage { Ok(Storage { db, crypto })
db,
crypto,
})
} }
pub fn is_encrypted(&self) -> bool { pub fn is_encrypted(&self) -> bool {
@@ -165,11 +172,22 @@ impl StorageBackend for Storage {
self.get_key_type(key) self.get_key_type(key)
} }
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
self.scan(cursor, pattern, count) self.scan(cursor, pattern, count)
} }
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
self.hscan(key, cursor, pattern, count) self.hscan(key, cursor, pattern, count)
} }

View File

@@ -1,6 +1,6 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
pub fn flushdb(&self) -> Result<(), DBError> { pub fn flushdb(&self) -> Result<(), DBError> {
@@ -15,11 +15,17 @@ impl Storage {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way // inefficient, but there is no other way
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = types_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
types_table.remove(key.as_str())?; types_table.remove(key.as_str())?;
} }
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = strings_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
strings_table.remove(key.as_str())?; strings_table.remove(key.as_str())?;
} }
@@ -34,23 +40,35 @@ impl Storage {
for (key, field) in keys { for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?; hashes_table.remove((key.as_str(), field.as_str()))?;
} }
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = lists_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
lists_table.remove(key.as_str())?; lists_table.remove(key.as_str())?;
} }
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = streams_meta_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
streams_meta_table.remove(key.as_str())?; streams_meta_table.remove(key.as_str())?;
} }
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| { let keys: Vec<(String, String)> = streams_data_table
.iter()?
.map(|item| {
let binding = item.unwrap(); let binding = item.unwrap();
let (key, field) = binding.0.value(); let (key, field) = binding.0.value();
(key.to_string(), field.to_string()) (key.to_string(), field.to_string())
}).collect(); })
.collect();
for (key, field) in keys { for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?; streams_data_table.remove((key.as_str(), field.as_str()))?;
} }
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect(); let keys: Vec<String> = expiration_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
expiration_table.remove(key.as_str())?; expiration_table.remove(key.as_str())?;
} }
@@ -163,7 +181,8 @@ impl Storage {
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table: redb::Table<(&str, &str), &[u8]> =
write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table // Remove from type table

View File

@@ -1,10 +1,15 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { pub fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let strings_table = read_txn.open_table(STRINGS_TABLE)?; let strings_table = read_txn.open_table(STRINGS_TABLE)?;
@@ -50,7 +55,11 @@ impl Storage {
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
@@ -178,8 +187,12 @@ impl Storage {
.unwrap_or(false); .unwrap_or(false);
if is_string { if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; let expires_at_ms: u128 = if ts_secs <= 0 {
expiration_table.insert(key, &((expires_at_ms as u64)))?; 0
} else {
(ts_secs as u128) * 1000
};
expiration_table.insert(key, &(expires_at_ms as u64))?;
applied = true; applied = true;
} }
} }
@@ -201,7 +214,7 @@ impl Storage {
if is_string { if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
expiration_table.insert(key, &((expires_at_ms as u64)))?; expiration_table.insert(key, &(expires_at_ms as u64))?;
applied = true; applied = true;
} }
} }

View File

@@ -1,6 +1,6 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage // ✅ ENCRYPTION APPLIED: Values are encrypted before storage
@@ -18,7 +18,8 @@ impl Storage {
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") | None => { // Proceed if hash or new key Some("hash") | None => {
// Proceed if hash or new key
// Set the type to hash (only if new key or existing hash) // Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?; types_table.insert(key, "hash")?;
@@ -35,7 +36,12 @@ impl Storage {
} }
} }
} }
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
} }
} }
@@ -62,7 +68,9 @@ impl Storage {
None => Ok(None), None => Ok(None),
} }
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(None), None => Ok(None),
} }
} }
@@ -94,7 +102,9 @@ impl Storage {
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -138,7 +148,11 @@ impl Storage {
types_table.remove(key)?; types_table.remove(key)?;
} }
} }
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
None => {} // Key does not exist, nothing to delete, return 0 deleted None => {} // Key does not exist, nothing to delete, return 0 deleted
} }
@@ -159,7 +173,9 @@ impl Storage {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some()) Ok(hashes_table.get((key, field))?.is_some())
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(false), None => Ok(false),
} }
} }
@@ -188,7 +204,9 @@ impl Storage {
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -220,7 +238,9 @@ impl Storage {
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -249,7 +269,9 @@ impl Storage {
Ok(count) Ok(count)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(0), None => Ok(0),
} }
} }
@@ -281,7 +303,9 @@ impl Storage {
Ok(result) Ok(result)
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(fields.into_iter().map(|_| None).collect()), None => Ok(fields.into_iter().map(|_| None).collect()),
} }
} }
@@ -301,7 +325,8 @@ impl Storage {
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") | None => { // Proceed if hash or new key Some("hash") | None => {
// Proceed if hash or new key
// Check if field already exists // Check if field already exists
if hashes_table.get((key, field))?.is_none() { if hashes_table.get((key, field))?.is_none() {
// Set the type to hash (only if new key or existing hash) // Set the type to hash (only if new key or existing hash)
@@ -313,7 +338,12 @@ impl Storage {
result = true; result = true;
} }
} }
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
} }
} }
@@ -322,7 +352,13 @@ impl Storage {
} }
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { pub fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
@@ -367,10 +403,16 @@ impl Storage {
} }
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok((0, Vec::new())), None => Ok((0, Vec::new())),
} }
} }

View File

@@ -1,6 +1,6 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*; use super::*;
use crate::error::DBError;
use redb::ReadableTable;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage // ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
@@ -248,8 +248,16 @@ impl Storage {
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; let start_idx = if start < 0 {
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new()); return Ok(Vec::new());
@@ -298,8 +306,16 @@ impl Storage {
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) }; let start_idx = if start < 0 {
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) }; std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {

View File

@@ -1,12 +1,12 @@
// src/storage_sled/mod.rs // src/storage_sled/mod.rs
use std::path::Path; use crate::crypto::CryptoFactory;
use std::sync::Arc;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::error::DBError; use crate::error::DBError;
use crate::storage_trait::StorageBackend; use crate::storage_trait::StorageBackend;
use crate::crypto::CryptoFactory; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
enum ValueType { enum ValueType {
@@ -28,13 +28,22 @@ pub struct SledStorage {
} }
impl SledStorage { impl SledStorage {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> { pub fn new(
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?; let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?;
let types = db.open_tree("types").map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?; let types = db
.open_tree("types")
.map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?;
// Check if database was previously encrypted // Check if database was previously encrypted
let encrypted_tree = db.open_tree("encrypted").map_err(|e| DBError(e.to_string()))?; let encrypted_tree = db
let was_encrypted = encrypted_tree.get("encrypted") .open_tree("encrypted")
.map_err(|e| DBError(e.to_string()))?;
let was_encrypted = encrypted_tree
.get("encrypted")
.map_err(|e| DBError(e.to_string()))? .map_err(|e| DBError(e.to_string()))?
.map(|v| v[0] == 1) .map(|v| v[0] == 1)
.unwrap_or(false); .unwrap_or(false);
@@ -43,7 +52,9 @@ impl SledStorage {
if let Some(key) = master_key { if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes())) Some(CryptoFactory::new(key.as_bytes()))
} else { } else {
return Err(DBError("Encryption requested but no master key provided".to_string())); return Err(DBError(
"Encryption requested but no master key provided".to_string(),
));
} }
} else { } else {
None None
@@ -51,7 +62,8 @@ impl SledStorage {
// Mark database as encrypted if enabling encryption // Mark database as encrypted if enabling encryption
if should_encrypt && !was_encrypted { if should_encrypt && !was_encrypted {
encrypted_tree.insert("encrypted", &[1u8]) encrypted_tree
.insert("encrypted", &[1u8])
.map_err(|e| DBError(e.to_string()))?; .map_err(|e| DBError(e.to_string()))?;
encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?; encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?;
} }
@@ -101,7 +113,7 @@ impl SledStorage {
Ok(Some(storage_val)) Ok(Some(storage_val))
} }
None => Ok(None) None => Ok(None),
} }
} }
@@ -109,7 +121,9 @@ impl SledStorage {
let data = bincode::serialize(&storage_val) let data = bincode::serialize(&storage_val)
.map_err(|e| DBError(format!("Serialization error: {}", e)))?; .map_err(|e| DBError(format!("Serialization error: {}", e)))?;
let encrypted = self.encrypt_if_needed(&data)?; let encrypted = self.encrypt_if_needed(&data)?;
self.db.insert(key, encrypted).map_err(|e| DBError(e.to_string()))?; self.db
.insert(key, encrypted)
.map_err(|e| DBError(e.to_string()))?;
// Store type info (unencrypted for efficiency) // Store type info (unencrypted for efficiency)
let type_str = match &storage_val.value { let type_str = match &storage_val.value {
@@ -117,7 +131,9 @@ impl SledStorage {
ValueType::Hash(_) => "hash", ValueType::Hash(_) => "hash",
ValueType::List(_) => "list", ValueType::List(_) => "list",
}; };
self.types.insert(key, type_str.as_bytes()).map_err(|e| DBError(e.to_string()))?; self.types
.insert(key, type_str.as_bytes())
.map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -168,9 +184,9 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::String(s) => Ok(Some(s)), ValueType::String(s) => Ok(Some(s)),
_ => Ok(None) _ => Ok(None),
} },
None => Ok(None) None => Ok(None),
} }
} }
@@ -196,7 +212,9 @@ impl StorageBackend for SledStorage {
fn del(&self, key: String) -> Result<(), DBError> { fn del(&self, key: String) -> Result<(), DBError> {
self.db.remove(&key).map_err(|e| DBError(e.to_string()))?; self.db.remove(&key).map_err(|e| DBError(e.to_string()))?;
self.types.remove(&key).map_err(|e| DBError(e.to_string()))?; self.types
.remove(&key)
.map_err(|e| DBError(e.to_string()))?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -222,7 +240,12 @@ impl StorageBackend for SledStorage {
Ok(keys) Ok(keys)
} }
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
@@ -258,7 +281,11 @@ impl StorageBackend for SledStorage {
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
@@ -286,7 +313,7 @@ impl StorageBackend for SledStorage {
if self.get_storage_value(key)?.is_some() { if self.get_storage_value(key)?.is_some() {
match self.types.get(key).map_err(|e| DBError(e.to_string()))? { match self.types.get(key).map_err(|e| DBError(e.to_string()))? {
Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())), Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())),
None => Ok(None) None => Ok(None),
} }
} else { } else {
Ok(None) Ok(None)
@@ -302,7 +329,11 @@ impl StorageBackend for SledStorage {
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
let mut new_fields = 0i64; let mut new_fields = 0i64;
@@ -322,9 +353,9 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.get(field).cloned()), ValueType::Hash(h) => Ok(h.get(field).cloned()),
_ => Ok(None) _ => Ok(None),
} },
None => Ok(None) None => Ok(None),
} }
} }
@@ -332,13 +363,19 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.into_iter().collect()), ValueType::Hash(h) => Ok(h.into_iter().collect()),
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => { ValueType::Hash(h) => {
@@ -365,24 +402,30 @@ impl StorageBackend for SledStorage {
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { 0 } else { current_cursor }; let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
_ => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())) _ => Err(DBError(
} "WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
None => Ok((0, Vec::new())) )),
},
None => Ok((0, Vec::new())),
} }
} }
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> { fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(0) None => return Ok(0),
}; };
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => return Ok(0) _ => return Ok(0),
}; };
let mut deleted = 0i64; let mut deleted = 0i64;
@@ -406,9 +449,9 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.contains_key(field)), ValueType::Hash(h) => Ok(h.contains_key(field)),
_ => Ok(false) _ => Ok(false),
} },
None => Ok(false) None => Ok(false),
} }
} }
@@ -416,9 +459,9 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.keys().cloned().collect()), ValueType::Hash(h) => Ok(h.keys().cloned().collect()),
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
@@ -426,9 +469,9 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.values().cloned().collect()), ValueType::Hash(h) => Ok(h.values().cloned().collect()),
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
@@ -436,21 +479,19 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.len() as i64), ValueType::Hash(h) => Ok(h.len() as i64),
_ => Ok(0) _ => Ok(0),
} },
None => Ok(0) None => Ok(0),
} }
} }
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> { fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => { ValueType::Hash(h) => Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()),
Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()) _ => Ok(fields.into_iter().map(|_| None).collect()),
} },
_ => Ok(fields.into_iter().map(|_| None).collect()) None => Ok(fields.into_iter().map(|_| None).collect()),
}
None => Ok(fields.into_iter().map(|_| None).collect())
} }
} }
@@ -462,7 +503,11 @@ impl StorageBackend for SledStorage {
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
if hash.contains_key(field) { if hash.contains_key(field) {
@@ -484,7 +529,11 @@ impl StorageBackend for SledStorage {
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
for element in elements.into_iter().rev() { for element in elements.into_iter().rev() {
@@ -505,7 +554,11 @@ impl StorageBackend for SledStorage {
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())), _ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
list.extend(elements); list.extend(elements);
@@ -518,12 +571,12 @@ impl StorageBackend for SledStorage {
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(Vec::new()) None => return Ok(Vec::new()),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(Vec::new()) _ => return Ok(Vec::new()),
}; };
let mut result = Vec::new(); let mut result = Vec::new();
@@ -547,12 +600,12 @@ impl StorageBackend for SledStorage {
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(Vec::new()) None => return Ok(Vec::new()),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(Vec::new()) _ => return Ok(Vec::new()),
}; };
let mut result = Vec::new(); let mut result = Vec::new();
@@ -576,9 +629,9 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::List(l) => Ok(l.len() as i64), ValueType::List(l) => Ok(l.len() as i64),
_ => Ok(0) _ => Ok(0),
} },
None => Ok(0) None => Ok(0),
} }
} }
@@ -598,9 +651,9 @@ impl StorageBackend for SledStorage {
Ok(None) Ok(None)
} }
} }
_ => Ok(None) _ => Ok(None),
} },
None => Ok(None) None => Ok(None),
} }
} }
@@ -633,21 +686,21 @@ impl StorageBackend for SledStorage {
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
} }
_ => Ok(Vec::new()) _ => Ok(Vec::new()),
} },
None => Ok(Vec::new()) None => Ok(Vec::new()),
} }
} }
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(()) None => return Ok(()),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(()) _ => return Ok(()),
}; };
if list.is_empty() { if list.is_empty() {
@@ -687,12 +740,12 @@ impl StorageBackend for SledStorage {
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> { fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(0) None => return Ok(0),
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(0) _ => return Ok(0),
}; };
let mut removed = 0i64; let mut removed = 0i64;
@@ -751,14 +804,14 @@ impl StorageBackend for SledStorage {
Ok(-1) // Key exists but has no expiration Ok(-1) // Key exists but has no expiration
} }
} }
None => Ok(-2) // Key does not exist None => Ok(-2), // Key does not exist
} }
} }
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> { fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000); storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000);
@@ -770,7 +823,7 @@ impl StorageBackend for SledStorage {
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> { fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
storage_val.expires_at = Some(Self::now_millis() + ms); storage_val.expires_at = Some(Self::now_millis() + ms);
@@ -782,7 +835,7 @@ impl StorageBackend for SledStorage {
fn persist(&self, key: &str) -> Result<bool, DBError> { fn persist(&self, key: &str) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
if storage_val.expires_at.is_some() { if storage_val.expires_at.is_some() {
@@ -798,10 +851,14 @@ impl StorageBackend for SledStorage {
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> { fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 }; let expires_at_ms: u128 = if ts_secs <= 0 {
0
} else {
(ts_secs as u128) * 1000
};
storage_val.expires_at = Some(expires_at_ms); storage_val.expires_at = Some(expires_at_ms);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
@@ -811,7 +868,7 @@ impl StorageBackend for SledStorage {
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> { fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false) None => return Ok(false),
}; };
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };

View File

@@ -15,8 +15,19 @@ pub trait StorageBackend: Send + Sync {
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>; fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>;
// Scanning // Scanning
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>; fn scan(
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>; &self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
// Hash operations // Hash operations
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>; fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>;

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -47,17 +47,31 @@ async fn debug_hset_simple() {
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Test simple HSET // Test simple HSET
println!("Testing HSET..."); println!("Testing HSET...");
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
println!("HSET response: {}", response); println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected '1' but got: {}", response); assert!(response.contains("1"), "Expected '1' but got: {}", response);
// Test HGET // Test HGET
println!("Testing HGET..."); println!("Testing HGET...");
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HGET response: {}", response); println!("HGET response: {}", response);
assert!(response.contains("value1"), "Expected 'value1' but got: {}", response); assert!(
response.contains("value1"),
"Expected 'value1' but got: {}",
response
);
} }

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -53,5 +53,9 @@ async fn debug_hset_return_value() {
println!("Response bytes: {:?}", &buffer[..n]); println!("Response bytes: {:?}", &buffer[..n]);
// Check if response contains "1" // Check if response contains "1"
assert!(response.contains("1"), "Expected response to contain '1', got: {}", response); assert!(
response.contains("1"),
"Expected response to contain '1', got: {}",
response
);
} }

View File

@@ -1,11 +1,14 @@
use herodb::protocol::Protocol;
use herodb::cmd::Cmd; use herodb::cmd::Cmd;
use herodb::protocol::Protocol;
#[test] #[test]
fn test_protocol_parsing() { fn test_protocol_parsing() {
// Test TYPE command parsing // Test TYPE command parsing
let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n"; let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n";
println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n")); println!(
"Parsing TYPE command: {}",
type_cmd.replace("\r\n", "\\r\\n")
);
match Protocol::from(type_cmd) { match Protocol::from(type_cmd) {
Ok((protocol, _)) => { Ok((protocol, _)) => {
@@ -20,7 +23,10 @@ fn test_protocol_parsing() {
// Test HEXISTS command parsing // Test HEXISTS command parsing
let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n"; let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n";
println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n")); println!(
"\nParsing HEXISTS command: {}",
hexists_cmd.replace("\r\n", "\\r\\n")
);
match Protocol::from(hexists_cmd) { match Protocol::from(hexists_cmd) {
Ok((protocol, _)) => { Ok((protocol, _)) => {

View File

@@ -206,7 +206,9 @@ async fn test_expiration(conn: &mut Connection) {
async fn test_scan_operations(conn: &mut Connection) { async fn test_scan_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
for i in 0..5 { for i in 0..5 {
let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap(); let _: () = conn
.set(format!("key{}", i), format!("value{}", i))
.unwrap();
} }
let result: (u64, Vec<String>) = redis::cmd("SCAN") let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(0) .arg(0)
@@ -253,7 +255,9 @@ async fn test_scan_with_count(conn: &mut Connection) {
async fn test_hscan_operations(conn: &mut Connection) { async fn test_hscan_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
for i in 0..3 { for i in 0..3 {
let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap(); let _: () = conn
.hset("testhash", format!("field{}", i), format!("value{}", i))
.unwrap();
} }
let result: (u64, Vec<String>) = redis::cmd("HSCAN") let result: (u64, Vec<String>) = redis::cmd("HSCAN")
.arg("testhash") .arg("testhash")
@@ -273,8 +277,16 @@ async fn test_hscan_operations(conn: &mut Connection) {
async fn test_transaction_operations(conn: &mut Connection) { async fn test_transaction_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap(); let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap(); let _: () = redis::cmd("SET")
let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap(); .arg("key1")
.arg("value1")
.query(conn)
.unwrap();
let _: () = redis::cmd("SET")
.arg("key2")
.arg("value2")
.query(conn)
.unwrap();
let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap(); let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap();
let result: String = conn.get("key1").unwrap(); let result: String = conn.get("key1").unwrap();
assert_eq!(result, "value1"); assert_eq!(result, "value1");
@@ -286,7 +298,11 @@ async fn test_transaction_operations(conn: &mut Connection) {
async fn test_discard_transaction(conn: &mut Connection) { async fn test_discard_transaction(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap(); let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap(); let _: () = redis::cmd("SET")
.arg("discard")
.arg("value")
.query(conn)
.unwrap();
let _: () = redis::cmd("DISCARD").query(conn).unwrap(); let _: () = redis::cmd("DISCARD").query(conn).unwrap();
let result: Option<String> = conn.get("discard").unwrap(); let result: Option<String> = conn.get("discard").unwrap();
assert_eq!(result, None); assert_eq!(result, None);
@@ -306,7 +322,6 @@ async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
} }
async fn test_info_command(conn: &mut Connection) { async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let result: String = redis::cmd("INFO").query(conn).unwrap(); let result: String = redis::cmd("INFO").query(conn).unwrap();

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -99,7 +99,11 @@ async fn test_string_operations() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test SET // Test SET
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test GET // Test GET
@@ -148,7 +152,11 @@ async fn test_incr_operations() {
assert!(response.contains("2")); assert!(response.contains("2"));
// Test INCR on string value (should fail) // Test INCR on string value (should fail)
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await; send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await;
assert!(response.contains("ERR")); assert!(response.contains("ERR"));
} }
@@ -174,11 +182,19 @@ async fn test_hash_operations() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test HSET // Test HSET
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
assert!(response.contains("1")); // 1 new field assert!(response.contains("1")); // 1 new field
// Test HGET // Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("value1")); assert!(response.contains("value1"));
// Test HSET multiple fields // Test HSET multiple fields
@@ -197,14 +213,26 @@ async fn test_hash_operations() {
assert!(response.contains("3")); assert!(response.contains("3"));
// Test HEXISTS // Test HEXISTS
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
assert!(response.contains("0")); assert!(response.contains("0"));
// Test HDEL // Test HDEL
let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HKEYS // Test HKEYS
@@ -240,7 +268,11 @@ async fn test_expiration() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test SETEX (expire in 1 second) // Test SETEX (expire in 1 second)
let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await; let response = send_command(
&mut stream,
"*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test TTL // Test TTL
@@ -294,7 +326,11 @@ async fn test_scan_operations() {
} }
// Test SCAN // Test SCAN
let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; let response = send_command(
&mut stream,
"*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n",
)
.await;
assert!(response.contains("key")); assert!(response.contains("key"));
// Test KEYS // Test KEYS
@@ -325,7 +361,10 @@ async fn test_hscan_operations() {
// Set up hash data // Set up hash data
for i in 0..3 { for i in 0..3 {
let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i); let cmd = format!(
"*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n",
i, i
);
send_command(&mut stream, &cmd).await; send_command(&mut stream, &cmd).await;
} }
@@ -360,10 +399,18 @@ async fn test_transaction_operations() {
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued commands // Test queued commands
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test EXEC // Test EXEC
@@ -403,7 +450,11 @@ async fn test_discard_transaction() {
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued command // Test queued command
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test DISCARD // Test DISCARD
@@ -436,12 +487,20 @@ async fn test_type_command() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test string type // Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
assert!(response.contains("string")); assert!(response.contains("string"));
// Test hash type // Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
assert!(response.contains("hash")); assert!(response.contains("hash"));
@@ -471,12 +530,20 @@ async fn test_config_commands() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test CONFIG GET databases // Test CONFIG GET databases
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n",
)
.await;
assert!(response.contains("databases")); assert!(response.contains("databases"));
assert!(response.contains("16")); assert!(response.contains("16"));
// Test CONFIG GET dir // Test CONFIG GET dir
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n",
)
.await;
assert!(response.contains("dir")); assert!(response.contains("dir"));
assert!(response.contains("/tmp/herodb_test_config")); assert!(response.contains("/tmp/herodb_test_config"));
} }
@@ -531,8 +598,16 @@ async fn test_error_handling() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test WRONGTYPE error - try to use hash command on string // Test WRONGTYPE error - try to use hash command on string
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; send_command(
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await; &mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n",
)
.await;
assert!(response.contains("WRONGTYPE")); assert!(response.contains("WRONGTYPE"));
// Test unknown command // Test unknown command
@@ -569,11 +644,19 @@ async fn test_list_operations() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test LPUSH // Test LPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n",
)
.await;
assert!(response.contains("2")); // 2 elements assert!(response.contains("2")); // 2 elements
// Test RPUSH // Test RPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n",
)
.await;
assert!(response.contains("4")); // 4 elements assert!(response.contains("4")); // 4 elements
// Test LLEN // Test LLEN
@@ -581,11 +664,22 @@ async fn test_list_operations() {
assert!(response.contains("4")); assert!(response.contains("4"));
// Test LRANGE // Test LRANGE
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await; let response = send_command(
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"); &mut stream,
"*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n",
)
.await;
assert_eq!(
response,
"*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"
);
// Test LINDEX // Test LINDEX
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n",
)
.await;
assert_eq!(response, "$1\r\nb\r\n"); assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP // Test LPOP
@@ -597,12 +691,24 @@ async fn test_list_operations() {
assert_eq!(response, "$1\r\nd\r\n"); assert_eq!(response, "$1\r\nd\r\n");
// Test LREM // Test LREM
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a send_command(
let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await; &mut stream,
"*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n",
)
.await; // list is now a, c, a
let response = send_command(
&mut stream,
"*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test LTRIM // Test LTRIM
let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));

View File

@@ -1,8 +1,8 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server with clean data directory // Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Server, u16) {
@@ -33,7 +33,9 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
// Helper function to send Redis command and get response // Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String { async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
@@ -66,7 +68,8 @@ async fn test_basic_redis_functionality() {
assert!(response.contains("PONG")); assert!(response.contains("PONG"));
// Test SET // Test SET
let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; let response =
send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test GET // Test GET
@@ -74,11 +77,16 @@ async fn test_basic_redis_functionality() {
assert!(response.contains("value")); assert!(response.contains("value"));
// Test HSET // Test HSET
let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; let response = send_redis_command(
port,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HGET // Test HGET
let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await; let response =
send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
assert!(response.contains("value")); assert!(response.contains("value"));
// Test EXISTS // Test EXISTS
@@ -94,8 +102,13 @@ async fn test_basic_redis_functionality() {
assert!(response.contains("string")); assert!(response.contains("string"));
// Test QUIT to close connection gracefully // Test QUIT to close connection gracefully
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
stream.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes()).await.unwrap(); .await
.unwrap();
stream
.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes())
.await
.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
@@ -142,7 +155,11 @@ async fn test_hash_operations() {
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Test HEXISTS // Test HEXISTS
let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_redis_command(
port,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HLEN // Test HLEN
@@ -185,39 +202,59 @@ async fn test_transaction_operations() {
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction // Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap(); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Test MULTI // Test MULTI
stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap(); stream
.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes())
.await
.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued commands // Test queued commands
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap(); stream
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap(); stream
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test EXEC // Test EXEC
stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap(); stream
.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); // Should contain array of OK responses assert!(response.contains("OK")); // Should contain array of OK responses
// Verify commands were executed // Verify commands were executed
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes()).await.unwrap(); stream
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value1")); assert!(response.contains("value1"));
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes()).await.unwrap(); stream
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value2")); assert!(response.contains("value2"));

View File

@@ -1,4 +1,4 @@
use herodb::{server::Server, options::DBOption}; use herodb::{options::DBOption, server::Server};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -99,12 +99,24 @@ async fn test_hset_clean_db() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field // Test HSET - should return 1 for new field
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
println!("HSET response: {}", response); println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response); assert!(
response.contains("1"),
"Expected HSET to return 1, got: {}",
response
);
// Test HGET // Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HGET response: {}", response); println!("HGET response: {}", response);
assert!(response.contains("value1")); assert!(response.contains("value1"));
} }
@@ -131,13 +143,21 @@ async fn test_type_command_simple() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test string type // Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
println!("TYPE string response: {}", response); println!("TYPE string response: {}", response);
assert!(response.contains("string")); assert!(response.contains("string"));
// Test hash type // Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await; send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
println!("TYPE hash response: {}", response); println!("TYPE hash response: {}", response);
assert!(response.contains("hash")); assert!(response.contains("hash"));
@@ -145,7 +165,11 @@ async fn test_type_command_simple() {
// Test non-existent key // Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
println!("TYPE noexist response: {}", response); println!("TYPE noexist response: {}", response);
assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response); assert!(
response.contains("none"),
"Expected 'none' for non-existent key, got: {}",
response
);
} }
#[tokio::test] #[tokio::test]
@@ -170,15 +194,31 @@ async fn test_hexists_simple() {
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Set up hash // Set up hash
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await; send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
// Test HEXISTS for existing field // Test HEXISTS for existing field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HEXISTS existing field response: {}", response); println!("HEXISTS existing field response: {}", response);
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HEXISTS for non-existent field // Test HEXISTS for non-existent field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await; let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
println!("HEXISTS non-existent field response: {}", response); println!("HEXISTS non-existent field response: {}", response);
assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response); assert!(
response.contains("0"),
"Expected HEXISTS to return 0 for non-existent field, got: {}",
response
);
} }

View File

@@ -325,7 +325,11 @@ async fn test_03_scan_and_keys() {
let mut s = connect(port).await; let mut s = connect(port).await;
for i in 0..5 { for i in 0..5 {
let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await; let _ = send_cmd(
&mut s,
&["SET", &format!("key{}", i), &format!("value{}", i)],
)
.await;
} }
let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await; let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await;
@@ -358,7 +362,11 @@ async fn test_04_hashes_suite() {
assert_contains(&h2, "2", "HSET added 2 new fields"); assert_contains(&h2, "2", "HSET added 2 new fields");
// HMGET // HMGET
let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await; let hmg = send_cmd(
&mut s,
&["HMGET", "profile:1", "name", "age", "city", "nope"],
)
.await;
assert_contains(&hmg, "alice", "HMGET name"); assert_contains(&hmg, "alice", "HMGET name");
assert_contains(&hmg, "30", "HMGET age"); assert_contains(&hmg, "30", "HMGET age");
assert_contains(&hmg, "paris", "HMGET city"); assert_contains(&hmg, "paris", "HMGET city");
@@ -392,7 +400,11 @@ async fn test_04_hashes_suite() {
assert_contains(&hnx1, "1", "HSETNX new field -> 1"); assert_contains(&hnx1, "1", "HSETNX new field -> 1");
// HSCAN // HSCAN
let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await; let hscan = send_cmd(
&mut s,
&["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"],
)
.await;
assert_contains(&hscan, "name", "HSCAN matches fields starting with n"); assert_contains(&hscan, "name", "HSCAN matches fields starting with n");
assert_contains(&hscan, "nickname", "HSCAN nickname present"); assert_contains(&hscan, "nickname", "HSCAN nickname present");
@@ -424,13 +436,21 @@ async fn test_05_lists_suite_including_blpop() {
assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b"); assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b");
let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]"); assert_eq_resp(
&lr,
"*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n",
"LRANGE q:jobs 0 -1 should be [b,a,c]",
);
// LTRIM // LTRIM
let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await; let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await;
assert_contains(&ltrim, "OK", "LTRIM OK"); assert_contains(&ltrim, "OK", "LTRIM OK");
let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]"); assert_eq_resp(
&lr_post,
"*2\r\n$1\r\nb\r\n$1\r\na\r\n",
"After LTRIM, list [b,a]",
);
// LREM remove first occurrence of b // LREM remove first occurrence of b
let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await; let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await;
@@ -444,7 +464,11 @@ async fn test_05_lists_suite_including_blpop() {
// LPOP with count on empty -> [] // LPOP with count on empty -> []
let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await; let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await;
assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array"); assert_eq_resp(
&lpop0,
"*0\r\n",
"LPOP with count on empty returns empty array",
);
// BLPOP: block on one client, push from another // BLPOP: block on one client, push from another
let c1 = connect(port).await; let c1 = connect(port).await;
@@ -548,8 +572,16 @@ async fn test_07_age_stateless_suite() {
let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await; let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await;
assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature"); assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature");
let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await; let v_bad = send_cmd(
assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature"); &mut s,
&["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64],
)
.await;
assert_contains(
&v_bad,
"0",
"VERIFY should be 0 for invalid message/signature",
);
} }
#[tokio::test] #[tokio::test]
@@ -640,7 +672,12 @@ async fn test_10_expire_pexpire_persist() {
let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await; let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await;
let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await; let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert!( assert!(
ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"), ttl_pre.contains("5")
|| ttl_pre.contains("4")
|| ttl_pre.contains("3")
|| ttl_pre.contains("2")
|| ttl_pre.contains("1")
|| ttl_pre.contains("0"),
"TTL exp:persist should be >=0 before persist, got: {}", "TTL exp:persist should be >=0 before persist, got: {}",
ttl_pre ttl_pre
); );
@@ -650,7 +687,11 @@ async fn test_10_expire_pexpire_persist() {
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)"); assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
// Second persist should return 0 (nothing to remove) // Second persist should return 0 (nothing to remove)
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)"); assert_contains(
&persist2,
"0",
"PERSIST again -> 0 (no expiration to remove)",
);
} }
#[tokio::test] #[tokio::test]
@@ -663,7 +704,11 @@ async fn test_11_set_with_options() {
// SET with GET on non-existing key -> returns Null, sets value // SET with GET on non-existing key -> returns Null, sets value
let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await;
assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist"); assert_contains(
&set_get1,
"$-1",
"SET s1 v1 GET returns Null when key didn't exist",
);
let g1 = send_cmd(&mut s, &["GET", "s1"]).await; let g1 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g1, "v1", "GET s1 after first SET"); assert_contains(&g1, "v1", "GET s1 after first SET");
@@ -862,9 +907,16 @@ async fn test_14_expireat_pexpireat() {
let mut s = connect(port).await; let mut s = connect(port).await;
// EXPIREAT: seconds since epoch // EXPIREAT: seconds since epoch
let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await;
let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await; let exat = send_cmd(
&mut s,
&["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)],
)
.await;
assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)"); assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await; let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await;
assert!( assert!(
@@ -874,12 +926,23 @@ async fn test_14_expireat_pexpireat() {
); );
sleep(Duration::from_millis(1200)).await; sleep(Duration::from_millis(1200)).await;
let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await; let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await;
assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0"); assert_contains(
&exists_after_exat,
"0",
"EXISTS exp:at:s after EXPIREAT expiry -> 0",
);
// PEXPIREAT: milliseconds since epoch // PEXPIREAT: milliseconds since epoch
let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64; let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await;
let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await; let pexat = send_cmd(
&mut s,
&["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)],
)
.await;
assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)"); assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)");
let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await; let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await;
assert!( assert!(
@@ -889,5 +952,9 @@ async fn test_14_expireat_pexpireat() {
); );
sleep(Duration::from_millis(600)).await; sleep(Duration::from_millis(600)).await;
let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await; let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await;
assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0"); assert_contains(
&exists_after_pexat,
"0",
"EXISTS exp:at:ms after PEXPIREAT expiry -> 0",
);
} }