Compare commits

..

31 Commits

Author SHA1 Message Date
Maxime Van Hees
892e6e2b90 Implemented EXPIREAT and PEXPIREAT 2025-08-19 16:21:43 +02:00
Maxime Van Hees
b9a9f3e6d6 Implemented DBSIZE 2025-08-19 16:05:25 +02:00
Maxime Van Hees
463000c8f7 Implemented BRPOP and minimal COMMAND DOCS stub, and wired side-aware waiter delivery 2025-08-19 15:52:36 +02:00
Maxime Van Hees
a92c90e9cb implemented HINCRBY/HINCRBYFLOAT + fixed partial-frame handling bug causing sporadic protocol parsing errors when sending or receiving large bulk strings (AGE ciphertext/signature) (happend because TCP segmentation can split a single RESP frame; both client & server assumed a single read would contain the whole frame) 2025-08-19 15:36:07 +02:00
Maxime Van Hees
34808fc1c9 Implemented EXPIRE/PEXPIRE/PERSIST 2025-08-19 11:34:04 +02:00
Maxime Van Hees
b644bf873f implement MGET/MSET and variadic DEL/EXISTS 2025-08-19 11:16:00 +02:00
Maxime Van Hees
a306544a34 implement COMMAND 2025-08-18 16:21:49 +02:00
Maxime Van Hees
afa1033cd6 implement BLPOP 2025-08-18 16:13:46 +02:00
9177fa4091 ... 2025-08-18 12:30:20 +02:00
51ab90c4ad ... 2025-08-16 18:24:46 +02:00
30a09e6d53 ... 2025-08-16 13:58:40 +02:00
542996a0ff ... 2025-08-16 13:33:56 +02:00
63ab39b4b1 ... 2025-08-16 11:22:01 +02:00
ee94d731d7 ... 2025-08-16 11:09:18 +02:00
c7945624bd ... 2025-08-16 10:53:48 +02:00
f8dd304820 it works 2025-08-16 10:41:26 +02:00
5eab3b080c ... 2025-08-16 10:28:28 +02:00
246304b9fa ... 2025-08-16 10:10:24 +02:00
074be114c3 ... 2025-08-16 09:55:34 +02:00
e51af83e45 ... 2025-08-16 09:52:36 +02:00
dbd0635cd9 ... 2025-08-16 09:50:56 +02:00
0000d82799 ... 2025-08-16 09:29:18 +02:00
5502ff4bc5 ... 2025-08-16 09:06:33 +02:00
0511dddd99 ... 2025-08-16 08:50:28 +02:00
bec9b20ec7 ... 2025-08-16 08:41:19 +02:00
ad255a9f51 ... 2025-08-16 08:28:52 +02:00
7bcb673361 ... 2025-08-16 08:25:25 +02:00
0f6e595000 ... 2025-08-16 07:54:55 +02:00
d3e28cafe4 ... 2025-08-16 07:23:20 +02:00
de2be4a785 ... 2025-08-16 07:18:55 +02:00
cd61406d1d ... 2025-08-16 06:58:04 +02:00
43 changed files with 10158 additions and 2 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
* text=auto

12
.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
# Generated by Cargo
# will have compiled files and executables
debug/
target/
.vscode/
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
dumb.rdb

2061
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

12
Cargo.toml Normal file
View File

@@ -0,0 +1,12 @@
[workspace]
members = [
"herodb",
"supervisor",
]
resolver = "2"
# You can define shared profiles for all workspace members here
[profile.release]
lto = true
codegen-units = 1
strip = true

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Pin Fang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,2 +0,0 @@
# herodb

28
herodb/Cargo.toml Normal file
View File

@@ -0,0 +1,28 @@
[package]
name = "herodb"
version = "0.0.1"
authors = ["Pin Fang <fpfangpin@hotmail.com>"]
edition = "2021"
[dependencies]
anyhow = "1.0.59"
bytes = "1.3.0"
thiserror = "1.0.32"
tokio = { version = "1.23.0", features = ["full"] }
clap = { version = "4.5.20", features = ["derive"] }
byteorder = "1.4.3"
futures = "0.3"
redb = "2.1.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
bincode = "1.3.3"
chacha20poly1305 = "0.10.1"
rand = "0.8"
sha2 = "0.10"
age = "0.10"
secrecy = "0.8"
ed25519-dalek = "2"
base64 = "0.22"
[dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] }

0
herodb/README.md Normal file
View File

9
herodb/build.sh Executable file
View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -euo pipefail
export SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
echo "I am in $SCRIPT_DIR"
cd "$SCRIPT_DIR"
cargo build

View File

@@ -0,0 +1,71 @@
#!/bin/bash
# Start the herodb server in the background
echo "Starting herodb server..."
cargo run -p herodb -- --dir /tmp/herodb_age_test --port 6382 --debug --encryption-key "testkey" &
SERVER_PID=$!
sleep 2 # Give the server a moment to start
REDIS_CLI="redis-cli -p 6382"
echo "--- Generating and Storing Encryption Keys ---"
# The new AGE commands are 'AGE KEYGEN <name>' etc., based on src/cmd.rs
# This script uses older commands like 'AGE.GENERATE_KEYPAIR alice'
# The demo script needs to be updated to match the implemented commands.
# Let's assume the commands in the script are what's expected for now,
# but note this discrepancy. The new commands are AGE KEYGEN etc.
# The script here uses a different syntax not found in src/cmd.rs like 'AGE.GENERATE_KEYPAIR'.
# For now, I will modify the script to fit the actual implementation.
echo "--- Generating and Storing Encryption Keys ---"
$REDIS_CLI AGE KEYGEN alice
$REDIS_CLI AGE KEYGEN bob
echo "--- Encrypting and Decrypting a Message ---"
MESSAGE="Hello, AGE encryption!"
# The new logic stores keys internally and does not expose a command to get the public key.
# We will encrypt by name.
ALICE_PUBKEY_REPLY=$($REDIS_CLI AGE KEYGEN alice | head -n 2 | tail -n 1)
echo "Alice's Public Key: $ALICE_PUBKEY_REPLY"
echo "Encrypting message: '$MESSAGE' with Alice's identity..."
# AGE.ENCRYPT recipient message. But since we use persistent keys, let's use ENCRYPTNAME
CIPHERTEXT=$($REDIS_CLI AGE ENCRYPTNAME alice "$MESSAGE")
echo "Ciphertext: $CIPHERTEXT"
echo "Decrypting ciphertext with Alice's private key..."
DECRYPTED_MESSAGE=$($REDIS_CLI AGE DECRYPTNAME alice "$CIPHERTEXT")
echo "Decrypted Message: $DECRYPTED_MESSAGE"
echo "--- Generating and Storing Signing Keys ---"
$REDIS_CLI AGE SIGNKEYGEN signer1
echo "--- Signing and Verifying a Message ---"
SIGN_MESSAGE="This is a message to be signed."
# Similar to above, we don't have GET_SIGN_PUBKEY. We will verify by name.
echo "Signing message: '$SIGN_MESSAGE' with signer1's private key..."
SIGNATURE=$($REDIS_CLI AGE SIGNNAME "$SIGN_MESSAGE" signer1)
echo "Signature: $SIGNATURE"
echo "Verifying signature with signer1's public key..."
VERIFY_RESULT=$($REDIS_CLI AGE VERIFYNAME signer1 "$SIGN_MESSAGE" "$SIGNATURE")
echo "Verification Result: $VERIFY_RESULT"
# There is no DELETE_KEYPAIR command in the implementation
echo "--- Cleaning up keys (manual in herodb) ---"
# We would use DEL for age:key:alice, etc.
$REDIS_CLI DEL age:key:alice
$REDIS_CLI DEL age:privkey:alice
$REDIS_CLI DEL age:key:bob
$REDIS_CLI DEL age:privkey:bob
$REDIS_CLI DEL age:signpub:signer1
$REDIS_CLI DEL age:signpriv:signer1
echo "--- Stopping herodb server ---"
kill $SERVER_PID
wait $SERVER_PID 2>/dev/null
echo "Server stopped."
echo "Bash demo complete."

View File

@@ -0,0 +1,83 @@
use std::io::{Read, Write};
use std::net::TcpStream;
// Minimal RESP helpers
fn arr(parts: &[&str]) -> String {
let mut out = format!("*{}\r\n", parts.len());
for p in parts {
out.push_str(&format!("${}\r\n{}\r\n", p.len(), p));
}
out
}
fn read_reply(s: &mut TcpStream) -> String {
let mut buf = [0u8; 65536];
let n = s.read(&mut buf).unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}
fn parse_two_bulk(reply: &str) -> Option<(String,String)> {
let mut lines = reply.split("\r\n");
if lines.next()? != "*2" { return None; }
let _n = lines.next()?;
let a = lines.next()?.to_string();
let _m = lines.next()?;
let b = lines.next()?.to_string();
Some((a,b))
}
fn parse_bulk(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('$') { return None; }
Some(lines.next()?.to_string())
}
fn parse_simple(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('+') { return None; }
Some(hdr[1..].to_string())
}
fn main() {
let mut args = std::env::args().skip(1);
let host = args.next().unwrap_or_else(|| "127.0.0.1".into());
let port = args.next().unwrap_or_else(|| "6379".into());
let addr = format!("{host}:{port}");
println!("Connecting to {addr}...");
let mut s = TcpStream::connect(addr).expect("connect");
// Generate & persist X25519 enc keys under name "alice"
s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap();
let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc");
// Generate & persist Ed25519 signing key under name "signer"
s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap();
let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign");
// Encrypt by name
let msg = "hello from persistent keys";
s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap();
let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64");
println!("ciphertext b64: {}", ct_b64);
// Decrypt by name
s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap();
let pt = parse_bulk(&read_reply(&mut s)).expect("pt");
assert_eq!(pt, msg);
println!("decrypted ok");
// Sign by name
s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap();
let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64");
// Verify by name
s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap();
let ok = parse_simple(&read_reply(&mut s)).expect("verify");
assert_eq!(ok, "1");
println!("signature verified");
// List names
s.write_all(arr(&["age","list"]).as_bytes()).unwrap();
let list = read_reply(&mut s);
println!("LIST -> {list}");
println!("✔ persistent AGE workflow complete.");
}

View File

@@ -0,0 +1,99 @@
### Cargo.toml
```toml
[dependencies]
chacha20poly1305 = { version = "0.10", features = ["xchacha20"] }
rand = "0.8"
sha2 = "0.10"
```
### `crypto_factory.rs`
```rust
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
XChaCha20Poly1305, Key, XNonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
const VERSION: u8 = 1;
const NONCE_LEN: usize = 24;
const TAG_LEN: usize = 16;
#[derive(Debug)]
pub enum CryptoError {
Format, // wrong length / header
Version(u8), // unknown version
Decrypt, // wrong key or corrupted data
}
/// Super-simple factory: new(secret) + encrypt(bytes) + decrypt(bytes)
pub struct CryptoFactory {
key: Key<XChaCha20Poly1305>,
}
impl CryptoFactory {
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
/// (If your secret is already 32 bytes, this is still fine.)
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
let mut h = Sha256::new();
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
h.update(secret.as_ref());
let digest = h.finalize(); // 32 bytes
let key = Key::<XChaCha20Poly1305>::from_slice(&digest).to_owned();
Self { key }
}
/// Output layout: [version:1][nonce:24][ciphertext||tag]
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
let cipher = XChaCha20Poly1305::new(&self.key);
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
out.push(VERSION);
out.extend_from_slice(&nonce_bytes);
let ct = cipher.encrypt(nonce, plaintext).expect("encrypt");
out.extend_from_slice(&ct);
out
}
pub fn decrypt(&self, blob: &[u8]) -> Result<Vec<u8>, CryptoError> {
if blob.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(CryptoError::Format);
}
let ver = blob[0];
if ver != VERSION {
return Err(CryptoError::Version(ver));
}
let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
let ct = &blob[1 + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
}
}
```
### Tiny usage example
```rust
fn main() {
let f = CryptoFactory::new(b"super-secret-key-material");
let val = b"\x00\xFFbinary\x01\x02\x03";
let blob = f.encrypt(val);
let roundtrip = f.decrypt(&blob).unwrap();
assert_eq!(roundtrip, val);
}
```
Thats it: `new(secret)`, `encrypt(bytes)`, `decrypt(bytes)`.
You can stash the returned `blob` directly in your storage layer behind Redis.

View File

@@ -0,0 +1,80 @@
========================
CODE SNIPPETS
========================
TITLE: 1PC+C Commit Strategy Vulnerability Example
DESCRIPTION: Illustrates a scenario where a partially committed transaction might appear complete due to the non-cryptographic checksum (XXH3) used in the 1PC+C commit strategy. This requires controlling page flush order, introducing a crash during fsync, and ensuring valid checksums for partially written data.
SOURCE: https://github.com/cberner/redb/blob/master/docs/design.md#_snippet_9
LANGUAGE: rust
CODE:
```
table.insert(malicious_key, malicious_value);
table.insert(good_key, good_value);
txn.commit();
```
LANGUAGE: rust
CODE:
```
table.insert(malicious_key, malicious_value);
txn.commit();
```
----------------------------------------
TITLE: Basic Key-Value Operations in redb
DESCRIPTION: Demonstrates the fundamental usage of redb for creating a database, opening a table, inserting a key-value pair, and retrieving the value within separate read and write transactions.
SOURCE: https://github.com/cberner/redb/blob/master/README.md#_snippet_0
LANGUAGE: rust
CODE:
```
use redb::{Database, Error, ReadableTable, TableDefinition};
const TABLE: TableDefinition<&str, u64> = TableDefinition::new("my_data");
fn main() -> Result<(), Error> {
let db = Database::create("my_db.redb")?;
let write_txn = db.begin_write()?;
{
let mut table = write_txn.open_table(TABLE)?;
table.insert("my_key", &123)?;
}
write_txn.commit()?;
let read_txn = db.begin_read()?;
let table = read_txn.open_table(TABLE)?;
assert_eq!(table.get("my_key")?.unwrap().value(), 123);
Ok(())
}
```
## What *redb* currently supports:
* Simple operations like creating databases, inserting key-value pairs, opening and reading tables ([GitHub][1]).
* No mention of operations such as:
* Iterating over keys with a given prefix.
* Range queries based on string prefixes.
* Specialized prefixfiltered lookups.
## implement range scans as follows
You can implement prefix-like functionality using **range scans** combined with manual checks, similar to using a `BTreeSet` in Rust:
```rust
for key in table.range(prefix..).keys() {
if !key.starts_with(prefix) {
break;
}
// process key
}
```
This pattern iterates keys starting at the prefix, and stops once a key no longer matches the prefix—this works because the keys are sorted ([GitHub][1]).

View File

