fixed connection handling issue

This commit is contained in:
Maxime Van Hees
2025-09-10 11:53:01 +02:00
parent e84f7b7e3b
commit 271c6cb0ae
9 changed files with 154 additions and 94 deletions

View File

@@ -5,6 +5,7 @@ use tokio::sync::Mutex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use herodb::server; use herodb::server;
use herodb::server::Server;
use herodb::rpc_server; use herodb::rpc_server;
use clap::Parser; use clap::Parser;
@@ -98,8 +99,7 @@ async fn main() {
let sc = Arc::clone(&server); let sc = Arc::clone(&server);
tokio::spawn(async move { tokio::spawn(async move {
let mut server_guard = sc.lock().await; if let Err(e) = Server::handle(sc, stream).await {
if let Err(e) = server_guard.handle(stream).await {
println!("error: {:?}, will close the connection. Bye", e); println!("error: {:?}, will close the connection. Bye", e);
} }
}); });

View File

@@ -177,11 +177,34 @@ impl RpcServer for RpcServerImpl {
let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_index)); let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_index));
let file_exists = db_path.exists(); let file_exists = db_path.exists();
// Get file size if it exists // If database doesn't exist, return an error
let size_on_disk = if file_exists { if !file_exists && db_index != 0 {
std::fs::metadata(&db_path).ok().map(|m| m.len()) return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Database {} does not exist", db_index),
None::<()>
));
}
// Get file metadata if it exists
let (size_on_disk, created_at) = if file_exists {
if let Ok(metadata) = std::fs::metadata(&db_path) {
let size = Some(metadata.len());
let created = metadata.created()
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
(size, created)
} else {
(None, 0)
}
} else { } else {
None // Database 0 might not have a file yet
(None, std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs())
}; };
Ok(DatabaseInfo { Ok(DatabaseInfo {
@@ -193,10 +216,7 @@ impl RpcServer for RpcServerImpl {
storage_path: Some(self.base_dir.clone()), storage_path: Some(self.base_dir.clone()),
size_on_disk, size_on_disk,
key_count: None, // Would need to open DB to count keys key_count: None, // Would need to open DB to count keys
created_at: std::time::SystemTime::now() created_at,
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
last_access: None, last_access: None,
}) })
} }

View File

@@ -167,7 +167,7 @@ impl Server {
} }
pub async fn handle( pub async fn handle(
&mut self, server: Arc<Mutex<Server>>,
mut stream: tokio::net::TcpStream, mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> { ) -> Result<(), DBError> {
// Accumulate incoming bytes to handle partial RESP frames // Accumulate incoming bytes to handle partial RESP frames
@@ -205,31 +205,49 @@ impl Server {
// Advance the accumulator to the unparsed remainder // Advance the accumulator to the unparsed remainder
acc = remaining.to_string(); acc = remaining.to_string();
if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
} else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
}
// Check if this is a QUIT command before processing // Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit); let is_quit = matches!(cmd, Cmd::Quit);
let res = match cmd.run(self).await { // Lock the server only for command execution
Ok(p) => p, let (res, debug_info) = {
Err(e) => { let mut server_guard = server.lock().await;
if self.option.debug {
eprintln!("[run error] {:?}", e); if server_guard.option.debug {
} println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
Protocol::err(&format!("ERR {}", e.0)) } else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
} }
let res = match cmd.run(&mut server_guard).await {
Ok(p) => p,
Err(e) => {
if server_guard.option.debug {
eprintln!("[run error] {:?}", e);
}
Protocol::err(&format!("ERR {}", e.0))
}
};
let debug_info = if server_guard.option.debug {
Some((format!("queued cmd {:?}", server_guard.queued_cmd), format!("going to send response {}", res.encode())))
} else {
Some((format!("queued cmd {:?}", server_guard.queued_cmd), format!("going to send response {}", res.encode())))
};
(res, debug_info)
}; };
if self.option.debug { // Print debug info outside the lock
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd); if let Some((queued_info, response_info)) = debug_info {
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode()); if let Some((_, response)) = response_info.split_once("going to send response ") {
} else { if queued_info.contains("\x1b[34;1m") {
print!("queued cmd {:?}", self.queued_cmd); println!("\x1b[34;1m{}\x1b[0m", queued_info);
println!("going to send response {}", res.encode()); println!("\x1b[32;1mgoing to send response {}\x1b[0m", response);
} else {
println!("{}", queued_info);
println!("going to send response {}", response);
}
}
} }
_ = stream.write(res.encode().as_bytes()).await?; _ = stream.write(res.encode().as_bytes()).await?;

View File

