From bdf363016afc327b23b640ca845eba292ce341e6 Mon Sep 17 00:00:00 2001 From: Maxime Van Hees Date: Fri, 12 Sep 2025 17:11:50 +0200 Subject: [PATCH] WIP: adding access management control to db instances --- src/cmd.rs | 75 +++++++++++-- src/rpc.rs | 294 +++++++++++++++++++++++++++++++++++++++++++++++++- src/server.rs | 12 +++ 3 files changed, 373 insertions(+), 8 deletions(-) diff --git a/src/cmd.rs b/src/cmd.rs index 176ed2f..e3bcdda 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -6,7 +6,7 @@ use futures::future::select_all; pub enum Cmd { Ping, Echo(String), - Select(u64), // Changed from u16 to u64 + Select(u64, Option), // db_index, optional_key Get(String), Set(String, String), SetPx(String, String, u128), @@ -98,11 +98,18 @@ impl Cmd { Ok(( match cmd[0].to_lowercase().as_str() { "select" => { - if cmd.len() != 2 { + if cmd.len() < 2 || cmd.len() > 4 { return Err(DBError("wrong number of arguments for SELECT".to_string())); } let idx = cmd[1].parse::().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?; - Cmd::Select(idx) + let key = if cmd.len() == 4 && cmd[2].to_lowercase() == "key" { + Some(cmd[3].clone()) + } else if cmd.len() == 2 { + None + } else { + return Err(DBError("ERR syntax error".to_string())); + }; + Cmd::Select(idx, key) } "echo" => Cmd::Echo(cmd[1].clone()), "ping" => Cmd::Ping, @@ -642,7 +649,7 @@ impl Cmd { } match self { - Cmd::Select(db) => select_cmd(server, db).await, + Cmd::Select(db, key) => select_cmd(server, db, key).await, Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())), Cmd::Echo(s) => Ok(Protocol::BulkString(s)), Cmd::Get(k) => get_cmd(server, &k).await, @@ -736,7 +743,14 @@ impl Cmd { pub fn to_protocol(self) -> Protocol { match self { - Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]), + Cmd::Select(db, key) => { + let mut arr = vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]; + if let Some(k) = key { + arr.push(Protocol::BulkString("key".to_string())); + arr.push(Protocol::BulkString(k)); + } + Protocol::Array(arr) + } Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]), Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]), Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]), @@ -753,9 +767,44 @@ async fn flushdb_cmd(server: &mut Server) -> Result { } } -async fn select_cmd(server: &mut Server, db: u64) -> Result { - // Test if we can access the database (this will create it if needed) +async fn select_cmd(server: &mut Server, db: u64, key: Option) -> Result { + // Load database metadata + let meta = match crate::rpc::RpcServerImpl::load_meta_static(&server.option.dir, db).await { + Ok(m) => m, + Err(_) => { + // If meta doesn't exist, create default + let default_meta = crate::rpc::DatabaseMeta { + public: true, + keys: std::collections::HashMap::new(), + }; + if let Err(_) = crate::rpc::RpcServerImpl::save_meta_static(&server.option.dir, db, &default_meta).await { + return Ok(Protocol::err("ERR failed to initialize database metadata")); + } + default_meta + } + }; + + // Check access permissions + let permissions = if meta.public { + // Public database - full access + Some(crate::rpc::Permissions::ReadWrite) + } else if let Some(key_str) = key { + // Private database - check key + let hash = crate::rpc::hash_key(&key_str); + if let Some(access_key) = meta.keys.get(&hash) { + Some(access_key.permissions.clone()) + } else { + return Ok(Protocol::err("ERR invalid access key")); + } + } else { + return Ok(Protocol::err("ERR access key required for private database")); + }; + + // Set selected database and permissions server.selected_db = db; + server.current_permissions = permissions; + + // Test if we can access the database (this will create it if needed) match server.current_storage() { Ok(_) => Ok(Protocol::SimpleString("OK".to_string())), Err(e) => Ok(Protocol::err(&e.0)), @@ -1003,6 +1052,9 @@ async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul } async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result { + if !server.has_write_permission() { + return Ok(Protocol::err("ERR write permission denied")); + } match server.current_storage()?.lpush(key, elements.to_vec()) { Ok(len) => { // Attempt to deliver to any blocked BLPOP waiters @@ -1134,6 +1186,9 @@ async fn type_cmd(server: &Server, k: &String) -> Result { } async fn del_cmd(server: &Server, k: &str) -> Result { + if !server.has_write_permission() { + return Ok(Protocol::err("ERR write permission denied")); + } server.current_storage()?.del(k.to_string())?; Ok(Protocol::SimpleString("1".to_string())) } @@ -1159,6 +1214,9 @@ async fn set_px_cmd( } async fn set_cmd(server: &Server, k: &str, v: &str) -> Result { + if !server.has_write_permission() { + return Ok(Protocol::err("ERR write permission denied")); + } server.current_storage()?.set(k.to_string(), v.to_string())?; Ok(Protocol::SimpleString("OK".to_string())) } @@ -1273,6 +1331,9 @@ async fn get_cmd(server: &Server, k: &str) -> Result { // Hash command implementations async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result { + if !server.has_write_permission() { + return Ok(Protocol::err("ERR write permission denied")); + } let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?; Ok(Protocol::SimpleString(new_fields.to_string())) } diff --git a/src/rpc.rs b/src/rpc.rs index b791bec..afbf34c 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use tokio::sync::RwLock; use jsonrpsee::{core::RpcResult, proc_macros::rpc}; use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; use crate::server::Server; use crate::options::DBOption; @@ -39,6 +40,43 @@ pub struct DatabaseInfo { pub last_access: Option, } +/// Access permissions for database keys +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum Permissions { + Read, + ReadWrite, +} + +/// Access key information +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessKey { + pub hash: String, + pub permissions: Permissions, + pub created_at: u64, +} + +/// Database metadata containing access keys +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DatabaseMeta { + pub public: bool, + pub keys: HashMap, +} + +/// Access key information returned by RPC +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessKeyInfo { + pub hash: String, + pub permissions: Permissions, + pub created_at: u64, +} + +/// Hash a plaintext key using SHA-256 +pub fn hash_key(key: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(key.as_bytes()); + format!("{:x}", hasher.finalize()) +} + /// RPC trait for HeroDB management #[rpc(server, client, namespace = "herodb")] pub trait Rpc { @@ -70,6 +108,22 @@ pub trait Rpc { /// Get server statistics #[method(name = "getServerStats")] async fn get_server_stats(&self) -> RpcResult>; + + /// Add an access key to a database + #[method(name = "addAccessKey")] + async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult; + + /// Delete an access key from a database + #[method(name = "deleteAccessKey")] + async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult; + + /// List all access keys for a database + #[method(name = "listAccessKeys")] + async fn list_access_keys(&self, db_id: u64) -> RpcResult>; + + /// Set database public/private status + #[method(name = "setDatabasePublic")] + async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult; } /// RPC Server implementation @@ -84,6 +138,8 @@ pub struct RpcServerImpl { next_encrypted_id: Arc>, /// Default backend type backend: crate::options::BackendType, + /// Encryption keys for databases + encryption_keys: Arc>>>, } impl RpcServerImpl { @@ -95,6 +151,7 @@ impl RpcServerImpl { next_unencrypted_id: Arc::new(RwLock::new(0)), next_encrypted_id: Arc::new(RwLock::new(10)), backend, + encryption_keys: Arc::new(RwLock::new(HashMap::new())), } } @@ -177,6 +234,166 @@ impl RpcServerImpl { current_id } } + + /// Load database metadata from file (static version) + pub async fn load_meta_static(base_dir: &str, db_id: u64) -> Result { + let meta_path = std::path::PathBuf::from(base_dir).join(format!("{}_meta.json", db_id)); + + // If meta file doesn't exist, return default + if !meta_path.exists() { + return Ok(DatabaseMeta { + public: true, + keys: HashMap::new(), + }); + } + + // Read file + let content = std::fs::read(&meta_path) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to read meta file: {}", e), + None::<()> + ))?; + + let json_str = String::from_utf8(content) + .map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Invalid UTF-8 in meta file", + None::<()> + ))?; + + serde_json::from_str(&json_str) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to parse meta JSON: {}", e), + None::<()> + )) + } + + /// Load database metadata from file + async fn load_meta(&self, db_id: u64) -> Result { + let meta_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}_meta.json", db_id)); + + // If meta file doesn't exist, create default + if !meta_path.exists() { + let default_meta = DatabaseMeta { + public: true, + keys: HashMap::new(), + }; + self.save_meta(db_id, &default_meta).await?; + return Ok(default_meta); + } + + // Read and potentially decrypt + let content = std::fs::read(&meta_path) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to read meta file: {}", e), + None::<()> + ))?; + + let json_str = if db_id >= 10 { + // Encrypted database, decrypt meta + if let Some(key) = self.encryption_keys.read().await.get(&db_id).and_then(|k| k.as_ref()) { + use crate::crypto::CryptoFactory; + let crypto = CryptoFactory::new(key.as_bytes()); + String::from_utf8(crypto.decrypt(&content) + .map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Failed to decrypt meta file", + None::<()> + ))?) + .map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Invalid UTF-8 in decrypted meta", + None::<()> + ))? + } else { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Encryption key not found for encrypted database", + None::<()> + )); + } + } else { + String::from_utf8(content) + .map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Invalid UTF-8 in meta file", + None::<()> + ))? + }; + + serde_json::from_str(&json_str) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to parse meta JSON: {}", e), + None::<()> + )) + } + + /// Save database metadata to file (static version) + pub async fn save_meta_static(base_dir: &str, db_id: u64, meta: &DatabaseMeta) -> Result<(), jsonrpsee::types::ErrorObjectOwned> { + let meta_path = std::path::PathBuf::from(base_dir).join(format!("{}_meta.json", db_id)); + + let json_str = serde_json::to_string(meta) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to serialize meta: {}", e), + None::<()> + ))?; + + std::fs::write(&meta_path, json_str) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to write meta file: {}", e), + None::<()> + ))?; + + Ok(()) + } + + /// Save database metadata to file + async fn save_meta(&self, db_id: u64, meta: &DatabaseMeta) -> Result<(), jsonrpsee::types::ErrorObjectOwned> { + let meta_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}_meta.json", db_id)); + + let json_str = serde_json::to_string(meta) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to serialize meta: {}", e), + None::<()> + ))?; + + if db_id >= 10 { + // Encrypted database, encrypt meta + if let Some(key) = self.encryption_keys.read().await.get(&db_id).and_then(|k| k.as_ref()) { + use crate::crypto::CryptoFactory; + let crypto = CryptoFactory::new(key.as_bytes()); + let encrypted = crypto.encrypt(json_str.as_bytes()); + std::fs::write(&meta_path, encrypted) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to write encrypted meta file: {}", e), + None::<()> + ))?; + } else { + return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Encryption key not found for encrypted database", + None::<()> + )); + } + } else { + std::fs::write(&meta_path, json_str) + .map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + format!("Failed to write meta file: {}", e), + None::<()> + ))?; + } + + Ok(()) + } } #[jsonrpsee::core::async_trait] @@ -213,7 +430,7 @@ impl RpcServer for RpcServerImpl { dir: db_dir.to_string_lossy().to_string(), port: 0, // Not used for RPC-managed databases debug: false, - encryption_key, + encryption_key: encryption_key.clone(), encrypt, backend: match backend { BackendType::Redb => crate::options::BackendType::Redb, @@ -230,6 +447,19 @@ impl RpcServer for RpcServerImpl { // Initialize the storage to create the database file let _ = server.current_storage(); + // Store the encryption key + { + let mut keys = self.encryption_keys.write().await; + keys.insert(db_id, encryption_key.clone()); + } + + // Initialize meta file + let meta = DatabaseMeta { + public: true, + keys: HashMap::new(), + }; + self.save_meta(db_id, &meta).await?; + // Store the server let mut servers = self.servers.write().await; servers.insert(db_id, Arc::new(server)); @@ -339,4 +569,66 @@ impl RpcServer for RpcServerImpl { Ok(stats) } + + async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult { + let mut meta = self.load_meta(db_id).await?; + + let perms = match permissions.to_lowercase().as_str() { + "read" => Permissions::Read, + "readwrite" => Permissions::ReadWrite, + _ => return Err(jsonrpsee::types::ErrorObjectOwned::owned( + -32000, + "Invalid permissions: use 'read' or 'readwrite'", + None::<()> + )), + }; + + let hash = hash_key(&key); + let access_key = AccessKey { + hash: hash.clone(), + permissions: perms, + created_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + meta.keys.insert(hash, access_key); + self.save_meta(db_id, &meta).await?; + Ok(true) + } + + async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult { + let mut meta = self.load_meta(db_id).await?; + + if meta.keys.remove(&key_hash).is_some() { + // If no keys left, make database public + if meta.keys.is_empty() { + meta.public = true; + } + self.save_meta(db_id, &meta).await?; + Ok(true) + } else { + Ok(false) + } + } + + async fn list_access_keys(&self, db_id: u64) -> RpcResult> { + let meta = self.load_meta(db_id).await?; + let keys: Vec = meta.keys.values() + .map(|k| AccessKeyInfo { + hash: k.hash.clone(), + permissions: k.permissions.clone(), + created_at: k.created_at, + }) + .collect(); + Ok(keys) + } + + async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult { + let mut meta = self.load_meta(db_id).await?; + meta.public = public; + self.save_meta(db_id, &meta).await?; + Ok(true) + } } \ No newline at end of file diff --git a/src/server.rs b/src/server.rs index a6e43e2..63864c6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -22,6 +22,7 @@ pub struct Server { pub client_name: Option, pub selected_db: u64, // Changed from usize to u64 pub queued_cmd: Option>, + pub current_permissions: Option, // BLPOP waiter registry: per (db_index, key) FIFO of waiters pub list_waiters: Arc>>>>, @@ -48,6 +49,7 @@ impl Server { client_name: None, selected_db: 0, queued_cmd: None, + current_permissions: None, list_waiters: Arc::new(Mutex::new(HashMap::new())), waiter_seq: Arc::new(AtomicU64::new(1)), @@ -101,6 +103,16 @@ impl Server { self.option.encrypt && db_index >= 10 } + /// Check if current permissions allow read operations + pub fn has_read_permission(&self) -> bool { + matches!(self.current_permissions, Some(crate::rpc::Permissions::Read) | Some(crate::rpc::Permissions::ReadWrite)) + } + + /// Check if current permissions allow write operations + pub fn has_write_permission(&self) -> bool { + matches!(self.current_permissions, Some(crate::rpc::Permissions::ReadWrite)) + } + // ----- BLPOP waiter helpers ----- pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) {