@@ -0,0 +1,150 @@
]
# INFO
**What it does**
Returns server stats in a human-readable text block, optionally filtered by sections. Typical sections: `server`, `clients`, `memory`, `persistence`, `stats`, `replication`, `cpu`, `commandstats`, `latencystats`, `cluster`, `modules`, `keyspace`, `errorstats`. Special args: `all`, `default`, `everything`. The reply is a **Bulk String** with `# <Section>` headers and `key:value` lines. ([Redis][1])
**Syntax**
```
INFO [section [section ...]]
```
**Return (RESP2/RESP3)**: Bulk String. ([Redis][1])
**RESP request/response**
```
# Request: whole default set
*1\r\n$4\r\nINFO\r\n
# Request: a specific section, e.g., clients
*2\r\n$4\r\nINFO\r\n$7\r\nclients\r\n
# Response (prefix shown; body is long)
$1234\r\n# Server\r\nredis_version:7.4.0\r\n...\r\n# Clients\r\nconnected_clients:3\r\n...\r\n
```
(Reply type/format per RESP spec and the INFO page.) ([Redis][2])
---
# Connection “name” (there is **no** top-level `NAME` command)
Redis doesnt have a standalone `NAME` command. Connection names are handled via `CLIENT SETNAME` and retrieved via `CLIENT GETNAME`. ([Redis][3])
## CLIENT SETNAME
Assigns a human label to the current connection (shown in `CLIENT LIST`, logs, etc.). No spaces allowed in the name; empty string clears it. Length is limited by Redis string limits (practically huge). **Reply**: Simple String `OK`. ([Redis][4])
**Syntax**
```
CLIENT SETNAME connection-name
```
**RESP**
```
# Set the name "myapp"
*3\r\n$6\r\nCLIENT\r\n$7\r\nSETNAME\r\n$5\r\nmyapp\r\n
# Reply
+OK\r\n
```
## CLIENT GETNAME
Returns the current connections name or **Null Bulk String** if unset. ([Redis][5])
**Syntax**
```
CLIENT GETNAME
```
**RESP**
```
# Before SETNAME:
*2\r\n$6\r\nCLIENT\r\n$7\r\nGETNAME\r\n
$-1\r\n # nil (no name)
# After SETNAME myapp:
*2\r\n$6\r\nCLIENT\r\n$7\r\nGETNAME\r\n
$5\r\nmyapp\r\n
```
(Null/Bulk String encoding per RESP spec.) ([Redis][2])
---
# CLIENT (container command + key subcommands)
`CLIENT` is a **container**; use subcommands like `CLIENT LIST`, `CLIENT INFO`, `CLIENT ID`, `CLIENT KILL`, `CLIENT TRACKING`, etc. Call `CLIENT HELP` to enumerate them. ([Redis][3])
## CLIENT LIST
Shows all connections as a single **Bulk String**: one line per client with `field=value` pairs (includes `id`, `addr`, `name`, `db`, `user`, `resp`, and more). Filters: `TYPE` and `ID`. **Return**: Bulk String (RESP2/RESP3). ([Redis][6])
**Syntax**
```
CLIENT LIST [TYPE <NORMAL|MASTER|REPLICA|PUBSUB>] [ID client-id ...]
```
**RESP**
```
*2\r\n$6\r\nCLIENT\r\n$4\r\nLIST\r\n
# Reply (single Bulk String; example with one line shown)
$188\r\nid=7 addr=127.0.0.1:60840 laddr=127.0.0.1:6379 fd=8 name=myapp age=12 idle=3 flags=N db=0 ...\r\n
```
## CLIENT INFO
Returns info for **this** connection only (same format/fields as a single line of `CLIENT LIST`). **Return**: Bulk String. Available since 6.2.0. ([Redis][7])
**Syntax**
```
CLIENT INFO
```
**RESP**
```
*2\r\n$6\r\nCLIENT\r\n$4\r\nINFO\r\n
$160\r\nid=7 addr=127.0.0.1:60840 laddr=127.0.0.1:6379 fd=8 name=myapp db=0 user=default resp=2 ...\r\n
```
---
# RESP notes youll need for your parser
* **Requests** are Arrays: `*N\r\n` followed by `N` Bulk Strings for verb/args.
* **Common replies here**: Simple String (`+OK\r\n`), Bulk String (`$<len>\r\n...\r\n`), and **Null Bulk String** (`$-1\r\n`). (These cover `INFO`, `CLIENT LIST/INFO`, `CLIENT GETNAME`, `CLIENT SETNAME`.) ([Redis][2])
---
## Sources (checked)
* INFO command (syntax, sections, behavior). ([Redis][1])
* RESP spec (request/response framing, Bulk/Null Bulk Strings). ([Redis][2])
* CLIENT container + subcommands index. ([Redis][3])
* CLIENT LIST (fields, bulk-string return, filters). ([Redis][6])
* CLIENT INFO (exists since 6.2, reply format). ([Redis][7])
* CLIENT SETNAME (no spaces; clears with empty string; huge length OK). ([Redis][4])
* CLIENT GETNAME (nil if unset). ([Redis][5])
If you want, I can fold this into a tiny Rust “command + RESP” test harness that exercises `INFO`, `CLIENT SETNAME/GETNAME`, `CLIENT LIST`, and `CLIENT INFO` against your in-mem RESP parser.
[1]: https://redis.io/docs/latest/commands/info/ "INFO | Docs"
[2]: https://redis.io/docs/latest/develop/reference/protocol-spec/?utm_source=chatgpt.com "Redis serialization protocol specification | Docs"
[3]: https://redis.io/docs/latest/commands/client/ "CLIENT | Docs"
[4]: https://redis.io/docs/latest/commands/client-setname/?utm_source=chatgpt.com "CLIENT SETNAME | Docs"
[5]: https://redis.io/docs/latest/commands/client-getname/?utm_source=chatgpt.com "CLIENT GETNAME | Docs"
[6]: https://redis.io/docs/latest/commands/client-list/ "CLIENT LIST | Docs"
[7]: https://redis.io/docs/latest/commands/client-info/?utm_source=chatgpt.com "CLIENT INFO | Docs"

View File

@@ -0,0 +1,251 @@
Got it 👍 — lets break this down properly.
Redis has two broad classes youre asking about:
1. **Basic key-space functions** (SET, GET, DEL, EXISTS, etc.)
2. **Iteration commands** (`SCAN`, `SSCAN`, `HSCAN`, `ZSCAN`)
And for each Ill show:
* What it does
* How it works at a high level
* Its **RESP protocol implementation** (the actual wire format).
---
# 1. Basic Key-Space Commands
### `SET key value`
* Stores a string value at a key.
* Overwrites if the key already exists.
**Protocol (RESP2):**
```
*3
$3
SET
$3
foo
$3
bar
```
(client sends: array of 3 bulk strings: `["SET", "foo", "bar"]`)
**Reply:**
```
+OK
```
---
### `GET key`
* Retrieves the string value stored at the key.
* Returns `nil` if key doesnt exist.
**Protocol:**
```
*2
$3
GET
$3
foo
```
**Reply:**
```
$3
bar
```
(or `$-1` for nil)
---
### `DEL key [key ...]`
* Removes one or more keys.
* Returns number of keys actually removed.
**Protocol:**
```
*2
$3
DEL
$3
foo
```
**Reply:**
```
:1
```
(integer reply = number of deleted keys)
---
### `EXISTS key [key ...]`
* Checks if one or more keys exist.
* Returns count of existing keys.
**Protocol:**
```
*2
$6
EXISTS
$3
foo
```
**Reply:**
```
:1
```
---
### `KEYS pattern`
* Returns all keys matching a glob-style pattern.
⚠️ Not efficient in production (O(N)), better to use `SCAN`.
**Protocol:**
```
*2
$4
KEYS
$1
*
```
**Reply:**
```
*2
$3
foo
$3
bar
```
(array of bulk strings with key names)
---
# 2. Iteration Commands (`SCAN` family)
### `SCAN cursor [MATCH pattern] [COUNT n]`
* Iterates the keyspace incrementally.
* Client keeps sending back the cursor from previous call until it returns `0`.
**Protocol example:**
```
*2
$4
SCAN
$1
0
```
**Reply:**
```
*2
$1
0
*2
$3
foo
$3
bar
```
Explanation:
* First element = new cursor (`"0"` means iteration finished).
* Second element = array of keys returned in this batch.
---
### `HSCAN key cursor [MATCH pattern] [COUNT n]`
* Like `SCAN`, but iterates fields of a hash.
**Protocol:**
```
*3
$5
HSCAN
$3
myh
$1
0
```
**Reply:**
```
*2
$1
0
*4
$5
field
$5
value
$5
age
$2
42
```
(Array of alternating field/value pairs)
---
### `SSCAN key cursor [MATCH pattern] [COUNT n]`
* Iterates members of a set.
Protocol and reply structure same as SCAN.
---
### `ZSCAN key cursor [MATCH pattern] [COUNT n]`
* Iterates members of a sorted set with scores.
* Returns alternating `member`, `score`.
---
# Quick Comparison
| Command | Purpose | Return Type |
| -------- | ----------------------------- | --------------------- |
| `SET` | Store a string value | Simple string `+OK` |
| `GET` | Retrieve a string value | Bulk string / nil |
| `DEL` | Delete keys | Integer (count) |
| `EXISTS` | Check existence | Integer (count) |
| `KEYS` | List all matching keys (slow) | Array of bulk strings |
| `SCAN` | Iterate over keys (safe) | `[cursor, array]` |
| `HSCAN` | Iterate over hash fields | `[cursor, array]` |
| `SSCAN` | Iterate over set members | `[cursor, array]` |
| `ZSCAN` | Iterate over sorted set | `[cursor, array]` |
##

View File

@@ -0,0 +1,307 @@
# 🔑 Redis `HSET` and Related Hash Commands
## 1. `HSET`
* **Purpose**: Set the value of one or more fields in a hash.
* **Syntax**:
```bash
HSET key field value [field value ...]
```
* **Return**:
* Integer: number of fields that were newly added.
* **RESP Protocol**:
```
*4
$4
HSET
$3
key
$5
field
$5
value
```
(If multiple field-value pairs: `*6`, `*8`, etc.)
---
## 2. `HSETNX`
* **Purpose**: Set the value of a hash field only if it does **not** exist.
* **Syntax**:
```bash
HSETNX key field value
```
* **Return**:
* `1` if field was set.
* `0` if field already exists.
* **RESP Protocol**:
```
*4
$6
HSETNX
$3
key
$5
field
$5
value
```
---
## 3. `HGET`
* **Purpose**: Get the value of a hash field.
* **Syntax**:
```bash
HGET key field
```
* **Return**:
* Bulk string (value) or `nil` if field does not exist.
* **RESP Protocol**:
```
*3
$4
HGET
$3
key
$5
field
```
---
## 4. `HGETALL`
* **Purpose**: Get all fields and values in a hash.
* **Syntax**:
```bash
HGETALL key
```
* **Return**:
* Array of `[field1, value1, field2, value2, ...]`.
* **RESP Protocol**:
```
*2
$7
HGETALL
$3
key
```
---
## 5. `HMSET` (⚠️ Deprecated, use `HSET`)
* **Purpose**: Set multiple field-value pairs.
* **Syntax**:
```bash
HMSET key field value [field value ...]
```
* **Return**:
* Always `OK`.
* **RESP Protocol**:
```
*6
$5
HMSET
$3
key
$5
field
$5
value
$5
field2
$5
value2
```
---
## 6. `HMGET`
* **Purpose**: Get values of multiple fields.
* **Syntax**:
```bash
HMGET key field [field ...]
```
* **Return**:
* Array of values (bulk strings or nils).
* **RESP Protocol**:
```
*4
$5
HMGET
$3
key
$5
field1
$5
field2
```
---
## 7. `HDEL`
* **Purpose**: Delete one or more fields from a hash.
* **Syntax**:
```bash
HDEL key field [field ...]
```
* **Return**:
* Integer: number of fields removed.
* **RESP Protocol**:
```
*3
$4
HDEL
$3
key
$5
field
```
---
## 8. `HEXISTS`
* **Purpose**: Check if a field exists.
* **Syntax**:
```bash
HEXISTS key field
```
* **Return**:
* `1` if exists, `0` if not.
* **RESP Protocol**:
```
*3
$7
HEXISTS
$3
key
$5
field
```
---
## 9. `HKEYS`
* **Purpose**: Get all field names in a hash.
* **Syntax**:
```bash
HKEYS key
```
* **Return**:
* Array of field names.
* **RESP Protocol**:
```
*2
$5
HKEYS
$3
key
```
---
## 10. `HVALS`
* **Purpose**: Get all values in a hash.
* **Syntax**:
```bash
HVALS key
```
* **Return**:
* Array of values.
* **RESP Protocol**:
```
*2
$5
HVALS
$3
key
```
---
## 11. `HLEN`
* **Purpose**: Get number of fields in a hash.
* **Syntax**:
```bash
HLEN key
```
* **Return**:
* Integer: number of fields.
* **RESP Protocol**:
```
*2
$4
HLEN
$3
key
```
## 12. `HSCAN`
* **Purpose**: Iterate fields/values of a hash (cursor-based scan).
* **Syntax**:
```bash
HSCAN key cursor [MATCH pattern] [COUNT count]
```
* **Return**:
* Array: `[new-cursor, [field1, value1, ...]]`
* **RESP Protocol**:
```
*3
$5
HSCAN
$3
key
$1
0
```

View File

@@ -0,0 +1,259 @@
# 1) Data model & basics
* A **queue** is a List at key `queue:<name>`.
* Common patterns:
* **Producer**: `LPUSH queue item` (or `RPUSH`)
* **Consumer (non-blocking)**: `RPOP queue` (or `LPOP`)
* **Consumer (blocking)**: `BRPOP queue timeout` (or `BLPOP`)
* If a key doesnt exist, its treated as an **empty list**; push **creates** the list; when the **last element is popped, the key is deleted**. ([Redis][1])
---
# 2) Commands to implement (queues via Lists)
## LPUSH / RPUSH
Prepend/append one or more elements. Create the list if it doesnt exist.
**Return**: Integer = new length of the list.
**Syntax**
```
LPUSH key element [element ...]
RPUSH key element [element ...]
```
**RESP (example)**
```
*3\r\n$5\r\nLPUSH\r\n$5\r\nqueue\r\n$5\r\njob-1\r\n
:1\r\n
```
Refs: semantics & multi-arg ordering. ([Redis][1])
### LPUSHX / RPUSHX (optional but useful)
Like LPUSH/RPUSH, **but only if the list exists**.
**Return**: Integer = new length (0 if key didnt exist).
```
LPUSHX key element [element ...]
RPUSHX key element [element ...]
```
Refs: command index. ([Redis][2])
---
## LPOP / RPOP
Remove & return one (default) or **up to COUNT** elements since Redis 6.2.
If the list is empty or missing, **Null** is returned (Null Bulk or Null Array if COUNT>1).
**Return**:
* No COUNT: Bulk String or Null Bulk.
* With COUNT: Array of Bulk Strings (possibly empty) or Null Array if key missing.
**Syntax**
```
LPOP key [count]
RPOP key [count]
```
**RESP (no COUNT)**
```
*2\r\n$4\r\nRPOP\r\n$5\r\nqueue\r\n
$5\r\njob-1\r\n # or $-1\r\n if empty
```
**RESP (COUNT=2)**
```
*3\r\n$4\r\nLPOP\r\n$5\r\nqueue\r\n$1\r\n2\r\n
*2\r\n$5\r\njob-2\r\n$5\r\njob-3\r\n # or *-1\r\n if key missing
```
Refs: LPOP w/ COUNT; general pop semantics. ([Redis][3])
---
## BLPOP / BRPOP (blocking consumers)
Block until an element is available in any of the given lists or until `timeout` (seconds, **double**, `0` = forever).
**Return** on success: **Array \[key, element]**.
**Return** on timeout: **Null Array**.
**Syntax**
```
BLPOP key [key ...] timeout
BRPOP key [key ...] timeout
```
**RESP**
```
*3\r\n$5\r\nBRPOP\r\n$5\r\nqueue\r\n$1\r\n0\r\n # block forever
# Success reply
*2\r\n$5\r\nqueue\r\n$5\r\njob-4\r\n
# Timeout reply
*-1\r\n
```
**Implementation notes**
* If any listed key is non-empty at call time, reply **immediately** from the first non-empty key **by the commands key order**.
* Otherwise, put the client into a **blocked state** (register per-key waiters). On any `LPUSH/RPUSH` to those keys, **wake the earliest waiter** and serve it atomically.
* If timeout expires, return **Null Array** and clear the blocked state.
Refs: timeout semantics and return shape. ([Redis][4])
---
## LMOVE / BLMOVE (atomic move; replaces RPOPLPUSH/BRPOPLPUSH)
Atomically **pop from one side** of `source` and **push to one side** of `destination`.
* Use for **reliable queues** (move to a *processing* list).
* `BLMOVE` blocks like `BLPOP` when `source` is empty.
**Syntax**
```
LMOVE source destination LEFT|RIGHT LEFT|RIGHT
BLMOVE source destination LEFT|RIGHT LEFT|RIGHT timeout
```
**Return**: Bulk String element moved, or Null if `source` empty (LMOVE); `BLMOVE` blocks/Null on timeout.
**RESP (LMOVE RIGHT->LEFT)**
```
*5\r\n$5\r\nLMOVE\r\n$6\r\nsource\r\n$3\r\ndst\r\n$5\r\nRIGHT\r\n$4\r\nLEFT\r\n
$5\r\njob-5\r\n
```
**Notes**
* Prefer `LMOVE/BLMOVE` over deprecated `RPOPLPUSH/BRPOPLPUSH`.
* Pattern: consumer `LMOVE queue processing RIGHT LEFT` → work → `LREM processing 1 <elem>` to ACK; a reaper can requeue stale items.
Refs: LMOVE/BLMOVE behavior and reliable-queue pattern; deprecation of RPOPLPUSH. ([Redis][5])
*(Compat: you can still implement `RPOPLPUSH source dest` and `BRPOPLPUSH source dest timeout`, but mark them deprecated and map to LMOVE/BLMOVE.)* ([Redis][6])
---
## LLEN (length)
Useful for metrics/backpressure.
```
LLEN key
```
**RESP**
```
*2\r\n$4\r\nLLEN\r\n$5\r\nqueue\r\n
:3\r\n
```
Refs: list overview mentioning LLEN. ([Redis][7])
---
## LREM (ack for “reliable” processing)
Remove occurrences of `element` from the list (head→tail scan).
Use `count=1` to ACK a single processed item from `processing`.
```
LREM key count element
```
**RESP**
```
*4\r\n$4\r\nLREM\r\n$9\r\nprocessing\r\n$1\r\n1\r\n$5\r\njob-5\r\n
:1\r\n
```
Refs: reliable pattern mentions LREM to ACK. ([Redis][5])
---
## LTRIM (bounded queues / retention)
Keep only `[start, stop]` range; everything else is dropped.
Use to cap queue length after pushes.
```
LTRIM key start stop
```
**RESP**
```
*4\r\n$5\r\nLTRIM\r\n$5\r\nqueue\r\n$2\r\n0\r\n$3\r\n999\r\n
+OK\r\n
```
Refs: list overview includes LTRIM for retention. ([Redis][7])
---
## LRANGE / LINDEX (debugging / peeking)
* `LRANGE key start stop` → Array of elements (non-destructive).
* `LINDEX key index` → one element or Null.
These arent required for queue semantics, but handy. ([Redis][7])
---
# 3) Errors & types
* Wrong type: `-WRONGTYPE Operation against a key holding the wrong kind of value\r\n`
* Non-existing key:
* Push: creates the list (returns new length).
* Pop (non-blocking): returns **Null**.
* Blocking pop: **Null Array** on timeout. ([Redis][1])
---
# 4) Blocking engine (implementation sketch)
1. **Call time**: scan keys in user order. If a non-empty list is found, pop & reply immediately.
2. **Otherwise**: register the client as **blocked** on those keys with `deadline = now + timeout` (or infinite).
3. **On push to any key**: if waiters exist, **wake one** (FIFO) and serve its pop **atomically** with the push result.
4. **On timer**: for each blocked client whose deadline passed, reply `Null Array` and clear state.
5. **Connection close**: remove from any wait queues.
Refs for timeout/block semantics. ([Redis][4])
---
# 5) Reliable queue pattern (recommended)
* **Consume**: `LMOVE queue processing RIGHT LEFT` (or `BLMOVE ... 0`).
* **Process** the job.
* **ACK**: `LREM processing 1 <job>` when done.
* **Reaper**: auxiliary task that detects stale jobs (e.g., track job IDs + timestamps in a ZSET) and requeues them. (Lists dont include timestamps; pairing with a ZSET is standard practice.)
Refs: LMOVE docs pattern. ([Redis][5])
---
# 6) Minimal test matrix
* Push/pop happy path (both ends), with/without COUNT.
* Blocking pop: immediate availability, block + timeout, wake on push, multiple keys order, FIFO across multiple waiters.
* LMOVE/BLMOVE: RIGHT→LEFT pipeline, block + wake, cross-list atomicity, ACK via LREM.
* Type errors and key deletion on last pop.

