9 Commits

Author SHA1 Message Date
Maxime Van Hees
bdf363016a WIP: adding access management control to db instances 2025-09-12 17:11:50 +02:00
Maxime Van Hees
8798bc202e Restore working code 2025-09-11 18:33:09 +02:00
Maxime Van Hees
9fa9832605 combined curret main (with sled) and RPC server 2025-09-11 17:23:46 +02:00
Maxime Van Hees
4bb24b38dd fix typo in README 2025-09-11 15:34:03 +02:00
Maxime Van Hees
f3da14b957 Merge branch 'append' 2025-09-11 15:31:47 +02:00
Maxime Van Hees
5ea34b4445 update variable name as 'gen' is a reserved keyword since Rust 2024 edition 2025-09-11 15:25:26 +02:00
Maxime Van Hees
d9a3b711d1 Update tot Rust 2024 edition + update Cargo.toml file 2025-09-11 15:24:28 +02:00
Maxime Van Hees
d931770e90 Fix test suite + update Cargo.toml file 2025-09-09 16:04:31 +02:00
Timur Gordon
a87ec4dbb5 add readme 2025-08-27 15:39:59 +02:00
11 changed files with 1778 additions and 36 deletions

926
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
[package] [package]
name = "herodb" name = "herodb"
version = "0.0.1" version = "0.0.1"
authors = ["Pin Fang <fpfangpin@hotmail.com>"] authors = ["ThreeFold Tech NV"]
edition = "2021" edition = "2024"
[dependencies] [dependencies]
anyhow = "1.0.59" anyhow = "1.0.59"
@@ -24,6 +24,7 @@ age = "0.10"
secrecy = "0.8" secrecy = "0.8"
ed25519-dalek = "2" ed25519-dalek = "2"
base64 = "0.22" base64 = "0.22"
jsonrpsee = { version = "0.26.0", features = ["http-client", "ws-client", "server", "macros"] }
[dev-dependencies] [dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] } redis = { version = "0.24", features = ["aio", "tokio-comp"] }

View File

@@ -47,13 +47,13 @@ You can start HeroDB with different backends and encryption options:
#### `redb` with Encryption #### `redb` with Encryption
```bash ```bash
./target/release/herodb --dir /tmp/herodb_encrypted --port 6379 --encrypt --key mysecretkey ./target/release/herodb --dir /tmp/herodb_encrypted --port 6379 --encrypt --encryption_key mysecretkey
``` ```
#### `sled` with Encryption #### `sled` with Encryption
```bash ```bash
./target/release/herodb --dir /tmp/herodb_sled_encrypted --port 6379 --sled --encrypt --key mysecretkey ./target/release/herodb --dir /tmp/herodb_sled_encrypted --port 6379 --sled --encrypt --encryption_key mysecretkey
``` ```
## Usage with Redis Clients ## Usage with Redis Clients

View File

