diff --git a/src/main.rs b/src/main.rs index 34d7fcf..6b88c33 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,8 +71,13 @@ async fn main() { }, }; + let backend = option.backend.clone(); + // new server - let server = server::Server::new(option.clone()).await; + let mut server = server::Server::new(option).await; + + // Initialize the default database storage + let _ = server.current_storage(); // Add a small delay to ensure the port is ready tokio::time::sleep(std::time::Duration::from_millis(100)).await; @@ -82,7 +87,7 @@ async fn main() { let rpc_addr = format!("127.0.0.1:{}", args.rpc_port).parse().unwrap(); let base_dir = args.dir.clone(); - match rpc_server::start_rpc_server(rpc_addr, base_dir, option).await { + match rpc_server::start_rpc_server(rpc_addr, base_dir, backend).await { Ok(handle) => { println!("RPC management server started on port {}", args.rpc_port); Some(handle) diff --git a/src/rpc.rs b/src/rpc.rs index 4715767..b791bec 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -78,20 +78,23 @@ pub struct RpcServerImpl { base_dir: String, /// Managed database servers servers: Arc>>>, - /// Next database ID to assign - next_db_id: Arc>, - /// Database options (backend, encryption, etc.) - db_option: DBOption, + /// Next unencrypted database ID to assign + next_unencrypted_id: Arc>, + /// Next encrypted database ID to assign + next_encrypted_id: Arc>, + /// Default backend type + backend: crate::options::BackendType, } impl RpcServerImpl { /// Create a new RPC server instance - pub fn new(base_dir: String, db_option: DBOption) -> Self { + pub fn new(base_dir: String, backend: crate::options::BackendType) -> Self { Self { base_dir, servers: Arc::new(RwLock::new(HashMap::new())), - next_db_id: Arc::new(RwLock::new(0)), - db_option, + next_unencrypted_id: Arc::new(RwLock::new(0)), + next_encrypted_id: Arc::new(RwLock::new(10)), + backend, } } @@ -105,31 +108,30 @@ impl RpcServerImpl { } } - // Check if database file exists (either direct or in RPC subdirectory) - let direct_db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id)); - let rpc_db_path = std::path::PathBuf::from(&self.base_dir) - .join(format!("rpc_db_{}", db_id)) - .join("0.db"); - - let (db_path, db_dir) = if direct_db_path.exists() { - // Main server database - (direct_db_path, self.base_dir.clone()) - } else if rpc_db_path.exists() { - // RPC database - (rpc_db_path, std::path::PathBuf::from(&self.base_dir).join(format!("rpc_db_{}", db_id)).to_string_lossy().to_string()) - } else { + // Check if database file exists + let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id)); + if !db_path.exists() { return Err(jsonrpsee::types::ErrorObjectOwned::owned( -32000, format!("Database {} not found", db_id), None::<()> )); + } + + // Create server instance with default options + let db_option = DBOption { + dir: self.base_dir.clone(), + port: 0, // Not used for RPC-managed databases + debug: false, + encryption_key: None, + encrypt: false, + backend: self.backend.clone(), }; - // Create server instance - let mut db_option = self.db_option.clone(); - db_option.dir = db_dir; + let mut server = Server::new(db_option).await; - let server = Server::new(db_option).await; + // Set the selected database to the db_id for proper file naming + server.selected_db = db_id; // Store the server let mut servers = self.servers.write().await; @@ -138,29 +140,14 @@ impl RpcServerImpl { Ok(Arc::new(server)) } - /// Discover existing database files in the base directory and RPC subdirectories + /// Discover existing database files in the base directory async fn discover_databases(&self) -> Vec { let mut db_ids = Vec::new(); if let Ok(entries) = std::fs::read_dir(&self.base_dir) { for entry in entries.flatten() { - let path = entry.path(); - - // Check if it's a directory starting with "rpc_db_" - if path.is_dir() { - if let Some(dir_name) = path.file_name().and_then(|n| n.to_str()) { - if dir_name.starts_with("rpc_db_") { - // Extract database ID from directory name (e.g., "rpc_db_1" -> 1) - if let Some(id_str) = dir_name.strip_prefix("rpc_db_") { - if let Ok(db_id) = id_str.parse::() { - db_ids.push(db_id); - } - } - } - } - } - // Also check for direct .db files (for main server databases) - else if let Some(file_name) = entry.file_name().to_str() { + if let Ok(file_name) = entry.file_name().into_string() { + // Check if it's a database file (ends with .db) if file_name.ends_with(".db") { // Extract database ID from filename (e.g., "11.db" -> 11) if let Some(id_str) = file_name.strip_suffix(".db") { @@ -177,11 +164,18 @@ impl RpcServerImpl { } /// Get the next available database ID - async fn get_next_db_id(&self) -> u64 { - let mut id = self.next_db_id.write().await; - let current_id = *id; - *id += 1; - current_id + async fn get_next_db_id(&self, is_encrypted: bool) -> u64 { + if is_encrypted { + let mut id = self.next_encrypted_id.write().await; + let current_id = *id; + *id += 1; + current_id + } else { + let mut id = self.next_unencrypted_id.write().await; + let current_id = *id; + *id += 1; + current_id + } } } @@ -193,14 +187,14 @@ impl RpcServer for RpcServerImpl { config: DatabaseConfig, encryption_key: Option, ) -> RpcResult { - let db_id = self.get_next_db_id().await; + let db_id = self.get_next_db_id(encryption_key.is_some()).await; // Handle both Redb and Sled backends match backend { BackendType::Redb | BackendType::Sled => { - // Always create RPC databases in subdirectories to avoid conflicts with main server + // Create database directory let db_dir = if let Some(path) = &config.storage_path { - std::path::PathBuf::from(path).join(format!("rpc_db_{}", db_id)) + std::path::PathBuf::from(path) } else { std::path::PathBuf::from(&self.base_dir).join(format!("rpc_db_{}", db_id)) }; @@ -228,7 +222,13 @@ impl RpcServer for RpcServerImpl { }; // Create server instance - let server = Server::new(option).await; + let mut server = Server::new(option).await; + + // Set the selected database to the db_id for proper file naming + server.selected_db = db_id; + + // Initialize the storage to create the database file + let _ = server.current_storage(); // Store the server let mut servers = self.servers.write().await; @@ -309,9 +309,16 @@ impl RpcServer for RpcServerImpl { async fn delete_database(&self, db_id: u64) -> RpcResult { let mut servers = self.servers.write().await; - if let Some(server) = servers.remove(&db_id) { - // TODO: Clean up database files - let _ = server; + if let Some(_server) = servers.remove(&db_id) { + // Clean up database files + let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id)); + if db_path.exists() { + if db_path.is_dir() { + std::fs::remove_dir_all(&db_path).ok(); + } else { + std::fs::remove_file(&db_path).ok(); + } + } Ok(true) } else { Ok(false) diff --git a/src/rpc_server.rs b/src/rpc_server.rs index a8e9048..88ab432 100644 --- a/src/rpc_server.rs +++ b/src/rpc_server.rs @@ -3,12 +3,11 @@ use jsonrpsee::server::{ServerBuilder, ServerHandle}; use jsonrpsee::RpcModule; use crate::rpc::{RpcServer, RpcServerImpl}; -use crate::options::DBOption; /// Start the RPC server on the specified address -pub async fn start_rpc_server(addr: SocketAddr, base_dir: String, db_option: DBOption) -> Result> { +pub async fn start_rpc_server(addr: SocketAddr, base_dir: String, backend: crate::options::BackendType) -> Result> { // Create the RPC server implementation - let rpc_impl = RpcServerImpl::new(base_dir, db_option); + let rpc_impl = RpcServerImpl::new(base_dir, backend); // Create the RPC module let mut module = RpcModule::new(()); @@ -36,18 +35,9 @@ mod tests { async fn test_rpc_server_startup() { let addr = "127.0.0.1:0".parse().unwrap(); // Use port 0 for auto-assignment let base_dir = "/tmp/test_rpc".to_string(); + let backend = crate::options::BackendType::Redb; // Default for test - // Create a dummy DBOption for testing - let db_option = crate::options::DBOption { - dir: base_dir.clone(), - port: 0, - debug: false, - encryption_key: None, - encrypt: false, - backend: crate::options::BackendType::Redb, - }; - - let handle = start_rpc_server(addr, base_dir, db_option).await.unwrap(); + let handle = start_rpc_server(addr, base_dir, backend).await.unwrap(); // Give the server a moment to start tokio::time::sleep(Duration::from_millis(100)).await;