Compare commits

..

8 Commits
main ... blpop

8 changed files with 1762 additions and 70 deletions

View File

@ -1,5 +1,7 @@
use crate::{error::DBError, protocol::Protocol, server::Server}; use crate::{error::DBError, protocol::Protocol, server::Server};
use serde::Serialize; use serde::Serialize;
use tokio::time::{timeout, Duration};
use futures::future::select_all;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Cmd { pub enum Cmd {
@ -10,7 +12,12 @@ pub enum Cmd {
Set(String, String), Set(String, String),
SetPx(String, String, u128), SetPx(String, String, u128),
SetEx(String, String, u128), SetEx(String, String, u128),
// Advanced SET with options: (key, value, ex_ms, nx, xx, get)
SetOpts(String, String, Option<u128>, bool, bool, bool),
MGet(Vec<String>),
MSet(Vec<(String, String)>),
Keys, Keys,
DbSize,
ConfigGet(String), ConfigGet(String),
Info(Option<String>), Info(Option<String>),
Del(String), Del(String),
@ -30,19 +37,31 @@ pub enum Cmd {
HLen(String), HLen(String),
HMGet(String, Vec<String>), HMGet(String, Vec<String>),
HSetNx(String, String, String), HSetNx(String, String, String),
HIncrBy(String, String, i64),
HIncrByFloat(String, String, f64),
HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count HScan(String, u64, Option<String>, Option<u64>), // key, cursor, pattern, count
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Ttl(String), Ttl(String),
Expire(String, i64),
PExpire(String, i64),
ExpireAt(String, i64),
PExpireAt(String, i64),
Persist(String),
Exists(String), Exists(String),
ExistsMulti(Vec<String>),
DelMulti(Vec<String>),
Quit, Quit,
Client(Vec<String>), Client(Vec<String>),
ClientSetName(String), ClientSetName(String),
ClientGetName, ClientGetName,
Command(Vec<String>),
// List commands // List commands
LPush(String, Vec<String>), LPush(String, Vec<String>),
RPush(String, Vec<String>), RPush(String, Vec<String>),
LPop(String, Option<u64>), LPop(String, Option<u64>),
RPop(String, Option<u64>), RPop(String, Option<u64>),
BLPop(Vec<String>, f64),
BRPop(Vec<String>, f64),
LLen(String), LLen(String),
LRem(String, i64, String), LRem(String, i64, String),
LTrim(String, i64, i64), LTrim(String, i64, i64),
@ -90,14 +109,51 @@ impl Cmd {
"ping" => Cmd::Ping, "ping" => Cmd::Ping,
"get" => Cmd::Get(cmd[1].clone()), "get" => Cmd::Get(cmd[1].clone()),
"set" => { "set" => {
if cmd.len() == 5 && cmd[3].to_lowercase() == "px" { if cmd.len() < 3 {
Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) return Err(DBError("wrong number of arguments for SET".to_string()));
} else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" { }
Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap()) let key = cmd[1].clone();
} else if cmd.len() == 3 { let val = cmd[2].clone();
Cmd::Set(cmd[1].clone(), cmd[2].clone())
// Parse optional flags: EX sec | PX ms | NX | XX | GET
let mut ex_ms: Option<u128> = None;
let mut nx = false;
let mut xx = false;
let mut getflag = false;
let mut i = 3;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"ex" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
let secs: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
ex_ms = Some(secs * 1000);
i += 2;
}
"px" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
let ms: u128 = cmd[i + 1].parse().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
ex_ms = Some(ms);
i += 2;
}
"nx" => { nx = true; i += 1; }
"xx" => { xx = true; i += 1; }
"get" => { getflag = true; i += 1; }
_ => {
return Err(DBError(format!("unsupported cmd {:?}", cmd)));
}
}
}
// If no options, keep legacy behavior
if ex_ms.is_none() && !nx && !xx && !getflag {
Cmd::Set(key, val)
} else { } else {
return Err(DBError(format!("unsupported cmd {:?}", cmd))); Cmd::SetOpts(key, val, ex_ms, nx, xx, getflag)
} }
} }
"setex" => { "setex" => {
@ -106,6 +162,24 @@ impl Cmd {
} }
Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap()) Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap())
} }
"mget" => {
if cmd.len() < 2 {
return Err(DBError("wrong number of arguments for MGET command".to_string()));
}
Cmd::MGet(cmd[1..].to_vec())
}
"mset" => {
if cmd.len() < 3 || ((cmd.len() - 1) % 2 != 0) {
return Err(DBError("wrong number of arguments for MSET command".to_string()));
}
let mut pairs = Vec::new();
let mut i = 1;
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::MSet(pairs)
}
"config" => { "config" => {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" { if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd))); return Err(DBError(format!("unsupported cmd {:?}", cmd)));
@ -120,6 +194,12 @@ impl Cmd {
Cmd::Keys Cmd::Keys
} }
} }
"dbsize" => {
if cmd.len() != 1 {
return Err(DBError(format!("wrong number of arguments for DBSIZE command")));
}
Cmd::DbSize
}
"info" => { "info" => {
let section = if cmd.len() == 2 { let section = if cmd.len() == 2 {
Some(cmd[1].clone()) Some(cmd[1].clone())
@ -129,10 +209,14 @@ impl Cmd {
Cmd::Info(section) Cmd::Info(section)
} }
"del" => { "del" => {
if cmd.len() != 2 { if cmd.len() < 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd))); return Err(DBError(format!("wrong number of arguments for DEL command")));
}
if cmd.len() == 2 {
Cmd::Del(cmd[1].clone())
} else {
Cmd::DelMulti(cmd[1..].to_vec())
} }
Cmd::Del(cmd[1].clone())
} }
"type" => { "type" => {
if cmd.len() != 2 { if cmd.len() != 2 {
@ -226,6 +310,20 @@ impl Cmd {
} }
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
} }
"hincrby" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HINCRBY command")));
}
let delta = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::HIncrBy(cmd[1].clone(), cmd[2].clone(), delta)
}
"hincrbyfloat" => {
if cmd.len() != 4 {
return Err(DBError(format!("wrong number of arguments for HINCRBYFLOAT command")));
}
let delta = cmd[3].parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?;
Cmd::HIncrByFloat(cmd[1].clone(), cmd[2].clone(), delta)
}
"hscan" => { "hscan" => {
if cmd.len() < 3 { if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for HSCAN command"))); return Err(DBError(format!("wrong number of arguments for HSCAN command")));
@ -307,11 +405,49 @@ impl Cmd {
} }
Cmd::Ttl(cmd[1].clone()) Cmd::Ttl(cmd[1].clone())
} }
"exists" => { "expire" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for EXPIRE command".to_string()));
}
let secs = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::Expire(cmd[1].clone(), secs)
}
"pexpire" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for PEXPIRE command".to_string()));
}
let ms = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::PExpire(cmd[1].clone(), ms)
}
"expireat" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for EXPIREAT command".to_string()));
}
let ts = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::ExpireAt(cmd[1].clone(), ts)
}
"pexpireat" => {
if cmd.len() != 3 {
return Err(DBError("wrong number of arguments for PEXPIREAT command".to_string()));
}
let ts_ms = cmd[2].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
Cmd::PExpireAt(cmd[1].clone(), ts_ms)
}
"persist" => {
if cmd.len() != 2 { if cmd.len() != 2 {
return Err(DBError("wrong number of arguments for PERSIST command".to_string()));
}
Cmd::Persist(cmd[1].clone())
}
"exists" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for EXISTS command"))); return Err(DBError(format!("wrong number of arguments for EXISTS command")));
} }
Cmd::Exists(cmd[1].clone()) if cmd.len() == 2 {
Cmd::Exists(cmd[1].clone())
} else {
Cmd::ExistsMulti(cmd[1..].to_vec())
}
} }
"quit" => { "quit" => {
if cmd.len() != 1 { if cmd.len() != 1 {
@ -342,6 +478,10 @@ impl Cmd {
Cmd::Client(vec![]) Cmd::Client(vec![])
} }
} }
"command" => {
let args = if cmd.len() > 1 { cmd[1..].to_vec() } else { vec![] };
Cmd::Command(args)
}
"lpush" => { "lpush" => {
if cmd.len() < 3 { if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for LPUSH command"))); return Err(DBError(format!("wrong number of arguments for LPUSH command")));
@ -376,6 +516,28 @@ impl Cmd {
}; };
Cmd::RPop(cmd[1].clone(), count) Cmd::RPop(cmd[1].clone(), count)
} }
"blpop" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for BLPOP command")));
}
// keys are all but the last argument
let keys = cmd[1..cmd.len()-1].to_vec();
let timeout_f = cmd[cmd.len()-1]
.parse::<f64>()
.map_err(|_| DBError("ERR timeout is not a number".to_string()))?;
Cmd::BLPop(keys, timeout_f)
}
"brpop" => {
if cmd.len() < 3 {
return Err(DBError(format!("wrong number of arguments for BRPOP command")));
}
// keys are all but the last argument
let keys = cmd[1..cmd.len()-1].to_vec();
let timeout_f = cmd[cmd.len()-1]
.parse::<f64>()
.map_err(|_| DBError("ERR timeout is not a number".to_string()))?;
Cmd::BRPop(keys, timeout_f)
}
"llen" => { "llen" => {
if cmd.len() != 2 { if cmd.len() != 2 {
return Err(DBError(format!("wrong number of arguments for LLEN command"))); return Err(DBError(format!("wrong number of arguments for LLEN command")));
@ -488,9 +650,14 @@ impl Cmd {
Cmd::Set(k, v) => set_cmd(server, &k, &v).await, Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::SetOpts(k, v, ex_ms, nx, xx, getflag) => set_with_opts_cmd(server, &k, &v, ex_ms, nx, xx, getflag).await,
Cmd::MGet(keys) => mget_cmd(server, &keys).await,
Cmd::MSet(pairs) => mset_cmd(server, &pairs).await,
Cmd::Del(k) => del_cmd(server, &k).await, Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::DelMulti(keys) => del_multi_cmd(server, &keys).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server), Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await, Cmd::Keys => keys_cmd(server).await,
Cmd::DbSize => dbsize_cmd(server).await,
Cmd::Info(section) => info_cmd(server, &section).await, Cmd::Info(section) => info_cmd(server, &section).await,
Cmd::Type(k) => type_cmd(server, &k).await, Cmd::Type(k) => type_cmd(server, &k).await,
Cmd::Incr(key) => incr_cmd(server, &key).await, Cmd::Incr(key) => incr_cmd(server, &key).await,
@ -518,19 +685,30 @@ impl Cmd {
Cmd::HLen(key) => hlen_cmd(server, &key).await, Cmd::HLen(key) => hlen_cmd(server, &key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await, Cmd::HMGet(key, fields) => hmget_cmd(server, &key, &fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await, Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, &key, &field, &value).await,
Cmd::HIncrBy(key, field, delta) => hincrby_cmd(server, &key, &field, delta).await,
Cmd::HIncrByFloat(key, field, delta) => hincrbyfloat_cmd(server, &key, &field, delta).await,
Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await, Cmd::HScan(key, cursor, pattern, count) => hscan_cmd(server, &key, &cursor, pattern.as_deref(), &count).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await, Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Expire(key, secs) => expire_cmd(server, &key, secs).await,
Cmd::PExpire(key, ms) => pexpire_cmd(server, &key, ms).await,
Cmd::ExpireAt(key, ts_secs) => expireat_cmd(server, &key, ts_secs).await,
Cmd::PExpireAt(key, ts_ms) => pexpireat_cmd(server, &key, ts_ms).await,
Cmd::Persist(key) => persist_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await, Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::ExistsMulti(keys) => exists_multi_cmd(server, &keys).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await, Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
Cmd::ClientGetName => client_getname_cmd(server).await, Cmd::ClientGetName => client_getname_cmd(server).await,
Cmd::Command(args) => command_cmd(&args),
// List commands // List commands
Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await, Cmd::LPush(key, elements) => lpush_cmd(server, &key, &elements).await,
Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await, Cmd::RPush(key, elements) => rpush_cmd(server, &key, &elements).await,
Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await, Cmd::LPop(key, count) => lpop_cmd(server, &key, &count).await,
Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await, Cmd::RPop(key, count) => rpop_cmd(server, &key, &count).await,
Cmd::BLPop(keys, timeout) => blpop_cmd(server, &keys, timeout).await,
Cmd::BRPop(keys, timeout) => brpop_cmd(server, &keys, timeout).await,
Cmd::LLen(key) => llen_cmd(server, &key).await, Cmd::LLen(key) => llen_cmd(server, &key).await,
Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await, Cmd::LRem(key, count, element) => lrem_cmd(server, &key, count, &element).await,
Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await, Cmd::LTrim(key, start, stop) => ltrim_cmd(server, &key, start, stop).await,
@ -661,16 +839,188 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro
} }
} }
// BLPOP implementation
async fn blpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result<Protocol, DBError> {
// Immediate, non-blocking attempt in key order
for k in keys {
let elems = server.current_storage()?.lpop(k, 1)?;
if !elems.is_empty() {
return Ok(Protocol::Array(vec![
Protocol::BulkString(k.clone()),
Protocol::BulkString(elems[0].clone()),
]));
}
}
// If timeout is zero, return immediately with Null
if timeout_secs <= 0.0 {
return Ok(Protocol::Null);
}
// Register waiters for each key
let db_index = server.selected_db;
let mut ids: Vec<u64> = Vec::with_capacity(keys.len());
let mut names: Vec<String> = Vec::with_capacity(keys.len());
let mut rxs: Vec<tokio::sync::oneshot::Receiver<(String, String)>> = Vec::with_capacity(keys.len());
for k in keys {
let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Left).await;
ids.push(id);
names.push(k.clone());
rxs.push(rx);
}
// Wait for the first delivery or timeout
let wait_fut = async move {
let mut futures_vec = rxs;
loop {
if futures_vec.is_empty() {
return None;
}
let (res, idx, remaining) = select_all(futures_vec).await;
match res {
Ok((k, elem)) => {
return Some((k, elem, idx, remaining));
}
Err(_canceled) => {
// That waiter was canceled; continue with the rest
futures_vec = remaining;
continue;
}
}
}
};
match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await {
Ok(Some((k, elem, idx, _remaining))) => {
// Unregister other waiters
for (i, key_name) in names.iter().enumerate() {
if i != idx {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
}
Ok(Protocol::Array(vec![
Protocol::BulkString(k),
Protocol::BulkString(elem),
]))
}
Ok(None) => {
// No futures left; unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
Err(_elapsed) => {
// Timeout: unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
}
}
// BRPOP implementation (mirror of BLPOP, popping from the right)
async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Result<Protocol, DBError> {
// Immediate, non-blocking attempt in key order using RPOP
for k in keys {
let elems = server.current_storage()?.rpop(k, 1)?;
if !elems.is_empty() {
return Ok(Protocol::Array(vec![
Protocol::BulkString(k.clone()),
Protocol::BulkString(elems[0].clone()),
]));
}
}
// If timeout is zero, return immediately with Null
if timeout_secs <= 0.0 {
return Ok(Protocol::Null);
}
// Register waiters for each key (Right side)
let db_index = server.selected_db;
let mut ids: Vec<u64> = Vec::with_capacity(keys.len());
let mut names: Vec<String> = Vec::with_capacity(keys.len());
let mut rxs: Vec<tokio::sync::oneshot::Receiver<(String, String)>> = Vec::with_capacity(keys.len());
for k in keys {
let (id, rx) = server.register_waiter(db_index, k, crate::server::PopSide::Right).await;
ids.push(id);
names.push(k.clone());
rxs.push(rx);
}
// Wait for the first delivery or timeout
let wait_fut = async move {
let mut futures_vec = rxs;
loop {
if futures_vec.is_empty() {
return None;
}
let (res, idx, remaining) = select_all(futures_vec).await;
match res {
Ok((k, elem)) => {
return Some((k, elem, idx, remaining));
}
Err(_canceled) => {
// That waiter was canceled; continue with the rest
futures_vec = remaining;
continue;
}
}
}
};
match timeout(Duration::from_secs_f64(timeout_secs), wait_fut).await {
Ok(Some((k, elem, idx, _remaining))) => {
// Unregister other waiters
for (i, key_name) in names.iter().enumerate() {
if i != idx {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
}
Ok(Protocol::Array(vec![
Protocol::BulkString(k),
Protocol::BulkString(elem),
]))
}
Ok(None) => {
// No futures left; unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
Err(_elapsed) => {
// Timeout: unregister all waiters
for (i, key_name) in names.iter().enumerate() {
server.unregister_waiter(db_index, key_name, ids[i]).await;
}
Ok(Protocol::Null)
}
}
}
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> { async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.lpush(key, elements.to_vec()) { match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())), Ok(len) => {
// Attempt to deliver to any blocked BLPOP waiters
let _ = server.drain_waiters_after_push(key).await;
Ok(Protocol::SimpleString(len.to_string()))
}
Err(e) => Ok(Protocol::err(&e.0)), Err(e) => Ok(Protocol::err(&e.0)),
} }
} }
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> { async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
match server.current_storage()?.rpush(key, elements.to_vec()) { match server.current_storage()?.rpush(key, elements.to_vec()) {
Ok(len) => Ok(Protocol::SimpleString(len.to_string())), Ok(len) => {
// Attempt to deliver to any blocked BLPOP waiters
let _ = server.drain_waiters_after_push(key).await;
Ok(Protocol::SimpleString(len.to_string()))
}
Err(e) => Ok(Protocol::err(&e.0)), Err(e) => Ok(Protocol::err(&e.0)),
} }
} }
@ -736,6 +1086,13 @@ async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
)) ))
} }
async fn dbsize_cmd(server: &Server) -> Result<Protocol, DBError> {
match server.current_storage()?.dbsize() {
Ok(n) => Ok(Protocol::SimpleString(n.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
#[derive(Serialize)] #[derive(Serialize)]
struct ServerInfo { struct ServerInfo {
redis_version: String, redis_version: String,
@ -757,17 +1114,19 @@ async fn info_cmd(server: &Server, section: &Option<String>) -> Result<Protocol,
info_string.push_str(&format!("# Keyspace\n")); info_string.push_str(&format!("# Keyspace\n"));
info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db)); info_string.push_str(&format!("db{}:keys=0,expires=0,avg_ttl=0\n", info.selected_db));
match section { match section {
Some(s) => match s.as_str() { Some(s) => {
"replication" => Ok(Protocol::BulkString( let sl = s.to_lowercase();
"role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string() if sl == "replication" {
)), Ok(Protocol::BulkString(
_ => Err(DBError(format!("unsupported section {:?}", s))), "role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
}, ))
None => { } else {
Ok(Protocol::BulkString(info_string)) // Return general info for unknown sections (e.g., SERVER)
Ok(Protocol::BulkString(info_string))
}
} }
None => Ok(Protocol::BulkString(info_string)),
} }
} }
@ -808,6 +1167,109 @@ async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError>
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
} }
// Advanced SET with options: EX/PX/NX/XX/GET
async fn set_with_opts_cmd(
server: &Server,
key: &str,
value: &str,
ex_ms: Option<u128>,
nx: bool,
xx: bool,
get_old: bool,
) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
// Determine existence (for NX/XX)
let exists = storage.exists(key)?;
// If both NX and XX, condition can never be satisfied -> no-op
let mut should_set = true;
if nx && exists {
should_set = false;
}
if xx && !exists {
should_set = false;
}
// Fetch old value if needed for GET
let old_val = if get_old {
storage.get(key)?
} else {
None
};
if should_set {
if let Some(ms) = ex_ms {
storage.setx(key.to_string(), value.to_string(), ms)?;
} else {
storage.set(key.to_string(), value.to_string())?;
}
}
if get_old {
// Return previous value (or Null), regardless of NX/XX outcome only if set executed?
// We follow Redis semantics: return old value if set executed, else Null
if should_set {
Ok(old_val.map_or(Protocol::Null, Protocol::BulkString))
} else {
Ok(Protocol::Null)
}
} else {
if should_set {
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Ok(Protocol::Null)
}
}
}
// MGET: return array of bulk strings or Null for missing
async fn mget_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let mut out: Vec<Protocol> = Vec::with_capacity(keys.len());
let storage = server.current_storage()?;
for k in keys {
match storage.get(k)? {
Some(v) => out.push(Protocol::BulkString(v)),
None => out.push(Protocol::Null),
}
}
Ok(Protocol::Array(out))
}
// MSET: set multiple key/value pairs, return OK
async fn mset_cmd(server: &Server, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
for (k, v) in pairs {
storage.set(k.clone(), v.clone())?;
}
Ok(Protocol::SimpleString("OK".to_string()))
}
// DEL with multiple keys: return count of keys actually deleted
async fn del_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let mut deleted = 0i64;
for k in keys {
if storage.exists(k)? {
storage.del(k.clone())?;
deleted += 1;
}
}
Ok(Protocol::SimpleString(deleted.to_string()))
}
// EXISTS with multiple keys: return count existing
async fn exists_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let mut count = 0i64;
for k in keys {
if storage.exists(k)? {
count += 1;
}
}
Ok(Protocol::SimpleString(count.to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> { async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.current_storage()?.get(k)?; let v = server.current_storage()?.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString)) Ok(v.map_or(Protocol::Null, Protocol::BulkString))
@ -900,6 +1362,32 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res
} }
} }
async fn hincrby_cmd(server: &Server, key: &str, field: &str, delta: i64) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current = storage.hget(key, field)?;
let base: i64 = match current {
Some(v) => v.parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?,
None => 0,
};
let new_val = base.checked_add(delta).ok_or_else(|| DBError("ERR increment or decrement would overflow".to_string()))?;
// Update the field
storage.hset(key, vec![(field.to_string(), new_val.to_string())])?;
Ok(Protocol::SimpleString(new_val.to_string()))
}
async fn hincrbyfloat_cmd(server: &Server, key: &str, field: &str, delta: f64) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let current = storage.hget(key, field)?;
let base: f64 = match current {
Some(v) => v.parse::<f64>().map_err(|_| DBError("ERR value is not a valid float".to_string()))?,
None => 0.0,
};
let new_val = base + delta;
// Update the field
storage.hset(key, vec![(field.to_string(), new_val.to_string())])?;
Ok(Protocol::SimpleString(new_val.to_string()))
}
async fn scan_cmd( async fn scan_cmd(
server: &Server, server: &Server,
cursor: &u64, cursor: &u64,
@ -957,6 +1445,51 @@ async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
} }
} }
// EXPIRE key seconds -> 1 if timeout set, 0 otherwise
async fn expire_cmd(server: &Server, key: &str, secs: i64) -> Result<Protocol, DBError> {
if secs < 0 {
return Ok(Protocol::SimpleString("0".to_string()));
}
match server.current_storage()?.expire_seconds(key, secs as u64) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// PEXPIRE key milliseconds -> 1 if timeout set, 0 otherwise
async fn pexpire_cmd(server: &Server, key: &str, ms: i64) -> Result<Protocol, DBError> {
if ms < 0 {
return Ok(Protocol::SimpleString("0".to_string()));
}
match server.current_storage()?.pexpire_millis(key, ms as u128) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// PERSIST key -> 1 if timeout removed, 0 otherwise
async fn persist_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
match server.current_storage()?.persist(key) {
Ok(removed) => Ok(Protocol::SimpleString(if removed { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// EXPIREAT key timestamp-seconds -> 1 if timeout set, 0 otherwise
async fn expireat_cmd(server: &Server, key: &str, ts_secs: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.expire_at_seconds(key, ts_secs) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
// PEXPIREAT key timestamp-milliseconds -> 1 if timeout set, 0 otherwise
async fn pexpireat_cmd(server: &Server, key: &str, ts_ms: i64) -> Result<Protocol, DBError> {
match server.current_storage()?.pexpire_at_millis(key, ts_ms) {
Ok(applied) => Ok(Protocol::SimpleString(if applied { "1" } else { "0" }.to_string())),
Err(e) => Ok(Protocol::err(&e.0)),
}
}
async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> { async fn client_setname_cmd(server: &mut Server, name: &str) -> Result<Protocol, DBError> {
server.client_name = Some(name.to_string()); server.client_name = Some(name.to_string());
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
@ -968,3 +1501,19 @@ async fn client_getname_cmd(server: &Server) -> Result<Protocol, DBError> {
None => Ok(Protocol::Null), None => Ok(Protocol::Null),
} }
} }
// Minimal COMMAND subcommands stub to satisfy redis-cli probes.
// - COMMAND DOCS ... => return empty array
// - COMMAND INFO ... => return empty array
// - Any other => empty array
fn command_cmd(args: &[String]) -> Result<Protocol, DBError> {
if args.is_empty() {
return Ok(Protocol::Array(vec![]));
}
let sub = args[0].to_lowercase();
match sub.as_str() {
"docs" => Ok(Protocol::Array(vec![])),
"info" => Ok(Protocol::Array(vec![])),
_ => Ok(Protocol::Array(vec![])),
}
}

View File

@ -19,6 +19,10 @@ impl fmt::Display for Protocol {
impl Protocol { impl Protocol {
pub fn from(protocol: &str) -> Result<(Self, &str), DBError> { pub fn from(protocol: &str) -> Result<(Self, &str), DBError> {
if protocol.is_empty() {
// Incomplete frame; caller should read more bytes
return Err(DBError("[incomplete] empty".to_string()));
}
let ret = match protocol.chars().nth(0) { let ret = match protocol.chars().nth(0) {
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]), Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]), Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
@ -101,21 +105,20 @@ impl Protocol {
let size = Self::parse_usize(&protocol[..len_end])?; let size = Self::parse_usize(&protocol[..len_end])?;
let data_start = len_end + 2; let data_start = len_end + 2;
let data_end = data_start + size; let data_end = data_start + size;
let s = Self::parse_string(&protocol[data_start..data_end])?;
if protocol.len() < data_end + 2 || &protocol[data_end..data_end+2] != "\r\n" { // If we don't yet have the full bulk payload + trailing CRLF, signal INCOMPLETE
Err(DBError(format!( if protocol.len() < data_end + 2 {
"[new bulk string] unmatched string length in prototocl {:?}", return Err(DBError("[incomplete] bulk body".to_string()));
protocol,
)))
} else {
Ok((Protocol::BulkString(s), &protocol[data_end + 2..]))
} }
if &protocol[data_end..data_end + 2] != "\r\n" {
return Err(DBError("[incomplete] bulk terminator".to_string()));
}
let s = Self::parse_string(&protocol[data_start..data_end])?;
Ok((Protocol::BulkString(s), &protocol[data_end + 2..]))
} else { } else {
Err(DBError(format!( // No CRLF after bulk length header yet
"[new bulk string] unsupported protocol: {:?}", Err(DBError("[incomplete] bulk header".to_string()))
protocol
)))
} }
} }
@ -125,16 +128,25 @@ impl Protocol {
let mut remaining = &s[len_end + 2..]; let mut remaining = &s[len_end + 2..];
let mut vec = vec![]; let mut vec = vec![];
for _ in 0..array_len { for _ in 0..array_len {
let (p, rem) = Protocol::from(remaining)?; match Protocol::from(remaining) {
vec.push(p); Ok((p, rem)) => {
remaining = rem; vec.push(p);
remaining = rem;
}
Err(e) => {
// Propagate incomplete so caller can read more bytes
if e.0.starts_with("[incomplete]") {
return Err(e);
} else {
return Err(e);
}
}
}
} }
Ok((Protocol::Array(vec), remaining)) Ok((Protocol::Array(vec), remaining))
} else { } else {
Err(DBError(format!( // No CRLF after array header yet
"[new array] unsupported protocol: {:?}", Err(DBError("[incomplete] array header".to_string()))
s
)))
} }
} }

View File

@ -3,6 +3,9 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::sync::{Mutex, oneshot};
use std::sync::atomic::{AtomicU64, Ordering};
use crate::cmd::Cmd; use crate::cmd::Cmd;
use crate::error::DBError; use crate::error::DBError;
@ -17,6 +20,22 @@ pub struct Server {
pub client_name: Option<String>, pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64 pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>, pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
pub waiter_seq: Arc<AtomicU64>,
}
pub struct Waiter {
pub id: u64,
pub side: PopSide,
pub tx: oneshot::Sender<(String, String)>, // (key, element)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PopSide {
Left,
Right,
} }
impl Server { impl Server {
@ -27,6 +46,9 @@ impl Server {
client_name: None, client_name: None,
selected_db: 0, selected_db: 0,
queued_cmd: None, queued_cmd: None,
list_waiters: Arc::new(Mutex::new(HashMap::new())),
waiter_seq: Arc::new(AtomicU64::new(1)),
} }
} }
@ -66,35 +88,122 @@ impl Server {
self.option.encrypt && db_index >= 10 self.option.encrypt && db_index >= 10
} }
// ----- BLPOP waiter helpers -----
pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) {
let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel::<(String, String)>();
let mut guard = self.list_waiters.lock().await;
let per_db = guard.entry(db_index).or_insert_with(HashMap::new);
let q = per_db.entry(key.to_string()).or_insert_with(Vec::new);
q.push(Waiter { id, side, tx });
(id, rx)
}
pub async fn unregister_waiter(&self, db_index: u64, key: &str, id: u64) {
let mut guard = self.list_waiters.lock().await;
if let Some(per_db) = guard.get_mut(&db_index) {
if let Some(q) = per_db.get_mut(key) {
q.retain(|w| w.id != id);
if q.is_empty() {
per_db.remove(key);
}
}
if per_db.is_empty() {
guard.remove(&db_index);
}
}
}
// Called after LPUSH/RPUSH to deliver to blocked BLPOP waiters.
pub async fn drain_waiters_after_push(&self, key: &str) -> Result<(), DBError> {
let db_index = self.selected_db;
loop {
// Check if any waiter exists
let maybe_waiter = {
let mut guard = self.list_waiters.lock().await;
if let Some(per_db) = guard.get_mut(&db_index) {
if let Some(q) = per_db.get_mut(key) {
if !q.is_empty() {
// Pop FIFO
Some(q.remove(0))
} else {
None
}
} else {
None
}
} else {
None
}
};
let waiter = if let Some(w) = maybe_waiter { w } else { break };
// Pop one element depending on waiter side
let elems = match waiter.side {
PopSide::Left => self.current_storage()?.lpop(key, 1)?,
PopSide::Right => self.current_storage()?.rpop(key, 1)?,
};
if elems.is_empty() {
// Nothing to deliver; re-register waiter at the front to preserve order
let mut guard = self.list_waiters.lock().await;
let per_db = guard.entry(db_index).or_insert_with(HashMap::new);
let q = per_db.entry(key.to_string()).or_insert_with(Vec::new);
q.insert(0, waiter);
break;
} else {
let elem = elems[0].clone();
// Send to waiter; if receiver dropped, just continue
let _ = waiter.tx.send((key.to_string(), elem));
// Loop to try to satisfy more waiters if more elements remain
continue;
}
}
Ok(())
}
pub async fn handle( pub async fn handle(
&mut self, &mut self,
mut stream: tokio::net::TcpStream, mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> { ) -> Result<(), DBError> {
let mut buf = [0; 512]; // Accumulate incoming bytes to handle partial RESP frames
let mut acc = String::new();
let mut buf = vec![0u8; 8192];
loop { loop {
let len = match stream.read(&mut buf).await { let n = match stream.read(&mut buf).await {
Ok(0) => { Ok(0) => {
println!("[handle] connection closed"); println!("[handle] connection closed");
return Ok(()); return Ok(());
} }
Ok(len) => len, Ok(n) => n,
Err(e) => { Err(e) => {
println!("[handle] read error: {:?}", e); println!("[handle] read error: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
}; };
let mut s = str::from_utf8(&buf[..len])?; // Append to accumulator. RESP for our usage is ASCII-safe.
while !s.is_empty() { acc.push_str(str::from_utf8(&buf[..n])?);
let (cmd, protocol, remaining) = match Cmd::from(s) {
// Try to parse as many complete commands as are available in 'acc'.
loop {
let parsed = Cmd::from(&acc);
let (cmd, protocol, remaining) = match parsed {
Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining),
Err(e) => { Err(_e) => {
println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); // Incomplete or invalid frame; assume incomplete and wait for more data.
(Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") // This avoids emitting spurious protocol_error for split frames.
break;
} }
}; };
s = remaining;
// Advance the accumulator to the unparsed remainder
acc = remaining.to_string();
if self.option.debug { if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
@ -130,6 +239,11 @@ impl Server {
println!("[handle] QUIT command received, closing connection"); println!("[handle] QUIT command received, closing connection");
return Ok(()); return Ok(());
} }
// Continue parsing any further complete commands already in 'acc'
if acc.is_empty() {
break;
}
} }
} }
} }

View File

@ -216,3 +216,30 @@ impl Storage {
Ok(keys) Ok(keys)
} }
} }
impl Storage {
pub fn dbsize(&self) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
let mut count: i64 = 0;
let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let key = entry.0.value();
let ty = entry.1.value();
if ty == "string" {
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
// Skip logically expired string keys
continue;
}
}
}
count += 1;
}
Ok(count)
}
}