@@ -1,4 +1,6 @@
use herodb::{server::Server, options::DBOption}; use herodb::{server::Server, options::DBOption};
use std::sync::Arc;
use tokio::sync::Mutex;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -29,17 +31,20 @@ async fn debug_hset_simple() {
encryption_key: None, encryption_key: None,
}; };
let mut server = Server::new(option).await; let server = Arc::new(Mutex::new(Server::new(option).await));
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let server_clone = Arc::clone(&server);
tokio::spawn(async move {
let _ = Server::handle(server_clone, stream).await;
});
} }
} }
}); });

View File

@@ -3,6 +3,8 @@ use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep; use tokio::time::sleep;
use std::sync::Arc;
use tokio::sync::Mutex;
#[tokio::test] #[tokio::test]
async fn debug_hset_return_value() { async fn debug_hset_return_value() {
@@ -20,17 +22,20 @@ async fn debug_hset_return_value() {
encryption_key: None, encryption_key: None,
}; };
let mut server = Server::new(option).await; let server = Arc::new(Mutex::new(Server::new(option).await));
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:16390") let listener = tokio::net::TcpListener::bind("127.0.0.1:16390")
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let server_clone = Arc::clone(&server);
tokio::spawn(async move {
let _ = Server::handle(server_clone, stream).await;
});
} }
} }
}); });

View File