24
herodb/run_tests.sh Executable file
View File

@@ -0,0 +1,24 @@
#!/bin/bash
echo "🧪 Running HeroDB Redis Compatibility Tests"
echo "=========================================="
echo ""
echo "1⃣ Running Simple Redis Tests (4 tests)..."
echo "----------------------------------------------"
cargo test -p herodb --test simple_redis_test -- --nocapture
echo ""
echo "2⃣ Running Comprehensive Redis Integration Tests (13 tests)..."
echo "----------------------------------------------------------------"
cargo test -p herodb --test redis_integration_tests -- --nocapture
cargo test -p herodb --test debug_hset -- --nocapture
cargo test -p herodb --test debug_hset_simple -- --nocapture
echo ""
echo "3⃣ Running All Workspace Tests..."
echo "--------------------------------"
cargo test --workspace -- --nocapture
echo ""
echo "✅ Test execution completed!"

308
herodb/src/age.rs Normal file
View File

@@ -0,0 +1,308 @@
//! age.rs — AGE (rage) helpers + persistent key management for your mini-Redis.
//
// Features:
// - X25519 encryption/decryption (age style)
// - Ed25519 detached signatures + verification
// - Persistent named keys in DB (strings):
// age:key:{name} -> X25519 recipient (public encryption key, "age1...")
// age:privkey:{name} -> X25519 identity (secret encryption key, "AGE-SECRET-KEY-1...")
// age:signpub:{name} -> Ed25519 verify pubkey (public, used to verify signatures)
// age:signpriv:{name} -> Ed25519 signing secret key (private, used to sign)
// - Base64 wrapping for ciphertext/signature binary blobs.
use std::str::FromStr;
use secrecy::ExposeSecret;
use age::{Decryptor, Encryptor};
use age::x25519;
use ed25519_dalek::{Signature, Signer, Verifier, SigningKey, VerifyingKey};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use crate::protocol::Protocol;
use crate::server::Server;
use crate::error::DBError;
// ---------- Internal helpers ----------
#[derive(Debug)]
pub enum AgeWireError {
ParseKey,
Crypto(String),
Utf8,
SignatureLen,
NotFound(&'static str), // which kind of key was missing
Storage(String),
}
impl AgeWireError {
fn to_protocol(self) -> Protocol {
match self {
AgeWireError::ParseKey => Protocol::err("ERR age: invalid key"),
AgeWireError::Crypto(e) => Protocol::err(&format!("ERR age: {e}")),
AgeWireError::Utf8 => Protocol::err("ERR age: invalid UTF-8 plaintext"),
AgeWireError::SignatureLen => Protocol::err("ERR age: bad signature length"),
AgeWireError::NotFound(w) => Protocol::err(&format!("ERR age: missing {w}")),
AgeWireError::Storage(e) => Protocol::err(&format!("ERR storage: {e}")),
}
}
}
fn parse_recipient(s: &str) -> Result<x25519::Recipient, AgeWireError> {
x25519::Recipient::from_str(s).map_err(|_| AgeWireError::ParseKey)
}
fn parse_identity(s: &str) -> Result<x25519::Identity, AgeWireError> {
x25519::Identity::from_str(s).map_err(|_| AgeWireError::ParseKey)
}
fn parse_ed25519_signing_key(s: &str) -> Result<SigningKey, AgeWireError> {
// Parse base64-encoded signing key
let bytes = B64.decode(s).map_err(|_| AgeWireError::ParseKey)?;
if bytes.len() != 32 {
return Err(AgeWireError::ParseKey);
}
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AgeWireError::ParseKey)?;
Ok(SigningKey::from_bytes(&key_bytes))
}
fn parse_ed25519_verifying_key(s: &str) -> Result<VerifyingKey, AgeWireError> {
// Parse base64-encoded verifying key
let bytes = B64.decode(s).map_err(|_| AgeWireError::ParseKey)?;
if bytes.len() != 32 {
return Err(AgeWireError::ParseKey);
}
let key_bytes: [u8; 32] = bytes.try_into().map_err(|_| AgeWireError::ParseKey)?;
VerifyingKey::from_bytes(&key_bytes).map_err(|_| AgeWireError::ParseKey)
}
// ---------- Stateless crypto helpers (string in/out) ----------
pub fn gen_enc_keypair() -> (String, String) {
let id = x25519::Identity::generate();
let pk = id.to_public();
(pk.to_string(), id.to_string().expose_secret().to_string()) // (recipient, identity)
}
pub fn gen_sign_keypair() -> (String, String) {
use rand::RngCore;
use rand::rngs::OsRng;
// Generate random 32 bytes for the signing key
let mut secret_bytes = [0u8; 32];
OsRng.fill_bytes(&mut secret_bytes);
let signing_key = SigningKey::from_bytes(&secret_bytes);
let verifying_key = signing_key.verifying_key();
// Encode as base64 for storage
let signing_key_b64 = B64.encode(signing_key.to_bytes());
let verifying_key_b64 = B64.encode(verifying_key.to_bytes());
(verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret)
}
/// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext).
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> {
let recipient = parse_recipient(recipient_str)?;
let enc = Encryptor::with_recipients(vec![Box::new(recipient)])
.expect("failed to create encryptor"); // Handle Option<Encryptor>
let mut out = Vec::new();
{
use std::io::Write;
let mut w = enc.wrap_output(&mut out).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.write_all(msg.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.finish().map_err(|e| AgeWireError::Crypto(e.to_string()))?;
}
Ok(B64.encode(out))
}
/// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String.
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> {
let id = parse_identity(identity_str)?;
let ct = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
// The decrypt method returns a Result<StreamReader, DecryptError>
let mut r = match dec {
Decryptor::Recipients(d) => d.decrypt(std::iter::once(&id as &dyn age::Identity))
.map_err(|e| AgeWireError::Crypto(e.to_string()))?,
Decryptor::Passphrase(_) => return Err(AgeWireError::Crypto("Expected recipients, got passphrase".to_string())),
};
let mut pt = Vec::new();
use std::io::Read;
r.read_to_end(&mut pt).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
String::from_utf8(pt).map_err(|_| AgeWireError::Utf8)
}
/// Sign bytes of `msg` (detached). Returns base64(signature bytes, 64 bytes).
pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AgeWireError> {
let signing_key = parse_ed25519_signing_key(signing_secret_str)?;
let sig = signing_key.sign(msg.as_bytes());
Ok(B64.encode(sig.to_bytes()))
}
/// Verify detached signature (base64) for `msg` with pubkey.
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> {
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
let sig_bytes = B64.decode(sig_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
if sig_bytes.len() != 64 {
return Err(AgeWireError::SignatureLen);
}
let sig = Signature::from_bytes(sig_bytes[..].try_into().unwrap());
Ok(verifying_key.verify(msg.as_bytes(), &sig).is_ok())
}
// ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.get(key).map_err(|e| AgeWireError::Storage(e.0))
}
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0))
}
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
// ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = gen_enc_keypair();
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = gen_sign_keypair();
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
match encrypt_b64(recipient, message) {
Ok(b64) => Protocol::BulkString(b64),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_decrypt(identity: &str, ct_b64: &str) -> Protocol {
match decrypt_b64(identity, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_sign(secret: &str, message: &str) -> Protocol {
match sign_b64(secret, message) {
Ok(b64sig) => Protocol::BulkString(b64sig),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> Protocol {
match verify_b64(verify_pub, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => e.to_protocol(),
}
}
// ---------- NEW: Persistent, named-key commands ----------
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); }
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); }
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
let recip = match sget(server, &enc_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("recipient (age:key:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match encrypt_b64(&recip, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> Protocol {
let ident = match sget(server, &enc_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("identity (age:privkey:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match decrypt_b64(&ident, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match sign_b64(&sec, message) {
Ok(sig) => Protocol::BulkString(sig),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match verify_b64(&pubk, message, sig_b64) {
Ok(true) => Protocol::SimpleString("1".to_string()),
Ok(false) => Protocol::SimpleString("0".to_string()),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?;
let mut names: Vec<String> = keys.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect();
names.sort();
Ok(names)
};
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect()));
Protocol::Array(out)
};
Protocol::Array(vec![
to_arr("encpub", encpub),
to_arr("encpriv", encpriv),
to_arr("signpub", signpub),
to_arr("signpriv", signpriv),
])
}

1519
herodb/src/cmd.rs Normal file

File diff suppressed because it is too large Load Diff

73
herodb/src/crypto.rs Normal file
View File

@@ -0,0 +1,73 @@
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
XChaCha20Poly1305, XNonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
const VERSION: u8 = 1;
const NONCE_LEN: usize = 24;
const TAG_LEN: usize = 16;
#[derive(Debug)]
pub enum CryptoError {
Format, // wrong length / header
Version(u8), // unknown version
Decrypt, // wrong key or corrupted data
}
impl From<CryptoError> for crate::error::DBError {
fn from(e: CryptoError) -> Self {
crate::error::DBError(format!("Crypto error: {:?}", e))
}
}
/// Super-simple factory: new(secret) + encrypt(bytes) + decrypt(bytes)
pub struct CryptoFactory {
key: chacha20poly1305::Key,
}
impl CryptoFactory {
/// Accepts any secret bytes; turns them into a 32-byte key (SHA-256).
pub fn new<S: AsRef<[u8]>>(secret: S) -> Self {
let mut h = Sha256::new();
h.update(b"xchacha20poly1305-factory:v1"); // domain separation
h.update(secret.as_ref());
let digest = h.finalize(); // 32 bytes
let key = chacha20poly1305::Key::from_slice(&digest).to_owned();
Self { key }
}
/// Output layout: [version:1][nonce:24][ciphertext||tag]
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
let cipher = XChaCha20Poly1305::new(&self.key);
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
out.push(VERSION);
out.extend_from_slice(&nonce_bytes);
let ct = cipher.encrypt(nonce, plaintext).expect("encrypt");
out.extend_from_slice(&ct);
out
}
pub fn decrypt(&self, blob: &[u8]) -> Result<Vec<u8>, CryptoError> {
if blob.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(CryptoError::Format);
}
let ver = blob[0];
if ver != VERSION {
return Err(CryptoError::Version(ver));
}
let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
let ct = &blob[1 + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
}
}

94
herodb/src/error.rs Normal file
View File