View File

@ -98,6 +98,116 @@ impl Storage {
None => Ok(false), // Key does not exist None => Ok(false), // Key does not exist
} }
} }
// -------- Expiration helpers (string keys only, consistent with TTL/EXISTS) --------
// Set expiry in seconds; returns true if applied (key exists and is string), false otherwise
pub fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
// Determine eligibility first to avoid holding borrows across commit
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = now_in_millis() + (secs as u128) * 1000;
expiration_table.insert(key, &(expires_at as u64))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
// Set expiry in milliseconds; returns true if applied (key exists and is string), false otherwise
pub fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = now_in_millis() + ms;
expiration_table.insert(key, &(expires_at as u64))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
// Remove expiry if present; returns true if removed, false otherwise
pub fn persist(&self, key: &str) -> Result<bool, DBError> {
let mut removed = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
if expiration_table.remove(key)?.is_some() {
removed = true;
}
}
}
write_txn.commit()?;
Ok(removed)
}
// Absolute EXPIREAT in seconds since epoch
// Returns true if applied (key exists and is string), false otherwise
pub fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
expiration_table.insert(key, &((expires_at_ms as u64)))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
// Absolute PEXPIREAT in milliseconds since epoch
// Returns true if applied (key exists and is string), false otherwise
pub fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
expiration_table.insert(key, &((expires_at_ms as u64)))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
} }
// Utility function for glob pattern matching // Utility function for glob pattern matching

