use core::str; use std::collections::HashMap; use std::sync::Arc; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::sync::{Mutex, oneshot}; use std::sync::atomic::{AtomicU64, Ordering}; use crate::cmd::Cmd; use crate::error::DBError; use crate::options; use crate::protocol::Protocol; use crate::storage::Storage; use crate::storage_sled::SledStorage; use crate::storage_trait::StorageBackend; #[derive(Clone)] pub struct Server { pub db_cache: std::sync::Arc>>>, pub option: options::DBOption, pub client_name: Option, pub selected_db: u64, // Changed from usize to u64 pub queued_cmd: Option>, // BLPOP waiter registry: per (db_index, key) FIFO of waiters pub list_waiters: Arc>>>>, pub waiter_seq: Arc, } pub struct Waiter { pub id: u64, pub side: PopSide, pub tx: oneshot::Sender<(String, String)>, // (key, element) } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum PopSide { Left, Right, } impl Server { pub async fn new(option: options::DBOption) -> Self { Server { db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())), option, client_name: None, selected_db: 0, queued_cmd: None, list_waiters: Arc::new(Mutex::new(HashMap::new())), waiter_seq: Arc::new(AtomicU64::new(1)), } } pub fn current_storage(&self) -> Result, DBError> { let mut cache = self.db_cache.write().unwrap(); if let Some(storage) = cache.get(&self.selected_db) { return Ok(storage.clone()); } // Create new database file let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) .join(format!("{}.db", self.selected_db)); // Ensure the directory exists before creating the database file if let Some(parent_dir) = db_file_path.parent() { std::fs::create_dir_all(parent_dir).map_err(|e| { DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e)) })?; } println!("Creating new db file: {}", db_file_path.display()); let storage: Arc = match self.option.backend { options::BackendType::Redb => { Arc::new(Storage::new( db_file_path, self.should_encrypt_db(self.selected_db), self.option.encryption_key.as_deref() )?) } options::BackendType::Sled => { Arc::new(SledStorage::new( db_file_path, self.should_encrypt_db(self.selected_db), self.option.encryption_key.as_deref() )?) } }; cache.insert(self.selected_db, storage.clone()); Ok(storage) } fn should_encrypt_db(&self, db_index: u64) -> bool { // DB 0-9 are non-encrypted, DB 10+ are encrypted self.option.encrypt && db_index >= 10 } // ----- BLPOP waiter helpers ----- pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) { let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); let (tx, rx) = oneshot::channel::<(String, String)>(); let mut guard = self.list_waiters.lock().await; let per_db = guard.entry(db_index).or_insert_with(HashMap::new); let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); q.push(Waiter { id, side, tx }); (id, rx) } pub async fn unregister_waiter(&self, db_index: u64, key: &str, id: u64) { let mut guard = self.list_waiters.lock().await; if let Some(per_db) = guard.get_mut(&db_index) { if let Some(q) = per_db.get_mut(key) { q.retain(|w| w.id != id); if q.is_empty() { per_db.remove(key); } } if per_db.is_empty() { guard.remove(&db_index); } } } // Called after LPUSH/RPUSH to deliver to blocked BLPOP waiters. pub async fn drain_waiters_after_push(&self, key: &str) -> Result<(), DBError> { let db_index = self.selected_db; loop { // Check if any waiter exists let maybe_waiter = { let mut guard = self.list_waiters.lock().await; if let Some(per_db) = guard.get_mut(&db_index) { if let Some(q) = per_db.get_mut(key) { if !q.is_empty() { // Pop FIFO Some(q.remove(0)) } else { None } } else { None } } else { None } }; let waiter = if let Some(w) = maybe_waiter { w } else { break }; // Pop one element depending on waiter side let elems = match waiter.side { PopSide::Left => self.current_storage()?.lpop(key, 1)?, PopSide::Right => self.current_storage()?.rpop(key, 1)?, }; if elems.is_empty() { // Nothing to deliver; re-register waiter at the front to preserve order let mut guard = self.list_waiters.lock().await; let per_db = guard.entry(db_index).or_insert_with(HashMap::new); let q = per_db.entry(key.to_string()).or_insert_with(Vec::new); q.insert(0, waiter); break; } else { let elem = elems[0].clone(); // Send to waiter; if receiver dropped, just continue let _ = waiter.tx.send((key.to_string(), elem)); // Loop to try to satisfy more waiters if more elements remain continue; } } Ok(()) } pub async fn handle( &mut self, mut stream: tokio::net::TcpStream, ) -> Result<(), DBError> { // Accumulate incoming bytes to handle partial RESP frames let mut acc = String::new(); let mut buf = vec![0u8; 8192]; loop { let n = match stream.read(&mut buf).await { Ok(0) => { println!("[handle] connection closed"); return Ok(()); } Ok(n) => n, Err(e) => { println!("[handle] read error: {:?}", e); return Err(e.into()); } }; // Append to accumulator. RESP for our usage is ASCII-safe. acc.push_str(str::from_utf8(&buf[..n])?); // Try to parse as many complete commands as are available in 'acc'. loop { let parsed = Cmd::from(&acc); let (cmd, protocol, remaining) = match parsed { Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining), Err(_e) => { // Incomplete or invalid frame; assume incomplete and wait for more data. // This avoids emitting spurious protocol_error for split frames. break; } }; // Advance the accumulator to the unparsed remainder acc = remaining.to_string(); if self.option.debug { println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol); } else { println!("got command: {:?}, protocol: {:?}", cmd, protocol); } // Check if this is a QUIT command before processing let is_quit = matches!(cmd, Cmd::Quit); let res = match cmd.run(self).await { Ok(p) => p, Err(e) => { if self.option.debug { eprintln!("[run error] {:?}", e); } Protocol::err(&format!("ERR {}", e.0)) } }; if self.option.debug { println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); } else { print!("queued cmd {:?}", self.queued_cmd); println!("going to send response {}", res.encode()); } _ = stream.write(res.encode().as_bytes()).await?; // If this was a QUIT command, close the connection if is_quit { println!("[handle] QUIT command received, closing connection"); return Ok(()); } // Continue parsing any further complete commands already in 'acc' if acc.is_empty() { break; } } } } }