@@ -0,0 +1,94 @@
use std::num::ParseIntError;
use tokio::sync::mpsc;
use redb;
use bincode;
// todo: more error types
#[derive(Debug)]
pub struct DBError(pub String);
impl From<std::io::Error> for DBError {
fn from(item: std::io::Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<ParseIntError> for DBError {
fn from(item: ParseIntError) -> Self {
DBError(item.to_string().clone())
}
}
impl From<std::str::Utf8Error> for DBError {
fn from(item: std::str::Utf8Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<std::string::FromUtf8Error> for DBError {
fn from(item: std::string::FromUtf8Error) -> Self {
DBError(item.to_string().clone())
}
}
impl From<redb::Error> for DBError {
fn from(item: redb::Error) -> Self {
DBError(item.to_string())
}
}
impl From<redb::DatabaseError> for DBError {
fn from(item: redb::DatabaseError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::TransactionError> for DBError {
fn from(item: redb::TransactionError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::TableError> for DBError {
fn from(item: redb::TableError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::StorageError> for DBError {
fn from(item: redb::StorageError) -> Self {
DBError(item.to_string())
}
}
impl From<redb::CommitError> for DBError {
fn from(item: redb::CommitError) -> Self {
DBError(item.to_string())
}
}
impl From<Box<bincode::ErrorKind>> for DBError {
fn from(item: Box<bincode::ErrorKind>) -> Self {
DBError(item.to_string())
}
}
impl From<tokio::sync::mpsc::error::SendError<()>> for DBError {
fn from(item: mpsc::error::SendError<()>) -> Self {
DBError(item.to_string().clone())
}
}
impl From<serde_json::Error> for DBError {
fn from(item: serde_json::Error) -> Self {
DBError(item.to_string())
}
}
impl From<chacha20poly1305::Error> for DBError {
fn from(item: chacha20poly1305::Error) -> Self {
DBError(item.to_string())
}
}

8
herodb/src/lib.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod age; // NEW
pub mod cmd;
pub mod crypto;
pub mod error;
pub mod options;
pub mod protocol;
pub mod server;
pub mod storage;

81
herodb/src/main.rs Normal file
View File

@@ -0,0 +1,81 @@
// #![allow(unused_imports)]
use tokio::net::TcpListener;
use herodb::server;
use clap::Parser;
/// Simple program to greet a person
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// The directory of Redis DB file
#[arg(long)]
dir: String,
/// The port of the Redis server, default is 6379 if not specified
#[arg(long)]
port: Option<u16>,
/// Enable debug mode
#[arg(long)]
debug: bool,
/// Master encryption key for encrypted databases
#[arg(long)]
encryption_key: Option<String>,
/// Encrypt the database
#[arg(long)]
encrypt: bool,
}
#[tokio::main]
async fn main() {
// parse args
let args = Args::parse();
// bind port
let port = args.port.unwrap_or(6379);
println!("will listen on port: {}", port);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// new DB option
let option = herodb::options::DBOption {
dir: args.dir,
port,
debug: args.debug,
encryption_key: args.encryption_key,
encrypt: args.encrypt,
};
// new server
let server = server::Server::new(option).await;
// Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// accept new connections
loop {
let stream = listener.accept().await;
match stream {
Ok((stream, _)) => {
println!("accepted new connection");
let mut sc = server.clone();
tokio::spawn(async move {
if let Err(e) = sc.handle(stream).await {
println!("error: {:?}, will close the connection. Bye", e);
}
});
}
Err(e) => {
println!("error: {}", e);
}
}
}
}

8
herodb/src/options.rs Normal file
View File

@@ -0,0 +1,8 @@
#[derive(Clone)]
pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
pub encrypt: bool,
pub encryption_key: Option<String>, // Master encryption key
}

171
herodb/src/protocol.rs Normal file
View File

@@ -0,0 +1,171 @@
use core::fmt;
use crate::error::DBError;
#[derive(Debug, Clone)]
pub enum Protocol {
SimpleString(String),
BulkString(String),
Null,
Array(Vec<Protocol>),
Error(String), // NEW
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.decode().as_str())
}
}
impl Protocol {
pub fn from(protocol: &str) -> Result<(Self, &str), DBError> {
if protocol.is_empty() {
// Incomplete frame; caller should read more bytes
return Err(DBError("[incomplete] empty".to_string()));
}
let ret = match protocol.chars().nth(0) {
Some('+') => Self::parse_simple_string_sfx(&protocol[1..]),
Some('$') => Self::parse_bulk_string_sfx(&protocol[1..]),
Some('*') => Self::parse_array_sfx(&protocol[1..]),
_ => Err(DBError(format!(
"[from] unsupported protocol: {:?}",
protocol
))),
};
ret
}
pub fn from_vec(array: Vec<&str>) -> Self {
let array = array
.into_iter()
.map(|x| Protocol::BulkString(x.to_string()))
.collect();
Protocol::Array(array)
}
#[inline]
pub fn ok() -> Self {
Protocol::SimpleString("ok".to_string())
}
#[inline]
pub fn err(msg: &str) -> Self {
Protocol::Error(msg.to_string())
}
#[inline]
pub fn write_on_slave_err() -> Self {
Self::err("DISALLOW WRITE ON SLAVE")
}
#[inline]
pub fn psync_on_slave_err() -> Self {
Self::err("PSYNC ON SLAVE IS NOT ALLOWED")
}
#[inline]
pub fn none() -> Self {
Self::SimpleString("none".to_string())
}
pub fn decode(&self) -> String {
match self {
Protocol::SimpleString(s) => s.to_string(),
Protocol::BulkString(s) => s.to_string(),
Protocol::Null => "".to_string(),
Protocol::Array(s) => s.iter().map(|x| x.decode()).collect::<Vec<_>>().join(" "),
Protocol::Error(s) => s.to_string(),
}
}
pub fn encode(&self) -> String {
match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
}
Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
}
}
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") {
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])),
_ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}",
protocol
))),
}
}
fn parse_bulk_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
if let Some(len_end) = protocol.find("\r\n") {
let size = Self::parse_usize(&protocol[..len_end])?;
let data_start = len_end + 2;
let data_end = data_start + size;
// If we don't yet have the full bulk payload + trailing CRLF, signal INCOMPLETE
if protocol.len() < data_end + 2 {
return Err(DBError("[incomplete] bulk body".to_string()));
}
if &protocol[data_end..data_end + 2] != "\r\n" {
return Err(DBError("[incomplete] bulk terminator".to_string()));
}
let s = Self::parse_string(&protocol[data_start..data_end])?;
Ok((Protocol::BulkString(s), &protocol[data_end + 2..]))
} else {
// No CRLF after bulk length header yet
Err(DBError("[incomplete] bulk header".to_string()))
}
}
fn parse_array_sfx(s: &str) -> Result<(Self, &str), DBError> {
if let Some(len_end) = s.find("\r\n") {
let array_len = s[..len_end].parse::<usize>()?;
let mut remaining = &s[len_end + 2..];
let mut vec = vec![];
for _ in 0..array_len {
match Protocol::from(remaining) {
Ok((p, rem)) => {
vec.push(p);
remaining = rem;
}
Err(e) => {
// Propagate incomplete so caller can read more bytes
if e.0.starts_with("[incomplete]") {
return Err(e);
} else {
return Err(e);
}
}
}
}
Ok((Protocol::Array(vec), remaining))
} else {
// No CRLF after array header yet
Err(DBError("[incomplete] array header".to_string()))
}
}
fn parse_usize(protocol: &str) -> Result<usize, DBError> {
if protocol.is_empty() {
Err(DBError("Cannot parse usize from empty string".to_string()))
} else {
protocol
.parse::<usize>()
.map_err(|_| DBError(format!("Failed to parse usize from: {}", protocol)))
}
}
fn parse_string(protocol: &str) -> Result<String, DBError> {
if protocol.is_empty() {
// Allow empty strings, but handle appropriately
Ok("".to_string())
} else {
Ok(protocol.to_string())
}
}
}

250
herodb/src/server.rs Normal file
View File

@@ -0,0 +1,250 @@
use core::str;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::sync::{Mutex, oneshot};
use std::sync::atomic::{AtomicU64, Ordering};
use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::storage::Storage;
#[derive(Clone)]
pub struct Server {
pub db_cache: std::sync::Arc<std::sync::RwLock<HashMap<u64, Arc<Storage>>>>,
pub option: options::DBOption,
pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
pub waiter_seq: Arc<AtomicU64>,
}
pub struct Waiter {
pub id: u64,
pub side: PopSide,
pub tx: oneshot::Sender<(String, String)>, // (key, element)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PopSide {
Left,
Right,
}
impl Server {
pub async fn new(option: options::DBOption) -> Self {
Server {
db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())),
option,
client_name: None,
selected_db: 0,
queued_cmd: None,
list_waiters: Arc::new(Mutex::new(HashMap::new())),
waiter_seq: Arc::new(AtomicU64::new(1)),
}
}
pub fn current_storage(&self) -> Result<Arc<Storage>, DBError> {
let mut cache = self.db_cache.write().unwrap();
if let Some(storage) = cache.get(&self.selected_db) {
return Ok(storage.clone());
}
// Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
println!("Creating new db file: {}", db_file_path.display());
let storage = Arc::new(Storage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?);
cache.insert(self.selected_db, storage.clone());
Ok(storage)
}
fn should_encrypt_db(&self, db_index: u64) -> bool {
// DB 0-9 are non-encrypted, DB 10+ are encrypted
self.option.encrypt && db_index >= 10
}
// ----- BLPOP waiter helpers -----
pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) {
let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel::<(String, String)>();
let mut guard = self.list_waiters.lock().await;
let per_db = guard.entry(db_index).or_insert_with(HashMap::new);
let q = per_db.entry(key.to_string()).or_insert_with(Vec::new);
q.push(Waiter { id, side, tx });
(id, rx)
}
pub async fn unregister_waiter(&self, db_index: u64, key: &str, id: u64) {
let mut guard = self.list_waiters.lock().await;
if let Some(per_db) = guard.get_mut(&db_index) {
if let Some(q) = per_db.get_mut(key) {
q.retain(|w| w.id != id);
if q.is_empty() {
per_db.remove(key);
}
}
if per_db.is_empty() {
guard.remove(&db_index);
}
}
}
// Called after LPUSH/RPUSH to deliver to blocked BLPOP waiters.
pub async fn drain_waiters_after_push(&self, key: &str) -> Result<(), DBError> {
let db_index = self.selected_db;
loop {
// Check if any waiter exists
let maybe_waiter = {
let mut guard = self.list_waiters.lock().await;
if let Some(per_db) = guard.get_mut(&db_index) {
if let Some(q) = per_db.get_mut(key) {
if !q.is_empty() {
// Pop FIFO
Some(q.remove(0))
} else {
None
}
} else {
None
}
} else {
None
}
};
let waiter = if let Some(w) = maybe_waiter { w } else { break };
// Pop one element depending on waiter side
let elems = match waiter.side {
PopSide::Left => self.current_storage()?.lpop(key, 1)?,
PopSide::Right => self.current_storage()?.rpop(key, 1)?,
};
if elems.is_empty() {
// Nothing to deliver; re-register waiter at the front to preserve order
let mut guard = self.list_waiters.lock().await;
let per_db = guard.entry(db_index).or_insert_with(HashMap::new);
let q = per_db.entry(key.to_string()).or_insert_with(Vec::new);
q.insert(0, waiter);
break;
} else {
let elem = elems[0].clone();
// Send to waiter; if receiver dropped, just continue
let _ = waiter.tx.send((key.to_string(), elem));
// Loop to try to satisfy more waiters if more elements remain
continue;
}
}
Ok(())
}
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
// Accumulate incoming bytes to handle partial RESP frames
let mut acc = String::new();
let mut buf = vec![0u8; 8192];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) => {
println!("[handle] connection closed");
return Ok(());
}
Ok(n) => n,
Err(e) => {
println!("[handle] read error: {:?}", e);
return Err(e.into());
}
};
// Append to accumulator. RESP for our usage is ASCII-safe.
acc.push_str(str::from_utf8(&buf[..n])?);
// Try to parse as many complete commands as are available in 'acc'.
loop {
let parsed = Cmd::from(&acc);
let (cmd, protocol, remaining) = match parsed {
Ok((cmd, protocol, remaining)) => (cmd, protocol, remaining),
Err(_e) => {
// Incomplete or invalid frame; assume incomplete and wait for more data.
// This avoids emitting spurious protocol_error for split frames.
break;
}
};
// Advance the accumulator to the unparsed remainder
acc = remaining.to_string();
if self.option.debug {
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
} else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
}
// Check if this is a QUIT command before processing
let is_quit = matches!(cmd, Cmd::Quit);
let res = match cmd.run(self).await {
Ok(p) => p,
Err(e) => {
if self.option.debug {
eprintln!("[run error] {:?}", e);
}
Protocol::err(&format!("ERR {}", e.0))
}
};
if self.option.debug {
println!("\x1b[34;1mqueued cmd {:?}\x1b[0m", self.queued_cmd);
println!("\x1b[32;1mgoing to send response {}\x1b[0m", res.encode());
} else {
print!("queued cmd {:?}", self.queued_cmd);
println!("going to send response {}", res.encode());
}
_ = stream.write(res.encode().as_bytes()).await?;
// If this was a QUIT command, close the connection
if is_quit {
println!("[handle] QUIT command received, closing connection");
return Ok(());
}
// Continue parsing any further complete commands already in 'acc'
if acc.is_empty() {
break;
}
}
}
}
}

126
herodb/src/storage/mod.rs Normal file
View File

@@ -0,0 +1,126 @@
use std::{
path::Path,
time::{SystemTime, UNIX_EPOCH},
};
use redb::{Database, TableDefinition};
use serde::{Deserialize, Serialize};
use crate::crypto::CryptoFactory;
use crate::error::DBError;
// Re-export modules
mod storage_basic;
mod storage_hset;
mod storage_lists;
mod storage_extra;
// Re-export implementations
// Note: These imports are used by the impl blocks in the submodules
// The compiler shows them as unused because they're not directly used in this file
// but they're needed for the Storage struct methods to be available
pub use storage_extra::*;
// Table definitions for different Redis data types
const TYPES_TABLE: TableDefinition<&str, &str> = TableDefinition::new("types");
const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("strings");
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamEntry {
pub fields: Vec<(String, String)>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ListValue {
pub elements: Vec<String>,
}
#[inline]
pub fn now_in_millis() -> u128 {
let start = SystemTime::now();
let duration_since_epoch = start.duration_since(UNIX_EPOCH).unwrap();
duration_since_epoch.as_millis()
}
pub struct Storage {
db: Database,
crypto: Option<CryptoFactory>,
}
impl Storage {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
let db = Database::create(path)?;
// Create tables if they don't exist
let write_txn = db.begin_write()?;
{
let _ = write_txn.open_table(TYPES_TABLE)?;
let _ = write_txn.open_table(STRINGS_TABLE)?;
let _ = write_txn.open_table(HASHES_TABLE)?;
let _ = write_txn.open_table(LISTS_TABLE)?;
let _ = write_txn.open_table(STREAMS_META_TABLE)?;
let _ = write_txn.open_table(STREAMS_DATA_TABLE)?;
let _ = write_txn.open_table(ENCRYPTED_TABLE)?;
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
}
write_txn.commit()?;
// Check if database was previously encrypted
let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
drop(read_txn);
let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes()))
} else {
return Err(DBError("Encryption requested but no master key provided".to_string()));
}
} else {
None
};
// If we're enabling encryption for the first time, mark it
if should_encrypt && !was_encrypted {
let write_txn = db.begin_write()?;
{
let mut encrypted_table = write_txn.open_table(ENCRYPTED_TABLE)?;
encrypted_table.insert("encrypted", &1u8)?;
}
write_txn.commit()?;
}
Ok(Storage {
db,
crypto,
})
}
pub fn is_encrypted(&self) -> bool {
self.crypto.is_some()
}
// Helper methods for encryption
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.encrypt(data))
} else {
Ok(data.to_vec())
}
}
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?)
} else {
Ok(data.to_vec())
}
}
}

View File

@@ -0,0 +1,245 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*;
impl Storage {
pub fn flushdb(&self) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
let mut streams_meta_table = write_txn.open_table(STREAMS_META_TABLE)?;
let mut streams_data_table = write_txn.open_table(STREAMS_DATA_TABLE)?;
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
types_table.remove(key.as_str())?;
}
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
strings_table.remove(key.as_str())?;
}
let keys: Vec<(String, String)> = hashes_table
.iter()?
.map(|item| {
let binding = item.unwrap();
let (k, f) = binding.0.value();
(k.to_string(), f.to_string())
})
.collect();
for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
lists_table.remove(key.as_str())?;
}
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
streams_meta_table.remove(key.as_str())?;
}
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
let binding = item.unwrap();
let (key, field) = binding.0.value();
(key.to_string(), field.to_string())
}).collect();
for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
expiration_table.remove(key.as_str())?;
}
}
write_txn.commit()?;
Ok(())
}
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
// Before returning type, check for expiration
if let Some(type_val) = table.get(key)? {
if type_val.value() == "string" {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
// The key is expired, so it effectively has no type
return Ok(None);
}
}
}
Ok(Some(type_val.value().to_string()))
} else {
Ok(None)
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check expiration first (unencrypted)
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
drop(read_txn);
self.del(key.to_string())?;
return Ok(None);
}
}
// Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
}
write_txn.commit()?;
Ok(())
}
pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table
types_table.remove(key.as_str())?;
// Remove from strings table
strings_table.remove(key.as_str())?;
// Remove all hash fields for this key
let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key.as_str() {
to_remove.push((hash_key.to_string(), field.to_string()));
}
}
drop(iter);
for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?;
}
// Remove from lists table
lists_table.remove(key.as_str())?;
// Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new();
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
if pattern == "*" || super::storage_extra::glob_match(pattern, &key) {
keys.push(key);
}
}
Ok(keys)
}
}
impl Storage {
pub fn dbsize(&self) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
let mut count: i64 = 0;
let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let key = entry.0.value();
let ty = entry.1.value();
if ty == "string" {
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
// Skip logically expired string keys
continue;
}
}
}
count += 1;
}
Ok(count)
}
}

View File

