use core::str; use std::collections::HashMap; use std::sync::Arc; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use crate::cmd::Cmd; use crate::error::DBError; use crate::options; use crate::protocol::Protocol; use crate::storage::Storage; #[derive(Clone)] pub struct Server { pub db_cache: std::sync::Arc>>>, pub option: options::DBOption, pub client_name: Option, pub selected_db: u64, // Changed from usize to u64 pub queued_cmd: Option>, } impl Server { pub async fn new(option: options::DBOption) -> Self { Server { db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())), option, client_name: None, selected_db: 0, queued_cmd: None, } } pub fn current_storage(&self) -> Result, DBError> { let mut cache = self.db_cache.write().unwrap(); if let Some(storage) = cache.get(&self.selected_db) { return Ok(storage.clone()); } // Create new database file let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) .join(format!("{}.db", self.selected_db)); // Ensure the directory exists before creating the database file if let Some(parent_dir) = db_file_path.parent() { std::fs::create_dir_all(parent_dir).map_err(|e| { DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e)) })?; } println!("Creating new db file: {}", db_file_path.display()); let storage = Arc::new(Storage::new( db_file_path, self.should_encrypt_db(self.selected_db), self.option.encryption_key.as_deref() )?); cache.insert(self.selected_db, storage.clone()); Ok(storage) } fn should_encrypt_db(&self, db_index: u64) -> bool { // DB 0-9 are non-encrypted, DB 10+ are encrypted self.option.encrypt && db_index >= 10 } pub async fn handle( &mut self, mut stream: tokio::net::TcpStream, ) -> Result<(), DBError> { let mut buf = [0; 512]; loop { let len = match stream.read(&mut buf).await { Ok(0) => { println!("[handle] connection closed"); return Ok(()); } Ok(len) => len, Err(e) => { println!("[handle] read error: {:?}", e); return Err(e.into()); } }; let mut s = str::from_utf8(&buf[..len])?; while !s.is_empty() { let (cmd, protocol, remaining) = match Cmd::from(s) { Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), Err(e) => { println!("\x1b[31;1mprotocol error: {:?}\x1b[0m", e); (Cmd::Unknow("protocol_error".to_string()), Protocol::err(&format!("protocol error: {}", e.0)), "") } }; s = remaining; if self.option.debug { println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); } else { println!("got command: {:?}, protocol: {:?}", cmd, protocol); } // Check if this is a QUIT command before processing let is_quit = matches!(cmd, Cmd::Quit); let res = match cmd.run(self).await { Ok(p) => p, Err(e) => { if self.option.debug { eprintln!("[run error] {:?}", e); } Protocol::err(&format!("ERR {}", e.0)) } }; if self.option.debug { println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); } else { print!("queued cmd {:?}", self.queued_cmd); println!("going to send response {}", res.encode()); } _ = stream.write(res.encode().as_bytes()).await?; // If this was a QUIT command, close the connection if is_quit { println!("[handle] QUIT command received, closing connection"); return Ok(()); } } } } }