@@ -6,7 +6,7 @@ use futures::future::select_all;
pub enum Cmd { pub enum Cmd {
Ping, Ping,
Echo(String), Echo(String),
Select(u64), // Changed from u16 to u64 Select(u64, Option<String>), // db_index, optional_key
Get(String), Get(String),
Set(String, String), Set(String, String),
SetPx(String, String, u128), SetPx(String, String, u128),
@@ -98,11 +98,18 @@ impl Cmd {
Ok(( Ok((
match cmd[0].to_lowercase().as_str() { match cmd[0].to_lowercase().as_str() {
"select" => { "select" => {
if cmd.len() != 2 { if cmd.len() < 2 || cmd.len() > 4 {
return Err(DBError("wrong number of arguments for SELECT".to_string())); return Err(DBError("wrong number of arguments for SELECT".to_string()));
} }
let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?; let idx = cmd[1].parse::<u64>().map_err(|_| DBError("ERR DB index is not an integer".to_string()))?;
Cmd::Select(idx) let key = if cmd.len() == 4 && cmd[2].to_lowercase() == "key" {
Some(cmd[3].clone())
} else if cmd.len() == 2 {
None
} else {
return Err(DBError("ERR syntax error".to_string()));
};
Cmd::Select(idx, key)
} }
"echo" => Cmd::Echo(cmd[1].clone()), "echo" => Cmd::Echo(cmd[1].clone()),
"ping" => Cmd::Ping, "ping" => Cmd::Ping,
@@ -642,7 +649,7 @@ impl Cmd {
} }
match self { match self {
Cmd::Select(db) => select_cmd(server, db).await, Cmd::Select(db, key) => select_cmd(server, db, key).await,
Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())), Cmd::Ping => Ok(Protocol::SimpleString("PONG".to_string())),
Cmd::Echo(s) => Ok(Protocol::BulkString(s)), Cmd::Echo(s) => Ok(Protocol::BulkString(s)),
Cmd::Get(k) => get_cmd(server, &k).await, Cmd::Get(k) => get_cmd(server, &k).await,
@@ -736,7 +743,14 @@ impl Cmd {
pub fn to_protocol(self) -> Protocol { pub fn to_protocol(self) -> Protocol {
match self { match self {
Cmd::Select(db) => Protocol::Array(vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())]), Cmd::Select(db, key) => {
let mut arr = vec![Protocol::BulkString("select".to_string()), Protocol::BulkString(db.to_string())];
if let Some(k) = key {
arr.push(Protocol::BulkString("key".to_string()));
arr.push(Protocol::BulkString(k));
}
Protocol::Array(arr)
}
Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]), Cmd::Ping => Protocol::Array(vec![Protocol::BulkString("ping".to_string())]),
Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]), Cmd::Echo(s) => Protocol::Array(vec![Protocol::BulkString("echo".to_string()), Protocol::BulkString(s)]),
Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]), Cmd::Get(k) => Protocol::Array(vec![Protocol::BulkString("get".to_string()), Protocol::BulkString(k)]),
@@ -753,9 +767,44 @@ async fn flushdb_cmd(server: &mut Server) -> Result<Protocol, DBError> {
} }
} }
async fn select_cmd(server: &mut Server, db: u64) -> Result<Protocol, DBError> { async fn select_cmd(server: &mut Server, db: u64, key: Option<String>) -> Result<Protocol, DBError> {
// Test if we can access the database (this will create it if needed) // Load database metadata
let meta = match crate::rpc::RpcServerImpl::load_meta_static(&server.option.dir, db).await {
Ok(m) => m,
Err(_) => {
// If meta doesn't exist, create default
let default_meta = crate::rpc::DatabaseMeta {
public: true,
keys: std::collections::HashMap::new(),
};
if let Err(_) = crate::rpc::RpcServerImpl::save_meta_static(&server.option.dir, db, &default_meta).await {
return Ok(Protocol::err("ERR failed to initialize database metadata"));
}
default_meta
}
};
// Check access permissions
let permissions = if meta.public {
// Public database - full access
Some(crate::rpc::Permissions::ReadWrite)
} else if let Some(key_str) = key {
// Private database - check key
let hash = crate::rpc::hash_key(&key_str);
if let Some(access_key) = meta.keys.get(&hash) {
Some(access_key.permissions.clone())
} else {
return Ok(Protocol::err("ERR invalid access key"));
}
} else {
return Ok(Protocol::err("ERR access key required for private database"));
};
// Set selected database and permissions
server.selected_db = db; server.selected_db = db;
server.current_permissions = permissions;
// Test if we can access the database (this will create it if needed)
match server.current_storage() { match server.current_storage() {
Ok(_) => Ok(Protocol::SimpleString("OK".to_string())), Ok(_) => Ok(Protocol::SimpleString("OK".to_string())),
Err(e) => Ok(Protocol::err(&e.0)), Err(e) => Ok(Protocol::err(&e.0)),
@@ -1003,6 +1052,9 @@ async fn brpop_cmd(server: &Server, keys: &[String], timeout_secs: f64) -> Resul
} }
async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> { async fn lpush_cmd(server: &Server, key: &str, elements: &[String]) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
match server.current_storage()?.lpush(key, elements.to_vec()) { match server.current_storage()?.lpush(key, elements.to_vec()) {
Ok(len) => { Ok(len) => {
// Attempt to deliver to any blocked BLPOP waiters // Attempt to deliver to any blocked BLPOP waiters
@@ -1134,6 +1186,9 @@ async fn type_cmd(server: &Server, k: &String) -> Result<Protocol, DBError> {
} }
async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> { async fn del_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
server.current_storage()?.del(k.to_string())?; server.current_storage()?.del(k.to_string())?;
Ok(Protocol::SimpleString("1".to_string())) Ok(Protocol::SimpleString("1".to_string()))
} }
@@ -1159,6 +1214,9 @@ async fn set_px_cmd(
} }
async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> { async fn set_cmd(server: &Server, k: &str, v: &str) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
server.current_storage()?.set(k.to_string(), v.to_string())?; server.current_storage()?.set(k.to_string(), v.to_string())?;
Ok(Protocol::SimpleString("OK".to_string())) Ok(Protocol::SimpleString("OK".to_string()))
} }
@@ -1273,6 +1331,9 @@ async fn get_cmd(server: &Server, k: &str) -> Result<Protocol, DBError> {
// Hash command implementations // Hash command implementations
async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> { async fn hset_cmd(server: &Server, key: &str, pairs: &[(String, String)]) -> Result<Protocol, DBError> {
if !server.has_write_permission() {
return Ok(Protocol::err("ERR write permission denied"));
}
let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?; let new_fields = server.current_storage()?.hset(key, pairs.to_vec())?;
Ok(Protocol::SimpleString(new_fields.to_string())) Ok(Protocol::SimpleString(new_fields.to_string()))
} }

View File

@@ -4,6 +4,8 @@ pub mod crypto;
pub mod error; pub mod error;
pub mod options; pub mod options;
pub mod protocol; pub mod protocol;
pub mod rpc;
pub mod rpc_server;
pub mod server; pub mod server;
pub mod storage; pub mod storage;
pub mod storage_trait; // Add this pub mod storage_trait; // Add this

View File

@@ -3,6 +3,7 @@
use tokio::net::TcpListener; use tokio::net::TcpListener;
use herodb::server; use herodb::server;
use herodb::rpc_server;
use clap::Parser; use clap::Parser;
@@ -31,6 +32,14 @@ struct Args {
#[arg(long)] #[arg(long)]
encrypt: bool, encrypt: bool,
/// Enable RPC management server
#[arg(long)]
enable_rpc: bool,
/// RPC server port (default: 8080)
#[arg(long, default_value = "8080")]
rpc_port: u16,
/// Use the sled backend /// Use the sled backend
#[arg(long)] #[arg(long)]
sled: bool, sled: bool,
@@ -50,7 +59,7 @@ async fn main() {
// new DB option // new DB option
let option = herodb::options::DBOption { let option = herodb::options::DBOption {
dir: args.dir, dir: args.dir.clone(),
port, port,
debug: args.debug, debug: args.debug,
encryption_key: args.encryption_key, encryption_key: args.encryption_key,
@@ -62,12 +71,36 @@ async fn main() {
}, },
}; };
let backend = option.backend.clone();
// new server // new server
let server = server::Server::new(option).await; let mut server = server::Server::new(option).await;
// Initialize the default database storage
let _ = server.current_storage();
// Add a small delay to ensure the port is ready // Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Start RPC server if enabled
let rpc_handle = if args.enable_rpc {
let rpc_addr = format!("127.0.0.1:{}", args.rpc_port).parse().unwrap();
let base_dir = args.dir.clone();
match rpc_server::start_rpc_server(rpc_addr, base_dir, backend).await {
Ok(handle) => {
println!("RPC management server started on port {}", args.rpc_port);
Some(handle)
}
Err(e) => {
eprintln!("Failed to start RPC server: {}", e);
None
}
}
} else {
None
};
// accept new connections // accept new connections
loop { loop {
let stream = listener.accept().await; let stream = listener.accept().await;

634
src/rpc.rs Normal file
View File

@@ -0,0 +1,634 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use jsonrpsee::{core::RpcResult, proc_macros::rpc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::server::Server;
use crate::options::DBOption;
/// Database backend types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BackendType {
Redb,
Sled,
// Future: InMemory, Custom(String)
}
/// Database configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub name: Option<String>,
pub storage_path: Option<String>,
pub max_size: Option<u64>,
pub redis_version: Option<String>,
}
/// Database information returned by metadata queries
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseInfo {
pub id: u64,
pub name: Option<String>,
pub backend: BackendType,
pub encrypted: bool,
pub redis_version: Option<String>,
pub storage_path: Option<String>,
pub size_on_disk: Option<u64>,
pub key_count: Option<u64>,
pub created_at: u64,
pub last_access: Option<u64>,
}
/// Access permissions for database keys
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Permissions {
Read,
ReadWrite,
}
/// Access key information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessKey {
pub hash: String,
pub permissions: Permissions,
pub created_at: u64,
}
/// Database metadata containing access keys
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseMeta {
pub public: bool,
pub keys: HashMap<String, AccessKey>,
}
/// Access key information returned by RPC
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessKeyInfo {
pub hash: String,
pub permissions: Permissions,
pub created_at: u64,
}
/// Hash a plaintext key using SHA-256
pub fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// RPC trait for HeroDB management
#[rpc(server, client, namespace = "herodb")]
pub trait Rpc {
/// Create a new database with specified configuration
#[method(name = "createDatabase")]
async fn create_database(
&self,
backend: BackendType,
config: DatabaseConfig,
encryption_key: Option<String>,
) -> RpcResult<u64>;
/// Set encryption for an existing database (write-only key)
#[method(name = "setEncryption")]
async fn set_encryption(&self, db_id: u64, encryption_key: String) -> RpcResult<bool>;
/// List all managed databases
#[method(name = "listDatabases")]
async fn list_databases(&self) -> RpcResult<Vec<DatabaseInfo>>;
/// Get detailed information about a specific database
#[method(name = "getDatabaseInfo")]
async fn get_database_info(&self, db_id: u64) -> RpcResult<DatabaseInfo>;
/// Delete a database
#[method(name = "deleteDatabase")]
async fn delete_database(&self, db_id: u64) -> RpcResult<bool>;
/// Get server statistics
#[method(name = "getServerStats")]
async fn get_server_stats(&self) -> RpcResult<HashMap<String, serde_json::Value>>;
/// Add an access key to a database
#[method(name = "addAccessKey")]
async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult<bool>;
/// Delete an access key from a database
#[method(name = "deleteAccessKey")]
async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult<bool>;
/// List all access keys for a database
#[method(name = "listAccessKeys")]
async fn list_access_keys(&self, db_id: u64) -> RpcResult<Vec<AccessKeyInfo>>;
/// Set database public/private status
#[method(name = "setDatabasePublic")]
async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult<bool>;
}
/// RPC Server implementation
pub struct RpcServerImpl {
/// Base directory for database files
base_dir: String,
/// Managed database servers
servers: Arc<RwLock<HashMap<u64, Arc<Server>>>>,
/// Next unencrypted database ID to assign
next_unencrypted_id: Arc<RwLock<u64>>,
/// Next encrypted database ID to assign
next_encrypted_id: Arc<RwLock<u64>>,
/// Default backend type
backend: crate::options::BackendType,
/// Encryption keys for databases
encryption_keys: Arc<RwLock<HashMap<u64, Option<String>>>>,
}
impl RpcServerImpl {
/// Create a new RPC server instance
pub fn new(base_dir: String, backend: crate::options::BackendType) -> Self {
Self {
base_dir,
servers: Arc::new(RwLock::new(HashMap::new())),
next_unencrypted_id: Arc::new(RwLock::new(0)),
next_encrypted_id: Arc::new(RwLock::new(10)),
backend,
encryption_keys: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Get or create a server instance for the given database ID
async fn get_or_create_server(&self, db_id: u64) -> Result<Arc<Server>, jsonrpsee::types::ErrorObjectOwned> {
// Check if server already exists
{
let servers = self.servers.read().await;
if let Some(server) = servers.get(&db_id) {
return Ok(server.clone());
}
}
// Check if database file exists
let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id));
if !db_path.exists() {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Database {} not found", db_id),
None::<()>
));
}
// Create server instance with default options
let db_option = DBOption {
dir: self.base_dir.clone(),
port: 0, // Not used for RPC-managed databases
debug: false,
encryption_key: None,
encrypt: false,
backend: self.backend.clone(),
};
let mut server = Server::new(db_option).await;
// Set the selected database to the db_id for proper file naming
server.selected_db = db_id;
// Store the server
let mut servers = self.servers.write().await;
servers.insert(db_id, Arc::new(server.clone()));
Ok(Arc::new(server))
}
/// Discover existing database files in the base directory
async fn discover_databases(&self) -> Vec<u64> {
let mut db_ids = Vec::new();
if let Ok(entries) = std::fs::read_dir(&self.base_dir) {
for entry in entries.flatten() {
if let Ok(file_name) = entry.file_name().into_string() {
// Check if it's a database file (ends with .db)
if file_name.ends_with(".db") {
// Extract database ID from filename (e.g., "11.db" -> 11)
if let Some(id_str) = file_name.strip_suffix(".db") {
if let Ok(db_id) = id_str.parse::<u64>() {
db_ids.push(db_id);
}
}
}
}
}
}
db_ids
}
/// Get the next available database ID
async fn get_next_db_id(&self, is_encrypted: bool) -> u64 {
if is_encrypted {
let mut id = self.next_encrypted_id.write().await;
let current_id = *id;
*id += 1;
current_id
} else {
let mut id = self.next_unencrypted_id.write().await;
let current_id = *id;
*id += 1;
current_id
}
}
/// Load database metadata from file (static version)
pub async fn load_meta_static(base_dir: &str, db_id: u64) -> Result<DatabaseMeta, jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(base_dir).join(format!("{}_meta.json", db_id));
// If meta file doesn't exist, return default
if !meta_path.exists() {
return Ok(DatabaseMeta {
public: true,
keys: HashMap::new(),
});
}
// Read file
let content = std::fs::read(&meta_path)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to read meta file: {}", e),
None::<()>
))?;
let json_str = String::from_utf8(content)
.map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Invalid UTF-8 in meta file",
None::<()>
))?;
serde_json::from_str(&json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to parse meta JSON: {}", e),
None::<()>
))
}
/// Load database metadata from file
async fn load_meta(&self, db_id: u64) -> Result<DatabaseMeta, jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}_meta.json", db_id));
// If meta file doesn't exist, create default
if !meta_path.exists() {
let default_meta = DatabaseMeta {
public: true,
keys: HashMap::new(),
};
self.save_meta(db_id, &default_meta).await?;
return Ok(default_meta);
}
// Read and potentially decrypt
let content = std::fs::read(&meta_path)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to read meta file: {}", e),
None::<()>
))?;
let json_str = if db_id >= 10 {
// Encrypted database, decrypt meta
if let Some(key) = self.encryption_keys.read().await.get(&db_id).and_then(|k| k.as_ref()) {
use crate::crypto::CryptoFactory;
let crypto = CryptoFactory::new(key.as_bytes());
String::from_utf8(crypto.decrypt(&content)
.map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Failed to decrypt meta file",
None::<()>
))?)
.map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Invalid UTF-8 in decrypted meta",
None::<()>
))?
} else {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Encryption key not found for encrypted database",
None::<()>
));
}
} else {
String::from_utf8(content)
.map_err(|_| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Invalid UTF-8 in meta file",
None::<()>
))?
};
serde_json::from_str(&json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to parse meta JSON: {}", e),
None::<()>
))
}
/// Save database metadata to file (static version)
pub async fn save_meta_static(base_dir: &str, db_id: u64, meta: &DatabaseMeta) -> Result<(), jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(base_dir).join(format!("{}_meta.json", db_id));
let json_str = serde_json::to_string(meta)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to serialize meta: {}", e),
None::<()>
))?;
std::fs::write(&meta_path, json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to write meta file: {}", e),
None::<()>
))?;
Ok(())
}
/// Save database metadata to file
async fn save_meta(&self, db_id: u64, meta: &DatabaseMeta) -> Result<(), jsonrpsee::types::ErrorObjectOwned> {
let meta_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}_meta.json", db_id));
let json_str = serde_json::to_string(meta)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to serialize meta: {}", e),
None::<()>
))?;
if db_id >= 10 {
// Encrypted database, encrypt meta
if let Some(key) = self.encryption_keys.read().await.get(&db_id).and_then(|k| k.as_ref()) {
use crate::crypto::CryptoFactory;
let crypto = CryptoFactory::new(key.as_bytes());
let encrypted = crypto.encrypt(json_str.as_bytes());
std::fs::write(&meta_path, encrypted)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to write encrypted meta file: {}", e),
None::<()>
))?;
} else {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Encryption key not found for encrypted database",
None::<()>
));
}
} else {
std::fs::write(&meta_path, json_str)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to write meta file: {}", e),
None::<()>
))?;
}
Ok(())
}
}
#[jsonrpsee::core::async_trait]
impl RpcServer for RpcServerImpl {
async fn create_database(
&self,
backend: BackendType,
config: DatabaseConfig,
encryption_key: Option<String>,
) -> RpcResult<u64> {
let db_id = self.get_next_db_id(encryption_key.is_some()).await;
// Handle both Redb and Sled backends
match backend {
BackendType::Redb | BackendType::Sled => {
// Create database directory
let db_dir = if let Some(path) = &config.storage_path {
std::path::PathBuf::from(path)
} else {
std::path::PathBuf::from(&self.base_dir).join(format!("rpc_db_{}", db_id))
};
// Ensure directory exists
std::fs::create_dir_all(&db_dir)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Failed to create directory: {}", e),
None::<()>
))?;
// Create DB options
let encrypt = encryption_key.is_some();
let option = DBOption {
dir: db_dir.to_string_lossy().to_string(),
port: 0, // Not used for RPC-managed databases
debug: false,
encryption_key: encryption_key.clone(),
encrypt,
backend: match backend {
BackendType::Redb => crate::options::BackendType::Redb,
BackendType::Sled => crate::options::BackendType::Sled,
},
};
// Create server instance
let mut server = Server::new(option).await;
// Set the selected database to the db_id for proper file naming
server.selected_db = db_id;
// Initialize the storage to create the database file
let _ = server.current_storage();
// Store the encryption key
{
let mut keys = self.encryption_keys.write().await;
keys.insert(db_id, encryption_key.clone());
}
// Initialize meta file
let meta = DatabaseMeta {
public: true,
keys: HashMap::new(),
};
self.save_meta(db_id, &meta).await?;
// Store the server
let mut servers = self.servers.write().await;
servers.insert(db_id, Arc::new(server));
Ok(db_id)
}
}
}
async fn set_encryption(&self, db_id: u64, _encryption_key: String) -> RpcResult<bool> {
// Note: In a real implementation, we'd need to modify the existing database
// For now, return false as encryption can only be set during creation
let _servers = self.servers.read().await;
// TODO: Implement encryption setting for existing databases
Ok(false)
}
async fn list_databases(&self) -> RpcResult<Vec<DatabaseInfo>> {
let db_ids = self.discover_databases().await;
let mut result = Vec::new();
for db_id in db_ids {
// Try to get or create server for this database
if let Ok(server) = self.get_or_create_server(db_id).await {
let backend = match server.option.backend {
crate::options::BackendType::Redb => BackendType::Redb,
crate::options::BackendType::Sled => BackendType::Sled,
};
let info = DatabaseInfo {
id: db_id,
name: None, // TODO: Store name in server metadata
backend,
encrypted: server.option.encrypt,
redis_version: Some("7.0".to_string()), // Default Redis compatibility
storage_path: Some(server.option.dir.clone()),
size_on_disk: None, // TODO: Calculate actual size
key_count: None, // TODO: Get key count from storage
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
last_access: None,
};
result.push(info);
}
}
Ok(result)
}
async fn get_database_info(&self, db_id: u64) -> RpcResult<DatabaseInfo> {
let server = self.get_or_create_server(db_id).await?;
let backend = match server.option.backend {
crate::options::BackendType::Redb => BackendType::Redb,
crate::options::BackendType::Sled => BackendType::Sled,
};
Ok(DatabaseInfo {
id: db_id,
name: None,
backend,
encrypted: server.option.encrypt,
redis_version: Some("7.0".to_string()),
storage_path: Some(server.option.dir.clone()),
size_on_disk: None,
key_count: None,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
last_access: None,
})
}
async fn delete_database(&self, db_id: u64) -> RpcResult<bool> {
let mut servers = self.servers.write().await;
if let Some(_server) = servers.remove(&db_id) {
// Clean up database files
let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id));
if db_path.exists() {
if db_path.is_dir() {
std::fs::remove_dir_all(&db_path).ok();
} else {
std::fs::remove_file(&db_path).ok();
}
}
Ok(true)
} else {
Ok(false)
}
}
async fn get_server_stats(&self) -> RpcResult<HashMap<String, serde_json::Value>> {
let db_ids = self.discover_databases().await;
let mut stats = HashMap::new();
stats.insert("total_databases".to_string(), serde_json::json!(db_ids.len()));
stats.insert("uptime".to_string(), serde_json::json!(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
));
Ok(stats)
}
async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult<bool> {
let mut meta = self.load_meta(db_id).await?;
let perms = match permissions.to_lowercase().as_str() {
"read" => Permissions::Read,
"readwrite" => Permissions::ReadWrite,
_ => return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Invalid permissions: use 'read' or 'readwrite'",
None::<()>
)),
};
let hash = hash_key(&key);
let access_key = AccessKey {
hash: hash.clone(),
permissions: perms,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
};
meta.keys.insert(hash, access_key);
self.save_meta(db_id, &meta).await?;
Ok(true)
}
async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult<bool> {
let mut meta = self.load_meta(db_id).await?;
if meta.keys.remove(&key_hash).is_some() {
// If no keys left, make database public
if meta.keys.is_empty() {
meta.public = true;
}
self.save_meta(db_id, &meta).await?;
Ok(true)
} else {
Ok(false)
}
}
async fn list_access_keys(&self, db_id: u64) -> RpcResult<Vec<AccessKeyInfo>> {
let meta = self.load_meta(db_id).await?;
let keys: Vec<AccessKeyInfo> = meta.keys.values()
.map(|k| AccessKeyInfo {
hash: k.hash.clone(),
permissions: k.permissions.clone(),
created_at: k.created_at,
})
.collect();
Ok(keys)
}
async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult<bool> {
let mut meta = self.load_meta(db_id).await?;
meta.public = public;
self.save_meta(db_id, &meta).await?;
Ok(true)
}
}