@@ -0,0 +1,278 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*;
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let key = entry.0.value().to_string();
let key_type = entry.1.value().to_string();
if current_cursor >= cursor {
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
glob_match(pat, &key)
} else {
true
};
if matches {
// For scan, we return key-value pairs for string types
if key_type == "string" {
if let Some(data) = strings_table.get(key.as_str())? {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
result.push((key, value));
} else {
result.push((key, String::new()));
}
} else {
// For non-string types, just return the key with type as value
result.push((key, key_type));
}
if result.len() >= limit {
break;
}
}
}
current_cursor += 1;
}
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
match expiration_table.get(key)? {
Some(expires_at) => {
let now = now_in_millis();
let expires_at_ms = expires_at.value() as u128;
if now >= expires_at_ms {
Ok(-2) // Key has expired
} else {
Ok(((expires_at_ms - now) / 1000) as i64) // TTL in seconds
}
}
None => Ok(-1), // Key exists but has no expiration
}
}
Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types)
None => Ok(-2), // Key does not exist
}
}
pub fn exists(&self, key: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check if string key has expired
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
if let Some(expires_at) = expiration_table.get(key)? {
if now_in_millis() > expires_at.value() as u128 {
return Ok(false); // Key has expired
}
}
Ok(true)
}
Some(_) => Ok(true), // Key exists and is not a string
None => Ok(false), // Key does not exist
}
}
// -------- Expiration helpers (string keys only, consistent with TTL/EXISTS) --------
// Set expiry in seconds; returns true if applied (key exists and is string), false otherwise
pub fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
// Determine eligibility first to avoid holding borrows across commit
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = now_in_millis() + (secs as u128) * 1000;
expiration_table.insert(key, &(expires_at as u64))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
// Set expiry in milliseconds; returns true if applied (key exists and is string), false otherwise
pub fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = now_in_millis() + ms;
expiration_table.insert(key, &(expires_at as u64))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
// Remove expiry if present; returns true if removed, false otherwise
pub fn persist(&self, key: &str) -> Result<bool, DBError> {
let mut removed = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
if expiration_table.remove(key)?.is_some() {
removed = true;
}
}
}
write_txn.commit()?;
Ok(removed)
}
// Absolute EXPIREAT in seconds since epoch
// Returns true if applied (key exists and is string), false otherwise
pub fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
expiration_table.insert(key, &((expires_at_ms as u64)))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
// Absolute PEXPIREAT in milliseconds since epoch
// Returns true if applied (key exists and is string), false otherwise
pub fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
let mut applied = false;
let write_txn = self.db.begin_write()?;
{
let types_table = write_txn.open_table(TYPES_TABLE)?;
let is_string = types_table
.get(key)?
.map(|v| v.value() == "string")
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
expiration_table.insert(key, &((expires_at_ms as u64)))?;
applied = true;
}
}
write_txn.commit()?;
Ok(applied)
}
}
// Utility function for glob pattern matching
pub fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" {
return true;
}
// Simple glob matching - supports * and ? wildcards
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() {
return ti >= text.len();
}
if ti >= text.len() {
// Check if remaining pattern is all '*'
return pattern[pi..].iter().all(|&c| c == '*');
}
match pattern[pi] {
'*' => {
// Try matching zero or more characters
for i in ti..=text.len() {
if match_recursive(pattern, text, pi + 1, i) {
return true;
}
}
false
}
'?' => {
// Match exactly one character
match_recursive(pattern, text, pi + 1, ti + 1)
}
c => {
// Match exact character
if text[ti] == c {
match_recursive(pattern, text, pi + 1, ti + 1)
} else {
false
}
}
}
}
match_recursive(&pattern_chars, &text_chars, 0, 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glob_match() {
assert!(glob_match("*", "anything"));
assert!(glob_match("hello", "hello"));
assert!(!glob_match("hello", "world"));
assert!(glob_match("h*o", "hello"));
assert!(glob_match("h*o", "ho"));
assert!(!glob_match("h*o", "hi"));
assert!(glob_match("h?llo", "hello"));
assert!(!glob_match("h?llo", "hllo"));
assert!(glob_match("*test*", "this_is_a_test_string"));
assert!(!glob_match("*test*", "this_is_a_string"));
}
}

View File

@@ -0,0 +1,377 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*;
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage
pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_fields = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") | None => { // Proceed if hash or new key
// Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?;
for (field, value) in pairs {
// Check if field already exists
let exists = hashes_table.get((key, field.as_str()))?.is_some();
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !exists {
new_fields += 1;
}
}
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
}
}
write_txn.commit()?;
Ok(new_fields)
}
// ✅ ENCRYPTION APPLIED: Value is decrypted after retrieval
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = types_table.get(key)?.map(|v| v.value().to_string());
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
match hashes_table.get((key, field))? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
Ok(Some(value))
}
None => Ok(None),
}
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push((field.to_string(), value));
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut deleted = 0i64;
// First check if key exists and is a hash
let key_type = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1;
}
}
// Check if hash is now empty and remove type if so
let mut has_fields = false;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
has_fields = true;
break;
}
}
drop(iter);
if !has_fields {
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
}
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => {} // Key does not exist, nothing to delete, return 0 deleted
}
write_txn.commit()?;
Ok(deleted)
}
pub fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some())
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(false),
}
}
pub fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
result.push(field.to_string());
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
// ✅ ENCRYPTION APPLIED: All values are decrypted after retrieval
pub fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push(value);
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
pub fn hlen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0i64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, _) = entry.0.value();
if hash_key == key {
count += 1;
}
}
Ok(count)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(0),
}
}
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let value = String::from_utf8(decrypted)?;
result.push(Some(value));
}
None => result.push(None),
}
}
Ok(result)
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(fields.into_iter().map(|_| None).collect()),
}
}
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = false;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") | None => { // Proceed if hash or new key
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
// Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?;
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted.as_slice())?;
result = true;
}
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
}
}
write_txn.commit()?;
Ok(result)
}
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
if current_cursor >= cursor {
let field_str = field.to_string();
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
super::storage_extra::glob_match(pat, &field_str)
} else {
true
};
if matches {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push((field_str, value));
if result.len() >= limit {
break;
}
}
}
current_cursor += 1;
}
}
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok((0, Vec::new())),
}
}
}

View File

