Restore working code

This commit is contained in:
Maxime Van Hees
2025-09-11 18:33:09 +02:00
parent 9fa9832605
commit 8798bc202e
3 changed files with 72 additions and 70 deletions

View File

@@ -71,8 +71,13 @@ async fn main() {
}, },
}; };
let backend = option.backend.clone();
// new server // new server
let server = server::Server::new(option.clone()).await; let mut server = server::Server::new(option).await;
// Initialize the default database storage
let _ = server.current_storage();
// Add a small delay to ensure the port is ready // Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
@@ -82,7 +87,7 @@ async fn main() {
let rpc_addr = format!("127.0.0.1:{}", args.rpc_port).parse().unwrap(); let rpc_addr = format!("127.0.0.1:{}", args.rpc_port).parse().unwrap();
let base_dir = args.dir.clone(); let base_dir = args.dir.clone();
match rpc_server::start_rpc_server(rpc_addr, base_dir, option).await { match rpc_server::start_rpc_server(rpc_addr, base_dir, backend).await {
Ok(handle) => { Ok(handle) => {
println!("RPC management server started on port {}", args.rpc_port); println!("RPC management server started on port {}", args.rpc_port);
Some(handle) Some(handle)

View File

@@ -78,20 +78,23 @@ pub struct RpcServerImpl {
base_dir: String, base_dir: String,
/// Managed database servers /// Managed database servers
servers: Arc<RwLock<HashMap<u64, Arc<Server>>>>, servers: Arc<RwLock<HashMap<u64, Arc<Server>>>>,
/// Next database ID to assign /// Next unencrypted database ID to assign
next_db_id: Arc<RwLock<u64>>, next_unencrypted_id: Arc<RwLock<u64>>,
/// Database options (backend, encryption, etc.) /// Next encrypted database ID to assign
db_option: DBOption, next_encrypted_id: Arc<RwLock<u64>>,
/// Default backend type
backend: crate::options::BackendType,
} }
impl RpcServerImpl { impl RpcServerImpl {
/// Create a new RPC server instance /// Create a new RPC server instance
pub fn new(base_dir: String, db_option: DBOption) -> Self { pub fn new(base_dir: String, backend: crate::options::BackendType) -> Self {
Self { Self {
base_dir, base_dir,
servers: Arc::new(RwLock::new(HashMap::new())), servers: Arc::new(RwLock::new(HashMap::new())),
next_db_id: Arc::new(RwLock::new(0)), next_unencrypted_id: Arc::new(RwLock::new(0)),
db_option, next_encrypted_id: Arc::new(RwLock::new(10)),
backend,
} }
} }
@@ -105,31 +108,30 @@ impl RpcServerImpl {
} }
} }
// Check if database file exists (either direct or in RPC subdirectory) // Check if database file exists
let direct_db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id)); let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id));
let rpc_db_path = std::path::PathBuf::from(&self.base_dir) if !db_path.exists() {
.join(format!("rpc_db_{}", db_id))
.join("0.db");
let (db_path, db_dir) = if direct_db_path.exists() {
// Main server database
(direct_db_path, self.base_dir.clone())
} else if rpc_db_path.exists() {
// RPC database
(rpc_db_path, std::path::PathBuf::from(&self.base_dir).join(format!("rpc_db_{}", db_id)).to_string_lossy().to_string())
} else {
return Err(jsonrpsee::types::ErrorObjectOwned::owned( return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000, -32000,
format!("Database {} not found", db_id), format!("Database {} not found", db_id),
None::<()> None::<()>
)); ));
}
// Create server instance with default options
let db_option = DBOption {
dir: self.base_dir.clone(),
port: 0, // Not used for RPC-managed databases
debug: false,
encryption_key: None,
encrypt: false,
backend: self.backend.clone(),
}; };
// Create server instance let mut server = Server::new(db_option).await;
let mut db_option = self.db_option.clone();
db_option.dir = db_dir;
let server = Server::new(db_option).await; // Set the selected database to the db_id for proper file naming
server.selected_db = db_id;
// Store the server // Store the server
let mut servers = self.servers.write().await; let mut servers = self.servers.write().await;
@@ -138,29 +140,14 @@ impl RpcServerImpl {
Ok(Arc::new(server)) Ok(Arc::new(server))
} }
/// Discover existing database files in the base directory and RPC subdirectories /// Discover existing database files in the base directory
async fn discover_databases(&self) -> Vec<u64> { async fn discover_databases(&self) -> Vec<u64> {
let mut db_ids = Vec::new(); let mut db_ids = Vec::new();
if let Ok(entries) = std::fs::read_dir(&self.base_dir) { if let Ok(entries) = std::fs::read_dir(&self.base_dir) {
for entry in entries.flatten() { for entry in entries.flatten() {
let path = entry.path(); if let Ok(file_name) = entry.file_name().into_string() {
// Check if it's a database file (ends with .db)
// Check if it's a directory starting with "rpc_db_"
if path.is_dir() {
if let Some(dir_name) = path.file_name().and_then(|n| n.to_str()) {
if dir_name.starts_with("rpc_db_") {
// Extract database ID from directory name (e.g., "rpc_db_1" -> 1)
if let Some(id_str) = dir_name.strip_prefix("rpc_db_") {
if let Ok(db_id) = id_str.parse::<u64>() {
db_ids.push(db_id);
}
}
}
}
}
// Also check for direct .db files (for main server databases)
else if let Some(file_name) = entry.file_name().to_str() {
if file_name.ends_with(".db") { if file_name.ends_with(".db") {
// Extract database ID from filename (e.g., "11.db" -> 11) // Extract database ID from filename (e.g., "11.db" -> 11)
if let Some(id_str) = file_name.strip_suffix(".db") { if let Some(id_str) = file_name.strip_suffix(".db") {
@@ -177,11 +164,18 @@ impl RpcServerImpl {
} }
/// Get the next available database ID /// Get the next available database ID
async fn get_next_db_id(&self) -> u64 { async fn get_next_db_id(&self, is_encrypted: bool) -> u64 {
let mut id = self.next_db_id.write().await; if is_encrypted {
let current_id = *id; let mut id = self.next_encrypted_id.write().await;
*id += 1; let current_id = *id;
current_id *id += 1;
current_id
} else {
let mut id = self.next_unencrypted_id.write().await;
let current_id = *id;
*id += 1;
current_id
}
} }
} }
@@ -193,14 +187,14 @@ impl RpcServer for RpcServerImpl {
config: DatabaseConfig, config: DatabaseConfig,
encryption_key: Option<String>, encryption_key: Option<String>,
) -> RpcResult<u64> { ) -> RpcResult<u64> {
let db_id = self.get_next_db_id().await; let db_id = self.get_next_db_id(encryption_key.is_some()).await;
// Handle both Redb and Sled backends // Handle both Redb and Sled backends
match backend { match backend {
BackendType::Redb | BackendType::Sled => { BackendType::Redb | BackendType::Sled => {
// Always create RPC databases in subdirectories to avoid conflicts with main server // Create database directory
let db_dir = if let Some(path) = &config.storage_path { let db_dir = if let Some(path) = &config.storage_path {
std::path::PathBuf::from(path).join(format!("rpc_db_{}", db_id)) std::path::PathBuf::from(path)
} else { } else {
std::path::PathBuf::from(&self.base_dir).join(format!("rpc_db_{}", db_id)) std::path::PathBuf::from(&self.base_dir).join(format!("rpc_db_{}", db_id))
}; };
@@ -228,7 +222,13 @@ impl RpcServer for RpcServerImpl {
}; };
// Create server instance // Create server instance
let server = Server::new(option).await; let mut server = Server::new(option).await;
// Set the selected database to the db_id for proper file naming
server.selected_db = db_id;
// Initialize the storage to create the database file
let _ = server.current_storage();
// Store the server // Store the server
let mut servers = self.servers.write().await; let mut servers = self.servers.write().await;
@@ -309,9 +309,16 @@ impl RpcServer for RpcServerImpl {
async fn delete_database(&self, db_id: u64) -> RpcResult<bool> { async fn delete_database(&self, db_id: u64) -> RpcResult<bool> {
let mut servers = self.servers.write().await; let mut servers = self.servers.write().await;
if let Some(server) = servers.remove(&db_id) { if let Some(_server) = servers.remove(&db_id) {
// TODO: Clean up database files // Clean up database files
let _ = server; let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id));
if db_path.exists() {
if db_path.is_dir() {
std::fs::remove_dir_all(&db_path).ok();
} else {
std::fs::remove_file(&db_path).ok();
}
}
Ok(true) Ok(true)
} else { } else {
Ok(false) Ok(false)

View File

@@ -3,12 +3,11 @@ use jsonrpsee::server::{ServerBuilder, ServerHandle};
use jsonrpsee::RpcModule; use jsonrpsee::RpcModule;
use crate::rpc::{RpcServer, RpcServerImpl}; use crate::rpc::{RpcServer, RpcServerImpl};
use crate::options::DBOption;
/// Start the RPC server on the specified address /// Start the RPC server on the specified address
pub async fn start_rpc_server(addr: SocketAddr, base_dir: String, db_option: DBOption) -> Result<ServerHandle, Box<dyn std::error::Error + Send + Sync>> { pub async fn start_rpc_server(addr: SocketAddr, base_dir: String, backend: crate::options::BackendType) -> Result<ServerHandle, Box<dyn std::error::Error + Send + Sync>> {
// Create the RPC server implementation // Create the RPC server implementation
let rpc_impl = RpcServerImpl::new(base_dir, db_option); let rpc_impl = RpcServerImpl::new(base_dir, backend);
// Create the RPC module // Create the RPC module
let mut module = RpcModule::new(()); let mut module = RpcModule::new(());
@@ -36,18 +35,9 @@ mod tests {
async fn test_rpc_server_startup() { async fn test_rpc_server_startup() {
let addr = "127.0.0.1:0".parse().unwrap(); // Use port 0 for auto-assignment let addr = "127.0.0.1:0".parse().unwrap(); // Use port 0 for auto-assignment
let base_dir = "/tmp/test_rpc".to_string(); let base_dir = "/tmp/test_rpc".to_string();
let backend = crate::options::BackendType::Redb; // Default for test
// Create a dummy DBOption for testing let handle = start_rpc_server(addr, base_dir, backend).await.unwrap();
let db_option = crate::options::DBOption {
dir: base_dir.clone(),
port: 0,
debug: false,
encryption_key: None,
encrypt: false,
backend: crate::options::BackendType::Redb,
};
let handle = start_rpc_server(addr, base_dir, db_option).await.unwrap();
// Give the server a moment to start // Give the server a moment to start
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;