View File

@ -148,8 +148,6 @@ impl Storage {
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> { pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -168,8 +166,6 @@ impl Storage {
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> { pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -200,8 +196,6 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> { pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -233,8 +227,6 @@ impl Storage {
pub fn hlen(&self, key: &str) -> Result<i64, DBError> { pub fn hlen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -265,8 +257,6 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> { pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -334,8 +324,6 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;

View File

@ -16,9 +16,9 @@ fn get_redis_connection(port: u16) -> Connection {
} }
} }
Err(e) => { Err(e) => {
if attempts >= 20 { if attempts >= 120 {
panic!( panic!(
"Failed to connect to Redis server after 20 attempts: {}", "Failed to connect to Redis server after 120 attempts: {}",
e e
); );
} }
@ -88,8 +88,8 @@ fn setup_server() -> (ServerProcessGuard, u16) {
test_dir, test_dir,
}; };
// Give the server a moment to start // Give the server time to build and start (cargo run may compile first)
std::thread::sleep(Duration::from_millis(500)); std::thread::sleep(Duration::from_millis(2500));
(guard, port) (guard, port)
} }

892
herodb/tests/usage_suite.rs Normal file
View File

@ -0,0 +1,892 @@
use herodb::{options::DBOption, server::Server};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::{sleep, Duration};
// =========================
// Helpers
// =========================
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17100);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_usage_suite_{}", test_name);
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
async fn spawn_listener(server: Server, port: u16) {
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("bind listener");
loop {
match listener.accept().await {
Ok((stream, _)) => {
let mut s_clone = server.clone();
tokio::spawn(async move {
let _ = s_clone.handle(stream).await;
});
}
Err(_e) => break,
}
}
});
}
/// Build RESP array for args ["PING"] -> "*1\r\n$4\r\nPING\r\n"
fn build_resp(args: &[&str]) -> String {
let mut s = format!("*{}\r\n", args.len());
for a in args {
s.push_str(&format!("${}\r\n{}\r\n", a.len(), a));
}
s
}
async fn connect(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(s) => return s,
Err(_) if attempts < 30 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect: {}", e),
}
}
}
fn find_crlf(buf: &[u8], start: usize) -> Option<usize> {
let mut i = start;
while i + 1 < buf.len() {
if buf[i] == b'\r' && buf[i + 1] == b'\n' {
return Some(i);
}
i += 1;
}
None
}
fn parse_number_i64(buf: &[u8], start: usize, end: usize) -> Option<i64> {
let s = std::str::from_utf8(&buf[start..end]).ok()?;
s.parse::<i64>().ok()
}
// Return number of bytes that make up a complete RESP element starting at 'i', or None if incomplete.
fn parse_elem(buf: &[u8], i: usize) -> Option<usize> {
if i >= buf.len() {
return None;
}
match buf[i] {
b'+' | b'-' | b':' => {
let end = find_crlf(buf, i + 1)?;
Some(end + 2 - i)
}
b'$' => {
let hdr_end = find_crlf(buf, i + 1)?;
let n = parse_number_i64(buf, i + 1, hdr_end)?;
if n < 0 {
// Null bulk string: only header
Some(hdr_end + 2 - i)
} else {
let need = hdr_end + 2 + (n as usize) + 2;
if need <= buf.len() {
Some(need - i)
} else {
None
}
}
}
b'*' => {
let hdr_end = find_crlf(buf, i + 1)?;
let n = parse_number_i64(buf, i + 1, hdr_end)?;
if n < 0 {
// Null array: only header
Some(hdr_end + 2 - i)
} else {
let mut j = hdr_end + 2;
for _ in 0..(n as usize) {
let consumed = parse_elem(buf, j)?;
j += consumed;
}
Some(j - i)
}
}
_ => None,
}
}
fn resp_frame_len(buf: &[u8]) -> Option<usize> {
parse_elem(buf, 0)
}
async fn read_full_resp(stream: &mut TcpStream) -> String {
let mut buf: Vec<u8> = Vec::with_capacity(8192);
let mut tmp = vec![0u8; 4096];
loop {
if let Some(total) = resp_frame_len(&buf) {
if buf.len() >= total {
return String::from_utf8_lossy(&buf[..total]).to_string();
}
}
match tokio::time::timeout(Duration::from_secs(2), stream.read(&mut tmp)).await {
Ok(Ok(n)) => {
if n == 0 {
if let Some(total) = resp_frame_len(&buf) {
if buf.len() >= total {
return String::from_utf8_lossy(&buf[..total]).to_string();
}
}
return String::from_utf8_lossy(&buf).to_string();
}
buf.extend_from_slice(&tmp[..n]);
}
Ok(Err(e)) => panic!("read error: {}", e),
Err(_) => panic!("timeout waiting for reply"),
}
if buf.len() > 8 * 1024 * 1024 {
panic!("reply too large");
}
}
}
async fn send_cmd(stream: &mut TcpStream, args: &[&str]) -> String {
let req = build_resp(args);
stream.write_all(req.as_bytes()).await.unwrap();
read_full_resp(stream).await
}
// Assert helpers with clearer output
fn assert_contains(haystack: &str, needle: &str, ctx: &str) {
assert!(
haystack.contains(needle),
"ASSERT CONTAINS failed: '{}' not found in response.\nContext: {}\nResponse:\n{}",
needle,
ctx,
haystack
);
}
fn assert_eq_resp(actual: &str, expected: &str, ctx: &str) {
assert!(
actual == expected,
"ASSERT EQUAL failed.\nContext: {}\nExpected:\n{:?}\nActual:\n{:?}",
ctx,
expected,
actual
);
}
/// Extract the payload of a single RESP Bulk String reply.
/// Example input:
/// "$5\r\nhello\r\n" -> Some("hello".to_string())
fn extract_bulk_payload(resp: &str) -> Option<String> {
// find first CRLF after "$len"
let first = resp.find("\r\n")?;
let after = &resp[(first + 2)..];
// find next CRLF ending payload
let second = after.find("\r\n")?;
Some(after[..second].to_string())
}
// =========================
// Test suites
// =========================
#[tokio::test]
async fn test_01_connection_and_info() {
let (server, port) = start_test_server("conn_info").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// redis-cli may send COMMAND DOCS, our server replies empty array; harmless.
let pong = send_cmd(&mut s, &["PING"]).await;
assert_contains(&pong, "PONG", "PING should return PONG");
let echo = send_cmd(&mut s, &["ECHO", "hello"]).await;
assert_contains(&echo, "hello", "ECHO hello");
// INFO (general)
let info = send_cmd(&mut s, &["INFO"]).await;
assert_contains(&info, "redis_version", "INFO should include redis_version");
// INFO REPLICATION (static stub)
let repl = send_cmd(&mut s, &["INFO", "replication"]).await;
assert_contains(&repl, "role:master", "INFO replication role");
// CONFIG GET subset
let cfg = send_cmd(&mut s, &["CONFIG", "GET", "databases"]).await;
assert_contains(&cfg, "databases", "CONFIG GET databases");
assert_contains(&cfg, "16", "CONFIG GET databases value");
// CLIENT name
let setname = send_cmd(&mut s, &["CLIENT", "SETNAME", "myapp"]).await;
assert_contains(&setname, "OK", "CLIENT SETNAME");
let getname = send_cmd(&mut s, &["CLIENT", "GETNAME"]).await;
assert_contains(&getname, "myapp", "CLIENT GETNAME");
// SELECT db
let sel = send_cmd(&mut s, &["SELECT", "0"]).await;
assert_contains(&sel, "OK", "SELECT 0");
// QUIT should close connection after sending OK
let quit = send_cmd(&mut s, &["QUIT"]).await;
assert_contains(&quit, "OK", "QUIT should return OK");
}
#[tokio::test]
async fn test_02_strings_and_expiry() {
let (server, port) = start_test_server("strings").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// SET / GET
let set = send_cmd(&mut s, &["SET", "user:1", "alice"]).await;
assert_contains(&set, "OK", "SET user:1 alice");
let get = send_cmd(&mut s, &["GET", "user:1"]).await;
assert_contains(&get, "alice", "GET user:1");
// EXISTS / DEL
let ex1 = send_cmd(&mut s, &["EXISTS", "user:1"]).await;
assert_contains(&ex1, "1", "EXISTS user:1");
let del = send_cmd(&mut s, &["DEL", "user:1"]).await;
assert_contains(&del, "1", "DEL user:1");
let ex0 = send_cmd(&mut s, &["EXISTS", "user:1"]).await;
assert_contains(&ex0, "0", "EXISTS after DEL");
// INCR behavior
let i1 = send_cmd(&mut s, &["INCR", "count"]).await;
assert_contains(&i1, "1", "INCR new key -> 1");
let i2 = send_cmd(&mut s, &["INCR", "count"]).await;
assert_contains(&i2, "2", "INCR existing -> 2");
let _ = send_cmd(&mut s, &["SET", "notnum", "abc"]).await;
let ierr = send_cmd(&mut s, &["INCR", "notnum"]).await;
assert_contains(&ierr, "ERR", "INCR on non-numeric should ERR");
// Expiration via SET EX
let setex = send_cmd(&mut s, &["SET", "tmp:1", "boom", "EX", "1"]).await;
assert_contains(&setex, "OK", "SET tmp:1 EX 1");
let g_immediate = send_cmd(&mut s, &["GET", "tmp:1"]).await;
assert_contains(&g_immediate, "boom", "GET tmp:1 immediately");
let ttl = send_cmd(&mut s, &["TTL", "tmp:1"]).await;
// Implementation returns a SimpleString, accept any numeric content
assert!(
ttl.contains("1") || ttl.contains("0"),
"TTL should be 1 or 0, got: {}",
ttl
);
sleep(Duration::from_millis(1100)).await;
let g_after = send_cmd(&mut s, &["GET", "tmp:1"]).await;
assert_contains(&g_after, "$-1", "GET tmp:1 after expiry -> Null");
// TYPE
let _ = send_cmd(&mut s, &["SET", "t", "v"]).await;
let ty = send_cmd(&mut s, &["TYPE", "t"]).await;
assert_contains(&ty, "string", "TYPE string key");
let ty_none = send_cmd(&mut s, &["TYPE", "noexist"]).await;
assert_contains(&ty_none, "none", "TYPE nonexistent");
}
#[tokio::test]
async fn test_03_scan_and_keys() {
let (server, port) = start_test_server("scan").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
for i in 0..5 {
let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await;
}
let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await;
assert_contains(&scan, "key0", "SCAN should return keys with MATCH");
assert_contains(&scan, "key4", "SCAN should return last key");
let keys = send_cmd(&mut s, &["KEYS", "*"]).await;
assert_contains(&keys, "key0", "KEYS * includes key0");
assert_contains(&keys, "key4", "KEYS * includes key4");
}
#[tokio::test]
async fn test_04_hashes_suite() {
let (server, port) = start_test_server("hashes").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// HSET (single, returns number of new fields)
let h1 = send_cmd(&mut s, &["HSET", "profile:1", "name", "alice"]).await;
assert_contains(&h1, "1", "HSET new field -> 1");
// HGET
let hg = send_cmd(&mut s, &["HGET", "profile:1", "name"]).await;
assert_contains(&hg, "alice", "HGET existing field");
// HSET multiple
let h2 = send_cmd(&mut s, &["HSET", "profile:1", "age", "30", "city", "paris"]).await;
assert_contains(&h2, "2", "HSET added 2 new fields");
// HMGET
let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await;
assert_contains(&hmg, "alice", "HMGET name");
assert_contains(&hmg, "30", "HMGET age");
assert_contains(&hmg, "paris", "HMGET city");
assert_contains(&hmg, "$-1", "HMGET non-existent -> Null");
// HGETALL
let hga = send_cmd(&mut s, &["HGETALL", "profile:1"]).await;
assert_contains(&hga, "name", "HGETALL contains name");
assert_contains(&hga, "alice", "HGETALL contains alice");
// HLEN
let hlen = send_cmd(&mut s, &["HLEN", "profile:1"]).await;
assert_contains(&hlen, "3", "HLEN is 3");
// HEXISTS
let hex1 = send_cmd(&mut s, &["HEXISTS", "profile:1", "age"]).await;
assert_contains(&hex1, "1", "HEXISTS age true");
let hex0 = send_cmd(&mut s, &["HEXISTS", "profile:1", "nope"]).await;
assert_contains(&hex0, "0", "HEXISTS nope false");
// HKEYS / HVALS
let hkeys = send_cmd(&mut s, &["HKEYS", "profile:1"]).await;
assert_contains(&hkeys, "name", "HKEYS includes name");
let hvals = send_cmd(&mut s, &["HVALS", "profile:1"]).await;
assert_contains(&hvals, "alice", "HVALS includes alice");
// HSETNX
let hnx0 = send_cmd(&mut s, &["HSETNX", "profile:1", "name", "bob"]).await;
assert_contains(&hnx0, "0", "HSETNX existing field -> 0");
let hnx1 = send_cmd(&mut s, &["HSETNX", "profile:1", "nickname", "ali"]).await;
assert_contains(&hnx1, "1", "HSETNX new field -> 1");
// HSCAN
let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await;
assert_contains(&hscan, "name", "HSCAN matches fields starting with n");
assert_contains(&hscan, "nickname", "HSCAN nickname present");
// HDEL
let hdel = send_cmd(&mut s, &["HDEL", "profile:1", "city", "age"]).await;
assert_contains(&hdel, "2", "HDEL removed two fields");
}
#[tokio::test]
async fn test_05_lists_suite_including_blpop() {
let (server, port) = start_test_server("lists").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut a = connect(port).await;
// LPUSH / RPUSH / LLEN
let lp = send_cmd(&mut a, &["LPUSH", "q:jobs", "a", "b"]).await;
assert_contains(&lp, "2", "LPUSH added 2, length 2");
let rp = send_cmd(&mut a, &["RPUSH", "q:jobs", "c"]).await;
assert_contains(&rp, "3", "RPUSH now length 3");
let llen = send_cmd(&mut a, &["LLEN", "q:jobs"]).await;
assert_contains(&llen, "3", "LLEN 3");
// LINDEX / LRANGE
let lidx = send_cmd(&mut a, &["LINDEX", "q:jobs", "0"]).await;
assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b");
let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]");
// LTRIM
let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await;
assert_contains(&ltrim, "OK", "LTRIM OK");
let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]");
// LREM remove first occurrence of b
let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await;
assert_contains(&lrem, "1", "LREM removed 1");
// LPOP and RPOP
let lpop1 = send_cmd(&mut a, &["LPOP", "q:jobs"]).await;
assert_contains(&lpop1, "$1\r\na\r\n", "LPOP returns a");
let rpop_empty = send_cmd(&mut a, &["RPOP", "q:jobs"]).await; // empty now
assert_contains(&rpop_empty, "$-1", "RPOP on empty -> Null");
// LPOP with count on empty -> []
let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await;
assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array");
// BLPOP: block on one client, push from another
let c1 = connect(port).await;
let mut c2 = connect(port).await;
// Start BLPOP on c1
let blpop_task = tokio::spawn(async move {
let mut c1_local = c1;
send_cmd(&mut c1_local, &["BLPOP", "q:block", "5"]).await
});
// Give it time to register waiter
sleep(Duration::from_millis(150)).await;
// Push from c2 to wake BLPOP
let _ = send_cmd(&mut c2, &["LPUSH", "q:block", "x"]).await;
// Await BLPOP result
let blpop_res = blpop_task.await.expect("BLPOP task join");
assert_contains(&blpop_res, "q:block", "BLPOP returned key");
assert_contains(&blpop_res, "x", "BLPOP returned element");
}
#[tokio::test]
async fn test_06_flushdb_suite() {
let (server, port) = start_test_server("flushdb").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
let _ = send_cmd(&mut s, &["SET", "k1", "v1"]).await;
let _ = send_cmd(&mut s, &["HSET", "h1", "f", "v"]).await;
let _ = send_cmd(&mut s, &["LPUSH", "l1", "a"]).await;
let keys_before = send_cmd(&mut s, &["KEYS", "*"]).await;
assert_contains(&keys_before, "k1", "have string key before FLUSHDB");
assert_contains(&keys_before, "h1", "have hash key before FLUSHDB");
assert_contains(&keys_before, "l1", "have list key before FLUSHDB");
let fl = send_cmd(&mut s, &["FLUSHDB"]).await;
assert_contains(&fl, "OK", "FLUSHDB OK");
let keys_after = send_cmd(&mut s, &["KEYS", "*"]).await;
assert_eq_resp(&keys_after, "*0\r\n", "DB should be empty after FLUSHDB");
}
#[tokio::test]
async fn test_07_age_stateless_suite() {
let (server, port) = start_test_server("age_stateless").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// GENENC -> [recipient, identity]
let gen = send_cmd(&mut s, &["AGE", "GENENC"]).await;
assert!(
gen.starts_with("*2\r\n$"),
"AGE GENENC should return array [recipient, identity], got:\n{}",
gen
);
// Parse simple RESP array of two bulk strings to extract keys
fn parse_two_bulk_array(resp: &str) -> (String, String) {
// naive parse for tests
let mut lines = resp.lines();
let _ = lines.next(); // *2
// $len
let _ = lines.next();
let recip = lines.next().unwrap_or("").to_string();
let _ = lines.next();
let ident = lines.next().unwrap_or("").to_string();
(recip, ident)
}
let (recipient, identity) = parse_two_bulk_array(&gen);
assert!(
recipient.starts_with("age1") && identity.starts_with("AGE-SECRET-KEY-1"),
"Unexpected AGE key formats.\nrecipient: {}\nidentity: {}",
recipient,
identity
);
// ENCRYPT / DECRYPT
let ct = send_cmd(&mut s, &["AGE", "ENCRYPT", &recipient, "hello world"]).await;
let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPT");
let pt = send_cmd(&mut s, &["AGE", "DECRYPT", &identity, &ct_b64]).await;
assert_contains(&pt, "hello world", "AGE DECRYPT round-trip");
// GENSIGN -> [verify_pub_b64, sign_secret_b64]
let gensign = send_cmd(&mut s, &["AGE", "GENSIGN"]).await;
let (verify_pub, sign_secret) = parse_two_bulk_array(&gensign);
assert!(
!verify_pub.is_empty() && !sign_secret.is_empty(),
"GENSIGN returned empty keys"
);
// SIGN / VERIFY
let sig = send_cmd(&mut s, &["AGE", "SIGN", &sign_secret, "msg"]).await;
let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGN");
let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await;
assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature");
let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await;
assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature");
}
#[tokio::test]
async fn test_08_age_persistent_named_suite() {
let (server, port) = start_test_server("age_persistent").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// KEYGEN + ENCRYPTNAME/DECRYPTNAME
let kg = send_cmd(&mut s, &["AGE", "KEYGEN", "app1"]).await;
assert!(
kg.starts_with("*2\r\n"),
"AGE KEYGEN should return [recipient, identity], got:\n{}",
kg
);
let ct = send_cmd(&mut s, &["AGE", "ENCRYPTNAME", "app1", "hello"]).await;
let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPTNAME");
let pt = send_cmd(&mut s, &["AGE", "DECRYPTNAME", "app1", &ct_b64]).await;
assert_contains(&pt, "hello", "DECRYPTNAME round-trip");
// SIGNKEYGEN + SIGNNAME/VERIFYNAME
let skg = send_cmd(&mut s, &["AGE", "SIGNKEYGEN", "app1"]).await;
assert!(
skg.starts_with("*2\r\n"),
"AGE SIGNKEYGEN should return [verify_pub, sign_secret], got:\n{}",
skg
);
let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await;
let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME");
let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await;
assert_contains(&v1, "1", "VERIFYNAME valid => 1");
let v0 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "bad", &sig_b64]).await;
assert_contains(&v0, "0", "VERIFYNAME invalid => 0");
// AGE LIST
let lst = send_cmd(&mut s, &["AGE", "LIST"]).await;
assert_contains(&lst, "encpub", "AGE LIST label encpub");
assert_contains(&lst, "app1", "AGE LIST includes app1");
}
#[tokio::test]
async fn test_10_expire_pexpire_persist() {
let (server, port) = start_test_server("expire_suite").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// EXPIRE: seconds
let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await;
let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await;
assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert!(
ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:s should be 1 or 0, got: {}",
ttl1
);
sleep(Duration::from_millis(1100)).await;
let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await;
assert_contains(&get_after, "$-1", "GET after expiry should be Null");
let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert_contains(&ttl_after, "-2", "TTL after expiry -> -2");
let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await;
assert_contains(&exists_after, "0", "EXISTS after expiry -> 0");
// PEXPIRE: milliseconds
let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await;
let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await;
assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)");
let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await;
assert!(
ttl_ms1.contains("1") || ttl_ms1.contains("0"),
"TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}",
ttl_ms1
);
sleep(Duration::from_millis(1600)).await;
let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await;
assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0");
// PERSIST: remove expiration
let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await;
let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await;
let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert!(
ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"),
"TTL exp:persist should be >=0 before persist, got: {}",
ttl_pre
);
let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist1, "1", "PERSIST should remove expiration");
let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
// Second persist should return 0 (nothing to remove)
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)");
}
#[tokio::test]
async fn test_11_set_with_options() {
let (server, port) = start_test_server("set_opts").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// SET with GET on non-existing key -> returns Null, sets value
let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await;
assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist");
let g1 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g1, "v1", "GET s1 after first SET");
// SET with GET should return old value, then set to new
let set_get2 = send_cmd(&mut s, &["SET", "s1", "v2", "GET"]).await;
assert_contains(&set_get2, "v1", "SET s1 v2 GET returns previous value v1");
let g2 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g2, "v2", "GET s1 now v2");
// NX prevents update when key exists; with GET should return Null and not change
let set_nx = send_cmd(&mut s, &["SET", "s1", "v3", "NX", "GET"]).await;
assert_contains(&set_nx, "$-1", "SET s1 v3 NX GET returns Null when not set");
let g3 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g3, "v2", "GET s1 remains v2 after NX prevented write");
// NX allows set when key does not exist
let set_nx2 = send_cmd(&mut s, &["SET", "s2", "v10", "NX"]).await;
assert_contains(&set_nx2, "OK", "SET s2 v10 NX -> OK for new key");
let g4 = send_cmd(&mut s, &["GET", "s2"]).await;
assert_contains(&g4, "v10", "GET s2 is v10");
// XX requires existing key; with GET returns old value and sets new
let set_xx = send_cmd(&mut s, &["SET", "s2", "v11", "XX", "GET"]).await;
assert_contains(&set_xx, "v10", "SET s2 v11 XX GET returns previous v10");
let g5 = send_cmd(&mut s, &["GET", "s2"]).await;
assert_contains(&g5, "v11", "GET s2 is now v11");
// PX expiration path via SET options
let set_px = send_cmd(&mut s, &["SET", "s3", "vpx", "PX", "500"]).await;
assert_contains(&set_px, "OK", "SET s3 vpx PX 500 -> OK");
let ttl_px1 = send_cmd(&mut s, &["TTL", "s3"]).await;
assert!(
ttl_px1.contains("0") || ttl_px1.contains("1"),
"TTL s3 immediately after PX should be 1 or 0, got: {}",
ttl_px1
);
sleep(Duration::from_millis(650)).await;
let g6 = send_cmd(&mut s, &["GET", "s3"]).await;
assert_contains(&g6, "$-1", "GET s3 after PX expiry -> Null");
}
#[tokio::test]
async fn test_09_mget_mset_and_variadic_exists_del() {
let (server, port) = start_test_server("mget_mset_variadic").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// MSET multiple keys
let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await;
assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK");
// MGET should return values and Null for missing
let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await;
// Expect an array with 4 entries; verify payloads
assert_contains(&mget, "v1", "MGET k1");
assert_contains(&mget, "v2", "MGET k2");
assert_contains(&mget, "v3", "MGET k3");
assert_contains(&mget, "$-1", "MGET missing returns Null");
// EXISTS variadic: count how many exist
let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await;
// Server returns SimpleString numeric, e.g. +2
assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2");
// DEL variadic: delete multiple keys, return count deleted
let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await;
assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2");
// Verify deletion
let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await;
assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0");
// MGET after deletion should include Nulls for deleted keys
let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await;
assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null");
assert_contains(&mget_after, "v2", "MGET k2 remains");
assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null");
}
#[tokio::test]
async fn test_12_hash_incr() {
let (server, port) = start_test_server("hash_incr").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// Integer increments
let _ = send_cmd(&mut s, &["HSET", "hinc", "a", "1"]).await;
let r1 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "2"]).await;
assert_contains(&r1, "3", "HINCRBY hinc a 2 -&gt; 3");
let r2 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "-1"]).await;
assert_contains(&r2, "2", "HINCRBY hinc a -1 -&gt; 2");
let r3 = send_cmd(&mut s, &["HINCRBY", "hinc", "b", "5"]).await;
assert_contains(&r3, "5", "HINCRBY hinc b 5 -&gt; 5");
// HINCRBY error on non-integer field
let _ = send_cmd(&mut s, &["HSET", "hinc", "s", "x"]).await;
let r_err = send_cmd(&mut s, &["HINCRBY", "hinc", "s", "1"]).await;
assert_contains(&r_err, "ERR", "HINCRBY on non-integer field should ERR");
// Float increments
let r4 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "1.5"]).await;
assert_contains(&r4, "1.5", "HINCRBYFLOAT hinc f 1.5 -&gt; 1.5");
let r5 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "2.5"]).await;
// Could be "4", "4.0", or "4.000000", accept "4" substring
assert_contains(&r5, "4", "HINCRBYFLOAT hinc f 2.5 -&gt; 4");
// HINCRBYFLOAT error on non-float field
let _ = send_cmd(&mut s, &["HSET", "hinc", "notf", "abc"]).await;
let r6 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "notf", "1"]).await;
assert_contains(&r6, "ERR", "HINCRBYFLOAT on non-float field should ERR");
}
#[tokio::test]
async fn test_05b_brpop_suite() {
let (server, port) = start_test_server("lists_brpop").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut a = connect(port).await;
// RPUSH some initial data, BRPOP should take from the right
let _ = send_cmd(&mut a, &["RPUSH", "q:rjobs", "1", "2"]).await;
let br_nonblock = send_cmd(&mut a, &["BRPOP", "q:rjobs", "0"]).await;
// Should pop the rightmost element "2"
assert_contains(&br_nonblock, "q:rjobs", "BRPOP returns key");
assert_contains(&br_nonblock, "2", "BRPOP returns rightmost element");
// Now test blocking BRPOP: start blocked client, then RPUSH from another client
let c1 = connect(port).await;
let mut c2 = connect(port).await;
// Start BRPOP on c1
let brpop_task = tokio::spawn(async move {
let mut c1_local = c1;
send_cmd(&mut c1_local, &["BRPOP", "q:blockr", "5"]).await
});
// Give it time to register waiter
sleep(Duration::from_millis(150)).await;
// Push from right to wake BRPOP
let _ = send_cmd(&mut c2, &["RPUSH", "q:blockr", "X"]).await;
// Await BRPOP result
let brpop_res = brpop_task.await.expect("BRPOP task join");
assert_contains(&brpop_res, "q:blockr", "BRPOP returned key");
assert_contains(&brpop_res, "X", "BRPOP returned element");
}
#[tokio::test]
async fn test_13_dbsize() {
let (server, port) = start_test_server("dbsize").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// Initially empty
let n0 = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n0, "0", "DBSIZE initial should be 0");
// Add a string, a hash, and a list -> dbsize = 3
let _ = send_cmd(&mut s, &["SET", "s", "v"]).await;
let _ = send_cmd(&mut s, &["HSET", "h", "f", "v"]).await;
let _ = send_cmd(&mut s, &["LPUSH", "l", "a", "b"]).await;
let n3 = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n3, "3", "DBSIZE after adding s,h,l should be 3");
// Expire the string and wait, dbsize should drop to 2
let _ = send_cmd(&mut s, &["PEXPIRE", "s", "400"]).await;
sleep(Duration::from_millis(500)).await;
let n2 = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n2, "2", "DBSIZE after string expiry should be 2");
// Delete remaining keys and confirm 0
let _ = send_cmd(&mut s, &["DEL", "h"]).await;
let _ = send_cmd(&mut s, &["DEL", "l"]).await;
let n_final = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n_final, "0", "DBSIZE after deleting all keys should be 0");
}
#[tokio::test]
async fn test_14_expireat_pexpireat() {
use std::time::{SystemTime, UNIX_EPOCH};
let (server, port) = start_test_server("expireat_suite").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// EXPIREAT: seconds since epoch
let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await;
let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await;
assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await;
assert!(
ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:at:s should be 1 or 0 shortly after EXPIREAT, got: {}",
ttl1
);
sleep(Duration::from_millis(1200)).await;
let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await;
assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0");
// PEXPIREAT: milliseconds since epoch
let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await;
let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await;
assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)");
let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await;
assert!(
ttl2.contains("0") || ttl2.contains("1"),
"TTL exp:at:ms should be 0..1 soon after PEXPIREAT, got: {}",
ttl2
);
sleep(Duration::from_millis(600)).await;
let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await;
assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0");
}