@@ -0,0 +1,403 @@
use redb::{ReadableTable};
use crate::error::DBError;
use super::*;
impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
serde_json::from_slice(&decrypted)?
}
None => Vec::new(),
};
// Add elements to the front (left)
for element in elements.into_iter() {
list.insert(0, element);
}
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
serde_json::from_slice(&decrypted)?
}
None => Vec::new(),
};
// Add elements to the end (right)
list.extend(elements);
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
if !list.is_empty() {
result.push(list.remove(0));
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
if !list.is_empty() {
result.push(list.pop().unwrap());
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
pub fn llen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Ok(list.len() as i64)
}
None => Ok(0),
}
}
_ => Ok(0),
}
}
// ✅ ENCRYPTION APPLIED: Element is decrypted after retrieval
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
let actual_index = if index < 0 {
list.len() as i64 + index
} else {
index
};
if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone()))
} else {
Ok(None)
}
}
None => Ok(None),
}
}
_ => Ok(None),
}
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
match lists_table.get(key)? {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
if list.is_empty() {
return Ok(Vec::new());
}
let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new());
}
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
}
None => Ok(Vec::new()),
}
}
_ => Ok(Vec::new()),
}
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(list) = list_data {
if list.is_empty() {
write_txn.commit()?;
return Ok(());
}
let len = list.len() as i64;
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if start_idx > stop_idx || start_idx >= len {
// Remove the entire list
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if trimmed.is_empty() {
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the trimmed list
let serialized = serde_json::to_vec(&trimmed)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
}
write_txn.commit()?;
Ok(())
}
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut removed = 0i64;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
Some(list)
} else {
None
}
}
_ => None,
};
result
};
if let Some(mut list) = list_data {
if count == 0 {
// Remove all occurrences
let original_len = list.len();
list.retain(|x| x != element);
removed = (original_len - list.len()) as i64;
} else if count > 0 {
// Remove first count occurrences
let mut to_remove = count as usize;
list.retain(|x| {
if x == element && to_remove > 0 {
to_remove -= 1;
removed += 1;
false
} else {
true
}
});
} else {
// Remove last |count| occurrences
let mut to_remove = (-count) as usize;
for i in (0..list.len()).rev() {
if list[i] == element && to_remove > 0 {
list.remove(i);
to_remove -= 1;
removed += 1;
}
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
} else {
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(removed)
}
}

355
herodb/test_herodb.sh Executable file
View File

@@ -0,0 +1,355 @@
#!/bin/bash
# Test script for HeroDB - Redis-compatible database with redb backend
# This script starts the server and runs comprehensive tests
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Configuration
DB_DIR="./test_db"
PORT=6381
SERVER_PID=""
# Function to print colored output
print_status() {
echo -e "${BLUE}[INFO]${NC} $1"
}
print_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
print_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
# Function to cleanup on exit
cleanup() {
if [ ! -z "$SERVER_PID" ]; then
print_status "Stopping HeroDB server (PID: $SERVER_PID)..."
kill $SERVER_PID 2>/dev/null || true
wait $SERVER_PID 2>/dev/null || true
fi
# Clean up test database
if [ -d "$DB_DIR" ]; then
print_status "Cleaning up test database directory..."
rm -rf "$DB_DIR"
fi
}
# Set trap to cleanup on script exit
trap cleanup EXIT
# Function to wait for server to start
wait_for_server() {
local max_attempts=30
local attempt=1
print_status "Waiting for server to start on port $PORT..."
while [ $attempt -le $max_attempts ]; do
if nc -z localhost $PORT 2>/dev/null; then
print_success "Server is ready!"
return 0
fi
echo -n "."
sleep 1
attempt=$((attempt + 1))
done
print_error "Server failed to start within $max_attempts seconds"
return 1
}
# Function to send Redis command and get response
redis_cmd() {
local cmd="$1"
local expected="$2"
print_status "Testing: $cmd"
local result=$(echo "$cmd" | redis-cli -p $PORT --raw 2>/dev/null || echo "ERROR")
if [ "$expected" != "" ] && [ "$result" != "$expected" ]; then
print_error "Expected: '$expected', Got: '$result'"
return 1
else
print_success "$cmd -> $result"
return 0
fi
}
# Function to test basic string operations
test_string_operations() {
print_status "=== Testing String Operations ==="
redis_cmd "PING" "PONG"
redis_cmd "SET mykey hello" "OK"
redis_cmd "GET mykey" "hello"
redis_cmd "SET counter 1" "OK"
redis_cmd "INCR counter" "2"
redis_cmd "INCR counter" "3"
redis_cmd "GET counter" "3"
redis_cmd "DEL mykey" "1"
redis_cmd "GET mykey" ""
redis_cmd "TYPE counter" "string"
redis_cmd "TYPE nonexistent" "none"
}
# Function to test hash operations
test_hash_operations() {
print_status "=== Testing Hash Operations ==="
# HSET and HGET
redis_cmd "HSET user:1 name John" "1"
redis_cmd "HSET user:1 age 30 city NYC" "2"
redis_cmd "HGET user:1 name" "John"
redis_cmd "HGET user:1 age" "30"
redis_cmd "HGET user:1 nonexistent" ""
# HGETALL
print_status "Testing HGETALL user:1"
redis_cmd "HGETALL user:1" ""
# HEXISTS
redis_cmd "HEXISTS user:1 name" "1"
redis_cmd "HEXISTS user:1 nonexistent" "0"
# HKEYS
print_status "Testing HKEYS user:1"
redis_cmd "HKEYS user:1" ""
# HVALS
print_status "Testing HVALS user:1"
redis_cmd "HVALS user:1" ""
# HLEN
redis_cmd "HLEN user:1" "3"
# HMGET
print_status "Testing HMGET user:1 name age"
redis_cmd "HMGET user:1 name age" ""
# HSETNX
redis_cmd "HSETNX user:1 name Jane" "0" # Should not set, field exists
redis_cmd "HSETNX user:1 email john@example.com" "1" # Should set, new field
redis_cmd "HGET user:1 email" "john@example.com"
# HDEL
redis_cmd "HDEL user:1 age city" "2"
redis_cmd "HLEN user:1" "2"
redis_cmd "HEXISTS user:1 age" "0"
# Test type checking
redis_cmd "SET stringkey value" "OK"
print_status "Testing WRONGTYPE error on string key"
redis_cmd "HGET stringkey field" "" # Should return WRONGTYPE error
}
# Function to test configuration commands
test_config_operations() {
print_status "=== Testing Configuration Operations ==="
print_status "Testing CONFIG GET dir"
redis_cmd "CONFIG GET dir" ""
print_status "Testing CONFIG GET dbfilename"
redis_cmd "CONFIG GET dbfilename" ""
}
# Function to test transaction operations
test_transaction_operations() {
print_status "=== Testing Transaction Operations ==="
redis_cmd "MULTI" "OK"
redis_cmd "SET tx_key1 value1" "QUEUED"
redis_cmd "SET tx_key2 value2" "QUEUED"
redis_cmd "INCR counter" "QUEUED"
print_status "Testing EXEC"
redis_cmd "EXEC" ""
redis_cmd "GET tx_key1" "value1"
redis_cmd "GET tx_key2" "value2"
# Test DISCARD
redis_cmd "MULTI" "OK"
redis_cmd "SET discard_key value" "QUEUED"
redis_cmd "DISCARD" "OK"
redis_cmd "GET discard_key" ""
}
# Function to test keys operations
test_keys_operations() {
print_status "=== Testing Keys Operations ==="
print_status "Testing KEYS *"
redis_cmd "KEYS *" ""
}
# Function to test info operations
test_info_operations() {
print_status "=== Testing Info Operations ==="
print_status "Testing INFO"
redis_cmd "INFO" ""
print_status "Testing INFO replication"
redis_cmd "INFO replication" ""
}
# Function to test expiration
test_expiration() {
print_status "=== Testing Expiration ==="
redis_cmd "SET expire_key value" "OK"
redis_cmd "SET expire_px_key value PX 1000" "OK" # 1 second
redis_cmd "SET expire_ex_key value EX 1" "OK" # 1 second
redis_cmd "GET expire_key" "value"
redis_cmd "GET expire_px_key" "value"
redis_cmd "GET expire_ex_key" "value"
print_status "Waiting 2 seconds for expiration..."
sleep 2
redis_cmd "GET expire_key" "value" # Should still exist
redis_cmd "GET expire_px_key" "" # Should be expired
redis_cmd "GET expire_ex_key" "" # Should be expired
}
# Function to test SCAN operations
test_scan_operations() {
print_status "=== Testing SCAN Operations ==="
# Set up test data for scanning
redis_cmd "SET scan_test1 value1" "OK"
redis_cmd "SET scan_test2 value2" "OK"
redis_cmd "SET scan_test3 value3" "OK"
redis_cmd "SET other_key other_value" "OK"
redis_cmd "HSET scan_hash field1 value1" "1"
# Test basic SCAN
print_status "Testing basic SCAN with cursor 0"
redis_cmd "SCAN 0" ""
# Test SCAN with MATCH pattern
print_status "Testing SCAN with MATCH pattern"
redis_cmd "SCAN 0 MATCH scan_test*" ""
# Test SCAN with COUNT
print_status "Testing SCAN with COUNT 2"
redis_cmd "SCAN 0 COUNT 2" ""
# Test SCAN with both MATCH and COUNT
print_status "Testing SCAN with MATCH and COUNT"
redis_cmd "SCAN 0 MATCH scan_* COUNT 1" ""
# Test SCAN continuation with more keys
print_status "Setting up more keys for continuation test"
redis_cmd "SET scan_key1 val1" "OK"
redis_cmd "SET scan_key2 val2" "OK"
redis_cmd "SET scan_key3 val3" "OK"
redis_cmd "SET scan_key4 val4" "OK"
redis_cmd "SET scan_key5 val5" "OK"
print_status "Testing SCAN with small COUNT for pagination"
redis_cmd "SCAN 0 COUNT 3" ""
# Clean up SCAN test data
print_status "Cleaning up SCAN test data"
redis_cmd "DEL scan_test1" "1"
redis_cmd "DEL scan_test2" "1"
redis_cmd "DEL scan_test3" "1"
redis_cmd "DEL other_key" "1"
redis_cmd "DEL scan_hash" "1"
redis_cmd "DEL scan_key1" "1"
redis_cmd "DEL scan_key2" "1"
redis_cmd "DEL scan_key3" "1"
redis_cmd "DEL scan_key4" "1"
redis_cmd "DEL scan_key5" "1"
}
# Main execution
main() {
print_status "Starting HeroDB comprehensive test suite..."
# Build the project
print_status "Building HeroDB..."
if ! cargo build -p herodb --release; then
print_error "Failed to build HeroDB"
exit 1
fi
# Create test database directory
mkdir -p "$DB_DIR"
# Start the server
print_status "Starting HeroDB server..."
./target/release/herodb --dir "$DB_DIR" --port $PORT &
SERVER_PID=$!
# Wait for server to start
if ! wait_for_server; then
print_error "Failed to start server"
exit 1
fi
# Run tests
local failed_tests=0
test_string_operations || failed_tests=$((failed_tests + 1))
test_hash_operations || failed_tests=$((failed_tests + 1))
test_config_operations || failed_tests=$((failed_tests + 1))
test_transaction_operations || failed_tests=$((failed_tests + 1))
test_keys_operations || failed_tests=$((failed_tests + 1))
test_info_operations || failed_tests=$((failed_tests + 1))
test_expiration || failed_tests=$((failed_tests + 1))
test_scan_operations || failed_tests=$((failed_tests + 1))
# Summary
echo
print_status "=== Test Summary ==="
if [ $failed_tests -eq 0 ]; then
print_success "All tests completed! Some may have warnings due to protocol differences."
print_success "HeroDB is working with persistent redb storage!"
else
print_warning "$failed_tests test categories had issues"
print_warning "Check the output above for details"
fi
print_status "Database file created at: $DB_DIR/herodb.redb"
print_status "Server logs and any errors are shown above"
}
# Check dependencies
check_dependencies() {
if ! command -v cargo &> /dev/null; then
print_error "cargo is required but not installed"
exit 1
fi
if ! command -v nc &> /dev/null; then
print_warning "netcat (nc) not found - some tests may not work properly"
fi
if ! command -v redis-cli &> /dev/null; then
print_warning "redis-cli not found - using netcat fallback"
fi
}
# Run dependency check and main function
check_dependencies
main "$@"

View File

@@ -0,0 +1,62 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn debug_hset_simple() {
// Clean up any existing test database
let test_dir = "/tmp/herodb_debug_hset";
let _ = std::fs::remove_dir_all(test_dir);
std::fs::create_dir_all(test_dir).unwrap();
let port = 16500;
let option = DBOption {
dir: test_dir.to_string(),
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Test simple HSET
println!("Testing HSET...");
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected '1' but got: {}", response);
// Test HGET
println!("Testing HGET...");
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HGET response: {}", response);
assert!(response.contains("value1"), "Expected 'value1' but got: {}", response);
}

View File

@@ -0,0 +1,56 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
#[tokio::test]
async fn debug_hset_return_value() {
let test_dir = "/tmp/herodb_debug_hset_return";
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir.to_string(),
port: 16390,
debug: false,
encrypt: false,
encryption_key: None,
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:16390")
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
// Connect and test HSET
let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap();
// Send HSET command
let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n";
stream.write_all(cmd.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
println!("HSET response: {}", response);
println!("Response bytes: {:?}", &buffer[..n]);
// Check if response contains "1"
assert!(response.contains("1"), "Expected response to contain '1', got: {}", response);
}

View File

@@ -0,0 +1,35 @@
use herodb::protocol::Protocol;
use herodb::cmd::Cmd;
#[test]
fn test_protocol_parsing() {
// Test TYPE command parsing
let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n";
println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(type_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
match Cmd::from(type_cmd) {
Ok((cmd, _, _)) => println!("Command parsed successfully: {:?}", cmd),
Err(e) => println!("Command parsing failed: {:?}", e),
}
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
// Test HEXISTS command parsing
let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n";
println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(hexists_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
match Cmd::from(hexists_cmd) {
Ok((cmd, _, _)) => println!("Command parsed successfully: {:?}", cmd),
Err(e) => println!("Command parsing failed: {:?}", e),
}
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
}

View File

@@ -0,0 +1,317 @@
use redis::{Client, Commands, Connection, RedisResult};
use std::process::{Child, Command};
use std::time::Duration;
use tokio::time::sleep;
// Helper function to get Redis connection, retrying until successful
fn get_redis_connection(port: u16) -> Connection {
let connection_info = format!("redis://127.0.0.1:{}", port);
let client = Client::open(connection_info).unwrap();
let mut attempts = 0;
loop {
match client.get_connection() {
Ok(mut conn) => {
if redis::cmd("PING").query::<String>(&mut conn).is_ok() {
return conn;
}
}
Err(e) => {
if attempts >= 120 {
panic!(
"Failed to connect to Redis server after 120 attempts: {}",
e
);
}
}
}
attempts += 1;
std::thread::sleep(Duration::from_millis(100));
}
}
// A guard to ensure the server process is killed when it goes out of scope
struct ServerProcessGuard {
process: Child,
test_dir: String,
}
impl Drop for ServerProcessGuard {
fn drop(&mut self) {
println!("Killing server process (pid: {})...", self.process.id());
if let Err(e) = self.process.kill() {
eprintln!("Failed to kill server process: {}", e);
}
match self.process.wait() {
Ok(status) => println!("Server process exited with: {}", status),
Err(e) => eprintln!("Failed to wait on server process: {}", e),
}
// Clean up the specific test directory
println!("Cleaning up test directory: {}", self.test_dir);
if let Err(e) = std::fs::remove_dir_all(&self.test_dir) {
eprintln!("Failed to clean up test directory: {}", e);
}
}
}
// Helper to set up the server and return a connection
fn setup_server() -> (ServerProcessGuard, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16400);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", port);
// Clean up previous test data
if std::path::Path::new(&test_dir).exists() {
let _ = std::fs::remove_dir_all(&test_dir);
}
std::fs::create_dir_all(&test_dir).unwrap();
// Start the server in a subprocess
let child = Command::new("cargo")
.args(&[
"run",
"--",
"--dir",
&test_dir,
"--port",
&port.to_string(),
"--debug",
])
.spawn()
.expect("Failed to start server process");
// Create a new guard that also owns the test directory path
let guard = ServerProcessGuard {
process: child,
test_dir,
};
// Give the server time to build and start (cargo run may compile first)
std::thread::sleep(Duration::from_millis(2500));
(guard, port)
}
async fn cleanup_keys(conn: &mut Connection) {
let keys: Vec<String> = redis::cmd("KEYS").arg("*").query(conn).unwrap();
if !keys.is_empty() {
for key in keys {
let _: () = redis::cmd("DEL").arg(key).query(conn).unwrap();
}
}
}
#[tokio::test]
async fn all_tests() {
let (_server_guard, port) = setup_server();
let mut conn = get_redis_connection(port);
// Run all tests using the same connection
test_basic_ping(&mut conn).await;
test_string_operations(&mut conn).await;
test_incr_operations(&mut conn).await;
test_hash_operations(&mut conn).await;
test_expiration(&mut conn).await;
test_scan_operations(&mut conn).await;
test_scan_with_count(&mut conn).await;
test_hscan_operations(&mut conn).await;
test_transaction_operations(&mut conn).await;
test_discard_transaction(&mut conn).await;
test_type_command(&mut conn).await;
test_info_command(&mut conn).await;
}
async fn test_basic_ping(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: String = redis::cmd("PING").query(conn).unwrap();
assert_eq!(result, "PONG");
cleanup_keys(conn).await;
}
async fn test_string_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("key", "value").unwrap();
let result: String = conn.get("key").unwrap();
assert_eq!(result, "value");
let result: Option<String> = conn.get("noexist").unwrap();
assert_eq!(result, None);
let deleted: i32 = conn.del("key").unwrap();
assert_eq!(deleted, 1);
let result: Option<String> = conn.get("key").unwrap();
assert_eq!(result, None);
cleanup_keys(conn).await;
}
async fn test_incr_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: i32 = redis::cmd("INCR").arg("counter").query(conn).unwrap();
assert_eq!(result, 1);
let result: i32 = redis::cmd("INCR").arg("counter").query(conn).unwrap();
assert_eq!(result, 2);
let _: () = conn.set("string", "hello").unwrap();
let result: RedisResult<i32> = redis::cmd("INCR").arg("string").query(conn);
assert!(result.is_err());
cleanup_keys(conn).await;
}
async fn test_hash_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: i32 = conn.hset("hash", "field1", "value1").unwrap();
assert_eq!(result, 1);
let result: String = conn.hget("hash", "field1").unwrap();
assert_eq!(result, "value1");
let _: () = conn.hset("hash", "field2", "value2").unwrap();
let _: () = conn.hset("hash", "field3", "value3").unwrap();
let result: std::collections::HashMap<String, String> = conn.hgetall("hash").unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.get("field1").unwrap(), "value1");
assert_eq!(result.get("field2").unwrap(), "value2");
assert_eq!(result.get("field3").unwrap(), "value3");
let result: i32 = conn.hlen("hash").unwrap();
assert_eq!(result, 3);
let result: bool = conn.hexists("hash", "field1").unwrap();
assert_eq!(result, true);
let result: bool = conn.hexists("hash", "noexist").unwrap();
assert_eq!(result, false);
let result: i32 = conn.hdel("hash", "field1").unwrap();
assert_eq!(result, 1);
let mut result: Vec<String> = conn.hkeys("hash").unwrap();
result.sort();
assert_eq!(result, vec!["field2", "field3"]);
let mut result: Vec<String> = conn.hvals("hash").unwrap();
result.sort();
assert_eq!(result, vec!["value2", "value3"]);
cleanup_keys(conn).await;
}
async fn test_expiration(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set_ex("expkey", "value", 1).unwrap();
let result: i32 = conn.ttl("expkey").unwrap();
assert!(result == 1 || result == 0);
let result: bool = conn.exists("expkey").unwrap();
assert_eq!(result, true);
sleep(Duration::from_millis(1100)).await;
let result: Option<String> = conn.get("expkey").unwrap();
assert_eq!(result, None);
let result: i32 = conn.ttl("expkey").unwrap();
assert_eq!(result, -2);
let result: bool = conn.exists("expkey").unwrap();
assert_eq!(result, false);
cleanup_keys(conn).await;
}
async fn test_scan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..5 {
let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(0)
.arg("MATCH")
.arg("key*")
.arg("COUNT")
.arg(10)
.query(conn)
.unwrap();
let (cursor, keys) = result;
assert_eq!(cursor, 0);
assert_eq!(keys.len(), 5);
cleanup_keys(conn).await;
}
async fn test_scan_with_count(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..15 {
let _: () = conn.set(format!("scan_key{}", i), i).unwrap();
}
let mut cursor = 0;
let mut all_keys = std::collections::HashSet::new();
loop {
let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg("scan_key*")
.arg("COUNT")
.arg(5)
.query(conn)
.unwrap();
for key in keys {
all_keys.insert(key);
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
assert_eq!(all_keys.len(), 15);
cleanup_keys(conn).await;
}
async fn test_hscan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..3 {
let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("HSCAN")
.arg("testhash")
.arg(0)
.arg("MATCH")
.arg("*")
.arg("COUNT")
.arg(10)
.query(conn)
.unwrap();
let (cursor, fields) = result;
assert_eq!(cursor, 0);
assert_eq!(fields.len(), 6);
cleanup_keys(conn).await;
}
async fn test_transaction_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap();
let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap();
let result: String = conn.get("key1").unwrap();
assert_eq!(result, "value1");
let result: String = conn.get("key2").unwrap();
assert_eq!(result, "value2");
cleanup_keys(conn).await;
}
async fn test_discard_transaction(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap();
let _: () = redis::cmd("DISCARD").query(conn).unwrap();
let result: Option<String> = conn.get("discard").unwrap();
assert_eq!(result, None);
cleanup_keys(conn).await;
}
async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = conn.set("string", "value").unwrap();
let result: String = redis::cmd("TYPE").arg("string").query(conn).unwrap();
assert_eq!(result, "string");
let _: () = conn.hset("hash", "field", "value").unwrap();
let result: String = redis::cmd("TYPE").arg("hash").query(conn).unwrap();
assert_eq!(result, "hash");
let result: String = redis::cmd("TYPE").arg("noexist").query(conn).unwrap();
assert_eq!(result, "none");
cleanup_keys(conn).await;
}
async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: String = redis::cmd("INFO").query(conn).unwrap();
assert!(result.contains("redis_version"));
let result: String = redis::cmd("INFO").arg("replication").query(conn).unwrap();
assert!(result.contains("role:master"));
cleanup_keys(conn).await;
}

608
herodb/tests/redis_tests.rs Normal file
View File

@@ -0,0 +1,608 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: true,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect to test server: {}", e),
}
}
}
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn test_basic_ping() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
}
#[tokio::test]
async fn test_string_operations() {
let (mut server, port) = start_test_server("string").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SET
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test GET non-existent key
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("$-1")); // NULL response
// Test DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test GET after DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("$-1")); // NULL response
}
#[tokio::test]
async fn test_incr_operations() {
let (mut server, port) = start_test_server("incr").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INCR on non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("1"));
// Test INCR on existing key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("2"));
// Test INCR on string value (should fail)
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await;
assert!(response.contains("ERR"));
}
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test HSET
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("1")); // 1 new field
// Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("value1"));
// Test HSET multiple fields
let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("3"));
// Test HEXISTS
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("0"));
// Test HDEL
let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HKEYS
let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field2"));
assert!(response.contains("field3"));
assert!(!response.contains("field1")); // Should be deleted
// Test HVALS
let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("value2"));
assert!(response.contains("value3"));
}
#[tokio::test]
async fn test_expiration() {
let (mut server, port) = start_test_server("expiration").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SETEX (expire in 1 second)
let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await;
assert!(response.contains("OK"));
// Test TTL
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds
// Test EXISTS
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1"));
// Wait for expiration
sleep(Duration::from_millis(1100)).await;
// Test GET after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
// Test TTL after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("-2")); // Key doesn't exist
// Test EXISTS after expiration
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("0"));
}
#[tokio::test]
async fn test_scan_operations() {
let (mut server, port) = start_test_server("scan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up test data
for i in 0..5 {
let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test SCAN
let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("key"));
// Test KEYS
let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await;
assert!(response.contains("key0"));
assert!(response.contains("key1"));
}
#[tokio::test]
async fn test_hscan_operations() {
let (mut server, port) = start_test_server("hscan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up hash data
for i in 0..3 {
let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test HSCAN
let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field"));
assert!(response.contains("value"));
}
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transaction").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued commands
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("QUEUED"));
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("QUEUED"));
// Test EXEC
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("OK")); // Should contain results of executed commands
// Verify commands were executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1"));
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await;
assert!(response.contains("value2"));
}
#[tokio::test]
async fn test_discard_transaction() {
let (mut server, port) = start_test_server("discard").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued command
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("QUEUED"));
// Test DISCARD
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("OK"));
// Verify command was not executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
}
#[tokio::test]
async fn test_type_command() {
let (mut server, port) = start_test_server("type").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
assert!(response.contains("string"));
// Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("none"));
}
#[tokio::test]
async fn test_config_commands() {
let (mut server, port) = start_test_server("config").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test CONFIG GET databases
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await;
assert!(response.contains("databases"));
assert!(response.contains("16"));
// Test CONFIG GET dir
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await;
assert!(response.contains("dir"));
assert!(response.contains("/tmp/herodb_test_config"));
}
#[tokio::test]
async fn test_info_command() {
let (mut server, port) = start_test_server("info").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INFO
let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await;
assert!(response.contains("redis_version"));
// Test INFO replication
let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await;
assert!(response.contains("role:master"));
}
#[tokio::test]
async fn test_error_handling() {
let (mut server, port) = start_test_server("error").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test WRONGTYPE error - try to use hash command on string
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await;
assert!(response.contains("WRONGTYPE"));
// Test unknown command
let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await;
assert!(response.contains("unknown cmd") || response.contains("ERR"));
// Test EXEC without MULTI
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("ERR"));
// Test DISCARD without MULTI
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("ERR"));
}
#[tokio::test]
async fn test_list_operations() {
let (mut server, port) = start_test_server("list").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test LPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await;
assert!(response.contains("2")); // 2 elements
// Test RPUSH
let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await;
assert!(response.contains("4")); // 4 elements
// Test LLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("4"));
// Test LRANGE
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
// Test LINDEX
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nd\r\n");
// Test LREM
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a
let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await;
assert!(response.contains("1"));
// Test LTRIM
let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await;
assert!(response.contains("OK"));
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("1"));
}

