...
This commit is contained in:
		
							
								
								
									
										645
									
								
								src/cmd.rs
									
									
									
									
									
								
							
							
						
						
									
										645
									
								
								src/cmd.rs
									
									
									
									
									
								
							@@ -1,8 +1,4 @@
 | 
			
		||||
use std::{collections::BTreeMap, ops::Bound, time::Duration, u64};
 | 
			
		||||
 | 
			
		||||
use tokio::sync::mpsc;
 | 
			
		||||
 | 
			
		||||
use crate::{error::DBError, protocol::Protocol, server::Server, storage::now_in_millis};
 | 
			
		||||
use crate::{error::DBError, protocol::Protocol, server::Server};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub enum Cmd {
 | 
			
		||||
@@ -16,17 +12,23 @@ pub enum Cmd {
 | 
			
		||||
    ConfigGet(String),
 | 
			
		||||
    Info(Option<String>),
 | 
			
		||||
    Del(String),
 | 
			
		||||
    Replconf(String),
 | 
			
		||||
    Psync,
 | 
			
		||||
    Type(String),
 | 
			
		||||
    Xadd(String, String, Vec<(String, String)>),
 | 
			
		||||
    Xrange(String, String, String),
 | 
			
		||||
    Xread(Vec<String>, Vec<String>, Option<u64>),
 | 
			
		||||
    Incr(String),
 | 
			
		||||
    Multi,
 | 
			
		||||
    Exec,
 | 
			
		||||
    Unknow,
 | 
			
		||||
    Discard,
 | 
			
		||||
    // Hash commands
 | 
			
		||||
    HSet(String, Vec<(String, String)>),
 | 
			
		||||
    HGet(String, String),
 | 
			
		||||
    HGetAll(String),
 | 
			
		||||
    HDel(String, Vec<String>),
 | 
			
		||||
    HExists(String, String),
 | 
			
		||||
    HKeys(String),
 | 
			
		||||
    HVals(String),
 | 
			
		||||
    HLen(String),
 | 
			
		||||
    HMGet(String, Vec<String>),
 | 
			
		||||
    HSetNx(String, String, String),
 | 
			
		||||
    Unknow,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Cmd {
 | 
			
		||||
@@ -39,14 +41,14 @@ impl Cmd {
 | 
			
		||||
                    return Err(DBError("cmd length is 0".to_string()));
 | 
			
		||||
                }
 | 
			
		||||
                Ok((
 | 
			
		||||
                    match cmd[0].as_str() {
 | 
			
		||||
                    match cmd[0].to_lowercase().as_str() {
 | 
			
		||||
                        "echo" => Cmd::Echo(cmd[1].clone()),
 | 
			
		||||
                        "ping" => Cmd::Ping,
 | 
			
		||||
                        "get" => Cmd::Get(cmd[1].clone()),
 | 
			
		||||
                        "set" => {
 | 
			
		||||
                            if cmd.len() == 5 && cmd[3] == "px" {
 | 
			
		||||
                            if cmd.len() == 5 && cmd[3].to_lowercase() == "px" {
 | 
			
		||||
                                Cmd::SetPx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
 | 
			
		||||
                            } else if cmd.len() == 5 && cmd[3] == "ex" {
 | 
			
		||||
                            } else if cmd.len() == 5 && cmd[3].to_lowercase() == "ex" {
 | 
			
		||||
                                Cmd::SetEx(cmd[1].clone(), cmd[2].clone(), cmd[4].parse().unwrap())
 | 
			
		||||
                            } else if cmd.len() == 3 {
 | 
			
		||||
                                Cmd::Set(cmd[1].clone(), cmd[2].clone())
 | 
			
		||||
@@ -55,7 +57,7 @@ impl Cmd {
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        "config" => {
 | 
			
		||||
                            if cmd.len() != 3 || cmd[1] != "get" {
 | 
			
		||||
                            if cmd.len() != 3 || cmd[1].to_lowercase() != "get" {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
                            } else {
 | 
			
		||||
                                Cmd::ConfigGet(cmd[2].clone())
 | 
			
		||||
@@ -76,18 +78,6 @@ impl Cmd {
 | 
			
		||||
                            };
 | 
			
		||||
                            Cmd::Info(section)
 | 
			
		||||
                        }
 | 
			
		||||
                        "replconf" => {
 | 
			
		||||
                            if cmd.len() < 3 {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::Replconf(cmd[1].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "psync" => {
 | 
			
		||||
                            if cmd.len() != 3 {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::Psync
 | 
			
		||||
                        }
 | 
			
		||||
                        "del" => {
 | 
			
		||||
                            if cmd.len() != 2 {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
@@ -100,44 +90,6 @@ impl Cmd {
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::Type(cmd[1].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "xadd" => {
 | 
			
		||||
                            if cmd.len() < 5 {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            let mut key_value = Vec::<(String, String)>::new();
 | 
			
		||||
                            let mut i = 3;
 | 
			
		||||
                            while i < cmd.len() - 1 {
 | 
			
		||||
                                key_value.push((cmd[i].clone(), cmd[i + 1].clone()));
 | 
			
		||||
                                i += 2;
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::Xadd(cmd[1].clone(), cmd[2].clone(), key_value)
 | 
			
		||||
                        }
 | 
			
		||||
                        "xrange" => {
 | 
			
		||||
                            if cmd.len() != 4 {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::Xrange(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "xread" => {
 | 
			
		||||
                            if cmd.len() < 4 || cmd.len() % 2 != 0 {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
                            }
 | 
			
		||||
                            let mut offset = 2;
 | 
			
		||||
                            // block cmd
 | 
			
		||||
                            let mut block = None;
 | 
			
		||||
                            if cmd[1] == "block" {
 | 
			
		||||
                                offset += 2;
 | 
			
		||||
                                if let Ok(block_time) = cmd[2].parse() {
 | 
			
		||||
                                    block = Some(block_time);
 | 
			
		||||
                                } else {
 | 
			
		||||
                                    return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                            let cmd2 = &cmd[offset..];
 | 
			
		||||
                            let len2 = cmd2.len() / 2;
 | 
			
		||||
                            Cmd::Xread(cmd2[0..len2].to_vec(), cmd2[len2..].to_vec(), block)
 | 
			
		||||
                        }
 | 
			
		||||
                        "incr" => {
 | 
			
		||||
                            if cmd.len() != 2 {
 | 
			
		||||
                                return Err(DBError(format!("unsupported cmd {:?}", cmd)));
 | 
			
		||||
@@ -157,6 +109,73 @@ impl Cmd {
 | 
			
		||||
                            Cmd::Exec
 | 
			
		||||
                        }
 | 
			
		||||
                        "discard" => Cmd::Discard,
 | 
			
		||||
                        // Hash commands
 | 
			
		||||
                        "hset" => {
 | 
			
		||||
                            if cmd.len() < 4 || (cmd.len() - 2) % 2 != 0 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HSET command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            let mut pairs = Vec::new();
 | 
			
		||||
                            let mut i = 2;
 | 
			
		||||
                            while i < cmd.len() - 1 {
 | 
			
		||||
                                pairs.push((cmd[i].clone(), cmd[i + 1].clone()));
 | 
			
		||||
                                i += 2;
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HSet(cmd[1].clone(), pairs)
 | 
			
		||||
                        }
 | 
			
		||||
                        "hget" => {
 | 
			
		||||
                            if cmd.len() != 3 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HGET command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HGet(cmd[1].clone(), cmd[2].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hgetall" => {
 | 
			
		||||
                            if cmd.len() != 2 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HGETALL command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HGetAll(cmd[1].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hdel" => {
 | 
			
		||||
                            if cmd.len() < 3 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HDEL command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HDel(cmd[1].clone(), cmd[2..].to_vec())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hexists" => {
 | 
			
		||||
                            if cmd.len() != 3 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HEXISTS command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HExists(cmd[1].clone(), cmd[2].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hkeys" => {
 | 
			
		||||
                            if cmd.len() != 2 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HKEYS command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HKeys(cmd[1].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hvals" => {
 | 
			
		||||
                            if cmd.len() != 2 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HVALS command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HVals(cmd[1].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hlen" => {
 | 
			
		||||
                            if cmd.len() != 2 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HLEN command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HLen(cmd[1].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hmget" => {
 | 
			
		||||
                            if cmd.len() < 3 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HMGET command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HMGet(cmd[1].clone(), cmd[2..].to_vec())
 | 
			
		||||
                        }
 | 
			
		||||
                        "hsetnx" => {
 | 
			
		||||
                            if cmd.len() != 4 {
 | 
			
		||||
                                return Err(DBError(format!("wrong number of arguments for HSETNX command")));
 | 
			
		||||
                            }
 | 
			
		||||
                            Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
 | 
			
		||||
                        }
 | 
			
		||||
                        _ => Cmd::Unknow,
 | 
			
		||||
                    },
 | 
			
		||||
                    protocol.0,
 | 
			
		||||
@@ -171,13 +190,11 @@ impl Cmd {
 | 
			
		||||
 | 
			
		||||
    pub async fn run(
 | 
			
		||||
        &self,
 | 
			
		||||
        server: &mut Server,
 | 
			
		||||
        server: &Server,
 | 
			
		||||
        protocol: Protocol,
 | 
			
		||||
        is_rep_con: bool,
 | 
			
		||||
        queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
 | 
			
		||||
    ) -> Result<Protocol, DBError> {
 | 
			
		||||
        // return if the command is a write command
 | 
			
		||||
        let p = protocol.clone();
 | 
			
		||||
        // Handle queued commands for transactions
 | 
			
		||||
        if queued_cmd.is_some()
 | 
			
		||||
            && !matches!(self, Cmd::Exec)
 | 
			
		||||
            && !matches!(self, Cmd::Multi)
 | 
			
		||||
@@ -189,71 +206,57 @@ impl Cmd {
 | 
			
		||||
                .push((self.clone(), protocol.clone()));
 | 
			
		||||
            return Ok(Protocol::SimpleString("QUEUED".to_string()));
 | 
			
		||||
        }
 | 
			
		||||
        let ret = match self {
 | 
			
		||||
 | 
			
		||||
        match self {
 | 
			
		||||
            Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
 | 
			
		||||
            Cmd::Echo(s) => Ok(Protocol::SimpleString(s.clone())),
 | 
			
		||||
            Cmd::Get(k) => get_cmd(server, k).await,
 | 
			
		||||
            Cmd::Set(k, v) => set_cmd(server, k, v, protocol, is_rep_con).await,
 | 
			
		||||
            Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x, protocol, is_rep_con).await,
 | 
			
		||||
            Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x, protocol, is_rep_con).await,
 | 
			
		||||
            Cmd::Del(k) => del_cmd(server, k, protocol, is_rep_con).await,
 | 
			
		||||
            Cmd::Set(k, v) => set_cmd(server, k, v).await,
 | 
			
		||||
            Cmd::SetPx(k, v, x) => set_px_cmd(server, k, v, x).await,
 | 
			
		||||
            Cmd::SetEx(k, v, x) => set_ex_cmd(server, k, v, x).await,
 | 
			
		||||
            Cmd::Del(k) => del_cmd(server, k).await,
 | 
			
		||||
            Cmd::ConfigGet(name) => config_get_cmd(name, server),
 | 
			
		||||
            Cmd::Keys => keys_cmd(server).await,
 | 
			
		||||
            Cmd::Info(section) => info_cmd(section, server),
 | 
			
		||||
            Cmd::Replconf(sub_cmd) => replconf_cmd(sub_cmd, server),
 | 
			
		||||
            Cmd::Psync => psync_cmd(server),
 | 
			
		||||
            Cmd::Info(section) => info_cmd(section),
 | 
			
		||||
            Cmd::Type(k) => type_cmd(server, k).await,
 | 
			
		||||
            Cmd::Xadd(stream_key, offset, kvps) => {
 | 
			
		||||
                xadd_cmd(
 | 
			
		||||
                    offset.as_str(),
 | 
			
		||||
                    server,
 | 
			
		||||
                    stream_key.as_str(),
 | 
			
		||||
                    kvps,
 | 
			
		||||
                    protocol,
 | 
			
		||||
                    is_rep_con,
 | 
			
		||||
                )
 | 
			
		||||
                .await
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            Cmd::Xrange(stream_key, start, end) => xrange_cmd(server, stream_key, start, end).await,
 | 
			
		||||
            Cmd::Xread(stream_keys, starts, block) => {
 | 
			
		||||
                xread_cmd(starts, server, stream_keys, block).await
 | 
			
		||||
            }
 | 
			
		||||
            Cmd::Incr(key) => incr_cmd(server, key).await,
 | 
			
		||||
            Cmd::Multi => {
 | 
			
		||||
                *queued_cmd = Some(Vec::<(Cmd, Protocol)>::new());
 | 
			
		||||
                Ok(Protocol::SimpleString("ok".to_string()))
 | 
			
		||||
                Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
            }
 | 
			
		||||
            Cmd::Exec => exec_cmd(queued_cmd, server, is_rep_con).await,
 | 
			
		||||
            Cmd::Exec => exec_cmd(queued_cmd, server).await,
 | 
			
		||||
            Cmd::Discard => {
 | 
			
		||||
                if queued_cmd.is_some() {
 | 
			
		||||
                    *queued_cmd = None;
 | 
			
		||||
                    Ok(Protocol::SimpleString("ok".to_string()))
 | 
			
		||||
                    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
                } else {
 | 
			
		||||
                    Ok(Protocol::err("ERR Discard without MULTI"))
 | 
			
		||||
                    Ok(Protocol::err("ERR DISCARD without MULTI"))
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Cmd::Unknow => Ok(Protocol::err("unknow cmd")),
 | 
			
		||||
        };
 | 
			
		||||
        if ret.is_ok() {
 | 
			
		||||
            server.offset.fetch_add(
 | 
			
		||||
                p.encode().len() as u64,
 | 
			
		||||
                std::sync::atomic::Ordering::Relaxed,
 | 
			
		||||
            );
 | 
			
		||||
            // Hash commands
 | 
			
		||||
            Cmd::HSet(key, pairs) => hset_cmd(server, key, pairs).await,
 | 
			
		||||
            Cmd::HGet(key, field) => hget_cmd(server, key, field).await,
 | 
			
		||||
            Cmd::HGetAll(key) => hgetall_cmd(server, key).await,
 | 
			
		||||
            Cmd::HDel(key, fields) => hdel_cmd(server, key, fields).await,
 | 
			
		||||
            Cmd::HExists(key, field) => hexists_cmd(server, key, field).await,
 | 
			
		||||
            Cmd::HKeys(key) => hkeys_cmd(server, key).await,
 | 
			
		||||
            Cmd::HVals(key) => hvals_cmd(server, key).await,
 | 
			
		||||
            Cmd::HLen(key) => hlen_cmd(server, key).await,
 | 
			
		||||
            Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
 | 
			
		||||
            Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
 | 
			
		||||
            Cmd::Unknow => Ok(Protocol::err("unknown cmd")),
 | 
			
		||||
        }
 | 
			
		||||
        ret
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn exec_cmd(
 | 
			
		||||
    queued_cmd: &mut Option<Vec<(Cmd, Protocol)>>,
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    is_rep_con: bool,
 | 
			
		||||
    server: &Server,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    if queued_cmd.is_some() {
 | 
			
		||||
        let mut vec = Vec::new();
 | 
			
		||||
        for (cmd, protocol) in queued_cmd.as_ref().unwrap() {
 | 
			
		||||
            let res = Box::pin(cmd.run(server, protocol.clone(), is_rep_con, &mut None)).await?;
 | 
			
		||||
            let res = Box::pin(cmd.run(server, protocol.clone(), &mut None)).await?;
 | 
			
		||||
            vec.push(res);
 | 
			
		||||
        }
 | 
			
		||||
        *queued_cmd = None;
 | 
			
		||||
@@ -263,22 +266,24 @@ async fn exec_cmd(
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn incr_cmd(server: &mut Server, key: &String) -> Result<Protocol, DBError> {
 | 
			
		||||
    let mut storage = server.storage.lock().await;
 | 
			
		||||
    let v = storage.get(key);
 | 
			
		||||
    // return 1 if key is missing
 | 
			
		||||
    let v = v.map_or("1".to_string(), |v| v);
 | 
			
		||||
 | 
			
		||||
    if let Ok(x) = v.parse::<u64>() {
 | 
			
		||||
        let v = (x + 1).to_string();
 | 
			
		||||
        storage.set(key.clone(), v.clone());
 | 
			
		||||
        Ok(Protocol::SimpleString(v))
 | 
			
		||||
    } else {
 | 
			
		||||
        Ok(Protocol::err("ERR value is not an integer or out of range"))
 | 
			
		||||
    }
 | 
			
		||||
async fn incr_cmd(server: &Server, key: &String) -> Result<Protocol, DBError> {
 | 
			
		||||
    let current_value = server.storage.get(key)?;
 | 
			
		||||
    
 | 
			
		||||
    let new_value = match current_value {
 | 
			
		||||
        Some(v) => {
 | 
			
		||||
            match v.parse::<i64>() {
 | 
			
		||||
                Ok(num) => num + 1,
 | 
			
		||||
                Err(_) => return Ok(Protocol::err("ERR value is not an integer or out of range")),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        None => 1,
 | 
			
		||||
    };
 | 
			
		||||
    
 | 
			
		||||
    server.storage.set(key.clone(), new_value.to_string())?;
 | 
			
		||||
    Ok(Protocol::SimpleString(new_value.to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn config_get_cmd(name: &String, server: &mut Server) -> Result<Protocol, DBError> {
 | 
			
		||||
fn config_get_cmd(name: &String, server: &Server) -> Result<Protocol, DBError> {
 | 
			
		||||
    match name.as_str() {
 | 
			
		||||
        "dir" => Ok(Protocol::Array(vec![
 | 
			
		||||
            Protocol::BulkString(name.clone()),
 | 
			
		||||
@@ -286,336 +291,156 @@ fn config_get_cmd(name: &String, server: &mut Server) -> Result<Protocol, DBErro
 | 
			
		||||
        ])),
 | 
			
		||||
        "dbfilename" => Ok(Protocol::Array(vec![
 | 
			
		||||
            Protocol::BulkString(name.clone()),
 | 
			
		||||
            Protocol::BulkString(server.option.db_file_name.clone()),
 | 
			
		||||
            Protocol::BulkString("herodb.redb".to_string()),
 | 
			
		||||
        ])),
 | 
			
		||||
        _ => Err(DBError(format!("unsupported config {:?}", name))),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn keys_cmd(server: &mut Server) -> Result<Protocol, DBError> {
 | 
			
		||||
    let keys = { server.storage.lock().await.keys() };
 | 
			
		||||
async fn keys_cmd(server: &Server) -> Result<Protocol, DBError> {
 | 
			
		||||
    let keys = server.storage.keys("*")?;
 | 
			
		||||
    Ok(Protocol::Array(
 | 
			
		||||
        keys.into_iter().map(Protocol::BulkString).collect(),
 | 
			
		||||
    ))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn info_cmd(section: &Option<String>, server: &mut Server) -> Result<Protocol, DBError> {
 | 
			
		||||
fn info_cmd(section: &Option<String>) -> Result<Protocol, DBError> {
 | 
			
		||||
    match section {
 | 
			
		||||
        Some(s) => match s.as_str() {
 | 
			
		||||
            "replication" => Ok(Protocol::BulkString(format!(
 | 
			
		||||
                "role:{}\nmaster_replid:{}\nmaster_repl_offset:{}\n",
 | 
			
		||||
                server.option.replication.role,
 | 
			
		||||
                server.option.replication.master_replid,
 | 
			
		||||
                server.option.replication.master_repl_offset
 | 
			
		||||
            ))),
 | 
			
		||||
            "replication" => Ok(Protocol::BulkString(
 | 
			
		||||
                "role:master\nmaster_replid:8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea\nmaster_repl_offset:0\n".to_string()
 | 
			
		||||
            )),
 | 
			
		||||
            _ => Err(DBError(format!("unsupported section {:?}", s))),
 | 
			
		||||
        },
 | 
			
		||||
        None => Ok(Protocol::BulkString("default".to_string())),
 | 
			
		||||
        None => Ok(Protocol::BulkString("# Server\nredis_version:7.0.0\n".to_string())),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn xread_cmd(
 | 
			
		||||
    starts: &[String],
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    stream_keys: &[String],
 | 
			
		||||
    block_millis: &Option<u64>,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    if let Some(t) = block_millis {
 | 
			
		||||
        if t > &0 {
 | 
			
		||||
            tokio::time::sleep(Duration::from_millis(*t)).await;
 | 
			
		||||
        } else {
 | 
			
		||||
            let (sender, mut receiver) = mpsc::channel(4);
 | 
			
		||||
            {
 | 
			
		||||
                let mut blocker = server.stream_reader_blocker.lock().await;
 | 
			
		||||
                blocker.push(sender.clone());
 | 
			
		||||
            }
 | 
			
		||||
            while let Some(_) = receiver.recv().await {
 | 
			
		||||
                println!("get new xadd cmd, release block");
 | 
			
		||||
                // break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    let streams = server.streams.lock().await;
 | 
			
		||||
    let mut ret = Vec::new();
 | 
			
		||||
    for (i, stream_key) in stream_keys.iter().enumerate() {
 | 
			
		||||
        let stream = streams.get(stream_key);
 | 
			
		||||
        if let Some(s) = stream {
 | 
			
		||||
            let (offset_id, mut offset_seq, _) = split_offset(starts[i].as_str());
 | 
			
		||||
            offset_seq += 1;
 | 
			
		||||
            let start = format!("{}-{}", offset_id, offset_seq);
 | 
			
		||||
            let end = format!("{}-{}", u64::MAX - 1, 0);
 | 
			
		||||
 | 
			
		||||
            // query stream range
 | 
			
		||||
            let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
 | 
			
		||||
            let mut array = Vec::new();
 | 
			
		||||
            for (k, v) in range {
 | 
			
		||||
                array.push(Protocol::BulkString(k.clone()));
 | 
			
		||||
                array.push(Protocol::from_vec(
 | 
			
		||||
                    v.iter()
 | 
			
		||||
                        .flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
 | 
			
		||||
                        .collect(),
 | 
			
		||||
                ))
 | 
			
		||||
            }
 | 
			
		||||
            ret.push(Protocol::BulkString(stream_key.clone()));
 | 
			
		||||
            ret.push(Protocol::Array(array));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    Ok(Protocol::Array(ret))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn replconf_cmd(sub_cmd: &str, server: &mut Server) -> Result<Protocol, DBError> {
 | 
			
		||||
    match sub_cmd {
 | 
			
		||||
        "getack" => Ok(Protocol::from_vec(vec![
 | 
			
		||||
            "REPLCONF",
 | 
			
		||||
            "ACK",
 | 
			
		||||
            server
 | 
			
		||||
                .offset
 | 
			
		||||
                .load(std::sync::atomic::Ordering::Relaxed)
 | 
			
		||||
                .to_string()
 | 
			
		||||
                .as_str(),
 | 
			
		||||
        ])),
 | 
			
		||||
        _ => Ok(Protocol::SimpleString("OK".to_string())),
 | 
			
		||||
async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.get_key_type(k)? {
 | 
			
		||||
        Some(type_str) => Ok(Protocol::SimpleString(type_str)),
 | 
			
		||||
        None => Ok(Protocol::SimpleString("none".to_string())),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn xrange_cmd(
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    stream_key: &String,
 | 
			
		||||
    start: &String,
 | 
			
		||||
    end: &String,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    let streams = server.streams.lock().await;
 | 
			
		||||
    let stream = streams.get(stream_key);
 | 
			
		||||
    Ok(stream.map_or(Protocol::none(), |s| {
 | 
			
		||||
        // support query with '-'
 | 
			
		||||
        let start = if start == "-" {
 | 
			
		||||
            "0".to_string()
 | 
			
		||||
        } else {
 | 
			
		||||
            start.clone()
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        // support query with '+'
 | 
			
		||||
        let end = if end == "+" {
 | 
			
		||||
            u64::MAX.to_string()
 | 
			
		||||
        } else {
 | 
			
		||||
            end.clone()
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        // query stream range
 | 
			
		||||
        let range = s.range::<String, _>((Bound::Included(&start), Bound::Included(&end)));
 | 
			
		||||
        let mut array = Vec::new();
 | 
			
		||||
        for (k, v) in range {
 | 
			
		||||
            array.push(Protocol::BulkString(k.clone()));
 | 
			
		||||
            array.push(Protocol::from_vec(
 | 
			
		||||
                v.iter()
 | 
			
		||||
                    .flat_map(|(a, b)| vec![a.as_str(), b.as_str()])
 | 
			
		||||
                    .collect(),
 | 
			
		||||
            ))
 | 
			
		||||
        }
 | 
			
		||||
        println!("after xrange: {:?}", array);
 | 
			
		||||
        Protocol::Array(array)
 | 
			
		||||
    }))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn xadd_cmd(
 | 
			
		||||
    offset: &str,
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    stream_key: &str,
 | 
			
		||||
    kvps: &Vec<(String, String)>,
 | 
			
		||||
    protocol: Protocol,
 | 
			
		||||
    is_rep_con: bool,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    let mut offset = offset.to_string();
 | 
			
		||||
    if offset == "*" {
 | 
			
		||||
        offset = format!("{}-*", now_in_millis() as u64);
 | 
			
		||||
    }
 | 
			
		||||
    let (offset_id, mut offset_seq, has_wildcard) = split_offset(offset.as_str());
 | 
			
		||||
    if offset_id == 0 && offset_seq == 0 && !has_wildcard {
 | 
			
		||||
        return Ok(Protocol::err(
 | 
			
		||||
            "ERR The ID specified in XADD must be greater than 0-0",
 | 
			
		||||
        ));
 | 
			
		||||
    }
 | 
			
		||||
    {
 | 
			
		||||
        let mut streams = server.streams.lock().await;
 | 
			
		||||
        let stream = streams
 | 
			
		||||
            .entry(stream_key.to_string())
 | 
			
		||||
            .or_insert_with(BTreeMap::new);
 | 
			
		||||
 | 
			
		||||
        if let Some((last_offset, _)) = stream.last_key_value() {
 | 
			
		||||
            let (last_offset_id, last_offset_seq, _) = split_offset(last_offset.as_str());
 | 
			
		||||
            if last_offset_id > offset_id
 | 
			
		||||
                || (last_offset_id == offset_id && last_offset_seq >= offset_seq && !has_wildcard)
 | 
			
		||||
            {
 | 
			
		||||
                return Ok(Protocol::err("ERR The ID specified in XADD is equal or smaller than the target stream top item"));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if last_offset_id == offset_id && last_offset_seq >= offset_seq && has_wildcard {
 | 
			
		||||
                offset_seq = last_offset_seq + 1;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let offset = format!("{}-{}", offset_id, offset_seq);
 | 
			
		||||
 | 
			
		||||
        let s = stream.entry(offset.clone()).or_insert_with(Vec::new);
 | 
			
		||||
        for (key, value) in kvps {
 | 
			
		||||
            s.push((key.clone(), value.clone()));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    {
 | 
			
		||||
        let mut blocker = server.stream_reader_blocker.lock().await;
 | 
			
		||||
        for sender in blocker.iter() {
 | 
			
		||||
            sender.send(()).await?;
 | 
			
		||||
        }
 | 
			
		||||
        blocker.clear();
 | 
			
		||||
    }
 | 
			
		||||
    resp_and_replicate(
 | 
			
		||||
        server,
 | 
			
		||||
        Protocol::BulkString(offset.to_string()),
 | 
			
		||||
        protocol,
 | 
			
		||||
        is_rep_con,
 | 
			
		||||
    )
 | 
			
		||||
    .await
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn type_cmd(server: &mut Server, k: &String) -> Result<Protocol, DBError> {
 | 
			
		||||
    let v = { server.storage.lock().await.get(k) };
 | 
			
		||||
    if v.is_some() {
 | 
			
		||||
        return Ok(Protocol::SimpleString("string".to_string()));
 | 
			
		||||
    }
 | 
			
		||||
    let streams = server.streams.lock().await;
 | 
			
		||||
    let v = streams.get(k);
 | 
			
		||||
    Ok(v.map_or(Protocol::none(), |_| {
 | 
			
		||||
        Protocol::SimpleString("stream".to_string())
 | 
			
		||||
    }))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn psync_cmd(server: &mut Server) -> Result<Protocol, DBError> {
 | 
			
		||||
    if server.is_master() {
 | 
			
		||||
        Ok(Protocol::SimpleString(format!(
 | 
			
		||||
            "FULLRESYNC {} 0",
 | 
			
		||||
            server.option.replication.master_replid
 | 
			
		||||
        )))
 | 
			
		||||
    } else {
 | 
			
		||||
        Ok(Protocol::psync_on_slave_err())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn del_cmd(
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    k: &str,
 | 
			
		||||
    protocol: Protocol,
 | 
			
		||||
    is_rep_con: bool,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    // offset
 | 
			
		||||
    let _ = {
 | 
			
		||||
        let mut s = server.storage.lock().await;
 | 
			
		||||
        s.del(k.to_string());
 | 
			
		||||
        server
 | 
			
		||||
            .offset
 | 
			
		||||
            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
 | 
			
		||||
    };
 | 
			
		||||
    resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
 | 
			
		||||
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    server.storage.del(k.to_string())?;
 | 
			
		||||
    Ok(Protocol::SimpleString("1".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn set_ex_cmd(
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    server: &Server,
 | 
			
		||||
    k: &str,
 | 
			
		||||
    v: &str,
 | 
			
		||||
    x: &u128,
 | 
			
		||||
    protocol: Protocol,
 | 
			
		||||
    is_rep_con: bool,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    // offset
 | 
			
		||||
    let _ = {
 | 
			
		||||
        let mut s = server.storage.lock().await;
 | 
			
		||||
        s.setx(k.to_string(), v.to_string(), *x * 1000);
 | 
			
		||||
        server
 | 
			
		||||
            .offset
 | 
			
		||||
            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
 | 
			
		||||
    };
 | 
			
		||||
    resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
 | 
			
		||||
    server.storage.setx(k.to_string(), v.to_string(), *x * 1000)?;
 | 
			
		||||
    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn set_px_cmd(
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    server: &Server,
 | 
			
		||||
    k: &str,
 | 
			
		||||
    v: &str,
 | 
			
		||||
    x: &u128,
 | 
			
		||||
    protocol: Protocol,
 | 
			
		||||
    is_rep_con: bool,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    // offset
 | 
			
		||||
    let _ = {
 | 
			
		||||
        let mut s = server.storage.lock().await;
 | 
			
		||||
        s.setx(k.to_string(), v.to_string(), *x);
 | 
			
		||||
        server
 | 
			
		||||
            .offset
 | 
			
		||||
            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
 | 
			
		||||
    };
 | 
			
		||||
    resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
 | 
			
		||||
    server.storage.setx(k.to_string(), v.to_string(), *x)?;
 | 
			
		||||
    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn set_cmd(
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    k: &str,
 | 
			
		||||
    v: &str,
 | 
			
		||||
    protocol: Protocol,
 | 
			
		||||
    is_rep_con: bool,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    // offset
 | 
			
		||||
    let _ = {
 | 
			
		||||
        let mut s = server.storage.lock().await;
 | 
			
		||||
        s.set(k.to_string(), v.to_string());
 | 
			
		||||
        server
 | 
			
		||||
            .offset
 | 
			
		||||
            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
 | 
			
		||||
            + 1
 | 
			
		||||
    };
 | 
			
		||||
    resp_and_replicate(server, Protocol::ok(), protocol, is_rep_con).await
 | 
			
		||||
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    server.storage.set(k.to_string(), v.to_string())?;
 | 
			
		||||
    Ok(Protocol::SimpleString("OK".to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn get_cmd(server: &mut Server, k: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    let v = {
 | 
			
		||||
        let mut s = server.storage.lock().await;
 | 
			
		||||
        s.get(k)
 | 
			
		||||
    };
 | 
			
		||||
    Ok(v.map_or(Protocol::Null, Protocol::SimpleString))
 | 
			
		||||
async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    let v = server.storage.get(k)?;
 | 
			
		||||
    Ok(v.map_or(Protocol::Null, Protocol::BulkString))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn resp_and_replicate(
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
    resp: Protocol,
 | 
			
		||||
    replication: Protocol,
 | 
			
		||||
    is_rep_con: bool,
 | 
			
		||||
) -> Result<Protocol, DBError> {
 | 
			
		||||
    if server.is_master() {
 | 
			
		||||
        server
 | 
			
		||||
            .master_repl_clients
 | 
			
		||||
            .lock()
 | 
			
		||||
            .await
 | 
			
		||||
            .as_mut()
 | 
			
		||||
            .unwrap()
 | 
			
		||||
            .send_command(replication)
 | 
			
		||||
            .await?;
 | 
			
		||||
        Ok(resp)
 | 
			
		||||
    } else if !is_rep_con {
 | 
			
		||||
        Ok(Protocol::write_on_slave_err())
 | 
			
		||||
    } else {
 | 
			
		||||
        Ok(resp)
 | 
			
		||||
// Hash command implementations
 | 
			
		||||
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
 | 
			
		||||
    let new_fields = server.storage.hset(key, pairs)?;
 | 
			
		||||
    Ok(Protocol::SimpleString(new_fields.to_string()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hget_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hget(key, field) {
 | 
			
		||||
        Ok(Some(value)) => Ok(Protocol::BulkString(value)),
 | 
			
		||||
        Ok(None) => Ok(Protocol::Null),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn split_offset(offset: &str) -> (u64, u64, bool) {
 | 
			
		||||
    let offset_split = offset.split('-').collect::<Vec<_>>();
 | 
			
		||||
    let offset_id = offset_split[0].parse::<u64>().expect(&format!(
 | 
			
		||||
        "ERR The ID specified in XADD must be a number: {}",
 | 
			
		||||
        offset
 | 
			
		||||
    ));
 | 
			
		||||
 | 
			
		||||
    if offset_split.len() == 1 || offset_split[1] == "*" {
 | 
			
		||||
        return (offset_id, if offset_id == 0 { 1 } else { 0 }, true);
 | 
			
		||||
async fn hgetall_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hgetall(key) {
 | 
			
		||||
        Ok(pairs) => {
 | 
			
		||||
            let mut result = Vec::new();
 | 
			
		||||
            for (field, value) in pairs {
 | 
			
		||||
                result.push(Protocol::BulkString(field));
 | 
			
		||||
                result.push(Protocol::BulkString(value));
 | 
			
		||||
            }
 | 
			
		||||
            Ok(Protocol::Array(result))
 | 
			
		||||
        }
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hdel_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hdel(key, fields) {
 | 
			
		||||
        Ok(deleted) => Ok(Protocol::SimpleString(deleted.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hexists_cmd(server: &Server, key: &str, field: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hexists(key, field) {
 | 
			
		||||
        Ok(exists) => Ok(Protocol::SimpleString(if exists { "1" } else { "0" }.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hkeys_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hkeys(key) {
 | 
			
		||||
        Ok(keys) => Ok(Protocol::Array(
 | 
			
		||||
            keys.into_iter().map(Protocol::BulkString).collect(),
 | 
			
		||||
        )),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hvals_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hvals(key) {
 | 
			
		||||
        Ok(values) => Ok(Protocol::Array(
 | 
			
		||||
            values.into_iter().map(Protocol::BulkString).collect(),
 | 
			
		||||
        )),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hlen_cmd(server: &Server, key: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hlen(key) {
 | 
			
		||||
        Ok(len) => Ok(Protocol::SimpleString(len.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hmget_cmd(server: &Server, key: &str, fields: &[String]) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hmget(key, fields) {
 | 
			
		||||
        Ok(values) => {
 | 
			
		||||
            let result: Vec<Protocol> = values
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .map(|v| v.map_or(Protocol::Null, Protocol::BulkString))
 | 
			
		||||
                .collect();
 | 
			
		||||
            Ok(Protocol::Array(result))
 | 
			
		||||
        }
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Result<Protocol, DBError> {
 | 
			
		||||
    match server.storage.hsetnx(key, field, value) {
 | 
			
		||||
        Ok(was_set) => Ok(Protocol::SimpleString(if was_set { "1" } else { "0" }.to_string())),
 | 
			
		||||
        Err(e) => Ok(Protocol::err(&e.0)),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let offset_seq = offset_split[1].parse::<u64>().unwrap();
 | 
			
		||||
    (offset_id, offset_seq, false)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										45
									
								
								src/error.rs
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								src/error.rs
									
									
									
									
									
								
							@@ -1,6 +1,8 @@
 | 
			
		||||
use std::num::ParseIntError;
 | 
			
		||||
 | 
			
		||||
use tokio::sync::mpsc;
 | 
			
		||||
use redb;
 | 
			
		||||
use bincode;
 | 
			
		||||
 | 
			
		||||
use crate::protocol::Protocol;
 | 
			
		||||
 | 
			
		||||
@@ -32,11 +34,48 @@ impl From<std::string::FromUtf8Error> for DBError {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<mpsc::error::SendError<(Protocol, u64)>> for DBError {
 | 
			
		||||
    fn from(item: mpsc::error::SendError<(Protocol, u64)>) -> Self {
 | 
			
		||||
        DBError(item.to_string().clone())
 | 
			
		||||
impl From<redb::Error> for DBError {
 | 
			
		||||
    fn from(item: redb::Error) -> Self {
 | 
			
		||||
        DBError(item.to_string())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<redb::DatabaseError> for DBError {
 | 
			
		||||
    fn from(item: redb::DatabaseError) -> Self {
 | 
			
		||||
        DBError(item.to_string())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<redb::TransactionError> for DBError {
 | 
			
		||||
    fn from(item: redb::TransactionError) -> Self {
 | 
			
		||||
        DBError(item.to_string())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<redb::TableError> for DBError {
 | 
			
		||||
    fn from(item: redb::TableError) -> Self {
 | 
			
		||||
        DBError(item.to_string())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<redb::StorageError> for DBError {
 | 
			
		||||
    fn from(item: redb::StorageError) -> Self {
 | 
			
		||||
        DBError(item.to_string())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<redb::CommitError> for DBError {
 | 
			
		||||
    fn from(item: redb::CommitError) -> Self {
 | 
			
		||||
        DBError(item.to_string())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<Box<bincode::ErrorKind>> for DBError {
 | 
			
		||||
    fn from(item: Box<bincode::ErrorKind>) -> Self {
 | 
			
		||||
        DBError(item.to_string())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
 | 
			
		||||
    fn from(item: mpsc::error::SendError<()>) -> Self {
 | 
			
		||||
        DBError(item.to_string().clone())
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,5 @@ mod cmd;
 | 
			
		||||
pub mod error;
 | 
			
		||||
pub mod options;
 | 
			
		||||
mod protocol;
 | 
			
		||||
mod rdb;
 | 
			
		||||
mod replication_client;
 | 
			
		||||
pub mod server;
 | 
			
		||||
mod storage;
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										44
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										44
									
								
								src/main.rs
									
									
									
									
									
								
							@@ -2,7 +2,7 @@
 | 
			
		||||
 | 
			
		||||
use tokio::net::TcpListener;
 | 
			
		||||
 | 
			
		||||
use redis_rs::{options::ReplicationOption, server};
 | 
			
		||||
use redis_rs::server;
 | 
			
		||||
 | 
			
		||||
use clap::Parser;
 | 
			
		||||
 | 
			
		||||
@@ -14,17 +14,10 @@ struct Args {
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    dir: String,
 | 
			
		||||
 | 
			
		||||
    /// The name of the Redis DB file
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    dbfilename: String,
 | 
			
		||||
 | 
			
		||||
    /// The port of the Redis server, default is 6379 if not specified
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    port: Option<u16>,
 | 
			
		||||
 | 
			
		||||
    /// The address of the master Redis server, if the server is a replica. None if the server is a master.
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    replicaof: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[tokio::main]
 | 
			
		||||
@@ -42,42 +35,11 @@ async fn main() {
 | 
			
		||||
    // new DB option
 | 
			
		||||
    let option = redis_rs::options::DBOption {
 | 
			
		||||
        dir: args.dir,
 | 
			
		||||
        db_file_name: args.dbfilename,
 | 
			
		||||
        port,
 | 
			
		||||
        replication: ReplicationOption {
 | 
			
		||||
            role: if let Some(_) = args.replicaof {
 | 
			
		||||
                "slave".to_string()
 | 
			
		||||
            } else {
 | 
			
		||||
                "master".to_string()
 | 
			
		||||
            },
 | 
			
		||||
            master_replid: "8371b4fb1155b71f4a04d3e1bc3e18c4a990aeea".to_string(), // should be a random string but hard code for now
 | 
			
		||||
            master_repl_offset: 0,
 | 
			
		||||
            replica_of: args.replicaof,
 | 
			
		||||
        },
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // new server
 | 
			
		||||
    let mut server = server::Server::new(option).await;
 | 
			
		||||
 | 
			
		||||
    //start receive replication cmds for slave
 | 
			
		||||
    if server.is_slave() {
 | 
			
		||||
        let mut sc = server.clone();
 | 
			
		||||
 | 
			
		||||
        let mut follower_repl_client = server.get_follower_repl_client().await.unwrap();
 | 
			
		||||
        follower_repl_client.ping_master().await.unwrap();
 | 
			
		||||
        follower_repl_client
 | 
			
		||||
            .report_port(server.option.port)
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        follower_repl_client.report_sync_protocol().await.unwrap();
 | 
			
		||||
        follower_repl_client.start_psync(&mut sc).await.unwrap();
 | 
			
		||||
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            if let Err(e) = sc.handle(follower_repl_client.stream, true).await {
 | 
			
		||||
                println!("error: {:?}, will close the connection. Bye", e);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
    let server = server::Server::new(option).await;
 | 
			
		||||
 | 
			
		||||
    // accept new connections
 | 
			
		||||
    loop {
 | 
			
		||||
@@ -88,7 +50,7 @@ async fn main() {
 | 
			
		||||
 | 
			
		||||
                let mut sc = server.clone();
 | 
			
		||||
                tokio::spawn(async move {
 | 
			
		||||
                    if let Err(e) = sc.handle(stream, false).await {
 | 
			
		||||
                    if let Err(e) = sc.handle(stream).await {
 | 
			
		||||
                        println!("error: {:?}, will close the connection. Bye", e);
 | 
			
		||||
                    }
 | 
			
		||||
                });
 | 
			
		||||
 
 | 
			
		||||
@@ -1,15 +1,5 @@
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct DBOption {
 | 
			
		||||
    pub dir: String,
 | 
			
		||||
    pub db_file_name: String,
 | 
			
		||||
    pub replication: ReplicationOption,
 | 
			
		||||
    pub port: u16,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct ReplicationOption {
 | 
			
		||||
    pub role: String,
 | 
			
		||||
    pub master_replid: String,
 | 
			
		||||
    pub master_repl_offset: u64,
 | 
			
		||||
    pub replica_of: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										201
									
								
								src/rdb.rs
									
									
									
									
									
								
							
							
						
						
									
										201
									
								
								src/rdb.rs
									
									
									
									
									
								
							@@ -1,201 +0,0 @@
 | 
			
		||||
// parse Redis RDB file format: https://rdb.fnordig.de/file_format.html
 | 
			
		||||
 | 
			
		||||
use tokio::{
 | 
			
		||||
    fs,
 | 
			
		||||
    io::{AsyncRead, AsyncReadExt, BufReader},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{error::DBError, server::Server};
 | 
			
		||||
 | 
			
		||||
use futures::pin_mut;
 | 
			
		||||
 | 
			
		||||
enum StringEncoding {
 | 
			
		||||
    Raw,
 | 
			
		||||
    I8,
 | 
			
		||||
    I16,
 | 
			
		||||
    I32,
 | 
			
		||||
    LZF,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RDB file format.
 | 
			
		||||
const MAGIC: &[u8; 5] = b"REDIS";
 | 
			
		||||
const META: u8 = 0xFA;
 | 
			
		||||
const DB_SELECT: u8 = 0xFE;
 | 
			
		||||
const TABLE_SIZE_INFO: u8 = 0xFB;
 | 
			
		||||
pub const EOF: u8 = 0xFF;
 | 
			
		||||
 | 
			
		||||
pub async fn parse_rdb<R: AsyncRead + Unpin>(
 | 
			
		||||
    reader: &mut R,
 | 
			
		||||
    server: &mut Server,
 | 
			
		||||
) -> Result<(), DBError> {
 | 
			
		||||
    let mut storage = server.storage.lock().await;
 | 
			
		||||
    parse_magic(reader).await?;
 | 
			
		||||
    let _version = parse_version(reader).await?;
 | 
			
		||||
    pin_mut!(reader);
 | 
			
		||||
    loop {
 | 
			
		||||
        let op = reader.read_u8().await?;
 | 
			
		||||
        match op {
 | 
			
		||||
            META => {
 | 
			
		||||
                let _ = parse_aux(&mut *reader).await?;
 | 
			
		||||
                let _ = parse_aux(&mut *reader).await?;
 | 
			
		||||
                // just ignore the aux info for now
 | 
			
		||||
            }
 | 
			
		||||
            DB_SELECT => {
 | 
			
		||||
                let (_, _) = parse_len(&mut *reader).await?;
 | 
			
		||||
                // just ignore the db index for now
 | 
			
		||||
            }
 | 
			
		||||
            TABLE_SIZE_INFO => {
 | 
			
		||||
                let size_no_expire = parse_len(&mut *reader).await?.0;
 | 
			
		||||
                let size_expire = parse_len(&mut *reader).await?.0;
 | 
			
		||||
                for _ in 0..size_no_expire {
 | 
			
		||||
                    let (k, v) = parse_no_expire_entry(&mut *reader).await?;
 | 
			
		||||
                    storage.set(k, v);
 | 
			
		||||
                }
 | 
			
		||||
                for _ in 0..size_expire {
 | 
			
		||||
                    let (k, v, expire_timestamp) = parse_expire_entry(&mut *reader).await?;
 | 
			
		||||
                    storage.setx(k, v, expire_timestamp);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            EOF => {
 | 
			
		||||
                // not verify crc for now
 | 
			
		||||
                let _crc = reader.read_u64().await?;
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
            _ => return Err(DBError(format!("unexpected op: {}", op))),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn parse_rdb_file(f: &mut fs::File, server: &mut Server) -> Result<(), DBError> {
 | 
			
		||||
    let mut reader = BufReader::new(f);
 | 
			
		||||
    parse_rdb(&mut reader, server).await
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn parse_no_expire_entry<R: AsyncRead + Unpin>(
 | 
			
		||||
    input: &mut R,
 | 
			
		||||
) -> Result<(String, String), DBError> {
 | 
			
		||||
    let b = input.read_u8().await?;
 | 
			
		||||
    if b != 0 {
 | 
			
		||||
        return Err(DBError(format!("unexpected key type: {}", b)));
 | 
			
		||||
    }
 | 
			
		||||
    let k = parse_aux(input).await?;
 | 
			
		||||
    let v = parse_aux(input).await?;
 | 
			
		||||
    Ok((k, v))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn parse_expire_entry<R: AsyncRead + Unpin>(
 | 
			
		||||
    input: &mut R,
 | 
			
		||||
) -> Result<(String, String, u128), DBError> {
 | 
			
		||||
    let b = input.read_u8().await?;
 | 
			
		||||
    match b {
 | 
			
		||||
        0xFC => {
 | 
			
		||||
            // expire in milliseconds
 | 
			
		||||
            let expire_stamp = input.read_u64_le().await?;
 | 
			
		||||
            let (k, v) = parse_no_expire_entry(input).await?;
 | 
			
		||||
            Ok((k, v, expire_stamp as u128))
 | 
			
		||||
        }
 | 
			
		||||
        0xFD => {
 | 
			
		||||
            // expire in seconds
 | 
			
		||||
            let expire_timestamp = input.read_u32_le().await?;
 | 
			
		||||
            let (k, v) = parse_no_expire_entry(input).await?;
 | 
			
		||||
            Ok((k, v, (expire_timestamp * 1000) as u128))
 | 
			
		||||
        }
 | 
			
		||||
        _ => return Err(DBError(format!("unexpected expire type: {}", b))),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn parse_magic<R: AsyncRead + Unpin>(input: &mut R) -> Result<(), DBError> {
 | 
			
		||||
    let mut magic = [0; 5];
 | 
			
		||||
    let size_read = input.read(&mut magic).await?;
 | 
			
		||||
    if size_read != 5 {
 | 
			
		||||
        Err(DBError("expected 5 chars for magic number".to_string()))
 | 
			
		||||
    } else if magic.as_slice() == MAGIC {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    } else {
 | 
			
		||||
        Err(DBError(format!(
 | 
			
		||||
            "expected magic string {:?}, but got: {:?}",
 | 
			
		||||
            MAGIC, magic
 | 
			
		||||
        )))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn parse_version<R: AsyncRead + Unpin>(input: &mut R) -> Result<[u8; 4], DBError> {
 | 
			
		||||
    let mut version = [0; 4];
 | 
			
		||||
    let size_read = input.read(&mut version).await?;
 | 
			
		||||
    if size_read != 4 {
 | 
			
		||||
        Err(DBError("expected 4 chars for redis version".to_string()))
 | 
			
		||||
    } else {
 | 
			
		||||
        Ok(version)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn parse_aux<R: AsyncRead + Unpin>(input: &mut R) -> Result<String, DBError> {
 | 
			
		||||
    let (len, encoding) = parse_len(input).await?;
 | 
			
		||||
    let s = parse_string(input, len, encoding).await?;
 | 
			
		||||
    Ok(s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn parse_len<R: AsyncRead + Unpin>(input: &mut R) -> Result<(u32, StringEncoding), DBError> {
 | 
			
		||||
    let first = input.read_u8().await?;
 | 
			
		||||
    match first & 0xC0 {
 | 
			
		||||
        0x00 => {
 | 
			
		||||
            // The size is the remaining 6 bits of the byte.
 | 
			
		||||
            Ok((first as u32, StringEncoding::Raw))
 | 
			
		||||
        }
 | 
			
		||||
        0x04 => {
 | 
			
		||||
            // The size is the next 14 bits of the byte.
 | 
			
		||||
            let second = input.read_u8().await?;
 | 
			
		||||
            Ok((
 | 
			
		||||
                (((first & 0x3F) as u32) << 8 | second as u32) as u32,
 | 
			
		||||
                StringEncoding::Raw,
 | 
			
		||||
            ))
 | 
			
		||||
        }
 | 
			
		||||
        0x80 => {
 | 
			
		||||
            //Ignore the remaining 6 bits of the first byte.  The size is the next 4 bytes, in big-endian
 | 
			
		||||
            let second = input.read_u32().await?;
 | 
			
		||||
            Ok((second, StringEncoding::Raw))
 | 
			
		||||
        }
 | 
			
		||||
        0xC0 => {
 | 
			
		||||
            // The remaining 6 bits specify a type of string encoding.
 | 
			
		||||
            match first {
 | 
			
		||||
                0xC0 => Ok((1, StringEncoding::I8)),
 | 
			
		||||
                0xC1 => Ok((2, StringEncoding::I16)),
 | 
			
		||||
                0xC2 => Ok((4, StringEncoding::I32)),
 | 
			
		||||
                0xC3 => Ok((0, StringEncoding::LZF)), // not supported yet
 | 
			
		||||
                _ => Err(DBError(format!("unexpected string encoding: {}", first))),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        _ => Err(DBError(format!("unexpected len prefix: {}", first))),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn parse_string<R: AsyncRead + Unpin>(
 | 
			
		||||
    input: &mut R,
 | 
			
		||||
    len: u32,
 | 
			
		||||
    encoding: StringEncoding,
 | 
			
		||||
) -> Result<String, DBError> {
 | 
			
		||||
    match encoding {
 | 
			
		||||
        StringEncoding::Raw => {
 | 
			
		||||
            let mut s = vec![0; len as usize];
 | 
			
		||||
            input.read_exact(&mut s).await?;
 | 
			
		||||
            Ok(String::from_utf8(s).unwrap())
 | 
			
		||||
        }
 | 
			
		||||
        StringEncoding::I8 => {
 | 
			
		||||
            let b = input.read_u8().await?;
 | 
			
		||||
            Ok(b.to_string())
 | 
			
		||||
        }
 | 
			
		||||
        StringEncoding::I16 => {
 | 
			
		||||
            let b = input.read_u16_le().await?;
 | 
			
		||||
            Ok(b.to_string())
 | 
			
		||||
        }
 | 
			
		||||
        StringEncoding::I32 => {
 | 
			
		||||
            let b = input.read_u32_le().await?;
 | 
			
		||||
            Ok(b.to_string())
 | 
			
		||||
        }
 | 
			
		||||
        StringEncoding::LZF => {
 | 
			
		||||
            // not supported yet
 | 
			
		||||
            Err(DBError("LZF encoding not supported yet".to_string()))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,155 +0,0 @@
 | 
			
		||||
use std::{num::ParseIntError, sync::Arc};
 | 
			
		||||
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
 | 
			
		||||
    net::TcpStream,
 | 
			
		||||
    sync::Mutex,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::{error::DBError, protocol::Protocol, rdb, server::Server};
 | 
			
		||||
 | 
			
		||||
const EMPTY_RDB_FILE_HEX_STRING: &str = "524544495330303131fa0972656469732d76657205372e322e30fa0a72656469732d62697473c040fa056374696d65c26d08bc65fa08757365642d6d656dc2b0c41000fa08616f662d62617365c000fff06e3bfec0ff5aa2";
 | 
			
		||||
 | 
			
		||||
pub struct FollowerReplicationClient {
 | 
			
		||||
    pub stream: TcpStream,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl FollowerReplicationClient {
 | 
			
		||||
    pub async fn new(addr: String) -> FollowerReplicationClient {
 | 
			
		||||
        FollowerReplicationClient {
 | 
			
		||||
            stream: TcpStream::connect(addr).await.unwrap(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn ping_master(self: &mut Self) -> Result<(), DBError> {
 | 
			
		||||
        let protocol = Protocol::Array(vec![Protocol::BulkString("PING".to_string())]);
 | 
			
		||||
        self.stream.write_all(protocol.encode().as_bytes()).await?;
 | 
			
		||||
 | 
			
		||||
        self.check_resp("PONG").await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn report_port(self: &mut Self, port: u16) -> Result<(), DBError> {
 | 
			
		||||
        let protocol = Protocol::from_vec(vec![
 | 
			
		||||
            "REPLCONF",
 | 
			
		||||
            "listening-port",
 | 
			
		||||
            port.to_string().as_str(),
 | 
			
		||||
        ]);
 | 
			
		||||
        self.stream.write_all(protocol.encode().as_bytes()).await?;
 | 
			
		||||
 | 
			
		||||
        self.check_resp("OK").await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn report_sync_protocol(self: &mut Self) -> Result<(), DBError> {
 | 
			
		||||
        let p = Protocol::from_vec(vec!["REPLCONF", "capa", "psync2"]);
 | 
			
		||||
        self.stream.write_all(p.encode().as_bytes()).await?;
 | 
			
		||||
        self.check_resp("OK").await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn start_psync(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
 | 
			
		||||
        let p = Protocol::from_vec(vec!["PSYNC", "?", "-1"]);
 | 
			
		||||
        self.stream.write_all(p.encode().as_bytes()).await?;
 | 
			
		||||
        self.recv_rdb_file(server).await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn recv_rdb_file(self: &mut Self, server: &mut Server) -> Result<(), DBError> {
 | 
			
		||||
        let mut reader = BufReader::new(&mut self.stream);
 | 
			
		||||
 | 
			
		||||
        let mut buf = Vec::new();
 | 
			
		||||
        let _ = reader.read_until(b'\n', &mut buf).await?;
 | 
			
		||||
        buf.pop();
 | 
			
		||||
        buf.pop();
 | 
			
		||||
 | 
			
		||||
        let replication_info = String::from_utf8(buf)?;
 | 
			
		||||
        let replication_info = replication_info
 | 
			
		||||
            .split_whitespace()
 | 
			
		||||
            .map(|x| x.to_string())
 | 
			
		||||
            .collect::<Vec<String>>();
 | 
			
		||||
        if replication_info.len() != 3 {
 | 
			
		||||
            return Err(DBError(format!(
 | 
			
		||||
                "expect 3 args but found {:?}",
 | 
			
		||||
                replication_info
 | 
			
		||||
            )));
 | 
			
		||||
        }
 | 
			
		||||
        println!(
 | 
			
		||||
            "Get replication info: {:?} {:?} {:?}",
 | 
			
		||||
            replication_info[0], replication_info[1], replication_info[2]
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let c = reader.read_u8().await?;
 | 
			
		||||
        if c != b'$' {
 | 
			
		||||
            return Err(DBError(format!("expect $ but found {}", c)));
 | 
			
		||||
        }
 | 
			
		||||
        let mut buf = Vec::new();
 | 
			
		||||
        reader.read_until(b'\n', &mut buf).await?;
 | 
			
		||||
        buf.pop();
 | 
			
		||||
        buf.pop();
 | 
			
		||||
        let rdb_file_len = String::from_utf8(buf)?.parse::<usize>()?;
 | 
			
		||||
        println!("rdb file len: {}", rdb_file_len);
 | 
			
		||||
 | 
			
		||||
        // receive rdb file content
 | 
			
		||||
        rdb::parse_rdb(&mut reader, server).await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn check_resp(&mut self, expected: &str) -> Result<(), DBError> {
 | 
			
		||||
        let mut buf = [0; 1024];
 | 
			
		||||
        let n_bytes = self.stream.read(&mut buf).await?;
 | 
			
		||||
        println!(
 | 
			
		||||
            "check resp: recv {:?}",
 | 
			
		||||
            String::from_utf8(buf[..n_bytes].to_vec()).unwrap()
 | 
			
		||||
        );
 | 
			
		||||
        let expect = Protocol::SimpleString(expected.to_string()).encode();
 | 
			
		||||
        if expect.as_bytes() != &buf[..n_bytes] {
 | 
			
		||||
            return Err(DBError(format!(
 | 
			
		||||
                "expect response {:?} but found {:?}",
 | 
			
		||||
                expect,
 | 
			
		||||
                &buf[..n_bytes]
 | 
			
		||||
            )));
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct MasterReplicationClient {
 | 
			
		||||
    pub streams: Arc<Mutex<Vec<TcpStream>>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl MasterReplicationClient {
 | 
			
		||||
    pub fn new() -> MasterReplicationClient {
 | 
			
		||||
        MasterReplicationClient {
 | 
			
		||||
            streams: Arc::new(Mutex::new(Vec::new())),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn send_rdb_file(&mut self, stream: &mut TcpStream) -> Result<(), DBError> {
 | 
			
		||||
        let empty_rdb_file_bytes = (0..EMPTY_RDB_FILE_HEX_STRING.len())
 | 
			
		||||
            .step_by(2)
 | 
			
		||||
            .map(|i| u8::from_str_radix(&EMPTY_RDB_FILE_HEX_STRING[i..i + 2], 16))
 | 
			
		||||
            .collect::<Result<Vec<u8>, ParseIntError>>()?;
 | 
			
		||||
 | 
			
		||||
        println!("going to send rdb file");
 | 
			
		||||
        _ = stream.write("$".as_bytes()).await?;
 | 
			
		||||
        _ = stream
 | 
			
		||||
            .write(empty_rdb_file_bytes.len().to_string().as_bytes())
 | 
			
		||||
            .await?;
 | 
			
		||||
        _ = stream.write_all("\r\n".as_bytes()).await?;
 | 
			
		||||
        _ = stream.write_all(&empty_rdb_file_bytes).await?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn add_stream(&mut self, stream: TcpStream) -> Result<(), DBError> {
 | 
			
		||||
        let mut streams = self.streams.lock().await;
 | 
			
		||||
        streams.push(stream);
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn send_command(&mut self, protocol: Protocol) -> Result<(), DBError> {
 | 
			
		||||
        let mut streams = self.streams.lock().await;
 | 
			
		||||
        for stream in streams.iter_mut() {
 | 
			
		||||
            stream.write_all(protocol.encode().as_bytes()).await?;
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										116
									
								
								src/server.rs
									
									
									
									
									
								
							
							
						
						
									
										116
									
								
								src/server.rs
									
									
									
									
									
								
							@@ -1,143 +1,63 @@
 | 
			
		||||
use core::str;
 | 
			
		||||
use std::collections::BTreeMap;
 | 
			
		||||
use std::collections::HashMap;
 | 
			
		||||
use std::path::PathBuf;
 | 
			
		||||
use std::sync::atomic::AtomicU64;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use tokio::fs::OpenOptions;
 | 
			
		||||
use tokio::io::AsyncReadExt;
 | 
			
		||||
use tokio::io::AsyncWriteExt;
 | 
			
		||||
use tokio::sync::mpsc::Sender;
 | 
			
		||||
use tokio::sync::Mutex;
 | 
			
		||||
 | 
			
		||||
use crate::cmd::Cmd;
 | 
			
		||||
use crate::error::DBError;
 | 
			
		||||
use crate::options;
 | 
			
		||||
use crate::protocol::Protocol;
 | 
			
		||||
use crate::rdb;
 | 
			
		||||
use crate::replication_client::FollowerReplicationClient;
 | 
			
		||||
use crate::replication_client::MasterReplicationClient;
 | 
			
		||||
use crate::storage::Storage;
 | 
			
		||||
 | 
			
		||||
type Stream = BTreeMap<String, Vec<(String, String)>>;
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct Server {
 | 
			
		||||
    pub storage: Arc<Mutex<Storage>>,
 | 
			
		||||
    pub streams: Arc<Mutex<HashMap<String, Stream>>>,
 | 
			
		||||
    pub storage: Arc<Storage>,
 | 
			
		||||
    pub option: options::DBOption,
 | 
			
		||||
    pub offset: Arc<AtomicU64>,
 | 
			
		||||
    pub master_repl_clients: Arc<Mutex<Option<MasterReplicationClient>>>,
 | 
			
		||||
    pub stream_reader_blocker: Arc<Mutex<Vec<Sender<()>>>>,
 | 
			
		||||
    master_addr: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Server {
 | 
			
		||||
    pub async fn new(option: options::DBOption) -> Self {
 | 
			
		||||
        let master_addr = match option.replication.role.as_str() {
 | 
			
		||||
            "slave" => Some(
 | 
			
		||||
                option
 | 
			
		||||
                    .replication
 | 
			
		||||
                    .replica_of
 | 
			
		||||
                    .clone()
 | 
			
		||||
                    .unwrap()
 | 
			
		||||
                    .replace(' ', ":"),
 | 
			
		||||
            ),
 | 
			
		||||
            _ => None,
 | 
			
		||||
        };
 | 
			
		||||
        // Create database file path with fixed filename
 | 
			
		||||
        let db_file_path = PathBuf::from(option.dir.clone()).join("herodb.redb");
 | 
			
		||||
        println!("will open db file path: {}", db_file_path.display());
 | 
			
		||||
 | 
			
		||||
        let is_master = option.replication.role == "master";
 | 
			
		||||
        // Initialize storage with redb
 | 
			
		||||
        let storage = Storage::new(db_file_path).expect("Failed to initialize storage");
 | 
			
		||||
 | 
			
		||||
        let mut server = Server {
 | 
			
		||||
            storage: Arc::new(Mutex::new(Storage::new())),
 | 
			
		||||
            streams: Arc::new(Mutex::new(HashMap::new())),
 | 
			
		||||
        Server {
 | 
			
		||||
            storage: Arc::new(storage),
 | 
			
		||||
            option,
 | 
			
		||||
            master_repl_clients: if is_master {
 | 
			
		||||
                Arc::new(Mutex::new(Some(MasterReplicationClient::new())))
 | 
			
		||||
            } else {
 | 
			
		||||
                Arc::new(Mutex::new(None))
 | 
			
		||||
            },
 | 
			
		||||
            offset: Arc::new(AtomicU64::new(0)),
 | 
			
		||||
            stream_reader_blocker: Arc::new(Mutex::new(Vec::new())),
 | 
			
		||||
            master_addr,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        server.init().await.unwrap();
 | 
			
		||||
        server
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn init(&mut self) -> Result<(), DBError> {
 | 
			
		||||
        // master initialization
 | 
			
		||||
        if self.is_master() {
 | 
			
		||||
            println!("Start as master\n");
 | 
			
		||||
            let db_file_path =
 | 
			
		||||
                PathBuf::from(self.option.dir.clone()).join(self.option.db_file_name.clone());
 | 
			
		||||
            println!("will open db file path: {}", db_file_path.display());
 | 
			
		||||
 | 
			
		||||
            // create empty db file if not exits
 | 
			
		||||
            let mut file = OpenOptions::new()
 | 
			
		||||
                .read(true)
 | 
			
		||||
                .write(true)
 | 
			
		||||
                .create(true)
 | 
			
		||||
                .truncate(false)
 | 
			
		||||
                .open(db_file_path.clone())
 | 
			
		||||
                .await?;
 | 
			
		||||
 | 
			
		||||
            if file.metadata().await?.len() != 0 {
 | 
			
		||||
                rdb::parse_rdb_file(&mut file, self).await?;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn get_follower_repl_client(&mut self) -> Option<FollowerReplicationClient> {
 | 
			
		||||
        if self.is_slave() {
 | 
			
		||||
            Some(FollowerReplicationClient::new(self.master_addr.clone().unwrap()).await)
 | 
			
		||||
        } else {
 | 
			
		||||
            None
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn handle(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        mut stream: tokio::net::TcpStream,
 | 
			
		||||
        is_rep_conn: bool,
 | 
			
		||||
    ) -> Result<(), DBError> {
 | 
			
		||||
        let mut buf = [0; 512];
 | 
			
		||||
        let mut queued_cmd: Option<Vec<(Cmd, Protocol)>> = None;
 | 
			
		||||
        
 | 
			
		||||
        loop {
 | 
			
		||||
            if let Ok(len) = stream.read(&mut buf).await {
 | 
			
		||||
                if len == 0 {
 | 
			
		||||
                    println!("[handle] connection closed");
 | 
			
		||||
                    return Ok(());
 | 
			
		||||
                }
 | 
			
		||||
                
 | 
			
		||||
                let s = str::from_utf8(&buf[..len])?;
 | 
			
		||||
                let (cmd, protocol) =
 | 
			
		||||
                    Cmd::from(s).unwrap_or((Cmd::Unknow, Protocol::err("unknow cmd")));
 | 
			
		||||
                println!("got command: {:?}, protocol: {:?}", cmd, protocol);
 | 
			
		||||
 | 
			
		||||
                let res = cmd
 | 
			
		||||
                    .run(self, protocol, is_rep_conn, &mut queued_cmd)
 | 
			
		||||
                    .run(self, protocol, &mut queued_cmd)
 | 
			
		||||
                    .await
 | 
			
		||||
                    .unwrap_or(Protocol::err("unknow cmd"));
 | 
			
		||||
                print!("queued 2 cmd {:?}", queued_cmd);
 | 
			
		||||
                print!("queued cmd {:?}", queued_cmd);
 | 
			
		||||
 | 
			
		||||
                // only send response to normal client, do not send response to replication client
 | 
			
		||||
                if !is_rep_conn {
 | 
			
		||||
                    println!("going to send response {}", res.encode());
 | 
			
		||||
                    _ = stream.write(res.encode().as_bytes()).await?;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                // send a full RDB file to slave
 | 
			
		||||
                if self.is_master() {
 | 
			
		||||
                    if let Cmd::Psync = cmd {
 | 
			
		||||
                        let mut master_rep_client = self.master_repl_clients.lock().await;
 | 
			
		||||
                        let master_rep_client = master_rep_client.as_mut().unwrap();
 | 
			
		||||
                        master_rep_client.send_rdb_file(&mut stream).await?;
 | 
			
		||||
                        master_rep_client.add_stream(stream).await?;
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                println!("going to send response {}", res.encode());
 | 
			
		||||
                _ = stream.write(res.encode().as_bytes()).await?;
 | 
			
		||||
            } else {
 | 
			
		||||
                println!("[handle] going to break");
 | 
			
		||||
                break;
 | 
			
		||||
@@ -145,12 +65,4 @@ impl Server {
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn is_slave(&self) -> bool {
 | 
			
		||||
        self.option.replication.role == "slave"
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn is_master(&self) -> bool {
 | 
			
		||||
        !self.is_slave()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										445
									
								
								src/storage.rs
									
									
									
									
									
								
							
							
						
						
									
										445
									
								
								src/storage.rs
									
									
									
									
									
								
							@@ -1,13 +1,29 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    collections::HashMap,
 | 
			
		||||
    path::Path,
 | 
			
		||||
    time::{SystemTime, UNIX_EPOCH},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
pub type ValueType = (String, Option<u128>);
 | 
			
		||||
use redb::{Database, Error, ReadableTable, Table, TableDefinition, WriteTransaction, ReadTransaction};
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
 | 
			
		||||
pub struct Storage {
 | 
			
		||||
    // key -> (value, (insert/update time, expire milli seconds))
 | 
			
		||||
    set: HashMap<String, ValueType>,
 | 
			
		||||
use crate::error::DBError;
 | 
			
		||||
 | 
			
		||||
// Table definitions for different Redis data types
 | 
			
		||||
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
 | 
			
		||||
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
 | 
			
		||||
const HASHES_TABLE: TableDefinition<(&str, &str), &str> = TableDefinition::new("hashes");
 | 
			
		||||
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
 | 
			
		||||
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize, Debug, Clone)]
 | 
			
		||||
pub struct StringValue {
 | 
			
		||||
    pub value: String,
 | 
			
		||||
    pub expires_at_ms: Option<u128>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Serialize, Deserialize, Debug, Clone)]
 | 
			
		||||
pub struct StreamEntry {
 | 
			
		||||
    pub fields: Vec<(String, String)>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[inline]
 | 
			
		||||
@@ -17,43 +33,416 @@ pub fn now_in_millis() -> u128 {
 | 
			
		||||
    duration_since_epoch.as_millis()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct Storage {
 | 
			
		||||
    db: Database,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Storage {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Storage {
 | 
			
		||||
            set: HashMap::new(),
 | 
			
		||||
    pub fn new(path: impl AsRef<Path>) -> Result<Self, DBError> {
 | 
			
		||||
        let db = Database::create(path)?;
 | 
			
		||||
        
 | 
			
		||||
        // Create tables if they don't exist
 | 
			
		||||
        let write_txn = db.begin_write()?;
 | 
			
		||||
        {
 | 
			
		||||
            let _ = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            let _ = write_txn.open_table(STRINGS_TABLE)?;
 | 
			
		||||
            let _ = write_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
            let _ = write_txn.open_table(STREAMS_META_TABLE)?;
 | 
			
		||||
            let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
 | 
			
		||||
        }
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        
 | 
			
		||||
        Ok(Storage { db })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        let table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        
 | 
			
		||||
        match table.get(key)? {
 | 
			
		||||
            Some(type_val) => Ok(Some(type_val.value().to_string())),
 | 
			
		||||
            None => Ok(None),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn get(self: &mut Self, k: &str) -> Option<String> {
 | 
			
		||||
        match self.set.get(k) {
 | 
			
		||||
            Some((ss, expire_timestamp)) => match expire_timestamp {
 | 
			
		||||
                Some(expire_time_stamp) => {
 | 
			
		||||
                    if now_in_millis() > *expire_time_stamp {
 | 
			
		||||
                        self.set.remove(k);
 | 
			
		||||
                        None
 | 
			
		||||
                    } else {
 | 
			
		||||
                        Some(ss.clone())
 | 
			
		||||
    pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        // Check if key exists and is of string type
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "string" => {
 | 
			
		||||
                let strings_table = read_txn.open_table(STRINGS_TABLE)?;
 | 
			
		||||
                match strings_table.get(key)? {
 | 
			
		||||
                    Some(data) => {
 | 
			
		||||
                        let string_value: StringValue = bincode::deserialize(data.value())?;
 | 
			
		||||
                        
 | 
			
		||||
                        // Check if expired
 | 
			
		||||
                        if let Some(expires_at) = string_value.expires_at_ms {
 | 
			
		||||
                            if now_in_millis() > expires_at {
 | 
			
		||||
                                // Key expired, remove it
 | 
			
		||||
                                drop(read_txn);
 | 
			
		||||
                                self.del(key.to_string())?;
 | 
			
		||||
                                return Ok(None);
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        
 | 
			
		||||
                        Ok(Some(string_value.value))
 | 
			
		||||
                    }
 | 
			
		||||
                    None => Ok(None),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            _ => Ok(None),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            types_table.insert(key.as_str(), "string")?;
 | 
			
		||||
            
 | 
			
		||||
            let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
 | 
			
		||||
            let string_value = StringValue {
 | 
			
		||||
                value,
 | 
			
		||||
                expires_at_ms: None,
 | 
			
		||||
            };
 | 
			
		||||
            let serialized = bincode::serialize(&string_value)?;
 | 
			
		||||
            strings_table.insert(key.as_str(), serialized.as_slice())?;
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            types_table.insert(key.as_str(), "string")?;
 | 
			
		||||
            
 | 
			
		||||
            let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
 | 
			
		||||
            let string_value = StringValue {
 | 
			
		||||
                value,
 | 
			
		||||
                expires_at_ms: Some(expire_ms + now_in_millis()),
 | 
			
		||||
            };
 | 
			
		||||
            let serialized = bincode::serialize(&string_value)?;
 | 
			
		||||
            strings_table.insert(key.as_str(), serialized.as_slice())?;
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn del(&self, key: String) -> Result<(), DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
 | 
			
		||||
            let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
            
 | 
			
		||||
            // Remove from type table
 | 
			
		||||
            types_table.remove(key.as_str())?;
 | 
			
		||||
            
 | 
			
		||||
            // Remove from strings table
 | 
			
		||||
            strings_table.remove(key.as_str())?;
 | 
			
		||||
            
 | 
			
		||||
            // Remove all hash fields for this key
 | 
			
		||||
            let mut to_remove = Vec::new();
 | 
			
		||||
            let mut iter = hashes_table.iter()?;
 | 
			
		||||
            while let Some(entry) = iter.next() {
 | 
			
		||||
                let entry = entry?;
 | 
			
		||||
                let (hash_key, field) = entry.0.value();
 | 
			
		||||
                if hash_key == key.as_str() {
 | 
			
		||||
                    to_remove.push((hash_key.to_string(), field.to_string()));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            drop(iter);
 | 
			
		||||
            
 | 
			
		||||
            for (hash_key, field) in to_remove {
 | 
			
		||||
                hashes_table.remove((hash_key.as_str(), field.as_str()))?;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        let table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        
 | 
			
		||||
        let mut keys = Vec::new();
 | 
			
		||||
        let mut iter = table.iter()?;
 | 
			
		||||
        while let Some(entry) = iter.next() {
 | 
			
		||||
            let key = entry?.0.value().to_string();
 | 
			
		||||
            if pattern == "*" || key.contains(pattern) {
 | 
			
		||||
                keys.push(key);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        Ok(keys)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Hash operations
 | 
			
		||||
    pub fn hset(&self, key: &str, pairs: &[(String, String)]) -> Result<u64, DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        let mut new_fields = 0u64;
 | 
			
		||||
        
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
            
 | 
			
		||||
            // Check if key exists and is of correct type
 | 
			
		||||
            let existing_type = match types_table.get(key)? {
 | 
			
		||||
                Some(type_val) => Some(type_val.value().to_string()),
 | 
			
		||||
                None => None,
 | 
			
		||||
            };
 | 
			
		||||
            
 | 
			
		||||
            match existing_type {
 | 
			
		||||
                Some(ref type_str) if type_str != "hash" => {
 | 
			
		||||
                    return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
 | 
			
		||||
                }
 | 
			
		||||
                None => {
 | 
			
		||||
                    // Set type to hash
 | 
			
		||||
                    types_table.insert(key, "hash")?;
 | 
			
		||||
                }
 | 
			
		||||
                _ => {}
 | 
			
		||||
            }
 | 
			
		||||
            
 | 
			
		||||
            for (field, value) in pairs {
 | 
			
		||||
                let existed = hashes_table.get((key, field.as_str()))?.is_some();
 | 
			
		||||
                hashes_table.insert((key, field.as_str()), value.as_str())?;
 | 
			
		||||
                if !existed {
 | 
			
		||||
                    new_fields += 1;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        Ok(new_fields)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        // Check type
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                let hashes_table = read_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                match hashes_table.get((key, field))? {
 | 
			
		||||
                    Some(value) => Ok(Some(value.value().to_string())),
 | 
			
		||||
                    None => Ok(None),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(None),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        // Check type
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                let hashes_table = read_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                let mut result = Vec::new();
 | 
			
		||||
                
 | 
			
		||||
                let mut iter = hashes_table.iter()?;
 | 
			
		||||
                while let Some(entry) = iter.next() {
 | 
			
		||||
                    let entry = entry?;
 | 
			
		||||
                    let (hash_key, field) = entry.0.value();
 | 
			
		||||
                    let value = entry.1.value();
 | 
			
		||||
                    if hash_key == key {
 | 
			
		||||
                        result.push((field.to_string(), value.to_string()));
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                _ => Some(ss.clone()),
 | 
			
		||||
            },
 | 
			
		||||
            _ => None,
 | 
			
		||||
                
 | 
			
		||||
                Ok(result)
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(Vec::new()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn set(self: &mut Self, k: String, v: String) {
 | 
			
		||||
        self.set.insert(k, (v, None));
 | 
			
		||||
    pub fn hdel(&self, key: &str, fields: &[String]) -> Result<u64, DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        let mut deleted = 0u64;
 | 
			
		||||
        
 | 
			
		||||
        {
 | 
			
		||||
            let types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            let key_type = types_table.get(key)?;
 | 
			
		||||
            match key_type {
 | 
			
		||||
                Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                    let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                    
 | 
			
		||||
                    for field in fields {
 | 
			
		||||
                        if hashes_table.remove((key, field.as_str()))?.is_some() {
 | 
			
		||||
                            deleted += 1;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
                None => {}
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        Ok(deleted)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn setx(self: &mut Self, k: String, v: String, expire_ms: u128) {
 | 
			
		||||
        self.set.insert(k, (v, Some(expire_ms + now_in_millis())));
 | 
			
		||||
    pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                let hashes_table = read_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                Ok(hashes_table.get((key, field))?.is_some())
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(false),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn del(self: &mut Self, k: String) {
 | 
			
		||||
        self.set.remove(&k);
 | 
			
		||||
    pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                let hashes_table = read_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                let mut result = Vec::new();
 | 
			
		||||
                
 | 
			
		||||
                let mut iter = hashes_table.iter()?;
 | 
			
		||||
                while let Some(entry) = iter.next() {
 | 
			
		||||
                    let entry = entry?;
 | 
			
		||||
                    let (hash_key, field) = entry.0.value();
 | 
			
		||||
                    if hash_key == key {
 | 
			
		||||
                        result.push(field.to_string());
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                
 | 
			
		||||
                Ok(result)
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(Vec::new()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn keys(self: &Self) -> Vec<String> {
 | 
			
		||||
        self.set.keys().map(|x| x.clone()).collect()
 | 
			
		||||
    pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                let hashes_table = read_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                let mut result = Vec::new();
 | 
			
		||||
                
 | 
			
		||||
                let mut iter = hashes_table.iter()?;
 | 
			
		||||
                while let Some(entry) = iter.next() {
 | 
			
		||||
                    let entry = entry?;
 | 
			
		||||
                    let (hash_key, _) = entry.0.value();
 | 
			
		||||
                    let value = entry.1.value();
 | 
			
		||||
                    if hash_key == key {
 | 
			
		||||
                        result.push(value.to_string());
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                
 | 
			
		||||
                Ok(result)
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(Vec::new()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn hlen(&self, key: &str) -> Result<u64, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                let hashes_table = read_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                let mut count = 0u64;
 | 
			
		||||
                
 | 
			
		||||
                let mut iter = hashes_table.iter()?;
 | 
			
		||||
                while let Some(entry) = iter.next() {
 | 
			
		||||
                    let entry = entry?;
 | 
			
		||||
                    let (hash_key, _) = entry.0.value();
 | 
			
		||||
                    if hash_key == key {
 | 
			
		||||
                        count += 1;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                
 | 
			
		||||
                Ok(count)
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(0),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn hmget(&self, key: &str, fields: &[String]) -> Result<Vec<Option<String>>, DBError> {
 | 
			
		||||
        let read_txn = self.db.begin_read()?;
 | 
			
		||||
        
 | 
			
		||||
        let types_table = read_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
        match types_table.get(key)? {
 | 
			
		||||
            Some(type_val) if type_val.value() == "hash" => {
 | 
			
		||||
                let hashes_table = read_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
                let mut result = Vec::new();
 | 
			
		||||
                
 | 
			
		||||
                for field in fields {
 | 
			
		||||
                    match hashes_table.get((key, field.as_str()))? {
 | 
			
		||||
                        Some(value) => result.push(Some(value.value().to_string())),
 | 
			
		||||
                        None => result.push(None),
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                
 | 
			
		||||
                Ok(result)
 | 
			
		||||
            }
 | 
			
		||||
            Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
 | 
			
		||||
            None => Ok(fields.iter().map(|_| None).collect()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
 | 
			
		||||
        let write_txn = self.db.begin_write()?;
 | 
			
		||||
        let mut result = false;
 | 
			
		||||
        
 | 
			
		||||
        {
 | 
			
		||||
            let mut types_table = write_txn.open_table(TYPES_TABLE)?;
 | 
			
		||||
            let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
 | 
			
		||||
            
 | 
			
		||||
            // Check if key exists and is of correct type
 | 
			
		||||
            let existing_type = match types_table.get(key)? {
 | 
			
		||||
                Some(type_val) => Some(type_val.value().to_string()),
 | 
			
		||||
                None => None,
 | 
			
		||||
            };
 | 
			
		||||
            
 | 
			
		||||
            match existing_type {
 | 
			
		||||
                Some(ref type_str) if type_str != "hash" => {
 | 
			
		||||
                    return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()));
 | 
			
		||||
                }
 | 
			
		||||
                None => {
 | 
			
		||||
                    // Set type to hash
 | 
			
		||||
                    types_table.insert(key, "hash")?;
 | 
			
		||||
                }
 | 
			
		||||
                _ => {}
 | 
			
		||||
            }
 | 
			
		||||
            
 | 
			
		||||
            // Check if field already exists
 | 
			
		||||
            if hashes_table.get((key, field))?.is_none() {
 | 
			
		||||
                hashes_table.insert((key, field), value)?;
 | 
			
		||||
                result = true;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        write_txn.commit()?;
 | 
			
		||||
        Ok(result)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user