49
src/rpc_server.rs Normal file
View File

@@ -0,0 +1,49 @@
use std::net::SocketAddr;
use jsonrpsee::server::{ServerBuilder, ServerHandle};
use jsonrpsee::RpcModule;
use crate::rpc::{RpcServer, RpcServerImpl};
/// Start the RPC server on the specified address
pub async fn start_rpc_server(addr: SocketAddr, base_dir: String, backend: crate::options::BackendType) -> Result<ServerHandle, Box<dyn std::error::Error + Send + Sync>> {
// Create the RPC server implementation
let rpc_impl = RpcServerImpl::new(base_dir, backend);
// Create the RPC module
let mut module = RpcModule::new(());
module.merge(RpcServer::into_rpc(rpc_impl))?;
// Build the server with both HTTP and WebSocket support
let server = ServerBuilder::default()
.build(addr)
.await?;
// Start the server
let handle = server.start(module);
println!("RPC server started on {}", addr);
Ok(handle)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_rpc_server_startup() {
let addr = "127.0.0.1:0".parse().unwrap(); // Use port 0 for auto-assignment
let base_dir = "/tmp/test_rpc".to_string();
let backend = crate::options::BackendType::Redb; // Default for test
let handle = start_rpc_server(addr, base_dir, backend).await.unwrap();
// Give the server a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Stop the server
handle.stop().unwrap();
handle.stopped().await;
}
}