View File

@@ -0,0 +1,228 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
// Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000);
// Get a unique port for this test
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: true,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
#[tokio::test]
async fn test_basic_redis_functionality() {
let (mut server, port) = start_test_server("basic").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..10 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Test PING
let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
// Test SET
let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test HSET
let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("1"));
// Test HGET
let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
assert!(response.contains("value"));
// Test EXISTS
let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test TTL
let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("-1")); // No expiration
// Test TYPE
let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await;
assert!(response.contains("string"));
// Test QUIT to close connection gracefully
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
stream.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK"));
// Ensure the stream is closed
stream.shutdown().await.unwrap();
// Stop the server
server_handle.abort();
println!("✅ All basic Redis functionality tests passed!");
}
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash_ops").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Test HSET multiple fields
let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HEXISTS
let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HLEN
let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("2"));
// Test HSCAN
let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Stop the server
// For hash operations, we don't have a persistent stream, so we'll just abort the server.
// The server should handle closing its connections.
server_handle.abort();
println!("✅ All hash operations tests passed!");
}
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transactions").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Test MULTI
stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK"));
// Test queued commands
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
// Test EXEC
stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); // Should contain array of OK responses
// Verify commands were executed
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value1"));
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value2"));
// Stop the server
server_handle.abort();
println!("✅ All transaction operations tests passed!");
}

View File

@@ -0,0 +1,183 @@
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_simple_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect to test server: {}", e),
}
}
}
#[tokio::test]
async fn test_basic_ping_simple() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
}
#[tokio::test]
async fn test_hset_clean_db() {
let (mut server, port) = start_test_server("hset_clean").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response);
// Test HGET
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HGET response: {}", response);
assert!(response.contains("value1"));
}
#[tokio::test]
async fn test_type_command_simple() {
let (mut server, port) = start_test_server("type").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
println!("TYPE string response: {}", response);
assert!(response.contains("string"));
// Test hash type
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
println!("TYPE hash response: {}", response);
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
println!("TYPE noexist response: {}", response);
assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response);
}
#[tokio::test]
async fn test_hexists_simple() {
let (mut server, port) = start_test_server("hexists").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Set up hash
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
// Test HEXISTS for existing field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HEXISTS existing field response: {}", response);
assert!(response.contains("1"));
// Test HEXISTS for non-existent field
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
println!("HEXISTS non-existent field response: {}", response);
assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response);
}

892
herodb/tests/usage_suite.rs Normal file
View File