@@ -1,21 +1,23 @@
use herodb::{server::Server, options::DBOption}; use herodb::{server::Server, options::DBOption};
use std::sync::Arc;
use tokio::sync::Mutex;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep; use tokio::time::sleep;
// Helper function to start a test server // Helper function to start a test server
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Arc<Mutex<Server>>, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379); static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name); let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up and create test directory // Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -23,8 +25,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false, encrypt: false,
encryption_key: None, encryption_key: None,
}; };
let server = Server::new(option).await; let server = Arc::new(Mutex::new(Server::new(option).await));
(server, port) (server, port)
} }
@@ -54,7 +56,7 @@ async fn send_command(stream: &mut TcpStream, command: &str) -> String {
#[tokio::test] #[tokio::test]
async fn test_basic_ping() { async fn test_basic_ping() {
let (mut server, port) = start_test_server("ping").await; let (server, port) = start_test_server("ping").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -64,7 +66,7 @@ async fn test_basic_ping() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -78,7 +80,7 @@ async fn test_basic_ping() {
#[tokio::test] #[tokio::test]
async fn test_string_operations() { async fn test_string_operations() {
let (mut server, port) = start_test_server("string").await; let (server, port) = start_test_server("string").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -88,7 +90,7 @@ async fn test_string_operations() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -120,7 +122,7 @@ async fn test_string_operations() {
#[tokio::test] #[tokio::test]
async fn test_incr_operations() { async fn test_incr_operations() {
let (mut server, port) = start_test_server("incr").await; let (server, port) = start_test_server("incr").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
@@ -129,7 +131,7 @@ async fn test_incr_operations() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -154,7 +156,7 @@ async fn test_incr_operations() {
#[tokio::test] #[tokio::test]
async fn test_hash_operations() { async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash").await; let (server, port) = start_test_server("hash").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
@@ -163,7 +165,7 @@ async fn test_hash_operations() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -229,7 +231,7 @@ async fn test_expiration() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -277,7 +279,7 @@ async fn test_scan_operations() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -313,7 +315,7 @@ async fn test_hscan_operations() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -345,7 +347,7 @@ async fn test_transaction_operations() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -388,7 +390,7 @@ async fn test_discard_transaction() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -425,7 +427,7 @@ async fn test_type_command() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -460,7 +462,7 @@ async fn test_config_commands() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -491,7 +493,7 @@ async fn test_info_command() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -520,7 +522,7 @@ async fn test_error_handling() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -558,7 +560,7 @@ async fn test_list_operations() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });

View File

@@ -1,23 +1,25 @@
use herodb::{server::Server, options::DBOption}; use herodb::{server::Server, options::DBOption};
use std::sync::Arc;
use tokio::sync::Mutex;
use std::time::Duration; use std::time::Duration;
use tokio::time::sleep; use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
// Helper function to start a test server with clean data directory // Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Arc<Mutex<Server>>, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000); static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000);
// Get a unique port for this test // Get a unique port for this test
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name); let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -25,8 +27,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false, encrypt: false,
encryption_key: None, encryption_key: None,
}; };
let server = Server::new(option).await; let server = Arc::new(Mutex::new(Server::new(option).await));
(server, port) (server, port)
} }
@@ -42,7 +44,7 @@ async fn send_redis_command(port: u16, command: &str) -> String {
#[tokio::test] #[tokio::test]
async fn test_basic_redis_functionality() { async fn test_basic_redis_functionality() {
let (mut server, port) = start_test_server("basic").await; let (server, port) = start_test_server("basic").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
@@ -53,7 +55,7 @@ async fn test_basic_redis_functionality() {
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..10 { for _ in 0..10 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -111,7 +113,7 @@ async fn test_basic_redis_functionality() {
#[tokio::test] #[tokio::test]
async fn test_hash_operations() { async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash_ops").await; let (server, port) = start_test_server("hash_ops").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
@@ -122,7 +124,7 @@ async fn test_hash_operations() {
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..5 { for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });
@@ -165,7 +167,7 @@ async fn test_hash_operations() {
#[tokio::test] #[tokio::test]
async fn test_transaction_operations() { async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transactions").await; let (server, port) = start_test_server("transactions").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
@@ -176,7 +178,7 @@ async fn test_transaction_operations() {
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..5 { for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = Server::handle(Arc::clone(&server), stream).await;
} }
} }
}); });

View File

@@ -1,21 +1,23 @@
use herodb::{server::Server, options::DBOption}; use herodb::{server::Server, options::DBOption};
use std::sync::Arc;
use tokio::sync::Mutex;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep; use tokio::time::sleep;
// Helper function to start a test server with clean data directory // Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Arc<Mutex<Server>>, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500); static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_simple_test_{}", test_name); let test_dir = format!("/tmp/herodb_simple_test_{}", test_name);
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -23,8 +25,8 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false, encrypt: false,
encryption_key: None, encryption_key: None,
}; };
let server = Server::new(option).await; let server = Arc::new(Mutex::new(Server::new(option).await));
(server, port) (server, port)
} }
@@ -54,7 +56,7 @@ async fn connect_to_server(port: u16) -> TcpStream {
#[tokio::test] #[tokio::test]
async fn test_basic_ping_simple() { async fn test_basic_ping_simple() {
let (mut server, port) = start_test_server("ping").await; let (server, port) = start_test_server("ping").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -64,7 +66,8 @@ async fn test_basic_ping_simple() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let server_clone = Arc::clone(&server);
let _ = Server::handle(server_clone, stream).await;
} }
} }
}); });
@@ -78,7 +81,7 @@ async fn test_basic_ping_simple() {
#[tokio::test] #[tokio::test]
async fn test_hset_clean_db() { async fn test_hset_clean_db() {
let (mut server, port) = start_test_server("hset_clean").await; let (server, port) = start_test_server("hset_clean").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -88,7 +91,8 @@ async fn test_hset_clean_db() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let server_clone = Arc::clone(&server);
let _ = Server::handle(server_clone, stream).await;
} }
} }
}); });
@@ -110,7 +114,7 @@ async fn test_hset_clean_db() {
#[tokio::test] #[tokio::test]
async fn test_type_command_simple() { async fn test_type_command_simple() {
let (mut server, port) = start_test_server("type").await; let (server, port) = start_test_server("type").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -120,7 +124,8 @@ async fn test_type_command_simple() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let server_clone = Arc::clone(&server);
let _ = Server::handle(server_clone, stream).await;
} }
} }
}); });
@@ -149,7 +154,7 @@ async fn test_type_command_simple() {
#[tokio::test] #[tokio::test]
async fn test_hexists_simple() { async fn test_hexists_simple() {
let (mut server, port) = start_test_server("hexists").await; let (server, port) = start_test_server("hexists").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -159,7 +164,8 @@ async fn test_hexists_simple() {
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let server_clone = Arc::clone(&server);
let _ = Server::handle(server_clone, stream).await;
} }
} }
}); });

View File

@@ -2,12 +2,14 @@ use herodb::{options::DBOption, server::Server};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
use std::sync::Arc;
use tokio::sync::Mutex;
// ========================= // =========================
// Helpers // Helpers
// ========================= // =========================
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Arc<Mutex<Server>>, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17100); static PORT_COUNTER: AtomicU16 = AtomicU16::new(17100);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
@@ -24,11 +26,11 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encryption_key: None, encryption_key: None,
}; };
let server = Server::new(option).await; let server = Arc::new(Mutex::new(Server::new(option).await));
(server, port) (server, port)
} }
async fn spawn_listener(server: Server, port: u16) { async fn spawn_listener(server: Arc<Mutex<Server>>, port: u16) {
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
@@ -36,9 +38,9 @@ async fn spawn_listener(server: Server, port: u16) {
loop { loop {
match listener.accept().await { match listener.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
let mut s_clone = server.clone(); let server_clone = Arc::clone(&server);
tokio::spawn(async move { tokio::spawn(async move {
let _ = s_clone.handle(stream).await; let _ = Server::handle(server_clone, stream).await;
}); });
} }
Err(_e) => break, Err(_e) => break,