implement MGET/MSET and variadic DEL/EXISTS

This commit is contained in:
Maxime Van Hees 2025-08-19 11:16:00 +02:00
parent a306544a34
commit b644bf873f
2 changed files with 86 additions and 17 deletions

View File

@ -12,6 +12,8 @@ pub enum Cmd {
Set(String, String), Set(String, String),
SetPx(String, String, u128), SetPx(String, String, u128),
SetEx(String, String, u128), SetEx(String, String, u128),
MGet(Vec<String>),
MSet(Vec<(String, String)>),
Keys, Keys,
ConfigGet(String), ConfigGet(String),
Info(Option<String>), Info(Option<String>),
@ -36,6 +38,8 @@ pub enum Cmd {
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Ttl(String), Ttl(String),
Exists(String), Exists(String),
ExistsMulti(Vec<String>),
DelMulti(Vec<String>),
Quit, Quit,
Client(Vec<String>), Client(Vec<String>),
ClientSetName(String), ClientSetName(String),
@ -110,6 +114,24 @@ impl Cmd {
} }
Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap()) Cmd::SetEx(cmd[1].clone(), cmd[3].clone(), cmd[2].parse().unwrap())
} }
"mget" => {
if cmd.len() < 2 {
return Err(DBError("wrong number of arguments for MGET command".to_string()));
}
Cmd::MGet(cmd[1..].to_vec())
}
"mset" => {
if cmd.len() < 3 || ((cmd.len() - 1) % 2 != 0) {
return Err(DBError("wrong number of arguments for MSET command".to_string()));
}
let mut pairs = Vec::new();
let mut i = 1;
while i + 1 < cmd.len() {
pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
i += 2;
}
Cmd::MSet(pairs)
}
"config" => { "config" => {
if cmd.len() != 3 || cmd[1].to_lowercase() != "get" { if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
return Err(DBError(format!("unsupported cmd {:?}", cmd))); return Err(DBError(format!("unsupported cmd {:?}", cmd)));
@ -133,10 +155,14 @@ impl Cmd {
Cmd::Info(section) Cmd::Info(section)
} }
"del" => { "del" => {
if cmd.len() != 2 { if cmd.len() < 2 {
return Err(DBError(format!("unsupported cmd {:?}", cmd))); return Err(DBError(format!("wrong number of arguments for DEL command")));
} }
if cmd.len() == 2 {
Cmd::Del(cmd[1].clone()) Cmd::Del(cmd[1].clone())
} else {
Cmd::DelMulti(cmd[1..].to_vec())
}
} }
"type" => { "type" => {
if cmd.len() != 2 { if cmd.len() != 2 {
@ -312,10 +338,14 @@ impl Cmd {
Cmd::Ttl(cmd[1].clone()) Cmd::Ttl(cmd[1].clone())
} }
"exists" => { "exists" => {
if cmd.len() != 2 { if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for EXISTS command"))); return Err(DBError(format!("wrong number of arguments for EXISTS command")));
} }
if cmd.len() == 2 {
Cmd::Exists(cmd[1].clone()) Cmd::Exists(cmd[1].clone())
} else {
Cmd::ExistsMulti(cmd[1..].to_vec())
}
} }
"quit" => { "quit" => {
if cmd.len() != 1 { if cmd.len() != 1 {
@ -507,7 +537,10 @@ impl Cmd {
Cmd::Set(k, v) => set_cmd(server, &k, &v).await, Cmd::Set(k, v) => set_cmd(server, &k, &v).await,
Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await, Cmd::SetPx(k, v, x) => set_px_cmd(server, &k, &v, &x).await,
Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await, Cmd::SetEx(k, v, x) => set_ex_cmd(server, &k, &v, &x).await,
Cmd::MGet(keys) => mget_cmd(server, &keys).await,
Cmd::MSet(pairs) => mset_cmd(server, &pairs).await,
Cmd::Del(k) => del_cmd(server, &k).await, Cmd::Del(k) => del_cmd(server, &k).await,
Cmd::DelMulti(keys) => del_multi_cmd(server, &keys).await,
Cmd::ConfigGet(name) => config_get_cmd(&name, server), Cmd::ConfigGet(name) => config_get_cmd(&name, server),
Cmd::Keys => keys_cmd(server).await, Cmd::Keys => keys_cmd(server).await,
Cmd::Info(section) => info_cmd(server, &section).await, Cmd::Info(section) => info_cmd(server, &section).await,
@ -541,6 +574,7 @@ impl Cmd {
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await, Cmd::Scan(cursor, pattern, count) => scan_cmd(server, &cursor, pattern.as_deref(), &count).await,
Cmd::Ttl(key) => ttl_cmd(server, &key).await, Cmd::Ttl(key) => ttl_cmd(server, &key).await,
Cmd::Exists(key) => exists_cmd(server, &key).await, Cmd::Exists(key) => exists_cmd(server, &key).await,
Cmd::ExistsMulti(keys) => exists_multi_cmd(server, &keys).await,
Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Quit => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())), Cmd::Client(_) => Ok(Protocol::SimpleString("OK".to_string())),
Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await, Cmd::ClientSetName(name) => client_setname_cmd(server, &name).await,
@ -921,6 +955,53 @@ async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError>
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
} }
// MGET: return array of bulk strings or Null for missing
async fn mget_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let mut out: Vec<Protocol> = Vec::with_capacity(keys.len());
let storage = server.current_storage()?;
for k in keys {
match storage.get(k)? {
Some(v) => out.push(Protocol::BulkString(v)),
None => out.push(Protocol::Null),
}
}
Ok(Protocol::Array(out))
}
// MSET: set multiple key/value pairs, return OK
async fn mset_cmd(server: &Server, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
for (k, v) in pairs {
storage.set(k.clone(), v.clone())?;
}
Ok(Protocol::SimpleString("OK".to_string()))
}
// DEL with multiple keys: return count of keys actually deleted
async fn del_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let mut deleted = 0i64;
for k in keys {
if storage.exists(k)? {
storage.del(k.clone())?;
deleted += 1;
}
}
Ok(Protocol::SimpleString(deleted.to_string()))
}
// EXISTS with multiple keys: return count existing
async fn exists_multi_cmd(server: &Server, keys: &[String]) -> Result<Protocol, DBError> {
let storage = server.current_storage()?;
let mut count = 0i64;
for k in keys {
if storage.exists(k)? {
count += 1;
}
}
Ok(Protocol::SimpleString(count.to_string()))
}
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> { async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
let v = server.current_storage()?.get(k)?; let v = server.current_storage()?.get(k)?;
Ok(v.map_or(Protocol::Null, Protocol::BulkString)) Ok(v.map_or(Protocol::Null, Protocol::BulkString))

View File

@ -148,8 +148,6 @@ impl Storage {
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> { pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -168,8 +166,6 @@ impl Storage {
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> { pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -200,8 +196,6 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> { pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -233,8 +227,6 @@ impl Storage {
pub fn hlen(&self, key: &str) -> Result<i64, DBError> { pub fn hlen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -265,8 +257,6 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> { pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
@ -334,8 +324,6 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> { pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;