@@ -0,0 +1,892 @@
use herodb::{options::DBOption, server::Server};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::{sleep, Duration};
// =========================
// Helpers
// =========================
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17100);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_usage_suite_{}", test_name);
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
debug: false,
encrypt: false,
encryption_key: None,
};
let server = Server::new(option).await;
(server, port)
}
async fn spawn_listener(server: Server, port: u16) {
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("bind listener");
loop {
match listener.accept().await {
Ok((stream, _)) => {
let mut s_clone = server.clone();
tokio::spawn(async move {
let _ = s_clone.handle(stream).await;
});
}
Err(_e) => break,
}
}
});
}
/// Build RESP array for args ["PING"] -> "*1\r\n$4\r\nPING\r\n"
fn build_resp(args: &[&str]) -> String {
let mut s = format!("*{}\r\n", args.len());
for a in args {
s.push_str(&format!("${}\r\n{}\r\n", a.len(), a));
}
s
}
async fn connect(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(s) => return s,
Err(_) if attempts < 30 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
}
Err(e) => panic!("Failed to connect: {}", e),
}
}
}
fn find_crlf(buf: &[u8], start: usize) -> Option<usize> {
let mut i = start;
while i + 1 < buf.len() {
if buf[i] == b'\r' && buf[i + 1] == b'\n' {
return Some(i);
}
i += 1;
}
None
}
fn parse_number_i64(buf: &[u8], start: usize, end: usize) -> Option<i64> {
let s = std::str::from_utf8(&buf[start..end]).ok()?;
s.parse::<i64>().ok()
}
// Return number of bytes that make up a complete RESP element starting at 'i', or None if incomplete.
fn parse_elem(buf: &[u8], i: usize) -> Option<usize> {
if i >= buf.len() {
return None;
}
match buf[i] {
b'+' | b'-' | b':' => {
let end = find_crlf(buf, i + 1)?;
Some(end + 2 - i)
}
b'$' => {
let hdr_end = find_crlf(buf, i + 1)?;
let n = parse_number_i64(buf, i + 1, hdr_end)?;
if n < 0 {
// Null bulk string: only header
Some(hdr_end + 2 - i)
} else {
let need = hdr_end + 2 + (n as usize) + 2;
if need <= buf.len() {
Some(need - i)
} else {
None
}
}
}
b'*' => {
let hdr_end = find_crlf(buf, i + 1)?;
let n = parse_number_i64(buf, i + 1, hdr_end)?;
if n < 0 {
// Null array: only header
Some(hdr_end + 2 - i)
} else {
let mut j = hdr_end + 2;
for _ in 0..(n as usize) {
let consumed = parse_elem(buf, j)?;
j += consumed;
}
Some(j - i)
}
}
_ => None,
}
}
fn resp_frame_len(buf: &[u8]) -> Option<usize> {
parse_elem(buf, 0)
}
async fn read_full_resp(stream: &mut TcpStream) -> String {
let mut buf: Vec<u8> = Vec::with_capacity(8192);
let mut tmp = vec![0u8; 4096];
loop {
if let Some(total) = resp_frame_len(&buf) {
if buf.len() >= total {
return String::from_utf8_lossy(&buf[..total]).to_string();
}
}
match tokio::time::timeout(Duration::from_secs(2), stream.read(&mut tmp)).await {
Ok(Ok(n)) => {
if n == 0 {
if let Some(total) = resp_frame_len(&buf) {
if buf.len() >= total {
return String::from_utf8_lossy(&buf[..total]).to_string();
}
}
return String::from_utf8_lossy(&buf).to_string();
}
buf.extend_from_slice(&tmp[..n]);
}
Ok(Err(e)) => panic!("read error: {}", e),
Err(_) => panic!("timeout waiting for reply"),
}
if buf.len() > 8 * 1024 * 1024 {
panic!("reply too large");
}
}
}
async fn send_cmd(stream: &mut TcpStream, args: &[&str]) -> String {
let req = build_resp(args);
stream.write_all(req.as_bytes()).await.unwrap();
read_full_resp(stream).await
}
// Assert helpers with clearer output
fn assert_contains(haystack: &str, needle: &str, ctx: &str) {
assert!(
haystack.contains(needle),
"ASSERT CONTAINS failed: '{}' not found in response.\nContext: {}\nResponse:\n{}",
needle,
ctx,
haystack
);
}
fn assert_eq_resp(actual: &str, expected: &str, ctx: &str) {
assert!(
actual == expected,
"ASSERT EQUAL failed.\nContext: {}\nExpected:\n{:?}\nActual:\n{:?}",
ctx,
expected,
actual
);
}
/// Extract the payload of a single RESP Bulk String reply.
/// Example input:
/// "$5\r\nhello\r\n" -> Some("hello".to_string())
fn extract_bulk_payload(resp: &str) -> Option<String> {
// find first CRLF after "$len"
let first = resp.find("\r\n")?;
let after = &resp[(first + 2)..];
// find next CRLF ending payload
let second = after.find("\r\n")?;
Some(after[..second].to_string())
}
// =========================
// Test suites
// =========================
#[tokio::test]
async fn test_01_connection_and_info() {
let (server, port) = start_test_server("conn_info").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// redis-cli may send COMMAND DOCS, our server replies empty array; harmless.
let pong = send_cmd(&mut s, &["PING"]).await;
assert_contains(&pong, "PONG", "PING should return PONG");
let echo = send_cmd(&mut s, &["ECHO", "hello"]).await;
assert_contains(&echo, "hello", "ECHO hello");
// INFO (general)
let info = send_cmd(&mut s, &["INFO"]).await;
assert_contains(&info, "redis_version", "INFO should include redis_version");
// INFO REPLICATION (static stub)
let repl = send_cmd(&mut s, &["INFO", "replication"]).await;
assert_contains(&repl, "role:master", "INFO replication role");
// CONFIG GET subset
let cfg = send_cmd(&mut s, &["CONFIG", "GET", "databases"]).await;
assert_contains(&cfg, "databases", "CONFIG GET databases");
assert_contains(&cfg, "16", "CONFIG GET databases value");
// CLIENT name
let setname = send_cmd(&mut s, &["CLIENT", "SETNAME", "myapp"]).await;
assert_contains(&setname, "OK", "CLIENT SETNAME");
let getname = send_cmd(&mut s, &["CLIENT", "GETNAME"]).await;
assert_contains(&getname, "myapp", "CLIENT GETNAME");
// SELECT db
let sel = send_cmd(&mut s, &["SELECT", "0"]).await;
assert_contains(&sel, "OK", "SELECT 0");
// QUIT should close connection after sending OK
let quit = send_cmd(&mut s, &["QUIT"]).await;
assert_contains(&quit, "OK", "QUIT should return OK");
}
#[tokio::test]
async fn test_02_strings_and_expiry() {
let (server, port) = start_test_server("strings").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// SET / GET
let set = send_cmd(&mut s, &["SET", "user:1", "alice"]).await;
assert_contains(&set, "OK", "SET user:1 alice");
let get = send_cmd(&mut s, &["GET", "user:1"]).await;
assert_contains(&get, "alice", "GET user:1");
// EXISTS / DEL
let ex1 = send_cmd(&mut s, &["EXISTS", "user:1"]).await;
assert_contains(&ex1, "1", "EXISTS user:1");
let del = send_cmd(&mut s, &["DEL", "user:1"]).await;
assert_contains(&del, "1", "DEL user:1");
let ex0 = send_cmd(&mut s, &["EXISTS", "user:1"]).await;
assert_contains(&ex0, "0", "EXISTS after DEL");
// INCR behavior
let i1 = send_cmd(&mut s, &["INCR", "count"]).await;
assert_contains(&i1, "1", "INCR new key -> 1");
let i2 = send_cmd(&mut s, &["INCR", "count"]).await;
assert_contains(&i2, "2", "INCR existing -> 2");
let _ = send_cmd(&mut s, &["SET", "notnum", "abc"]).await;
let ierr = send_cmd(&mut s, &["INCR", "notnum"]).await;
assert_contains(&ierr, "ERR", "INCR on non-numeric should ERR");
// Expiration via SET EX
let setex = send_cmd(&mut s, &["SET", "tmp:1", "boom", "EX", "1"]).await;
assert_contains(&setex, "OK", "SET tmp:1 EX 1");
let g_immediate = send_cmd(&mut s, &["GET", "tmp:1"]).await;
assert_contains(&g_immediate, "boom", "GET tmp:1 immediately");
let ttl = send_cmd(&mut s, &["TTL", "tmp:1"]).await;
// Implementation returns a SimpleString, accept any numeric content
assert!(
ttl.contains("1") || ttl.contains("0"),
"TTL should be 1 or 0, got: {}",
ttl
);
sleep(Duration::from_millis(1100)).await;
let g_after = send_cmd(&mut s, &["GET", "tmp:1"]).await;
assert_contains(&g_after, "$-1", "GET tmp:1 after expiry -> Null");
// TYPE
let _ = send_cmd(&mut s, &["SET", "t", "v"]).await;
let ty = send_cmd(&mut s, &["TYPE", "t"]).await;
assert_contains(&ty, "string", "TYPE string key");
let ty_none = send_cmd(&mut s, &["TYPE", "noexist"]).await;
assert_contains(&ty_none, "none", "TYPE nonexistent");
}
#[tokio::test]
async fn test_03_scan_and_keys() {
let (server, port) = start_test_server("scan").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
for i in 0..5 {
let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await;
}
let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await;
assert_contains(&scan, "key0", "SCAN should return keys with MATCH");
assert_contains(&scan, "key4", "SCAN should return last key");
let keys = send_cmd(&mut s, &["KEYS", "*"]).await;
assert_contains(&keys, "key0", "KEYS * includes key0");
assert_contains(&keys, "key4", "KEYS * includes key4");
}
#[tokio::test]
async fn test_04_hashes_suite() {
let (server, port) = start_test_server("hashes").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// HSET (single, returns number of new fields)
let h1 = send_cmd(&mut s, &["HSET", "profile:1", "name", "alice"]).await;
assert_contains(&h1, "1", "HSET new field -> 1");
// HGET
let hg = send_cmd(&mut s, &["HGET", "profile:1", "name"]).await;
assert_contains(&hg, "alice", "HGET existing field");
// HSET multiple
let h2 = send_cmd(&mut s, &["HSET", "profile:1", "age", "30", "city", "paris"]).await;
assert_contains(&h2, "2", "HSET added 2 new fields");
// HMGET
let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await;
assert_contains(&hmg, "alice", "HMGET name");
assert_contains(&hmg, "30", "HMGET age");
assert_contains(&hmg, "paris", "HMGET city");
assert_contains(&hmg, "$-1", "HMGET non-existent -> Null");
// HGETALL
let hga = send_cmd(&mut s, &["HGETALL", "profile:1"]).await;
assert_contains(&hga, "name", "HGETALL contains name");
assert_contains(&hga, "alice", "HGETALL contains alice");
// HLEN
let hlen = send_cmd(&mut s, &["HLEN", "profile:1"]).await;
assert_contains(&hlen, "3", "HLEN is 3");
// HEXISTS
let hex1 = send_cmd(&mut s, &["HEXISTS", "profile:1", "age"]).await;
assert_contains(&hex1, "1", "HEXISTS age true");
let hex0 = send_cmd(&mut s, &["HEXISTS", "profile:1", "nope"]).await;
assert_contains(&hex0, "0", "HEXISTS nope false");
// HKEYS / HVALS
let hkeys = send_cmd(&mut s, &["HKEYS", "profile:1"]).await;
assert_contains(&hkeys, "name", "HKEYS includes name");
let hvals = send_cmd(&mut s, &["HVALS", "profile:1"]).await;
assert_contains(&hvals, "alice", "HVALS includes alice");
// HSETNX
let hnx0 = send_cmd(&mut s, &["HSETNX", "profile:1", "name", "bob"]).await;
assert_contains(&hnx0, "0", "HSETNX existing field -> 0");
let hnx1 = send_cmd(&mut s, &["HSETNX", "profile:1", "nickname", "ali"]).await;
assert_contains(&hnx1, "1", "HSETNX new field -> 1");
// HSCAN
let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await;
assert_contains(&hscan, "name", "HSCAN matches fields starting with n");
assert_contains(&hscan, "nickname", "HSCAN nickname present");
// HDEL
let hdel = send_cmd(&mut s, &["HDEL", "profile:1", "city", "age"]).await;
assert_contains(&hdel, "2", "HDEL removed two fields");
}
#[tokio::test]
async fn test_05_lists_suite_including_blpop() {
let (server, port) = start_test_server("lists").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut a = connect(port).await;
// LPUSH / RPUSH / LLEN
let lp = send_cmd(&mut a, &["LPUSH", "q:jobs", "a", "b"]).await;
assert_contains(&lp, "2", "LPUSH added 2, length 2");
let rp = send_cmd(&mut a, &["RPUSH", "q:jobs", "c"]).await;
assert_contains(&rp, "3", "RPUSH now length 3");
let llen = send_cmd(&mut a, &["LLEN", "q:jobs"]).await;
assert_contains(&llen, "3", "LLEN 3");
// LINDEX / LRANGE
let lidx = send_cmd(&mut a, &["LINDEX", "q:jobs", "0"]).await;
assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b");
let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]");
// LTRIM
let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await;
assert_contains(&ltrim, "OK", "LTRIM OK");
let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]");
// LREM remove first occurrence of b
let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await;
assert_contains(&lrem, "1", "LREM removed 1");
// LPOP and RPOP
let lpop1 = send_cmd(&mut a, &["LPOP", "q:jobs"]).await;
assert_contains(&lpop1, "$1\r\na\r\n", "LPOP returns a");
let rpop_empty = send_cmd(&mut a, &["RPOP", "q:jobs"]).await; // empty now
assert_contains(&rpop_empty, "$-1", "RPOP on empty -> Null");
// LPOP with count on empty -> []
let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await;
assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array");
// BLPOP: block on one client, push from another
let c1 = connect(port).await;
let mut c2 = connect(port).await;
// Start BLPOP on c1
let blpop_task = tokio::spawn(async move {
let mut c1_local = c1;
send_cmd(&mut c1_local, &["BLPOP", "q:block", "5"]).await
});
// Give it time to register waiter
sleep(Duration::from_millis(150)).await;
// Push from c2 to wake BLPOP
let _ = send_cmd(&mut c2, &["LPUSH", "q:block", "x"]).await;
// Await BLPOP result
let blpop_res = blpop_task.await.expect("BLPOP task join");
assert_contains(&blpop_res, "q:block", "BLPOP returned key");
assert_contains(&blpop_res, "x", "BLPOP returned element");
}
#[tokio::test]
async fn test_06_flushdb_suite() {
let (server, port) = start_test_server("flushdb").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
let _ = send_cmd(&mut s, &["SET", "k1", "v1"]).await;
let _ = send_cmd(&mut s, &["HSET", "h1", "f", "v"]).await;
let _ = send_cmd(&mut s, &["LPUSH", "l1", "a"]).await;
let keys_before = send_cmd(&mut s, &["KEYS", "*"]).await;
assert_contains(&keys_before, "k1", "have string key before FLUSHDB");
assert_contains(&keys_before, "h1", "have hash key before FLUSHDB");
assert_contains(&keys_before, "l1", "have list key before FLUSHDB");
let fl = send_cmd(&mut s, &["FLUSHDB"]).await;
assert_contains(&fl, "OK", "FLUSHDB OK");
let keys_after = send_cmd(&mut s, &["KEYS", "*"]).await;
assert_eq_resp(&keys_after, "*0\r\n", "DB should be empty after FLUSHDB");
}
#[tokio::test]
async fn test_07_age_stateless_suite() {
let (server, port) = start_test_server("age_stateless").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// GENENC -> [recipient, identity]
let gen = send_cmd(&mut s, &["AGE", "GENENC"]).await;
assert!(
gen.starts_with("*2\r\n$"),
"AGE GENENC should return array [recipient, identity], got:\n{}",
gen
);
// Parse simple RESP array of two bulk strings to extract keys
fn parse_two_bulk_array(resp: &str) -> (String, String) {
// naive parse for tests
let mut lines = resp.lines();
let _ = lines.next(); // *2
// $len
let _ = lines.next();
let recip = lines.next().unwrap_or("").to_string();
let _ = lines.next();
let ident = lines.next().unwrap_or("").to_string();
(recip, ident)
}
let (recipient, identity) = parse_two_bulk_array(&gen);
assert!(
recipient.starts_with("age1") && identity.starts_with("AGE-SECRET-KEY-1"),
"Unexpected AGE key formats.\nrecipient: {}\nidentity: {}",
recipient,
identity
);
// ENCRYPT / DECRYPT
let ct = send_cmd(&mut s, &["AGE", "ENCRYPT", &recipient, "hello world"]).await;
let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPT");
let pt = send_cmd(&mut s, &["AGE", "DECRYPT", &identity, &ct_b64]).await;
assert_contains(&pt, "hello world", "AGE DECRYPT round-trip");
// GENSIGN -> [verify_pub_b64, sign_secret_b64]
let gensign = send_cmd(&mut s, &["AGE", "GENSIGN"]).await;
let (verify_pub, sign_secret) = parse_two_bulk_array(&gensign);
assert!(
!verify_pub.is_empty() && !sign_secret.is_empty(),
"GENSIGN returned empty keys"
);
// SIGN / VERIFY
let sig = send_cmd(&mut s, &["AGE", "SIGN", &sign_secret, "msg"]).await;
let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGN");
let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await;
assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature");
let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await;
assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature");
}
#[tokio::test]
async fn test_08_age_persistent_named_suite() {
let (server, port) = start_test_server("age_persistent").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// KEYGEN + ENCRYPTNAME/DECRYPTNAME
let kg = send_cmd(&mut s, &["AGE", "KEYGEN", "app1"]).await;
assert!(
kg.starts_with("*2\r\n"),
"AGE KEYGEN should return [recipient, identity], got:\n{}",
kg
);
let ct = send_cmd(&mut s, &["AGE", "ENCRYPTNAME", "app1", "hello"]).await;
let ct_b64 = extract_bulk_payload(&ct).expect("Failed to parse bulk payload from ENCRYPTNAME");
let pt = send_cmd(&mut s, &["AGE", "DECRYPTNAME", "app1", &ct_b64]).await;
assert_contains(&pt, "hello", "DECRYPTNAME round-trip");
// SIGNKEYGEN + SIGNNAME/VERIFYNAME
let skg = send_cmd(&mut s, &["AGE", "SIGNKEYGEN", "app1"]).await;
assert!(
skg.starts_with("*2\r\n"),
"AGE SIGNKEYGEN should return [verify_pub, sign_secret], got:\n{}",
skg
);
let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await;
let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME");
let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await;
assert_contains(&v1, "1", "VERIFYNAME valid => 1");
let v0 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "bad", &sig_b64]).await;
assert_contains(&v0, "0", "VERIFYNAME invalid => 0");
// AGE LIST
let lst = send_cmd(&mut s, &["AGE", "LIST"]).await;
assert_contains(&lst, "encpub", "AGE LIST label encpub");
assert_contains(&lst, "app1", "AGE LIST includes app1");
}
#[tokio::test]
async fn test_10_expire_pexpire_persist() {
let (server, port) = start_test_server("expire_suite").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// EXPIRE: seconds
let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await;
let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await;
assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert!(
ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:s should be 1 or 0, got: {}",
ttl1
);
sleep(Duration::from_millis(1100)).await;
let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await;
assert_contains(&get_after, "$-1", "GET after expiry should be Null");
let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert_contains(&ttl_after, "-2", "TTL after expiry -> -2");
let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await;
assert_contains(&exists_after, "0", "EXISTS after expiry -> 0");
// PEXPIRE: milliseconds
let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await;
let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await;
assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)");
let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await;
assert!(
ttl_ms1.contains("1") || ttl_ms1.contains("0"),
"TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}",
ttl_ms1
);
sleep(Duration::from_millis(1600)).await;
let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await;
assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0");
// PERSIST: remove expiration
let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await;
let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await;
let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert!(
ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"),
"TTL exp:persist should be >=0 before persist, got: {}",
ttl_pre
);
let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist1, "1", "PERSIST should remove expiration");
let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
// Second persist should return 0 (nothing to remove)
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)");
}
#[tokio::test]
async fn test_11_set_with_options() {
let (server, port) = start_test_server("set_opts").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// SET with GET on non-existing key -> returns Null, sets value
let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await;
assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist");
let g1 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g1, "v1", "GET s1 after first SET");
// SET with GET should return old value, then set to new
let set_get2 = send_cmd(&mut s, &["SET", "s1", "v2", "GET"]).await;
assert_contains(&set_get2, "v1", "SET s1 v2 GET returns previous value v1");
let g2 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g2, "v2", "GET s1 now v2");
// NX prevents update when key exists; with GET should return Null and not change
let set_nx = send_cmd(&mut s, &["SET", "s1", "v3", "NX", "GET"]).await;
assert_contains(&set_nx, "$-1", "SET s1 v3 NX GET returns Null when not set");
let g3 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g3, "v2", "GET s1 remains v2 after NX prevented write");
// NX allows set when key does not exist
let set_nx2 = send_cmd(&mut s, &["SET", "s2", "v10", "NX"]).await;
assert_contains(&set_nx2, "OK", "SET s2 v10 NX -> OK for new key");
let g4 = send_cmd(&mut s, &["GET", "s2"]).await;
assert_contains(&g4, "v10", "GET s2 is v10");
// XX requires existing key; with GET returns old value and sets new
let set_xx = send_cmd(&mut s, &["SET", "s2", "v11", "XX", "GET"]).await;
assert_contains(&set_xx, "v10", "SET s2 v11 XX GET returns previous v10");
let g5 = send_cmd(&mut s, &["GET", "s2"]).await;
assert_contains(&g5, "v11", "GET s2 is now v11");
// PX expiration path via SET options
let set_px = send_cmd(&mut s, &["SET", "s3", "vpx", "PX", "500"]).await;
assert_contains(&set_px, "OK", "SET s3 vpx PX 500 -> OK");
let ttl_px1 = send_cmd(&mut s, &["TTL", "s3"]).await;
assert!(
ttl_px1.contains("0") || ttl_px1.contains("1"),
"TTL s3 immediately after PX should be 1 or 0, got: {}",
ttl_px1
);
sleep(Duration::from_millis(650)).await;
let g6 = send_cmd(&mut s, &["GET", "s3"]).await;
assert_contains(&g6, "$-1", "GET s3 after PX expiry -> Null");
}
#[tokio::test]
async fn test_09_mget_mset_and_variadic_exists_del() {
let (server, port) = start_test_server("mget_mset_variadic").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// MSET multiple keys
let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await;
assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK");
// MGET should return values and Null for missing
let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await;
// Expect an array with 4 entries; verify payloads
assert_contains(&mget, "v1", "MGET k1");
assert_contains(&mget, "v2", "MGET k2");
assert_contains(&mget, "v3", "MGET k3");
assert_contains(&mget, "$-1", "MGET missing returns Null");
// EXISTS variadic: count how many exist
let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await;
// Server returns SimpleString numeric, e.g. +2
assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2");
// DEL variadic: delete multiple keys, return count deleted
let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await;
assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2");
// Verify deletion
let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await;
assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0");
// MGET after deletion should include Nulls for deleted keys
let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await;
assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null");
assert_contains(&mget_after, "v2", "MGET k2 remains");
assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null");
}
#[tokio::test]
async fn test_12_hash_incr() {
let (server, port) = start_test_server("hash_incr").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// Integer increments
let _ = send_cmd(&mut s, &["HSET", "hinc", "a", "1"]).await;
let r1 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "2"]).await;
assert_contains(&r1, "3", "HINCRBY hinc a 2 -&gt; 3");
let r2 = send_cmd(&mut s, &["HINCRBY", "hinc", "a", "-1"]).await;
assert_contains(&r2, "2", "HINCRBY hinc a -1 -&gt; 2");
let r3 = send_cmd(&mut s, &["HINCRBY", "hinc", "b", "5"]).await;
assert_contains(&r3, "5", "HINCRBY hinc b 5 -&gt; 5");
// HINCRBY error on non-integer field
let _ = send_cmd(&mut s, &["HSET", "hinc", "s", "x"]).await;
let r_err = send_cmd(&mut s, &["HINCRBY", "hinc", "s", "1"]).await;
assert_contains(&r_err, "ERR", "HINCRBY on non-integer field should ERR");
// Float increments
let r4 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "1.5"]).await;
assert_contains(&r4, "1.5", "HINCRBYFLOAT hinc f 1.5 -&gt; 1.5");
let r5 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "f", "2.5"]).await;
// Could be "4", "4.0", or "4.000000", accept "4" substring
assert_contains(&r5, "4", "HINCRBYFLOAT hinc f 2.5 -&gt; 4");
// HINCRBYFLOAT error on non-float field
let _ = send_cmd(&mut s, &["HSET", "hinc", "notf", "abc"]).await;
let r6 = send_cmd(&mut s, &["HINCRBYFLOAT", "hinc", "notf", "1"]).await;
assert_contains(&r6, "ERR", "HINCRBYFLOAT on non-float field should ERR");
}
#[tokio::test]
async fn test_05b_brpop_suite() {
let (server, port) = start_test_server("lists_brpop").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut a = connect(port).await;
// RPUSH some initial data, BRPOP should take from the right
let _ = send_cmd(&mut a, &["RPUSH", "q:rjobs", "1", "2"]).await;
let br_nonblock = send_cmd(&mut a, &["BRPOP", "q:rjobs", "0"]).await;
// Should pop the rightmost element "2"
assert_contains(&br_nonblock, "q:rjobs", "BRPOP returns key");
assert_contains(&br_nonblock, "2", "BRPOP returns rightmost element");
// Now test blocking BRPOP: start blocked client, then RPUSH from another client
let c1 = connect(port).await;
let mut c2 = connect(port).await;
// Start BRPOP on c1
let brpop_task = tokio::spawn(async move {
let mut c1_local = c1;
send_cmd(&mut c1_local, &["BRPOP", "q:blockr", "5"]).await
});
// Give it time to register waiter
sleep(Duration::from_millis(150)).await;
// Push from right to wake BRPOP
let _ = send_cmd(&mut c2, &["RPUSH", "q:blockr", "X"]).await;
// Await BRPOP result
let brpop_res = brpop_task.await.expect("BRPOP task join");
assert_contains(&brpop_res, "q:blockr", "BRPOP returned key");
assert_contains(&brpop_res, "X", "BRPOP returned element");
}
#[tokio::test]
async fn test_13_dbsize() {
let (server, port) = start_test_server("dbsize").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// Initially empty
let n0 = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n0, "0", "DBSIZE initial should be 0");
// Add a string, a hash, and a list -> dbsize = 3
let _ = send_cmd(&mut s, &["SET", "s", "v"]).await;
let _ = send_cmd(&mut s, &["HSET", "h", "f", "v"]).await;
let _ = send_cmd(&mut s, &["LPUSH", "l", "a", "b"]).await;
let n3 = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n3, "3", "DBSIZE after adding s,h,l should be 3");
// Expire the string and wait, dbsize should drop to 2
let _ = send_cmd(&mut s, &["PEXPIRE", "s", "400"]).await;
sleep(Duration::from_millis(500)).await;
let n2 = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n2, "2", "DBSIZE after string expiry should be 2");
// Delete remaining keys and confirm 0
let _ = send_cmd(&mut s, &["DEL", "h"]).await;
let _ = send_cmd(&mut s, &["DEL", "l"]).await;
let n_final = send_cmd(&mut s, &["DBSIZE"]).await;
assert_contains(&n_final, "0", "DBSIZE after deleting all keys should be 0");
}
#[tokio::test]
async fn test_14_expireat_pexpireat() {
use std::time::{SystemTime, UNIX_EPOCH};
let (server, port) = start_test_server("expireat_suite").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
// EXPIREAT: seconds since epoch
let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await;
let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await;
assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await;
assert!(
ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:at:s should be 1 or 0 shortly after EXPIREAT, got: {}",
ttl1
);
sleep(Duration::from_millis(1200)).await;
let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await;
assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0");
// PEXPIREAT: milliseconds since epoch
let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await;
let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await;
assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)");
let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await;
assert!(
ttl2.contains("0") || ttl2.contains("1"),
"TTL exp:at:ms should be 0..1 soon after PEXPIREAT, got: {}",
ttl2
);
sleep(Duration::from_millis(600)).await;
let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await;
assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0");
}

9
supervisor/Cargo.toml Normal file
View File

@@ -0,0 +1,9 @@
[package]
name = "supervisor"
version = "0.1.0"
edition = "2021"
[dependencies]
# The supervisor will eventually depend on the herodb crate.
# We can add this dependency now.
# herodb = { path = "../herodb" }

4
supervisor/src/main.rs Normal file
View File

@@ -0,0 +1,4 @@
fn main() {
println!("Hello from the supervisor crate!");
// Supervisor logic will be implemented here.
}