...
This commit is contained in:
		
							
								
								
									
										99
									
								
								src/cmd.rs
									
									
									
									
									
								
							
							
						
						
									
										99
									
								
								src/cmd.rs
									
									
									
									
									
								
							@@ -4,7 +4,7 @@ use crate::{error::DBError, protocol::Protocol, server::Server};
 | 
			
		||||
pub enum Cmd {
 | 
			
		||||
    Ping,
 | 
			
		||||
    Echo(String),
 | 
			
		||||
    Select(u16),
 | 
			
		||||
    Select(u64), // Changed from u16 to u64
 | 
			
		||||
    Get(String),
 | 
			
		||||
    Set(String, String),
 | 
			
		||||
    SetPx(String, String, u128),
 | 
			
		||||
@@ -47,6 +47,7 @@ pub enum Cmd {
 | 
			
		||||
    LTrim(String, i64, i64),
 | 
			
		||||
    LIndex(String, i64),
 | 
			
		||||
    LRange(String, i64, i64),
 | 
			
		||||
    FlushDb,
 | 
			
		||||
    Unknow(String),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -65,7 +66,7 @@ impl Cmd {
 | 
			
		||||
                            if cmd.len() != 2 {
 | 
			
		||||
                                return Err(DBError("wrong number of arguments for SELECT".to_string()));
 | 
			
		||||
                            }
 | 
			
		||||
                            let idx = cmd[1].parse::<u16>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
 | 
			
		||||
                            let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
 | 
			
		||||
                            Cmd::Select(idx)
 | 
			
		||||
                        }
 | 
			
		||||
                        "echo" => Cmd::Echo(cmd[1].clone()),
 | 
			
		||||
@@ -394,6 +395,12 @@ impl Cmd {
 | 
			
		||||
                            let stop = cmd[3].parse::<i64>().map_err(|_| DBError("ERR value is not an integer or out of range".to_string()))?;
 | 
			
		||||
                            Cmd::LRange(cmd[1].clone(), start, stop)
 | 
			
		||||
                        }
 | 
			
		||||
                        "flushdb" => {
 | 
			
		||||
                            if cmd.len() != 1 {
 | 
			
		||||
                                return Err(DBError("wrong number of arguments for FLUSHDB command".to_string()));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::FlushDb
 | 
			
		||||
                        }
 | 
			
		||||
                        _ => Cmd::Unknow(cmd[0].clone()),
 | 
			
		||||
                    },
 | 
			
		||||
                    protocol,
 | 
			
		||||
@@ -482,6 +489,7 @@ impl Cmd {
 | 
			
		||||
            Cmd::LTrim(key, start, stop) => ltrim_cmd(server, key, *start, *stop).await,
 | 
			
		||||
            Cmd::LIndex(key, index) => lindex_cmd(server, key, *index).await,
 | 
			
		||||
            Cmd::LRange(key, start, stop) => lrange_cmd(server, key, *start, *stop).await,
 | 
			
		||||
            Cmd::FlushDb => flushdb_cmd(server).await,
 | 
			
		||||
            Cmd::Unknow(s) => {
 | 
			
		||||
                println!("\x1b[31;1munknown command: {}\x1b[0m", s);
 | 
			
		||||
                Ok(Protocol::err(&format!("ERR unknown command '{}'", s)))
 | 
			
		||||
@@ -489,17 +497,25 @@ impl Cmd {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
async fn select_cmd(server: &mut Server, db: u16) -> Result<Protocol, DBError> {
 | 
			
		||||
    let idx = db as usize;
 | 
			
		||||
    if idx >= server.storages.len() {
 | 
			
		||||
        return Ok(Protocol::err("ERR DB index is out of range"));
 | 
			
		||||
 | 
			
		||||
async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage()?.flushdb() {
 | 
			
		||||
        Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> {
 | 
			
		||||
    // Test if we can access the database (this will create it if needed)
 | 
			
		||||
    server.selected_db = db;
 | 
			
		||||
    match server.current_storage() {
 | 
			
		||||
        Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
    server.selected_db = idx;
 | 
			
		||||
    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().lindex(key, index) {
 | 
			
		||||
    match server.current_storage()?.lindex(key, index) {
 | 
			
		||||
        Ok(Some(element)) => Ok(Protocol::BulkString(element)),
 | 
			
		||||
        Ok(None) => Ok(Protocol::Null),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
@@ -507,35 +523,35 @@ async fn lindex_cmd(server: &Server, key: &str, index: i64) -> Result<Protocol,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn lrange_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().lrange(key, start, stop) {
 | 
			
		||||
    match server.current_storage()?.lrange(key, start, stop) {
 | 
			
		||||
        Ok(elements) => Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn ltrim_cmd(server: &Server, key: &str, start: i64, stop: i64) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().ltrim(key, start, stop) {
 | 
			
		||||
    match server.current_storage()?.ltrim(key, start, stop) {
 | 
			
		||||
        Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn lrem_cmd(server: &Server, key: &str, count: i64, element: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().lrem(key, count, element) {
 | 
			
		||||
    match server.current_storage()?.lrem(key, count, element) {
 | 
			
		||||
        Ok(removed_count) => Ok(Protocol::SimpleString(removed_count.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn llen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().llen(key) {
 | 
			
		||||
    match server.current_storage()?.llen(key) {
 | 
			
		||||
        Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().lpop(key, *count) {
 | 
			
		||||
    match server.current_storage()?.lpop(key, *count) {
 | 
			
		||||
        Ok(Some(elements)) => {
 | 
			
		||||
            if count.is_some() {
 | 
			
		||||
                Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
 | 
			
		||||
@@ -555,7 +571,7 @@ async fn lpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().rpop(key, *count) {
 | 
			
		||||
    match server.current_storage()?.rpop(key, *count) {
 | 
			
		||||
        Ok(Some(elements)) => {
 | 
			
		||||
            if count.is_some() {
 | 
			
		||||
                Ok(Protocol::Array(elements.into_iter().map(Protocol::BulkString).collect()))
 | 
			
		||||
@@ -575,14 +591,14 @@ async fn rpop_cmd(server: &Server, key: &str, count: &Option<u64>) -> Result<Pro
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().lpush(key, elements.to_vec()) {
 | 
			
		||||
    match server.current_storage()?.lpush(key, elements.to_vec()) {
 | 
			
		||||
        Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn rpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().rpush(key, elements.to_vec()) {
 | 
			
		||||
    match server.current_storage()?.rpush(key, elements.to_vec()) {
 | 
			
		||||
        Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
@@ -606,7 +622,8 @@ async fn exec_cmd(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
 | 
			
		||||
    let current_value = server.current_storage().get(key)?;
 | 
			
		||||
    let storage = server.current_storage()?;
 | 
			
		||||
    let current_value = storage.get(key)?;
 | 
			
		||||
    
 | 
			
		||||
    let new_value = match current_value {
 | 
			
		||||
        Some(v) => {
 | 
			
		||||
@@ -618,7 +635,7 @@ async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
 | 
			
		||||
        None => 1,
 | 
			
		||||
    };
 | 
			
		||||
    
 | 
			
		||||
    server.current_storage().set(key.clone(), new_value.to_string())?;
 | 
			
		||||
    storage.set(key.clone(), new_value.to_string())?;
 | 
			
		||||
    Ok(Protocol::SimpleString(new_value.to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -634,14 +651,14 @@ fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
 | 
			
		||||
        ])),
 | 
			
		||||
        "databases" => Ok(Protocol::Array(vec![
 | 
			
		||||
            Protocol::BulkString(name.clone()),
 | 
			
		||||
            Protocol::BulkString(server.option.databases.to_string()),
 | 
			
		||||
            Protocol::BulkString(server.option.max_databases.unwrap_or(0).to_string()),
 | 
			
		||||
        ])),
 | 
			
		||||
        _ => Ok(Protocol::Array(vec![])),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
 | 
			
		||||
    let keys = server.current_storage().keys("*")?;
 | 
			
		||||
    let keys = server.current_storage()?.keys("*")?;
 | 
			
		||||
    Ok(Protocol::Array(
 | 
			
		||||
        keys.into_iter().map(Protocol::BulkString).collect(),
 | 
			
		||||
    ))
 | 
			
		||||
@@ -660,14 +677,14 @@ fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().get_key_type(k)? {
 | 
			
		||||
    match server.current_storage()?.get_key_type(k)? {
 | 
			
		||||
        Some(type_str) => Ok(Protocol::SimpleString(type_str)),
 | 
			
		||||
        None => Ok(Protocol::SimpleString("none".to_string())),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    server.current_storage().del(k.to_string())?;
 | 
			
		||||
    server.current_storage()?.del(k.to_string())?;
 | 
			
		||||
    Ok(Protocol::SimpleString("1".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -677,7 +694,7 @@ async fn set_ex_cmd(
 | 
			
		||||
    v: &str,
 | 
			
		||||
    x: &u128,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    server.current_storage().setx(k.to_string(), v.to_string(), *x * 1000)?;
 | 
			
		||||
    server.current_storage()?.setx(k.to_string(), v.to_string(), *x * 1000)?;
 | 
			
		||||
    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -687,28 +704,28 @@ async fn set_px_cmd(
 | 
			
		||||
    v: &str,
 | 
			
		||||
    x: &u128,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    server.current_storage().setx(k.to_string(), v.to_string(), *x)?;
 | 
			
		||||
    server.current_storage()?.setx(k.to_string(), v.to_string(), *x)?;
 | 
			
		||||
    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    server.current_storage().set(k.to_string(), v.to_string())?;
 | 
			
		||||
    server.current_storage()?.set(k.to_string(), v.to_string())?;
 | 
			
		||||
    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    let v = server.current_storage().get(k)?;
 | 
			
		||||
    let v = server.current_storage()?.get(k)?;
 | 
			
		||||
    Ok(v.map_or(Protocol::Null, Protocol::BulkString))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Hash command implementations
 | 
			
		||||
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
 | 
			
		||||
    let new_fields = server.current_storage().hset(key, pairs)?;
 | 
			
		||||
    let new_fields = server.current_storage()?.hset(key, pairs)?;
 | 
			
		||||
    Ok(Protocol::SimpleString(new_fields.to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hget(key, field) {
 | 
			
		||||
    match server.current_storage()?.hget(key, field) {
 | 
			
		||||
        Ok(Some(value)) => Ok(Protocol::BulkString(value)),
 | 
			
		||||
        Ok(None) => Ok(Protocol::Null),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
@@ -716,7 +733,7 @@ async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, D
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hgetall(key) {
 | 
			
		||||
    match server.current_storage()?.hgetall(key) {
 | 
			
		||||
        Ok(pairs) => {
 | 
			
		||||
            let mut result = Vec::new();
 | 
			
		||||
            for (field, value) in pairs {
 | 
			
		||||
@@ -730,21 +747,21 @@ async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hdel(key, fields) {
 | 
			
		||||
    match server.current_storage()?.hdel(key, fields) {
 | 
			
		||||
        Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hexists(key, field) {
 | 
			
		||||
    match server.current_storage()?.hexists(key, field) {
 | 
			
		||||
        Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hkeys(key) {
 | 
			
		||||
    match server.current_storage()?.hkeys(key) {
 | 
			
		||||
        Ok(keys) => Ok(Protocol::Array(
 | 
			
		||||
            keys.into_iter().map(Protocol::BulkString).collect(),
 | 
			
		||||
        )),
 | 
			
		||||
@@ -753,7 +770,7 @@ async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hvals(key) {
 | 
			
		||||
    match server.current_storage()?.hvals(key) {
 | 
			
		||||
        Ok(values) => Ok(Protocol::Array(
 | 
			
		||||
            values.into_iter().map(Protocol::BulkString).collect(),
 | 
			
		||||
        )),
 | 
			
		||||
@@ -762,14 +779,14 @@ async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hlen(key) {
 | 
			
		||||
    match server.current_storage()?.hlen(key) {
 | 
			
		||||
        Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hmget(key, fields) {
 | 
			
		||||
    match server.current_storage()?.hmget(key, fields) {
 | 
			
		||||
        Ok(values) => {
 | 
			
		||||
            let result: Vec<Protocol> = values
 | 
			
		||||
                .into_iter()
 | 
			
		||||
@@ -782,14 +799,14 @@ async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Prot
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hsetnx(key, field, value) {
 | 
			
		||||
    match server.current_storage()?.hsetnx(key, field, value) {
 | 
			
		||||
        Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().scan(*cursor, pattern, *count) {
 | 
			
		||||
    match server.current_storage()?.scan(*cursor, pattern, *count) {
 | 
			
		||||
        Ok((next_cursor, keys)) => {
 | 
			
		||||
            let mut result = Vec::new();
 | 
			
		||||
            result.push(Protocol::BulkString(next_cursor.to_string()));
 | 
			
		||||
@@ -803,7 +820,7 @@ async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().hscan(key, *cursor, pattern, *count) {
 | 
			
		||||
    match server.current_storage()?.hscan(key, *cursor, pattern, *count) {
 | 
			
		||||
        Ok((next_cursor, fields)) => {
 | 
			
		||||
            let mut result = Vec::new();
 | 
			
		||||
            result.push(Protocol::BulkString(next_cursor.to_string()));
 | 
			
		||||
@@ -817,14 +834,14 @@ async fn hscan_cmd(server: &Server, key: &str, cursor: &u64, pattern: Option<&st
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn ttl_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().ttl(key) {
 | 
			
		||||
    match server.current_storage()?.ttl(key) {
 | 
			
		||||
        Ok(ttl) => Ok(Protocol::SimpleString(ttl.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn exists_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.current_storage().exists(key) {
 | 
			
		||||
    match server.current_storage()?.exists(key) {
 | 
			
		||||
        Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										73
									
								
								src/crypto.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								src/crypto.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,73 @@
 | 
			
		||||
use chacha20poly1305::{
 | 
			
		||||
    aead::{Aead, KeyInit, OsRng},
 | 
			
		||||
    XChaCha20Poly1305, XNonce,
 | 
			
		||||
};
 | 
			
		||||
use rand::RngCore;
 | 
			
		||||
use sha2::{Digest, Sha256};
 | 
			
		||||
 | 
			
		||||
const VERSION: u8 = 1;
 | 
			
		||||
const NONCE_LEN: usize = 24;
 | 
			
		||||
const TAG_LEN: usize = 16;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum CryptoError {
 | 
			
		||||
    Format,         // wrong length / header
 | 
			
		||||
    Version(u8),    // unknown version
 | 
			
		||||
    Decrypt,        // wrong key or corrupted data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<CryptoError> for crate::error::DBError {
 | 
			
		||||
    fn from(e: CryptoError) -> Self {
 | 
			
		||||
        crate::error::DBError(format!("Crypto error: {:?}", e))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Super-simple factory: new(secret) + encrypt(bytes) + decrypt(bytes)
 | 
			
		||||
pub struct CryptoFactory {
 | 
			
		||||
    key: chacha20poly1305::Key,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl CryptoFactory {
 | 
			
		||||
    /// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
 | 
			
		||||
    pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
 | 
			
		||||
        let mut h = Sha256::new();
 | 
			
		||||
        h.update(b"xchacha20poly1305-factory:v1"); // domain separation
 | 
			
		||||
        h.update(secret.as_ref());
 | 
			
		||||
        let digest = h.finalize(); // 32 bytes
 | 
			
		||||
        let key = chacha20poly1305::Key::from_slice(&digest).to_owned();
 | 
			
		||||
        Self { key }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Output layout: [version:1][nonce:24][ciphertext||tag]
 | 
			
		||||
    pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
 | 
			
		||||
        let cipher = XChaCha20Poly1305::new(&self.key);
 | 
			
		||||
 | 
			
		||||
        let mut nonce_bytes = [0u8; NONCE_LEN];
 | 
			
		||||
        OsRng.fill_bytes(&mut nonce_bytes);
 | 
			
		||||
        let nonce = XNonce::from_slice(&nonce_bytes);
 | 
			
		||||
 | 
			
		||||
        let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
 | 
			
		||||
        out.push(VERSION);
 | 
			
		||||
        out.extend_from_slice(&nonce_bytes);
 | 
			
		||||
 | 
			
		||||
        let ct = cipher.encrypt(nonce, plaintext).expect("encrypt");
 | 
			
		||||
        out.extend_from_slice(&ct);
 | 
			
		||||
        out
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn decrypt(&self, blob: &[u8]) -> Result<Vec<u8>, CryptoError> {
 | 
			
		||||
        if blob.len() < 1 + NONCE_LEN + TAG_LEN {
 | 
			
		||||
            return Err(CryptoError::Format);
 | 
			
		||||
        }
 | 
			
		||||
        let ver = blob[0];
 | 
			
		||||
        if ver != VERSION {
 | 
			
		||||
            return Err(CryptoError::Version(ver));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
 | 
			
		||||
        let ct = &blob[1 + NONCE_LEN..];
 | 
			
		||||
 | 
			
		||||
        let cipher = XChaCha20Poly1305::new(&self.key);
 | 
			
		||||
        cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,4 +1,5 @@
 | 
			
		||||
pub mod cmd;
 | 
			
		||||
pub mod crypto;
 | 
			
		||||
pub mod error;
 | 
			
		||||
pub mod options;
 | 
			
		||||
pub mod protocol;
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										14
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								src/main.rs
									
									
									
									
									
								
							@@ -14,7 +14,6 @@ struct Args {
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    dir: String,
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /// The port of the Redis server, default is 6379 if not specified
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    port: Option<u16>,
 | 
			
		||||
@@ -23,9 +22,13 @@ struct Args {
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    debug: bool,
 | 
			
		||||
 | 
			
		||||
    /// Number of logical databases (SELECT 0..N-1)
 | 
			
		||||
    #[arg(long, default_value_t = 16)]
 | 
			
		||||
    databases: u16,
 | 
			
		||||
    /// Maximum number of logical databases (None = unlimited)
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    max_databases: Option<u64>,
 | 
			
		||||
 | 
			
		||||
    /// Master encryption key for encrypted databases
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    encryption_key: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[tokio::main]
 | 
			
		||||
@@ -45,7 +48,8 @@ async fn main() {
 | 
			
		||||
        dir: args.dir,
 | 
			
		||||
        port,
 | 
			
		||||
        debug: args.debug,
 | 
			
		||||
        databases: args.databases,
 | 
			
		||||
        max_databases: args.max_databases,
 | 
			
		||||
        encryption_key: args.encryption_key,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // new server
 | 
			
		||||
 
 | 
			
		||||
@@ -3,5 +3,6 @@ pub struct DBOption {
 | 
			
		||||
    pub dir: String,
 | 
			
		||||
    pub port: u16,
 | 
			
		||||
    pub debug: bool,
 | 
			
		||||
    pub databases: u16, // number of logical DBs (default 16)
 | 
			
		||||
    pub max_databases: Option<u64>, // None = unlimited, Some(n) = limit to n
 | 
			
		||||
    pub encryption_key: Option<String>, // Master encryption key
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
use core::str;
 | 
			
		||||
use std::path::PathBuf;
 | 
			
		||||
use std::collections::HashMap;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use tokio::io::AsyncReadExt;
 | 
			
		||||
use tokio::io::AsyncWriteExt;
 | 
			
		||||
@@ -12,34 +12,56 @@ use crate::storage::Storage;
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct Server {
 | 
			
		||||
    pub storages: Vec<Arc<Storage>>,
 | 
			
		||||
    pub db_cache: std::sync::Arc<std::sync::RwLock<HashMap<u64, Arc<Storage>>>>,
 | 
			
		||||
    pub option: options::DBOption,
 | 
			
		||||
    pub client_name: Option<String>,
 | 
			
		||||
    pub selected_db: usize, // per-connection
 | 
			
		||||
    pub selected_db: u64, // Changed from usize to u64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Server {
 | 
			
		||||
    pub async fn new(option: options::DBOption) -> Self {
 | 
			
		||||
        // Eagerly create N db files: <dir>/<index>.db
 | 
			
		||||
        let mut storages = Vec::with_capacity(option.databases as usize);
 | 
			
		||||
        for i in 0..option.databases {
 | 
			
		||||
            let db_file_path = PathBuf::from(option.dir.clone()).join(format!("{}.db", i));
 | 
			
		||||
            println!("will open db file path (db {}): {}", i, db_file_path.display());
 | 
			
		||||
            let storage = Storage::new(db_file_path).expect("Failed to initialize storage");
 | 
			
		||||
            storages.push(Arc::new(storage));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Server {
 | 
			
		||||
            storages,
 | 
			
		||||
            db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())),
 | 
			
		||||
            option,
 | 
			
		||||
            client_name: None,
 | 
			
		||||
            selected_db: 0,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[inline]
 | 
			
		||||
    pub fn current_storage(&self) -> &Storage {
 | 
			
		||||
        self.storages[self.selected_db].as_ref()
 | 
			
		||||
    pub fn current_storage(&self) -> Result<Arc<Storage>, DBError> {
 | 
			
		||||
        let mut cache = self.db_cache.write().unwrap();
 | 
			
		||||
        
 | 
			
		||||
        if let Some(storage) = cache.get(&self.selected_db) {
 | 
			
		||||
            return Ok(storage.clone());
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        // Check database limit if set
 | 
			
		||||
        if let Some(max_db) = self.option.max_databases {
 | 
			
		||||
            if self.selected_db >= max_db {
 | 
			
		||||
                return Err(DBError(format!("DB index {} is out of range (max: {})", self.selected_db, max_db - 1)));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        // Create new database file
 | 
			
		||||
        let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
 | 
			
		||||
            .join(format!("{}.db", self.selected_db));
 | 
			
		||||
        
 | 
			
		||||
        println!("Creating new db file: {}", db_file_path.display());
 | 
			
		||||
        
 | 
			
		||||
        let storage = Arc::new(Storage::new(
 | 
			
		||||
            db_file_path,
 | 
			
		||||
            self.should_encrypt_db(self.selected_db),
 | 
			
		||||
            self.option.encryption_key.as_deref()
 | 
			
		||||
        )?);
 | 
			
		||||
        
 | 
			
		||||
        cache.insert(self.selected_db, storage.clone());
 | 
			
		||||
        Ok(storage)
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    fn should_encrypt_db(&self, db_index: u64) -> bool {
 | 
			
		||||
        // You can implement logic here to determine which databases should be encrypted
 | 
			
		||||
        // For now, let's say databases with even numbers are encrypted if key is provided
 | 
			
		||||
        self.option.encryption_key.is_some() && db_index % 2 == 0
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn handle(
 | 
			
		||||
@@ -104,6 +126,5 @@ impl Server {
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										281
									
								
								src/storage.rs
									
									
									
									
									
								
							
							
						
						
									
										281
									
								
								src/storage.rs
									
									
									
									
									
								
							@@ -6,8 +6,101 @@ use std::{
 | 
			
		||||
use redb::{Database, ReadableTable, TableDefinition};
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
 | 
			
		||||
use crate::crypto::CryptoFactory;
 | 
			
		||||
use crate::error::DBError;
 | 
			
		||||
 | 
			
		||||
// Add this glob matching function
 | 
			
		||||
fn glob_match(pattern: &str, text: &str) -> bool {
 | 
			
		||||
    fn match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
 | 
			
		||||
        if p_idx >= pattern.len() {
 | 
			
		||||
            return t_idx >= text.len();
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        match pattern[p_idx] {
 | 
			
		||||
            '*' => {
 | 
			
		||||
                // Try matching zero characters
 | 
			
		||||
                if match_recursive(pattern, text, p_idx + 1, t_idx) {
 | 
			
		||||
                    return true;
 | 
			
		||||
                }
 | 
			
		||||
                // Try matching one or more characters
 | 
			
		||||
                for i in t_idx..text.len() {
 | 
			
		||||
                    if match_recursive(pattern, text, p_idx + 1, i + 1) {
 | 
			
		||||
                        return true;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                false
 | 
			
		||||
            }
 | 
			
		||||
            '?' => {
 | 
			
		||||
                if t_idx >= text.len() {
 | 
			
		||||
                    false
 | 
			
		||||
                } else {
 | 
			
		||||
                    match_recursive(pattern, text, p_idx + 1, t_idx + 1)
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            '[' => {
 | 
			
		||||
                // Find the closing bracket
 | 
			
		||||
                let mut bracket_end = p_idx + 1;
 | 
			
		||||
                while bracket_end < pattern.len() && pattern[bracket_end] != ']' {
 | 
			
		||||
                    bracket_end += 1;
 | 
			
		||||
                }
 | 
			
		||||
                if bracket_end >= pattern.len() || t_idx >= text.len() {
 | 
			
		||||
                    return false;
 | 
			
		||||
                }
 | 
			
		||||
                
 | 
			
		||||
                let bracket_content = &pattern[p_idx + 1..bracket_end];
 | 
			
		||||
                let char_to_match = text[t_idx];
 | 
			
		||||
                let mut matched = false;
 | 
			
		||||
                
 | 
			
		||||
                let mut i = 0;
 | 
			
		||||
                while i < bracket_content.len() {
 | 
			
		||||
                    if i + 2 < bracket_content.len() && bracket_content[i + 1] == '-' {
 | 
			
		||||
                        // Range like [a-z]
 | 
			
		||||
                        if char_to_match >= bracket_content[i] && char_to_match <= bracket_content[i + 2] {
 | 
			
		||||
                            matched = true;
 | 
			
		||||
                            break;
 | 
			
		||||
                        }
 | 
			
		||||
                        i += 3;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        // Single character
 | 
			
		||||
                        if char_to_match == bracket_content[i] {
 | 
			
		||||
                            matched = true;
 | 
			
		||||
                            break;
 | 
			
		||||
                        }
 | 
			
		||||
                        i += 1;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                
 | 
			
		||||
                if matched {
 | 
			
		||||
                    match_recursive(pattern, text, bracket_end + 1, t_idx + 1)
 | 
			
		||||
                } else {
 | 
			
		||||
                    false
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            '\\' => {
 | 
			
		||||
                // Escape next character
 | 
			
		||||
                if p_idx + 1 >= pattern.len() || t_idx >= text.len() {
 | 
			
		||||
                    false
 | 
			
		||||
                } else if pattern[p_idx + 1] == text[t_idx] {
 | 
			
		||||
                    match_recursive(pattern, text, p_idx + 2, t_idx + 1)
 | 
			
		||||
                } else {
 | 
			
		||||
                    false
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            c => {
 | 
			
		||||
                if t_idx >= text.len() || c != text[t_idx] {
 | 
			
		||||
                    false
 | 
			
		||||
                } else {
 | 
			
		||||
                    match_recursive(pattern, text, p_idx + 1, t_idx + 1)
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    let pattern_chars: Vec<char> = pattern.chars().collect();
 | 
			
		||||
    let text_chars: Vec<char> = text.chars().collect();
 | 
			
		||||
    match_recursive(&pattern_chars, &text_chars, 0, 0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Table definitions for different Redis data types
 | 
			
		||||
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
 | 
			
		||||
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
 | 
			
		||||
@@ -15,6 +108,7 @@ const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("
 | 
			
		||||
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
 | 
			
		||||
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
 | 
			
		||||
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
 | 
			
		||||
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize, Debug, Clone)]
 | 
			
		||||
pub struct StringValue {
 | 
			
		||||
@@ -42,10 +136,11 @@ pub fn now_in_millis() -> u128 {
 | 
			
		||||
 | 
			
		||||
pub struct Storage {
 | 
			
		||||
    db: Database,
 | 
			
		||||
    crypto: Option<CryptoFactory>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Storage {
 | 
			
		||||
    pub fn new(path: impl AsRef<Path>) -> Result<Self, DBError> {
 | 
			
		||||
    pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
 | 
			
		||||
        let db = Database::create(path)?;
 | 
			
		||||
        
 | 
			
		||||
        // Create tables if they don't exist
 | 
			
		||||
@@ -57,10 +152,109 @@ impl Storage {
 | 
			
		||||
            let _ = write_txn.open_table(LISTS_TABLE)?;
 | 
			
		||||
            let _ = write_txn.open_table(STREAMS_META_TABLE)?;
 | 
			
		||||
            let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
 | 
			
		||||
            let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
 | 
			
		||||
        }
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        
 | 
			
		||||
        Ok(Storage { db })
 | 
			
		||||
        // Check if database was previously encrypted
 | 
			
		||||
        let read_txn = db.begin_read()?;
 | 
			
		||||
        let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
 | 
			
		||||
        let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
 | 
			
		||||
        drop(read_txn);
 | 
			
		||||
        
 | 
			
		||||
        let crypto = if should_encrypt || was_encrypted {
 | 
			
		||||
            if let Some(key) = master_key {
 | 
			
		||||
                Some(CryptoFactory::new(key.as_bytes()))
 | 
			
		||||
            } else {
 | 
			
		||||
                return Err(DBError("Encryption requested but no master key provided".to_string()));
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            None
 | 
			
		||||
        };
 | 
			
		||||
        
 | 
			
		||||
        // If we're enabling encryption for the first time, mark it
 | 
			
		||||
        if should_encrypt && !was_encrypted {
 | 
			
		||||
            let write_txn = db.begin_write()?;
 | 
			
		||||
            {
 | 
			
		||||
                let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?;
 | 
			
		||||
                encrypted_table.insert("encrypted", &1u8)?;
 | 
			
		||||
            }
 | 
			
		||||
            write_txn.commit()?;
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        Ok(Storage {
 | 
			
		||||
            db,
 | 
			
		||||
            crypto,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    pub fn is_encrypted(&self) -> bool {
 | 
			
		||||
        self.crypto.is_some()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Helper methods for encryption
 | 
			
		||||
    fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
 | 
			
		||||
        if let Some(crypto) = &self.crypto {
 | 
			
		||||
            Ok(crypto.encrypt(data))
 | 
			
		||||
        } else {
 | 
			
		||||
            Ok(data.to_vec())
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
 | 
			
		||||
        if let Some(crypto) = &self.crypto {
 | 
			
		||||
            Ok(crypto.decrypt(data)?)
 | 
			
		||||
        } else {
 | 
			
		||||
            Ok(data.to_vec())
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn flushdb(&self) -> Result<(), DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
 | 
			
		||||
            let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
            let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
 | 
			
		||||
            let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
 | 
			
		||||
            let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
 | 
			
		||||
 | 
			
		||||
            // inefficient, but there is no other way
 | 
			
		||||
            let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
 | 
			
		||||
            for key in keys {
 | 
			
		||||
                types_table.remove(key.as_str())?;
 | 
			
		||||
            }
 | 
			
		||||
            let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
 | 
			
		||||
            for key in keys {
 | 
			
		||||
                strings_table.remove(key.as_str())?;
 | 
			
		||||
            }
 | 
			
		||||
            let keys: Vec<(String,String)> = hashes_table.iter()?.map(|item| {
 | 
			
		||||
                let binding = item.unwrap();
 | 
			
		||||
                let (key, field) = binding.0.value();
 | 
			
		||||
                (key.to_string(), field.to_string())
 | 
			
		||||
            }).collect();
 | 
			
		||||
            for (key,field) in keys {
 | 
			
		||||
                hashes_table.remove((key.as_str(), field.as_str()))?;
 | 
			
		||||
            }
 | 
			
		||||
            let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
 | 
			
		||||
            for key in keys {
 | 
			
		||||
                lists_table.remove(key.as_str())?;
 | 
			
		||||
            }
 | 
			
		||||
            let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
 | 
			
		||||
            for key in keys {
 | 
			
		||||
                streams_meta_table.remove(key.as_str())?;
 | 
			
		||||
            }
 | 
			
		||||
            let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
 | 
			
		||||
                let binding = item.unwrap();
 | 
			
		||||
                let (key, field) = binding.0.value();
 | 
			
		||||
                (key.to_string(), field.to_string())
 | 
			
		||||
            }).collect();
 | 
			
		||||
            for (key, field) in keys {
 | 
			
		||||
                streams_data_table.remove((key.as_str(), field.as_str()))?;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
 | 
			
		||||
@@ -73,22 +267,22 @@ impl Storage {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Update the get method to use decryption
 | 
			
		||||
    pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        // Check if key exists and is of string type
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "string" => {
 | 
			
		||||
                let strings_table = read_txn.open_table(STRINGS_TABLE)?;
 | 
			
		||||
                match strings_table.get(key)? {
 | 
			
		||||
                    Some(data) => {
 | 
			
		||||
                        let string_value: StringValue = bincode::deserialize(data.value())?;
 | 
			
		||||
                        let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                        let string_value: StringValue = bincode::deserialize(&decrypted)?;
 | 
			
		||||
                        
 | 
			
		||||
                        // Check if expired
 | 
			
		||||
                        if let Some(expires_at) = string_value.expires_at_ms {
 | 
			
		||||
                            if now_in_millis() > expires_at {
 | 
			
		||||
                                // Key expired, remove it
 | 
			
		||||
                                drop(read_txn);
 | 
			
		||||
                                self.del(key.to_string())?;
 | 
			
		||||
                                return Ok(None);
 | 
			
		||||
@@ -103,7 +297,11 @@ impl Storage {
 | 
			
		||||
            _ => Ok(None),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    // Apply similar encryption/decryption to other methods (setx, hset, lpush, etc.)
 | 
			
		||||
    // ... (you'll need to update all methods that store/retrieve serialized data)
 | 
			
		||||
 | 
			
		||||
    // Update the set method to use encryption
 | 
			
		||||
    pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        
 | 
			
		||||
@@ -117,7 +315,8 @@ impl Storage {
 | 
			
		||||
                expires_at_ms: None,
 | 
			
		||||
            };
 | 
			
		||||
            let serialized = bincode::serialize(&string_value)?;
 | 
			
		||||
            strings_table.insert(key.as_str(), serialized.as_slice())?;
 | 
			
		||||
            let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
            strings_table.insert(key.as_str(), encrypted.as_slice())?;
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
@@ -137,7 +336,8 @@ impl Storage {
 | 
			
		||||
                expires_at_ms: Some(expire_ms + now_in_millis()),
 | 
			
		||||
            };
 | 
			
		||||
            let serialized = bincode::serialize(&string_value)?;
 | 
			
		||||
            strings_table.insert(key.as_str(), serialized.as_slice())?;
 | 
			
		||||
            let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
            strings_table.insert(key.as_str(), encrypted.as_slice())?;
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
@@ -191,7 +391,7 @@ impl Storage {
 | 
			
		||||
        let mut iter = table.iter()?;
 | 
			
		||||
        while let Some(entry) = iter.next() {
 | 
			
		||||
            let key = entry?.0.value().to_string();
 | 
			
		||||
            if pattern == "*" || key.contains(pattern) {
 | 
			
		||||
            if pattern == "*" || glob_match(pattern, &key) {
 | 
			
		||||
                keys.push(key);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
@@ -251,7 +451,7 @@ impl Storage {
 | 
			
		||||
                    None => Ok(None),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(None),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@@ -649,7 +849,7 @@ impl Storage {
 | 
			
		||||
    // List operations
 | 
			
		||||
    pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        let mut new_len = 0u64;
 | 
			
		||||
        let new_len;
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
@@ -671,7 +871,10 @@ impl Storage {
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let mut list_value: ListValue = match lists_table.get(key)? {
 | 
			
		||||
                Some(data) => bincode::deserialize(data.value())?,
 | 
			
		||||
                Some(data) => {
 | 
			
		||||
                    let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                    bincode::deserialize(&decrypted)?
 | 
			
		||||
                },
 | 
			
		||||
                None => ListValue { elements: Vec::new() },
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
@@ -681,7 +884,8 @@ impl Storage {
 | 
			
		||||
            new_len = list_value.elements.len() as u64;
 | 
			
		||||
 | 
			
		||||
            let serialized = bincode::serialize(&list_value)?;
 | 
			
		||||
            lists_table.insert(key, serialized.as_slice())?;
 | 
			
		||||
            let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
            lists_table.insert(key, encrypted.as_slice())?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
@@ -690,7 +894,7 @@ impl Storage {
 | 
			
		||||
 | 
			
		||||
    pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<u64, DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        let mut new_len = 0u64;
 | 
			
		||||
        let new_len;
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
@@ -712,7 +916,10 @@ impl Storage {
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let mut list_value: ListValue = match lists_table.get(key)? {
 | 
			
		||||
                Some(data) => bincode::deserialize(data.value())?,
 | 
			
		||||
                Some(data) => {
 | 
			
		||||
                    let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                    bincode::deserialize(&decrypted)?
 | 
			
		||||
                },
 | 
			
		||||
                None => ListValue { elements: Vec::new() },
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
@@ -722,7 +929,8 @@ impl Storage {
 | 
			
		||||
            new_len = list_value.elements.len() as u64;
 | 
			
		||||
 | 
			
		||||
            let serialized = bincode::serialize(&list_value)?;
 | 
			
		||||
            lists_table.insert(key, serialized.as_slice())?;
 | 
			
		||||
            let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
            lists_table.insert(key, encrypted.as_slice())?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
@@ -748,7 +956,10 @@ impl Storage {
 | 
			
		||||
                }
 | 
			
		||||
                Some(_) => {
 | 
			
		||||
                    let mut list_value: ListValue = match lists_table.get(key)? {
 | 
			
		||||
                        Some(data) => bincode::deserialize(data.value())?,
 | 
			
		||||
                        Some(data) => {
 | 
			
		||||
                            let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                            bincode::deserialize(&decrypted)?
 | 
			
		||||
                        },
 | 
			
		||||
                        None => return Ok(None), // Key exists but list is empty (shouldn't happen if type is "list")
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
@@ -766,7 +977,8 @@ impl Storage {
 | 
			
		||||
                        types_table.remove(key)?;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        let serialized = bincode::serialize(&list_value)?;
 | 
			
		||||
                        lists_table.insert(key, serialized.as_slice())?;
 | 
			
		||||
                        let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
                        lists_table.insert(key, encrypted.as_slice())?;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                None => return Ok(None),
 | 
			
		||||
@@ -800,7 +1012,10 @@ impl Storage {
 | 
			
		||||
                }
 | 
			
		||||
                Some(_) => {
 | 
			
		||||
                    let mut list_value: ListValue = match lists_table.get(key)? {
 | 
			
		||||
                        Some(data) => bincode::deserialize(data.value())?,
 | 
			
		||||
                        Some(data) => {
 | 
			
		||||
                            let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                            bincode::deserialize(&decrypted)?
 | 
			
		||||
                        }
 | 
			
		||||
                        None => return Ok(None),
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
@@ -818,7 +1033,8 @@ impl Storage {
 | 
			
		||||
                        types_table.remove(key)?;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        let serialized = bincode::serialize(&list_value)?;
 | 
			
		||||
                        lists_table.insert(key, serialized.as_slice())?;
 | 
			
		||||
                        let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
                        lists_table.insert(key, encrypted.as_slice())?;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                None => return Ok(None),
 | 
			
		||||
@@ -842,7 +1058,8 @@ impl Storage {
 | 
			
		||||
                let lists_table = read_txn.open_table(LISTS_TABLE)?;
 | 
			
		||||
                match lists_table.get(key)? {
 | 
			
		||||
                    Some(data) => {
 | 
			
		||||
                        let list_value: ListValue = bincode::deserialize(data.value())?;
 | 
			
		||||
                        let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                        let list_value: ListValue = bincode::deserialize(&decrypted)?;
 | 
			
		||||
                        Ok(list_value.elements.len() as u64)
 | 
			
		||||
                    }
 | 
			
		||||
                    None => Ok(0), // Key exists but list is empty
 | 
			
		||||
@@ -855,7 +1072,7 @@ impl Storage {
 | 
			
		||||
    
 | 
			
		||||
    pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<u64, DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        let mut removed_count = 0u64;
 | 
			
		||||
        let removed_count;
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
@@ -872,7 +1089,10 @@ impl Storage {
 | 
			
		||||
                }
 | 
			
		||||
                Some(_) => {
 | 
			
		||||
                    let mut list_value: ListValue = match lists_table.get(key)? {
 | 
			
		||||
                        Some(data) => bincode::deserialize(data.value())?,
 | 
			
		||||
                        Some(data) => {
 | 
			
		||||
                            let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                            bincode::deserialize(&decrypted)?
 | 
			
		||||
                        }
 | 
			
		||||
                        None => return Ok(0),
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
@@ -910,7 +1130,8 @@ impl Storage {
 | 
			
		||||
                        types_table.remove(key)?;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        let serialized = bincode::serialize(&list_value)?;
 | 
			
		||||
                        lists_table.insert(key, serialized.as_slice())?;
 | 
			
		||||
                        let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
                        lists_table.insert(key, encrypted.as_slice())?;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                None => return Ok(0),
 | 
			
		||||
@@ -939,7 +1160,10 @@ impl Storage {
 | 
			
		||||
                }
 | 
			
		||||
                Some(_) => {
 | 
			
		||||
                    let mut list_value: ListValue = match lists_table.get(key)? {
 | 
			
		||||
                        Some(data) => bincode::deserialize(data.value())?,
 | 
			
		||||
                        Some(data) => {
 | 
			
		||||
                            let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                            bincode::deserialize(&decrypted)?
 | 
			
		||||
                        }
 | 
			
		||||
                        None => return Ok(()),
 | 
			
		||||
                    };
 | 
			
		||||
 | 
			
		||||
@@ -974,7 +1198,8 @@ impl Storage {
 | 
			
		||||
                        types_table.remove(key)?;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        let serialized = bincode::serialize(&list_value)?;
 | 
			
		||||
                        lists_table.insert(key, serialized.as_slice())?;
 | 
			
		||||
                        let encrypted = self.encrypt_if_needed(&serialized)?;
 | 
			
		||||
                        lists_table.insert(key, encrypted.as_slice())?;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                None => {}
 | 
			
		||||
@@ -994,7 +1219,8 @@ impl Storage {
 | 
			
		||||
                let lists_table = read_txn.open_table(LISTS_TABLE)?;
 | 
			
		||||
                match lists_table.get(key)? {
 | 
			
		||||
                    Some(data) => {
 | 
			
		||||
                        let list_value: ListValue = bincode::deserialize(data.value())?;
 | 
			
		||||
                        let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                        let list_value: ListValue = bincode::deserialize(&decrypted)?;
 | 
			
		||||
                        let len = list_value.elements.len() as i64;
 | 
			
		||||
                        let mut index = index;
 | 
			
		||||
                        if index < 0 {
 | 
			
		||||
@@ -1023,7 +1249,8 @@ impl Storage {
 | 
			
		||||
                let lists_table = read_txn.open_table(LISTS_TABLE)?;
 | 
			
		||||
                match lists_table.get(key)? {
 | 
			
		||||
                    Some(data) => {
 | 
			
		||||
                        let list_value: ListValue = bincode::deserialize(data.value())?;
 | 
			
		||||
                        let decrypted = self.decrypt_if_needed(data.value())?;
 | 
			
		||||
                        let list_value: ListValue = bincode::deserialize(&decrypted)?;
 | 
			
		||||
                        let len = list_value.elements.len() as i64;
 | 
			
		||||
                        let mut start = start;
 | 
			
		||||
                        let mut stop = stop;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user