View File

@@ -22,6 +22,7 @@ pub struct Server {
pub client_name: Option<String>, pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64 pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>, pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
pub current_permissions: Option<crate::rpc::Permissions>,
// BLPOP waiter registry: per (db_index, key) FIFO of waiters // BLPOP waiter registry: per (db_index, key) FIFO of waiters
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>, pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
@@ -48,6 +49,7 @@ impl Server {
client_name: None, client_name: None,
selected_db: 0, selected_db: 0,
queued_cmd: None, queued_cmd: None,
current_permissions: None,
list_waiters: Arc::new(Mutex::new(HashMap::new())), list_waiters: Arc::new(Mutex::new(HashMap::new())),
waiter_seq: Arc::new(AtomicU64::new(1)), waiter_seq: Arc::new(AtomicU64::new(1)),
@@ -101,6 +103,16 @@ impl Server {
self.option.encrypt && db_index >= 10 self.option.encrypt && db_index >= 10
} }
/// Check if current permissions allow read operations
pub fn has_read_permission(&self) -> bool {
matches!(self.current_permissions, Some(crate::rpc::Permissions::Read) | Some(crate::rpc::Permissions::ReadWrite))
}
/// Check if current permissions allow write operations
pub fn has_write_permission(&self) -> bool {
matches!(self.current_permissions, Some(crate::rpc::Permissions::ReadWrite))
}
// ----- BLPOP waiter helpers ----- // ----- BLPOP waiter helpers -----
pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) { pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) {

62
tests/rpc_tests.rs Normal file
View File

@@ -0,0 +1,62 @@
use std::net::SocketAddr;
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::core::client::ClientT;
use serde_json::json;
use herodb::rpc::{RpcClient, BackendType, DatabaseConfig};
#[tokio::test]
async fn test_rpc_server_basic() {
// This test would require starting the RPC server in a separate thread
// For now, we'll just test that the types compile correctly
// Test serialization of types
let backend = BackendType::Redb;
let config = DatabaseConfig {
name: Some("test_db".to_string()),
storage_path: Some("/tmp/test".to_string()),
max_size: Some(1024 * 1024),
redis_version: Some("7.0".to_string()),
};
let backend_json = serde_json::to_string(&backend).unwrap();
let config_json = serde_json::to_string(&config).unwrap();
assert_eq!(backend_json, "\"Redb\"");
assert!(config_json.contains("test_db"));
}
#[tokio::test]
async fn test_database_config_serialization() {
let config = DatabaseConfig {
name: Some("my_db".to_string()),
storage_path: None,
max_size: Some(1000000),
redis_version: Some("7.0".to_string()),
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["name"], "my_db");
assert_eq!(json["max_size"], 1000000);
assert_eq!(json["redis_version"], "7.0");
}
#[tokio::test]
async fn test_backend_type_serialization() {
// Test that both Redb and Sled backends serialize correctly
let redb_backend = BackendType::Redb;
let sled_backend = BackendType::Sled;
let redb_json = serde_json::to_string(&redb_backend).unwrap();
let sled_json = serde_json::to_string(&sled_backend).unwrap();
assert_eq!(redb_json, "\"Redb\"");
assert_eq!(sled_json, "\"Sled\"");
// Test deserialization
let redb_deserialized: BackendType = serde_json::from_str(&redb_json).unwrap();
let sled_deserialized: BackendType = serde_json::from_str(&sled_json).unwrap();
assert!(matches!(redb_deserialized, BackendType::Redb));
assert!(matches!(sled_deserialized, BackendType::Sled));
}

View File

@@ -501,11 +501,11 @@ async fn test_07_age_stateless_suite() {
let mut s = connect(port).await; let mut s = connect(port).await;
// GENENC -> [recipient, identity] // GENENC -> [recipient, identity]
let gen = send_cmd(&mut s, &["AGE", "GENENC"]).await; let genenc = send_cmd(&mut s, &["AGE", "GENENC"]).await;
assert!( assert!(
gen.starts_with("*2\r\n$"), genenc.starts_with("*2\r\n$"),
"AGE GENENC should return array [recipient, identity], got:\n{}", "AGE GENENC should return array [recipient, identity], got:\n{}",
gen genenc
); );
// Parse simple RESP array of two bulk strings to extract keys // Parse simple RESP array of two bulk strings to extract keys
@@ -520,7 +520,7 @@ async fn test_07_age_stateless_suite() {
let ident = lines.next().unwrap_or("").to_string(); let ident = lines.next().unwrap_or("").to_string();
(recip, ident) (recip, ident)
} }
let (recipient, identity) = parse_two_bulk_array(&gen); let (recipient, identity) = parse_two_bulk_array(&genenc);
assert!( assert!(
recipient.starts_with("age1") && identity.starts_with("AGE-SECRET-KEY-1"), recipient.starts_with("age1") && identity.starts_with("AGE-SECRET-KEY-1"),
"Unexpected AGE key formats.\nrecipient: {}\nidentity: {}", "Unexpected AGE key formats.\nrecipient: {}\nidentity: {}",