21 Commits

Author SHA1 Message Date
Maxime Van Hees
c470772a13 Merge branch 'management_rpc_server' 2025-09-22 16:26:53 +02:00
Maxime Van Hees
bd34fd092a Persist backend per database id in admin metadata so restarts and lazy opens always use the correct engine (Sled/Redb) 2025-09-22 15:29:58 +02:00
Maxime Van Hees
8e044a64b7 fix incorrect keycount displayed in database info over RPC calls 2025-09-19 14:04:03 +02:00
Maxime Van Hees
87177f4a07 update documentation about 0.db admin db + symmetric encryption + include RPC examples + asymmetric transpart named key instances for encryption and signatures 2025-09-19 11:55:28 +02:00
Maxime Van Hees
151a6ffbfa fixed test 2025-09-19 10:35:08 +02:00
Maxime Van Hees
8ab841f68c Key generation now automatically derives X25519 keys from Ed25519 keys which allows user to transparantly use their key name for encrypting/decrypting and signing/verifying 2025-09-18 22:37:19 +02:00
Maxime Van Hees
8808c0e9d9 Implemented symmetric encryption; new commands are SYM KEYGEN; SYM ENCRYPT; SYM DECRYPT 2025-09-18 11:59:44 +02:00
Maxime Van Hees
c6b277cc9c fixed DEL showing wrong deletion amount + AGE LIST now returns a list of managed keys names without nested arrays or labels 2025-09-18 00:19:40 +02:00
8331ed032b ... 2025-09-17 07:02:44 +02:00
Maxime Van Hees
b8ca73397d implemented 0.db as admin database architecture + updated test file 2025-09-16 16:06:47 +02:00
Maxime Van Hees
1b15806a85 fix invalid values in RPC response about database instance details 2025-09-15 13:45:37 +02:00
Maxime Van Hees
da325a9659 fix bug where meta files where not auto-created upon starting + fix bug where meta json files were actually binary + improved access control to database instances 2025-09-15 10:34:03 +02:00
Maxime Van Hees
bdf363016a WIP: adding access management control to db instances 2025-09-12 17:11:50 +02:00
Maxime Van Hees
8798bc202e Restore working code 2025-09-11 18:33:09 +02:00
Maxime Van Hees
9fa9832605 combined curret main (with sled) and RPC server 2025-09-11 17:23:46 +02:00
Maxime Van Hees
4bb24b38dd fix typo in README 2025-09-11 15:34:03 +02:00
Maxime Van Hees
f3da14b957 Merge branch 'append' 2025-09-11 15:31:47 +02:00
Maxime Van Hees
5ea34b4445 update variable name as 'gen' is a reserved keyword since Rust 2024 edition 2025-09-11 15:25:26 +02:00
Maxime Van Hees
d9a3b711d1 Update tot Rust 2024 edition + update Cargo.toml file 2025-09-11 15:24:28 +02:00
Maxime Van Hees
d931770e90 Fix test suite + update Cargo.toml file 2025-09-09 16:04:31 +02:00
Timur Gordon
a87ec4dbb5 add readme 2025-08-27 15:39:59 +02:00
45 changed files with 4293 additions and 4473 deletions

1450
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
[package]
name = "herodb"
version = "0.0.1"
authors = ["Pin Fang <fpfangpin@hotmail.com>"]
edition = "2021"
authors = ["ThreeFold Tech NV"]
edition = "2024"
[dependencies]
anyhow = "1.0.59"
@@ -23,8 +23,9 @@ sha2 = "0.10"
age = "0.10"
secrecy = "0.8"
ed25519-dalek = "2"
x25519-dalek = "2"
base64 = "0.22"
tantivy = "0.25.0"
jsonrpsee = { version = "0.26.0", features = ["http-client", "ws-client", "server", "macros"] }
[dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] }

View File

@@ -17,6 +17,8 @@ The main purpose of HeroDB is to offer a lightweight, embeddable, and Redis-comp
- **Expiration**: Time-to-live (TTL) functionality for keys.
- **Scanning**: Cursor-based iteration for keys and hash fields (`SCAN`, `HSCAN`).
- **AGE Cryptography Commands**: HeroDB-specific extensions for cryptographic operations.
- **Symmetric Encryption**: Stateless symmetric encryption using XChaCha20-Poly1305.
- **Admin Database 0**: Centralized control for database management, access control, and per-database encryption.
## Quick Start
@@ -30,31 +32,14 @@ cargo build --release
### Running HeroDB
You can start HeroDB with different backends and encryption options:
#### Default `redb` Backend
Launch HeroDB with the required `--admin-secret` flag, which encrypts the admin database (DB 0) and authorizes admin access. Optional flags include `--dir` for the database directory, `--port` for the TCP port (default 6379), `--sled` for the sled backend, and `--enable-rpc` to start the JSON-RPC management server on port 8080.
Example:
```bash
./target/release/herodb --dir /tmp/herodb_redb --port 6379
./target/release/herodb --dir /tmp/herodb --admin-secret myadminsecret --port 6379 --enable-rpc
```
#### `sled` Backend
```bash
./target/release/herodb --dir /tmp/herodb_sled --port 6379 --sled
```
#### `redb` with Encryption
```bash
./target/release/herodb --dir /tmp/herodb_encrypted --port 6379 --encrypt --key mysecretkey
```
#### `sled` with Encryption
```bash
./target/release/herodb --dir /tmp/herodb_sled_encrypted --port 6379 --sled --encrypt --key mysecretkey
```
For detailed launch options, see [Basics](docs/basics.md).
## Usage with Redis Clients
@@ -76,10 +61,24 @@ redis-cli -p 6379 SCAN 0 MATCH user:* COUNT 10
# 2) 1) "user:1"
```
## Cryptography
HeroDB supports asymmetric encryption/signatures via AGE commands (X25519 for encryption, Ed25519 for signatures) in stateless or key-managed modes, and symmetric encryption via SYM commands. Keys are persisted in the admin database (DB 0) for managed modes.
For details, see [AGE Cryptography](docs/age.md) and [Basics](docs/basics.md).
## Database Management
Databases are managed via JSON-RPC API, with metadata stored in the encrypted admin database (DB 0). Databases are public by default upon creation; use RPC to set them private, requiring access keys for SELECT operations (read or readwrite based on permissions). This includes per-database encryption keys, access control, and lifecycle management.
For examples, see [JSON-RPC Examples](docs/rpc_examples.md) and [Admin DB 0 Model](docs/admin.md).
## Documentation
For more detailed information on commands, features, and advanced usage, please refer to the documentation:
- [Basics](docs/basics.md)
- [Supported Commands](docs/cmds.md)
- [AGE Cryptography](docs/age.md)
- [AGE Cryptography](docs/age.md)
- [Admin DB 0 Model (access control, per-db encryption)](docs/admin.md)
- [JSON-RPC Examples (management API)](docs/rpc_examples.md)

181
docs/admin.md Normal file
View File

@@ -0,0 +1,181 @@
# Admin Database 0 (`0.db`)
This page explains what the Admin Database `DB 0` is, why HeroDB uses it, and how to work with it as a developer and end-user. Its a practical guide covering how databases are created, listed, secured with access keys, and encrypted using per-database secrets.
## What is `DB 0`?
`DB 0` is the control-plane for a HeroDB instance. It stores metadata for all user databases (`db_id >= 1`) so the server can:
- Know which databases exist (without scanning the filesystem)
- Enforce access control (public/private with access keys)
- Enforce per-database encryption (whether a given database must be opened encrypted and with which write-only key)
`DB 0` itself is always encrypted with the admin secret (the process-level secret provided at startup).
## How `DB 0` is created and secured
- `DB 0` lives at `<base_dir>/0.db`
- It is always encrypted using the `admin secret` provided at process startup (using the `--admin-secret <secret>` CLI flag)
- Only clients that provide the correct admin secret can `SELECT 0` (see “`SELECT` + `KEY`” below)
At startup, the server bootstraps `DB 0` (initializes counters and structures) if its missing.
## Metadata stored in `DB 0`
Keys in `DB 0` (internal layout, but useful to understand how things work):
- `admin:next_id`
- String counter holding the next id to allocate (initialized to `"1"`)
- `admin:dbs`
- A hash acting as a set of existing database ids
- field = id (as string), value = `"1"`
- `meta:db:<id>`
- A hash holding db-level metadata
- field `public` = `"true"` or `"false"` (defaults to `true` if missing)
- `meta:db:<id>:keys`
- A hash mapping access-key hashes to the string `Permission:created_at_seconds`
- Examples: `Read:1713456789` or `ReadWrite:1713456789`
- The plaintext access keys are never stored; only their `SHA-256` hashes are kept
- `meta:db:<id>:enc`
- A string holding the per-database encryption key used to open `<id>.db` encrypted
- This value is write-only from the perspective of the management APIs (its set at creation and never returned)
- `age:key:<name>`
- Base64-encoded X25519 recipient (public encryption key) for named AGE keys
- `age:privkey:<name>`
- Base64-encoded X25519 identity (secret encryption key) for named AGE keys
- `age:signpub:<name>`
- Base64-encoded Ed25519 verify public key for named AGE keys
- `age:signpriv:<name>`
- Base64-encoded Ed25519 signing secret key for named AGE keys
> You dont need to manipulate these keys directly; theyre listed to clarify the model. AGE keys are managed via AGE commands.
## Database lifecycle
1) Create a database (via JSON-RPC)
- The server allocates an id from `admin:next_id`, registers it in `admin:dbs`, and defaults the database to `public=true`
- If you pass an optional `encryption_key` during creation, the server persists it in `meta:db:<id>:enc`. That database will be opened in encrypted mode from then on
2) Open and use a database
- Clients select a database over RESP using `SELECT`
- Authorization and encryption state are enforced using `DB 0` metadata
3) Delete database files
- Removing `<id>.db` removes the physical storage
- `DB 0` remains the source of truth for existence and may be updated by future management methods as the system evolves
## Access control model
- Public database (default)
- Anyone can `SELECT <id>` with no key, and will get `ReadWrite` permission
- Private database
- You must provide an access key when selecting the database
- The server hashes the provided key with `SHA-256` and checks membership in `meta:db:<id>:keys`
- Permissions are `Read` or `ReadWrite` depending on how the key was added
- Admin `DB 0`
- Requires the exact admin secret as the `KEY` argument to `SELECT 0`
- Permission is `ReadWrite` when the secret matches
### How to select databases with optional `KEY`
- Public DB (no key required)
- `SELECT <id>`
- Private DB (access key required)
- `SELECT <id> KEY <plaintext_key>`
- Admin `DB 0` (admin secret required)
- `SELECT 0 KEY <admin_secret>`
Examples (using `redis-cli`):
```bash
# Public database
redis-cli -p $PORT SELECT 1
# → OK
# Private database
redis-cli -p $PORT SELECT 2 KEY my-db2-access-key
# → OK
# Admin DB 0
redis-cli -p $PORT SELECT 0 KEY my-admin-secret
# → OK
```
## Per-database encryption
- At database creation, you can provide an optional per-db encryption key
- If provided, the server persists that key in `DB 0` as `meta:db:<id>:enc`
- When you later open the database, the engine checks whether `meta:db:<id>:enc` exists to decide if it must open `<id>.db` in encrypted mode
- The per-db key is not returned by RPC—it is considered write-only configuration data
Operationally:
- Create with encryption: pass a non-null `encryption_key` to the `createDatabase` RPC
- Open later: simply `SELECT` the database; encryption is transparent to clients
## Management via JSON-RPC
You can manage databases using the management RPC (namespaced `herodb.*`). Typical operations:
- `createDatabase(backend, config, encryption_key?)`
- Allocates a new id, sets optional encryption key
- `listDatabases()`
- Lists database ids and info (including whether storage is currently encrypted)
- `getDatabaseInfo(db_id)`
- Returns details: backend, encrypted flag, size on disk, `key_count`, timestamps, etc.
- `addAccessKey(db_id, key, permissions)`
- Adds a `Read` or `ReadWrite` access key (permissions = `"read"` | `"readwrite"`)
- `listAccessKeys(db_id)`
- Returns hashes and permissions; you can use these hashes to delete keys
- `deleteAccessKey(db_id, key_hash)`
- Removes a key by its hash
- `setDatabasePublic(db_id, public)`
- Toggles public/private
Copyable JSON examples are provided in the [RPC examples documentation](./rpc_examples.md).
## Typical flows
1) Public, unencrypted database
- Create a new database without an encryption key
- Clients can immediately `SELECT <id>` without a key
- You can later make it private and add keys if needed
2) Private, encrypted database
- Create passing an `encryption_key`
- Mark it private (`setDatabasePublic false`) and add access keys
- Clients must use `SELECT <id> KEY <plaintext_access_key>`
- Storage opens in encrypted mode automatically
## Security notes
- Only `SHA-256` hashes of access keys are stored in `DB 0`; keep plaintext keys safe on the client side
- The per-db encryption key is never exposed via the API after it is set
- The admin secret must be kept secure; anyone with it can `SELECT 0` and perform administrative actions
## Troubleshooting
- `ERR invalid access key` when selecting a private db
- Ensure you passed the `KEY` argument: `SELECT <id> KEY <plaintext_key>`
- If you recently added the key, confirm the permissions and that you used the exact plaintext (hash must match)
- `Database X not found`
- The id isnt registered in `DB 0` (`admin:dbs`). Use the management APIs to create or list databases
- Cannot `SELECT 0`
- The `KEY` must be the exact admin secret passed at server startup
## Reference
- Admin metadata lives in `DB 0` (`0.db`) and controls:
- Existence: `admin:dbs`
- Access: `meta:db:<id>.public` and `meta:db:<id>:keys`
- Encryption: `meta:db:<id>:enc`
For command examples and management payloads:
- RESP command basics: `docs/basics.md`
- Supported commands: `docs/cmds.md`
- JSON-RPC examples: `docs/rpc_examples.md`

View File

@@ -1,188 +1,96 @@
# HeroDB AGE usage: Stateless vs KeyManaged
# HeroDB AGE Cryptography
This document explains how to use the AGE cryptography commands exposed by HeroDB over the Redis protocol in two modes:
- Stateless (ephemeral keys; nothing stored on the server)
- Keymanaged (serverpersisted, named keys)
HeroDB provides AGE-based asymmetric encryption and digital signatures over the Redis protocol using X25519 for encryption and Ed25519 for signatures. Keys can be used in stateless (ephemeral) or key-managed (persistent, named) modes.
If you are new to the codebase, the exact tests that exercise these behaviors are:
- [rust.test_07_age_stateless_suite()](herodb/tests/usage_suite.rs:495)
- [rust.test_08_age_persistent_named_suite()](herodb/tests/usage_suite.rs:555)
In key-managed mode, HeroDB uses a unified keypair concept: a single Ed25519 signing key is deterministically derived into X25519 keys for encryption, allowing one keypair to handle both encryption and signatures transparently.
Implementation entry points:
- [herodb/src/age.rs](herodb/src/age.rs)
- Dispatch from [herodb/src/cmd.rs](herodb/src/cmd.rs)
## Cryptographic Algorithms
Note: Database-at-rest encryption flags in the test harness are unrelated to AGE commands; those flags control storage-level encryption of DB files. See the harness near [rust.start_test_server()](herodb/tests/usage_suite.rs:10).
### X25519 (Encryption)
- Elliptic-curve Diffie-Hellman key exchange for symmetric key derivation.
- Used for encrypting/decrypting messages.
## Quick start
### Ed25519 (Signatures)
- EdDSA digital signatures for message authentication.
- Used for signing/verifying messages.
Assuming the server is running on localhost on some $PORT:
### Key Derivation
Ed25519 signing keys are deterministically converted to X25519 keys for encryption. This enables a single keypair to support both operations without additional keys. Derivation uses the Ed25519 secret scalar clamped for X25519.
In named keypairs, Ed25519 keys are stored, and X25519 keys are derived on-demand and cached.
## Stateless Mode (Ephemeral Keys)
No server-side storage; keys are provided with each command.
Available commands:
- `AGE GENENC`: Generate ephemeral X25519 keypair. Returns `[recipient, identity]`.
- `AGE GENSIGN`: Generate ephemeral Ed25519 keypair. Returns `[verify_pub, sign_secret]`.
- `AGE ENCRYPT <recipient> <message>`: Encrypt message. Returns base64 ciphertext.
- `AGE DECRYPT <identity> <ciphertext_b64>`: Decrypt ciphertext. Returns plaintext.
- `AGE SIGN <sign_secret> <message>`: Sign message. Returns base64 signature.
- `AGE VERIFY <verify_pub> <message> <signature_b64>`: Verify signature. Returns 1 (valid) or 0 (invalid).
Example:
```bash
~/code/git.ourworld.tf/herocode/herodb/herodb/build.sh
~/code/git.ourworld.tf/herocode/herodb/target/release/herodb --dir /tmp/data --debug --$PORT 6381 --encryption-key 1234 --encrypt
```
redis-cli AGE GENENC
# → 1) "age1qz..." # recipient (X25519 public)
# 2) "AGE-SECRET-KEY-1..." # identity (X25519 secret)
redis-cli AGE ENCRYPT "age1qz..." "hello"
# → base64_ciphertext
```bash
export PORT=6381
# Generate an ephemeral keypair and encrypt/decrypt a message (stateless mode)
redis-cli -p $PORT AGE GENENC
# → returns an array: [recipient, identity]
redis-cli -p $PORT AGE ENCRYPT <recipient> "hello world"
# → returns ciphertext (base64 in a bulk string)
redis-cli -p $PORT AGE DECRYPT <identity> <ciphertext_b64>
# → returns "hello world"
```
For keymanaged mode, generate a named key once and reference it by name afterwards:
```bash
redis-cli -p $PORT AGE KEYGEN app1
# → persists encryption keypair under name "app1"
redis-cli -p $PORT AGE ENCRYPTNAME app1 "hello"
redis-cli -p $PORT AGE DECRYPTNAME app1 <ciphertext_b64>
```
## Stateless AGE (ephemeral)
Characteristics
- No serverside storage of keys.
- You pass the actual key material with every call.
- Not listable via AGE LIST.
Commands and examples
1) Ephemeral encryption keys
```bash
# Generate an ephemeral encryption keypair
redis-cli -p $PORT AGE GENENC
# Example output (abridged):
# 1) "age1qz..." # recipient (public key) = can be used by others e.g. to verify what I sign
# 2) "AGE-SECRET-KEY-1..." # identity (secret) = is like my private, cannot lose this one
# Encrypt with the recipient public key
redis-cli -p $PORT AGE ENCRYPT "age1qz..." "hello world"
# → returns bulk string payload: base64 ciphertext (encrypted content)
# Decrypt with the identity (secret) in other words your private key
redis-cli -p $PORT AGE DECRYPT "AGE-SECRET-KEY-1..." "<ciphertext_b64>"
# → "hello world"
```
2) Ephemeral signing keys
> ? is this same as my private key
```bash
# Generate an ephemeral signing keypair
redis-cli -p $PORT AGE GENSIGN
# Example output:
# 1) "<verify_pub_b64>"
# 2) "<sign_secret_b64>"
# Sign a message with the secret
redis-cli -p $PORT AGE SIGN "<sign_secret_b64>" "msg"
# → returns "<signature_b64>"
# Verify with the public key
redis-cli -p $PORT AGE VERIFY "<verify_pub_b64>" "msg" "<signature_b64>"
# → 1 (valid) or 0 (invalid)
```
When to use
- You do not want the server to store private keys.
- You already manage key material on the client side.
- You need adhoc operations without persistence.
Reference test: [rust.test_07_age_stateless_suite()](herodb/tests/usage_suite.rs:495)
## Keymanaged AGE (persistent, named)
Characteristics
- Server generates and persists keypairs under a chosen name.
- Clients refer to keys by name; raw secrets are not supplied on each call.
- Keys are discoverable via AGE LIST.
Commands and examples
1) Named encryption keys
```bash
# Create/persist a named encryption keypair
redis-cli -p $PORT AGE KEYGEN app1
# → returns [recipient, identity] but also stores them under name "app1"
> TODO: should not return identity (security, but there can be separate function to export it e.g. AGE EXPORTKEY app1)
# Encrypt using the stored public key
redis-cli -p $PORT AGE ENCRYPTNAME app1 "hello"
# → returns bulk string payload: base64 ciphertext
# Decrypt using the stored secret
redis-cli -p $PORT AGE DECRYPTNAME app1 "<ciphertext_b64>"
redis-cli AGE DECRYPT "AGE-SECRET-KEY-1..." base64_ciphertext
# → "hello"
```
2) Named signing keys
## Key-Managed Mode (Persistent Named Keys)
Keys are stored server-side under names. Supports unified keypairs for both encryption and signatures.
Available commands:
- `AGE KEYGEN <name>`: Generate and store unified keypair. Returns `[recipient, identity]` in age format.
- `AGE SIGNKEYGEN <name>`: Generate and store Ed25519 signing keypair. Returns `[verify_pub, sign_secret]`.
- `AGE ENCRYPTNAME <name> <message>`: Encrypt with named key. Returns base64 ciphertext.
- `AGE DECRYPTNAME <name> <ciphertext_b64>`: Decrypt with named key. Returns plaintext.
- `AGE SIGNNAME <name> <message>`: Sign with named key. Returns base64 signature.
- `AGE VERIFYNAME <name> <message> <signature_b64>`: Verify with named key. Returns 1 or 0.
- `AGE LIST`: List all stored key names. Returns sorted array of names.
### AGE LIST Output
Returns a flat, deduplicated, sorted array of key names (strings). Each name corresponds to a stored keypair, which may include encryption keys (X25519), signing keys (Ed25519), or both.
Output format: `["name1", "name2", ...]`
Example:
```bash
# Create/persist a named signing keypair
redis-cli -p $PORT AGE SIGNKEYGEN app1
# → returns [verify_pub_b64, sign_secret_b64] and stores under name "app1"
> TODO: should not return sign_secret_b64 (for security, but there can be separate function to export it e.g. AGE EXPORTSIGNKEY app1)
# Sign using the stored secret
redis-cli -p $PORT AGE SIGNNAME app1 "msg"
# → returns "<signature_b64>"
# Verify using the stored public key
redis-cli -p $PORT AGE VERIFYNAME app1 "msg" "<signature_b64>"
# → 1 (valid) or 0 (invalid)
redis-cli AGE LIST
# → 1) "<named_keypair_1>"
# 2) "<named_keypair_2>"
```
3) List stored AGE keys
For unified keypairs (from `AGE KEYGEN`), the name handles both encryption (derived X25519) and signatures (stored Ed25519) transparently.
Example with named keys:
```bash
redis-cli -p $PORT AGE LIST
# Example output includes labels such as "encpub" and your key names (e.g., "app1")
redis-cli AGE KEYGEN app1
# → 1) "age1..." # recipient
# 2) "AGE-SECRET-KEY-1..." # identity
redis-cli AGE ENCRYPTNAME app1 "secret message"
# → base64_ciphertext
redis-cli AGE DECRYPTNAME app1 base64_ciphertext
# → "secret message"
redis-cli AGE SIGNNAME app1 "message"
# → base64_signature
redis-cli AGE VERIFYNAME app1 "message" base64_signature
# → 1
```
When to use
- You want centralized key storage/rotation and fewer secrets on the client.
- You need names/labels for workflows and can trust the server with secrets.
- You want discoverability (AGE LIST) and simpler client commands.
## Choosing a Mode
- **Stateless**: For ad-hoc operations without persistence; client manages keys.
- **Key-managed**: For centralized key lifecycle; server stores keys for convenience and discoverability.
Reference test: [rust.test_08_age_persistent_named_suite()](herodb/tests/usage_suite.rs:555)
## Choosing a mode
- Prefer Stateless when:
- Minimizing server trust for secret material is the priority.
- Clients already have a secure mechanism to store/distribute keys.
- Prefer Keymanaged when:
- Centralized lifecycle, naming, and discoverability are beneficial.
- You plan to integrate rotation, ACLs, or auditability on the server side.
## Security notes
- Treat identities and signing secrets as sensitive; avoid logging them.
- For keymanaged mode, ensure server storage (and backups) are protected.
- AGE operations here are applicationlevel crypto and are distinct from database-at-rest encryption configured in the test harness.
## Repository pointers
- Stateless examples in tests: [rust.test_07_age_stateless_suite()](herodb/tests/usage_suite.rs:495)
- Keymanaged examples in tests: [rust.test_08_age_persistent_named_suite()](herodb/tests/usage_suite.rs:555)
- AGE implementation: [herodb/src/age.rs](herodb/src/age.rs)
- Command dispatch: [herodb/src/cmd.rs](herodb/src/cmd.rs)
- Bash demo: [herodb/examples/age_bash_demo.sh](herodb/examples/age_bash_demo.sh)
- Rust persistent demo: [herodb/examples/age_persist_demo.rs](herodb/examples/age_persist_demo.rs)
- Additional notes: [herodb/instructions/encrypt.md](herodb/instructions/encrypt.md)
Implementation: [herodb/src/age.rs](herodb/src/age.rs) <br>
Tests: [herodb/tests/usage_suite.rs](herodb/tests/usage_suite.rs)

View File

@@ -1,4 +1,58 @@
Here's an expanded version of the cmds.md documentation to include the list commands:
# HeroDB Basics
## Launching HeroDB
To launch HeroDB, use the binary with required and optional flags. The `--admin-secret` flag is mandatory, encrypting the admin database (DB 0) and authorizing admin access.
### Launch Flags
- `--dir <path>`: Directory for database files (default: current directory).
- `--port <port>`: TCP port for Redis protocol (default: 6379).
- `--debug`: Enable debug logging.
- `--sled`: Use Sled backend (default: Redb).
- `--enable-rpc`: Start JSON-RPC management server on port 8080.
- `--rpc-port <port>`: Custom RPC port (default: 8080).
- `--admin-secret <secret>`: Required secret for DB 0 encryption and admin access.
Example:
```bash
./target/release/herodb --dir /tmp/herodb --admin-secret mysecret --port 6379 --enable-rpc
```
Deprecated flags (`--encrypt`, `--encryption-key`) are ignored for data DBs; per-database encryption is managed via RPC.
## Admin Database (DB 0)
DB 0 acts as the administrative database instance, storing metadata for all user databases (IDs >= 1). It controls existence, access control, and per-database encryption. DB 0 is always encrypted with the `--admin-secret`.
When creating a new database, DB 0 allocates an ID, registers it, and optionally stores a per-database encryption key (write-only). Databases are public by default; use RPC to set them private, requiring access keys for SELECT (read or readwrite based on permissions). Keys are persisted in DB 0 for managed AGE operations.
Access DB 0 with `SELECT 0 KEY <admin-secret>`.
## Symmetric Encryption
HeroDB supports stateless symmetric encryption via SYM commands, using XChaCha20-Poly1305 AEAD.
Commands:
- `SYM KEYGEN`: Generate 32-byte key. Returns base64-encoded key.
- `SYM ENCRYPT <key_b64> <message>`: Encrypt message. Returns base64 ciphertext.
- `SYM DECRYPT <key_b64> <ciphertext_b64>`: Decrypt. Returns plaintext.
Example:
```bash
redis-cli SYM KEYGEN
# → base64_key
redis-cli SYM ENCRYPT base64_key "secret"
# → base64_ciphertext
redis-cli SYM DECRYPT base64_key base64_ciphertext
# → "secret"
```
## RPC Options
Enable the JSON-RPC server with `--enable-rpc` for database management. Methods include creating databases, managing access keys, and setting encryption. See [JSON-RPC Examples](./rpc_examples.md) for payloads.
# HeroDB Commands
HeroDB implements a subset of Redis commands over the Redis protocol. This document describes the available commands and their usage.
@@ -575,6 +629,29 @@ redis-cli -p $PORT AGE LIST
# 2) "keyname2"
```
## SYM Commands
### SYM KEYGEN
Generate a symmetric encryption key.
```bash
redis-cli -p $PORT SYM KEYGEN
# → base64_encoded_32byte_key
```
### SYM ENCRYPT
Encrypt a message with a symmetric key.
```bash
redis-cli -p $PORT SYM ENCRYPT <key_b64> "message"
# → base64_encoded_ciphertext
```
### SYM DECRYPT
Decrypt a ciphertext with a symmetric key.
```bash
redis-cli -p $PORT SYM DECRYPT <key_b64> <ciphertext_b64>
# → decrypted_message
```
## Server Information Commands
### INFO
@@ -621,3 +698,27 @@ This expanded documentation includes all the list commands that were implemented
10. LINDEX - get element by index
11. LRANGE - get range of elements
## Updated Database Selection and Access Keys
HeroDB uses an `Admin DB 0` to control database existence, access, and encryption. Access to data DBs can be public (no key) or private (requires a key). See detailed model in `docs/admin.md`.
Examples:
```bash
# Public database (no key required)
redis-cli -p $PORT SELECT 1
# → OK
```
```bash
# Private database (requires access key)
redis-cli -p $PORT SELECT 2 KEY my-db2-access-key
# → OK
```
```bash
# Admin DB 0 (requires admin secret)
redis-cli -p $PORT SELECT 0 KEY my-admin-secret
# → OK
```

View File

@@ -122,4 +122,27 @@ redis-cli -p 6379 --rdb dump.rdb
# Import to sled
redis-cli -p 6381 --pipe < dump.rdb
```
## Authentication and Database Selection
HeroDB uses an `Admin DB 0` to govern database existence, access and per-db encryption. Access control is enforced via `Admin DB 0` metadata. See the full model in `docs/admin.md`.
Examples:
```bash
# Public database (no key required)
redis-cli -p $PORT SELECT 1
# → OK
```
```bash
# Private database (requires access key)
redis-cli -p $PORT SELECT 2 KEY my-db2-access-key
# → OK
```
```bash
# Admin DB 0 (requires admin secret)
redis-cli -p $PORT SELECT 0 KEY my-admin-secret
# → OK
```

141
docs/rpc_examples.md Normal file
View File

@@ -0,0 +1,141 @@
# HeroDB JSON-RPC Examples
These examples show full JSON-RPC 2.0 payloads for managing HeroDB via the RPC API (enable with `--enable-rpc`). Methods are named as `hero_<function>`. Params are positional arrays; enum values are strings (e.g., `"Redb"`). Copy-paste into Postman or similar clients.
## Database Management
### Create Database
Creates a new database with optional per-database encryption key (stored write-only in Admin DB 0).
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "hero_createDatabase",
"params": [
"Redb",
{ "name": null, "storage_path": null, "max_size": null, "redis_version": null },
null
]
}
```
With encryption:
```json
{
"jsonrpc": "2.0",
"id": 2,
"method": "hero_createDatabase",
"params": [
"Sled",
{ "name": "secure-db", "storage_path": null, "max_size": null, "redis_version": null },
"my-per-db-encryption-key"
]
}
```
### List Databases
Returns array of database infos (id, backend, encrypted status, size, etc.).
```json
{
"jsonrpc": "2.0",
"id": 3,
"method": "hero_listDatabases",
"params": []
}
```
### Get Database Info
Retrieves detailed info for a specific database.
```json
{
"jsonrpc": "2.0",
"id": 4,
"method": "hero_getDatabaseInfo",
"params": [1]
}
```
### Delete Database
Removes physical database file; metadata remains in Admin DB 0.
```json
{
"jsonrpc": "2.0",
"id": 5,
"method": "hero_deleteDatabase",
"params": [1]
}
```
## Access Control
### Add Access Key
Adds a hashed access key for private databases. Permissions: `"read"` or `"readwrite"`.
```json
{
"jsonrpc": "2.0",
"id": 6,
"method": "hero_addAccessKey",
"params": [2, "my-access-key", "readwrite"]
}
```
### List Access Keys
Returns array of key hashes, permissions, and creation timestamps.
```json
{
"jsonrpc": "2.0",
"id": 7,
"method": "hero_listAccessKeys",
"params": [2]
}
```
### Delete Access Key
Removes key by its SHA-256 hash.
```json
{
"jsonrpc": "2.0",
"id": 8,
"method": "hero_deleteAccessKey",
"params": [2, "0123abcd...keyhash..."]
}
```
### Set Database Public/Private
Toggles public access (default true). Private databases require access keys.
```json
{
"jsonrpc": "2.0",
"id": 9,
"method": "hero_setDatabasePublic",
"params": [2, false]
}
```
## Server Info
### Get Server Stats
Returns stats like total databases and uptime.
```json
{
"jsonrpc": "2.0",
"id": 10,
"method": "hero_getServerStats",
"params": []
}
```
## Notes
- Per-database encryption keys are write-only; set at creation and used transparently.
- Access keys are hashed (SHA-256) for storage; provide plaintext in requests.
- Backend options: `"Redb"` (default) or `"Sled"`.
- Config object fields (name, storage_path, etc.) are optional and currently ignored but positional.

View File

@@ -14,31 +14,25 @@ fn read_reply(s: &mut TcpStream) -> String {
let n = s.read(&mut buf).unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}
fn parse_two_bulk(reply: &str) -> Option<(String, String)> {
fn parse_two_bulk(reply: &str) -> Option<(String,String)> {
let mut lines = reply.split("\r\n");
if lines.next()? != "*2" {
return None;
}
if lines.next()? != "*2" { return None; }
let _n = lines.next()?;
let a = lines.next()?.to_string();
let _m = lines.next()?;
let b = lines.next()?.to_string();
Some((a, b))
Some((a,b))
}
fn parse_bulk(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('$') {
return None;
}
if !hdr.starts_with('$') { return None; }
Some(lines.next()?.to_string())
}
fn parse_simple(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n");
let hdr = lines.next()?;
if !hdr.starts_with('+') {
return None;
}
if !hdr.starts_with('+') { return None; }
Some(hdr[1..].to_string())
}
@@ -51,45 +45,39 @@ fn main() {
let mut s = TcpStream::connect(addr).expect("connect");
// Generate & persist X25519 enc keys under name "alice"
s.write_all(arr(&["age", "keygen", "alice"]).as_bytes())
.unwrap();
s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap();
let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc");
// Generate & persist Ed25519 signing key under name "signer"
s.write_all(arr(&["age", "signkeygen", "signer"]).as_bytes())
.unwrap();
s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap();
let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign");
// Encrypt by name
let msg = "hello from persistent keys";
s.write_all(arr(&["age", "encryptname", "alice", msg]).as_bytes())
.unwrap();
s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap();
let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64");
println!("ciphertext b64: {}", ct_b64);
// Decrypt by name
s.write_all(arr(&["age", "decryptname", "alice", &ct_b64]).as_bytes())
.unwrap();
s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap();
let pt = parse_bulk(&read_reply(&mut s)).expect("pt");
assert_eq!(pt, msg);
println!("decrypted ok");
// Sign by name
s.write_all(arr(&["age", "signname", "signer", msg]).as_bytes())
.unwrap();
s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap();
let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64");
// Verify by name
s.write_all(arr(&["age", "verifyname", "signer", msg, &sig_b64]).as_bytes())
.unwrap();
s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap();
let ok = parse_simple(&read_reply(&mut s)).expect("verify");
assert_eq!(ok, "1");
println!("signature verified");
// List names
s.write_all(arr(&["age", "list"]).as_bytes()).unwrap();
s.write_all(arr(&["age","list"]).as_bytes()).unwrap();
let list = read_reply(&mut s);
println!("LIST -> {list}");
println!("✔ persistent AGE workflow complete.");
}
}

View File

@@ -1,239 +0,0 @@
#!/bin/bash
# HeroDB Tantivy Search Demo
# This script demonstrates full-text search capabilities using Redis commands
# HeroDB server should be running on port 6381
set -e # Exit on any error
# Configuration
REDIS_HOST="localhost"
REDIS_PORT="6382"
REDIS_CLI="redis-cli -h $REDIS_HOST -p $REDIS_PORT"
# Start the herodb server in the background
echo "Starting herodb server..."
cargo run -p herodb -- --dir /tmp/herodbtest --port ${REDIS_PORT} --debug &
SERVER_PID=$!
echo
sleep 2 # Give the server a moment to start
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
BLUE='\033[0;34m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Function to print colored output
print_header() {
echo -e "${BLUE}=== $1 ===${NC}"
}
print_success() {
echo -e "${GREEN}$1${NC}"
}
print_info() {
echo -e "${YELLOW} $1${NC}"
}
print_error() {
echo -e "${RED}$1${NC}"
}
# Function to check if HeroDB is running
check_herodb() {
print_info "Checking if HeroDB is running on port $REDIS_PORT..."
if ! $REDIS_CLI ping > /dev/null 2>&1; then
print_error "HeroDB is not running on port $REDIS_PORT"
print_info "Please start HeroDB with: cargo run -- --port $REDIS_PORT"
exit 1
fi
print_success "HeroDB is running and responding"
}
# Function to execute Redis command with error handling
execute_cmd() {
local cmd="$1"
local description="$2"
echo -e "${YELLOW}Command:${NC} $cmd"
if result=$($REDIS_CLI $cmd 2>&1); then
echo -e "${GREEN}Result:${NC} $result"
return 0
else
print_error "Failed: $description"
echo "Error: $result"
return 1
fi
}
# Function to pause for readability
pause() {
echo
read -p "Press Enter to continue..."
echo
}
# Main demo function
main() {
clear
print_header "HeroDB Tantivy Search Demonstration"
echo "This demo shows full-text search capabilities using Redis commands"
echo "HeroDB runs on port $REDIS_PORT (instead of Redis default 6379)"
echo
# Check if HeroDB is running
check_herodb
echo
print_header "Step 1: Create Search Index"
print_info "Creating a product catalog search index with various field types"
# Create search index with schema
execute_cmd "FT.CREATE product_catalog SCHEMA title TEXT description TEXT category TAG price NUMERIC rating NUMERIC location GEO" \
"Creating search index"
print_success "Search index 'product_catalog' created successfully"
pause
print_header "Step 2: Add Sample Products"
print_info "Adding sample products to demonstrate different search scenarios"
# Add sample products using FT.ADD
execute_cmd "FT.ADD product_catalog product:1 1.0 title 'Wireless Bluetooth Headphones' description 'Premium noise-canceling headphones with 30-hour battery life' category 'electronics,audio' price 299.99 rating 4.5 location '-122.4194,37.7749'" "Adding product 1"
execute_cmd "FT.ADD product_catalog product:2 1.0 title 'Organic Coffee Beans' description 'Single-origin Ethiopian coffee beans, medium roast' category 'food,beverages,organic' price 24.99 rating 4.8 location '-74.0060,40.7128'" "Adding product 2"
execute_cmd "FT.ADD product_catalog product:3 1.0 title 'Yoga Mat Premium' description 'Eco-friendly yoga mat with superior grip and cushioning' category 'fitness,wellness,eco-friendly' price 89.99 rating 4.3 location '-118.2437,34.0522'" "Adding product 3"
execute_cmd "FT.ADD product_catalog product:4 1.0 title 'Smart Home Speaker' description 'Voice-controlled smart speaker with AI assistant' category 'electronics,smart-home' price 149.99 rating 4.2 location '-87.6298,41.8781'" "Adding product 4"
execute_cmd "FT.ADD product_catalog product:5 1.0 title 'Organic Green Tea' description 'Premium organic green tea leaves from Japan' category 'food,beverages,organic,tea' price 18.99 rating 4.7 location '139.6503,35.6762'" "Adding product 5"
execute_cmd "FT.ADD product_catalog product:6 1.0 title 'Wireless Gaming Mouse' description 'High-precision gaming mouse with RGB lighting' category 'electronics,gaming' price 79.99 rating 4.4 location '-122.3321,47.6062'" "Adding product 6"
execute_cmd "FT.ADD product_catalog product:7 1.0 title 'Comfortable meditation cushion for mindfulness practice' description 'Meditation cushion with premium materials' category 'wellness,meditation' price 45.99 rating 4.6 location '-122.4194,37.7749'" "Adding product 7"
execute_cmd "FT.ADD product_catalog product:8 1.0 title 'Bluetooth Earbuds' description 'True wireless earbuds with active noise cancellation' category 'electronics,audio' price 199.99 rating 4.1 location '-74.0060,40.7128'" "Adding product 8"
print_success "Added 8 products to the index"
pause
print_header "Step 3: Basic Text Search"
print_info "Searching for 'wireless' products"
execute_cmd "FT.SEARCH product_catalog wireless" "Basic text search"
pause
print_header "Step 4: Search with Filters"
print_info "Searching for 'organic' products"
execute_cmd "FT.SEARCH product_catalog organic" "Filtered search"
pause
print_header "Step 5: Numeric Range Search"
print_info "Searching for 'premium' products"
execute_cmd "FT.SEARCH product_catalog premium" "Text search"
pause
print_header "Step 6: Sorting Results"
print_info "Searching for electronics"
execute_cmd "FT.SEARCH product_catalog electronics" "Category search"
pause
print_header "Step 7: Limiting Results"
print_info "Searching for wireless products with limit"
execute_cmd "FT.SEARCH product_catalog wireless LIMIT 0 3" "Limited results"
pause
print_header "Step 8: Complex Query"
print_info "Finding audio products with noise cancellation"
execute_cmd "FT.SEARCH product_catalog 'noise cancellation'" "Complex query"
pause
print_header "Step 9: Geographic Search"
print_info "Searching for meditation products"
execute_cmd "FT.SEARCH product_catalog meditation" "Text search"
pause
print_header "Step 10: Aggregation Example"
print_info "Getting index information and statistics"
execute_cmd "FT.INFO product_catalog" "Index information"
pause
print_header "Step 11: Search Comparison"
print_info "Comparing Tantivy search vs simple key matching"
echo -e "${YELLOW}Tantivy Full-Text Search:${NC}"
execute_cmd "FT.SEARCH product_catalog 'battery life'" "Full-text search for 'battery life'"
echo
echo -e "${YELLOW}Simple Key Pattern Matching:${NC}"
execute_cmd "KEYS *battery*" "Simple pattern matching for 'battery'"
print_info "Notice how full-text search finds relevant results even when exact words don't match keys"
pause
print_header "Step 12: Fuzzy Search"
print_info "Searching for headphones"
execute_cmd "FT.SEARCH product_catalog headphones" "Text search"
pause
print_header "Step 13: Phrase Search"
print_info "Searching for coffee products"
execute_cmd "FT.SEARCH product_catalog coffee" "Text search"
pause
print_header "Step 14: Boolean Queries"
print_info "Searching for gaming products"
execute_cmd "FT.SEARCH product_catalog gaming" "Text search"
echo
execute_cmd "FT.SEARCH product_catalog tea" "Text search"
pause
print_header "Step 15: Cleanup"
print_info "Removing test data"
# Delete the search index
execute_cmd "FT.DROP product_catalog" "Dropping search index"
# Clean up documents from search index
for i in {1..8}; do
execute_cmd "FT.DEL product_catalog product:$i" "Deleting product:$i from index"
done
print_success "Cleanup completed"
echo
print_header "Demo Summary"
echo "This demonstration showed:"
echo "• Creating search indexes with different field types"
echo "• Adding documents to the search index"
echo "• Basic and advanced text search queries"
echo "• Filtering by categories and numeric ranges"
echo "• Sorting and limiting results"
echo "• Geographic searches"
echo "• Fuzzy matching and phrase searches"
echo "• Boolean query operators"
echo "• Comparison with simple pattern matching"
echo
print_success "HeroDB Tantivy search demo completed successfully!"
echo
print_info "Key advantages of Tantivy full-text search:"
echo " - Relevance scoring and ranking"
echo " - Fuzzy matching and typo tolerance"
echo " - Complex boolean queries"
echo " - Field-specific searches and filters"
echo " - Geographic and numeric range queries"
echo " - Much faster than pattern matching on large datasets"
echo
print_info "To run HeroDB server: cargo run -- --port 6381"
print_info "To connect with redis-cli: redis-cli -h localhost -p 6381"
}
# Run the demo
main "$@"

View File

@@ -1,101 +0,0 @@
#!/bin/bash
# Simple Tantivy Search Integration Test for HeroDB
# This script tests the full-text search functionality we just integrated
set -e
echo "🔍 Testing Tantivy Search Integration..."
# Build the project first
echo "📦 Building HeroDB..."
cargo build --release
# Start the server in the background
echo "🚀 Starting HeroDB server on port 6379..."
cargo run --release -- --port 6379 --dir ./test_data &
SERVER_PID=$!
# Wait for server to start
sleep 3
# Function to cleanup on exit
cleanup() {
echo "🧹 Cleaning up..."
kill $SERVER_PID 2>/dev/null || true
rm -rf ./test_data
exit
}
# Set trap for cleanup
trap cleanup EXIT INT TERM
# Function to execute Redis command
execute_cmd() {
local cmd="$1"
local description="$2"
echo "📝 $description"
echo " Command: $cmd"
if result=$(redis-cli -p 6379 $cmd 2>&1); then
echo " ✅ Result: $result"
echo
return 0
else
echo " ❌ Failed: $result"
echo
return 1
fi
}
echo "🧪 Running Tantivy Search Tests..."
echo
# Test 1: Create a search index
execute_cmd "ft.create books SCHEMA title TEXT description TEXT author TEXT category TAG price NUMERIC" \
"Creating search index 'books'"
# Test 2: Add documents to the index
execute_cmd "ft.add books book1 1.0 title \"The Great Gatsby\" description \"A classic American novel about the Jazz Age\" author \"F. Scott Fitzgerald\" category \"fiction,classic\" price \"12.99\"" \
"Adding first book"
execute_cmd "ft.add books book2 1.0 title \"To Kill a Mockingbird\" description \"A novel about racial injustice in the American South\" author \"Harper Lee\" category \"fiction,classic\" price \"14.99\"" \
"Adding second book"
execute_cmd "ft.add books book3 1.0 title \"Programming Rust\" description \"A comprehensive guide to Rust programming language\" author \"Jim Blandy\" category \"programming,technical\" price \"49.99\"" \
"Adding third book"
execute_cmd "ft.add books book4 1.0 title \"The Rust Programming Language\" description \"The official book on Rust programming\" author \"Steve Klabnik\" category \"programming,technical\" price \"39.99\"" \
"Adding fourth book"
# Test 3: Basic search
execute_cmd "ft.search books Rust" \
"Searching for 'Rust'"
# Test 4: Search with filters
execute_cmd "ft.search books programming FILTER category programming" \
"Searching for 'programming' with category filter"
# Test 5: Search with limit
execute_cmd "ft.search books \"*\" LIMIT 0 2" \
"Getting first 2 documents"
# Test 6: Get index info
execute_cmd "ft.info books" \
"Getting index information"
# Test 7: Delete a document
execute_cmd "ft.del books book1" \
"Deleting book1"
# Test 8: Search again to verify deletion
execute_cmd "ft.search books Gatsby" \
"Searching for deleted book"
# Test 9: Drop the index
execute_cmd "ft.drop books" \
"Dropping the index"
echo "🎉 All tests completed successfully!"
echo "✅ Tantivy search integration is working correctly"

143
run.sh Executable file
View File

@@ -0,0 +1,143 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
# Test script for HeroDB - Redis-compatible database with redb backend
# This script starts the server and runs comprehensive tests
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Configuration
DB_DIR="/tmp/test_db"
PORT=6381
SERVER_PID=""
# Function to print colored output
print_status() {
echo -e "${BLUE}[INFO]${NC} $1"
}
print_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
print_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
# Function to cleanup on exit
cleanup() {
if [ ! -z "$SERVER_PID" ]; then
print_status "Stopping HeroDB server (PID: $SERVER_PID)..."
kill $SERVER_PID 2>/dev/null || true
wait $SERVER_PID 2>/dev/null || true
fi
# Clean up test database
if [ -d "$DB_DIR" ]; then
print_status "Cleaning up test database directory..."
rm -rf "$DB_DIR"
fi
}
# Set trap to cleanup on script exit
trap cleanup EXIT
# Function to wait for server to start
wait_for_server() {
local max_attempts=30
local attempt=1
print_status "Waiting for server to start on port $PORT..."
while [ $attempt -le $max_attempts ]; do
if nc -z localhost $PORT 2>/dev/null; then
print_success "Server is ready!"
return 0
fi
echo -n "."
sleep 1
attempt=$((attempt + 1))
done
print_error "Server failed to start within $max_attempts seconds"
return 1
}
# Function to send Redis command and get response
redis_cmd() {
local cmd="$1"
local expected="$2"
print_status "Testing: $cmd"
local result=$(echo "$cmd" | redis-cli -p $PORT --raw 2>/dev/null || echo "ERROR")
if [ "$expected" != "" ] && [ "$result" != "$expected" ]; then
print_error "Expected: '$expected', Got: '$result'"
return 1
else
print_success "$cmd -> $result"
return 0
fi
}
# Main execution
main() {
print_status "Starting HeroDB"
# Build the project
print_status "Building HeroDB..."
if ! cargo build -p herodb --release; then
print_error "Failed to build HeroDB"
exit 1
fi
# Create test database directory
mkdir -p "$DB_DIR"
# Start the server
print_status "Starting HeroDB server..."
${SCRIPT_DIR}/target/release/herodb --dir "$DB_DIR" --port $PORT &
SERVER_PID=$!
# Wait for server to start
if ! wait_for_server; then
print_error "Failed to start server"
exit 1
fi
}
# Check dependencies
check_dependencies() {
if ! command -v cargo &> /dev/null; then
print_error "cargo is required but not installed"
exit 1
fi
if ! command -v nc &> /dev/null; then
print_warning "netcat (nc) not found - some tests may not work properly"
fi
if ! command -v redis-cli &> /dev/null; then
print_warning "redis-cli not found - using netcat fallback"
fi
}
# Run dependency check and main function
check_dependencies
main "$@"
tail -f /dev/null

View File

@@ -1,4 +1,7 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
echo "🧪 Running HeroDB Redis Compatibility Tests"
echo "=========================================="

481
src/admin_meta.rs Normal file
View File

@@ -0,0 +1,481 @@
use std::path::PathBuf;
use std::sync::{Arc, OnceLock, Mutex, RwLock};
use std::collections::HashMap;
use crate::error::DBError;
use crate::options;
use crate::rpc::Permissions;
use crate::storage::Storage;
use crate::storage_sled::SledStorage;
use crate::storage_trait::StorageBackend;
// Key builders
fn k_admin_next_id() -> &'static str {
"admin:next_id"
}
fn k_admin_dbs() -> &'static str {
"admin:dbs"
}
fn k_meta_db(id: u64) -> String {
format!("meta:db:{}", id)
}
fn k_meta_db_keys(id: u64) -> String {
format!("meta:db:{}:keys", id)
}
fn k_meta_db_enc(id: u64) -> String {
format!("meta:db:{}:enc", id)
}
// Global cache of admin DB 0 handles per base_dir to avoid sled/reDB file-lock contention
// and to correctly isolate different test instances with distinct directories.
static ADMIN_STORAGES: OnceLock<RwLock<HashMap<String, Arc<dyn StorageBackend>>>> = OnceLock::new();
// Global registry for data DB storages to avoid double-open across process.
static DATA_STORAGES: OnceLock<RwLock<HashMap<u64, Arc<dyn StorageBackend>>>> = OnceLock::new();
static DATA_INIT_LOCK: Mutex<()> = Mutex::new(());
fn init_admin_storage(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<Arc<dyn StorageBackend>, DBError> {
let db_file = PathBuf::from(base_dir).join("0.db");
if let Some(parent_dir) = db_file.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
let storage: Arc<dyn StorageBackend> = match backend {
options::BackendType::Redb => Arc::new(Storage::new(&db_file, true, Some(admin_secret))?),
options::BackendType::Sled => Arc::new(SledStorage::new(&db_file, true, Some(admin_secret))?),
};
Ok(storage)
}
// Get or initialize a cached handle to admin DB 0 per base_dir (thread-safe, no double-open race)
pub fn open_admin_storage(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<Arc<dyn StorageBackend>, DBError> {
let map = ADMIN_STORAGES.get_or_init(|| RwLock::new(HashMap::new()));
// Fast path
if let Some(st) = map.read().unwrap().get(base_dir) {
return Ok(st.clone());
}
// Slow path with write lock
{
let mut w = map.write().unwrap();
if let Some(st) = w.get(base_dir) {
return Ok(st.clone());
}
// Detect existing 0.db backend by filesystem, if present.
let admin_path = PathBuf::from(base_dir).join("0.db");
let detected = if admin_path.exists() {
if admin_path.is_file() {
Some(options::BackendType::Redb)
} else if admin_path.is_dir() {
Some(options::BackendType::Sled)
} else {
None
}
} else {
None
};
let effective_backend = match detected {
Some(d) if d != backend => {
eprintln!(
"warning: Admin DB 0 at {} appears to be {:?}, but process default is {:?}. Using detected backend.",
admin_path.display(),
d,
backend
);
d
}
Some(d) => d,
None => backend, // First boot: use requested backend to initialize 0.db
};
let st = init_admin_storage(base_dir, effective_backend, admin_secret)?;
w.insert(base_dir.to_string(), st.clone());
Ok(st)
}
}
// Ensure admin structures exist in encrypted DB 0
pub fn ensure_bootstrap(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
// Initialize next id if missing
if !admin.exists(k_admin_next_id())? {
admin.set(k_admin_next_id().to_string(), "1".to_string())?;
}
// admin:dbs is a hash; it's fine if it doesn't exist (hlen -> 0)
Ok(())
}
// Get or initialize a shared handle to a data DB (> 0), avoiding double-open across subsystems
pub fn open_data_storage(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Arc<dyn StorageBackend>, DBError> {
if id == 0 {
return open_admin_storage(base_dir, backend, admin_secret);
}
// Validate existence in admin metadata
if !db_exists(base_dir, backend.clone(), admin_secret, id)? {
return Err(DBError(format!(
"Cannot open database instance {}, as that database instance does not exist.",
id
)));
}
let map = DATA_STORAGES.get_or_init(|| RwLock::new(HashMap::new()));
// Fast path
if let Some(st) = map.read().unwrap().get(&id) {
return Ok(st.clone());
}
// Slow path with init lock
let _guard = DATA_INIT_LOCK.lock().unwrap();
if let Some(st) = map.read().unwrap().get(&id) {
return Ok(st.clone());
}
// Resolve effective backend for this db id:
// 1) Try admin meta "backend" field
// 2) If missing, sniff filesystem (file => Redb, dir => Sled), then persist into admin meta
// 3) Fallback to requested 'backend' (startup default) if nothing else is known
let meta_backend = get_database_backend(base_dir, backend.clone(), admin_secret, id).ok().flatten();
let db_path = PathBuf::from(base_dir).join(format!("{}.db", id));
let sniffed_backend = if db_path.exists() {
if db_path.is_file() {
Some(options::BackendType::Redb)
} else if db_path.is_dir() {
Some(options::BackendType::Sled)
} else {
None
}
} else {
None
};
let effective_backend = meta_backend.clone().or(sniffed_backend).unwrap_or(backend.clone());
// If we had to sniff (i.e., meta missing), persist it for future robustness
if meta_backend.is_none() {
let _ = set_database_backend(base_dir, backend.clone(), admin_secret, id, effective_backend.clone());
}
// Warn if caller-provided backend differs from effective
if effective_backend != backend {
eprintln!(
"notice: Database {} backend resolved to {:?} (caller requested {:?}). Using resolved backend.",
id, effective_backend, backend
);
}
// Determine per-db encryption (from admin meta)
let enc = get_enc_key(base_dir, backend.clone(), admin_secret, id)?;
let should_encrypt = enc.is_some();
// Build database file path and ensure parent dir exists
let db_file = PathBuf::from(base_dir).join(format!("{}.db", id));
if let Some(parent_dir) = db_file.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
})?;
}
// Open storage using the effective backend
let storage: Arc<dyn StorageBackend> = match effective_backend {
options::BackendType::Redb => Arc::new(Storage::new(&db_file, should_encrypt, enc.as_deref())?),
options::BackendType::Sled => Arc::new(SledStorage::new(&db_file, should_encrypt, enc.as_deref())?),
};
// Publish to registry
map.write().unwrap().insert(id, storage.clone());
Ok(storage)
}
// Allocate the next DB id and persist new pointer
pub fn allocate_next_id(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<u64, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let cur = admin
.get(k_admin_next_id())?
.unwrap_or_else(|| "1".to_string());
let id: u64 = cur.parse().unwrap_or(1);
let next = id.checked_add(1).ok_or_else(|| DBError("next_id overflow".into()))?;
admin.set(k_admin_next_id().to_string(), next.to_string())?;
// Register into admin:dbs set/hash
let _ = admin.hset(k_admin_dbs(), vec![(id.to_string(), "1".to_string())])?;
// Default meta for the new db: public true
let meta_key = k_meta_db(id);
let _ = admin.hset(&meta_key, vec![("public".to_string(), "true".to_string())])?;
Ok(id)
}
// Check existence of a db id in admin:dbs
pub fn db_exists(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<bool, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
Ok(admin.hexists(k_admin_dbs(), &id.to_string())?)
}
// Get per-db encryption key, if any
pub fn get_enc_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Option<String>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
admin.get(&k_meta_db_enc(id))
}
// Set per-db encryption key (called during create)
pub fn set_enc_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key: &str,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
admin.set(k_meta_db_enc(id), key.to_string())
}
// Set database public flag
pub fn set_database_public(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
public: bool,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let mk = k_meta_db(id);
let _ = admin.hset(&mk, vec![("public".to_string(), public.to_string())])?;
Ok(())
}
// Persist per-db backend type in admin metadata (module-scope)
pub fn set_database_backend(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
db_backend: options::BackendType,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let mk = k_meta_db(id);
let val = match db_backend {
options::BackendType::Redb => "Redb",
options::BackendType::Sled => "Sled",
};
let _ = admin.hset(&mk, vec![("backend".to_string(), val.to_string())])?;
Ok(())
}
pub fn get_database_backend(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Option<options::BackendType>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let mk = k_meta_db(id);
match admin.hget(&mk, "backend")? {
Some(s) if s == "Redb" => Ok(Some(options::BackendType::Redb)),
Some(s) if s == "Sled" => Ok(Some(options::BackendType::Sled)),
_ => Ok(None),
}
}
// Set database name
pub fn set_database_name(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
name: &str,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let mk = k_meta_db(id);
let _ = admin.hset(&mk, vec![("name".to_string(), name.to_string())])?;
Ok(())
}
// Get database name
pub fn get_database_name(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Option<String>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let mk = k_meta_db(id);
admin.hget(&mk, "name")
}
// Internal: load public flag; default to true when meta missing
fn load_public(
admin: &Arc<dyn StorageBackend>,
id: u64,
) -> Result<bool, DBError> {
let mk = k_meta_db(id);
match admin.hget(&mk, "public")? {
Some(v) => Ok(v == "true"),
None => Ok(true),
}
}
// Add access key for db (value format: "Read:ts" or "ReadWrite:ts")
pub fn add_access_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key_plain: &str,
perms: Permissions,
) -> Result<(), DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let hash = crate::rpc::hash_key(key_plain);
let v = match perms {
Permissions::Read => format!("Read:{}", now_secs()),
Permissions::ReadWrite => format!("ReadWrite:{}", now_secs()),
};
let _ = admin.hset(&k_meta_db_keys(id), vec![(hash, v)])?;
Ok(())
}
// Delete access key by hash
pub fn delete_access_key(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key_hash: &str,
) -> Result<bool, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let n = admin.hdel(&k_meta_db_keys(id), vec![key_hash.to_string()])?;
Ok(n > 0)
}
// List access keys, returning (hash, perms, created_at_secs)
pub fn list_access_keys(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
) -> Result<Vec<(String, Permissions, u64)>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let pairs = admin.hgetall(&k_meta_db_keys(id))?;
let mut out = Vec::new();
for (hash, val) in pairs {
let (perm, ts) = parse_perm_value(&val);
out.push((hash, perm, ts));
}
Ok(out)
}
// Verify access permission for db id with optional key
// Returns:
// - Ok(Some(Permissions)) when access is allowed
// - Ok(None) when not allowed or db missing (caller can distinguish by calling db_exists)
pub fn verify_access(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
id: u64,
key_opt: Option<&str>,
) -> Result<Option<Permissions>, DBError> {
// Admin DB 0: require exact admin_secret
if id == 0 {
if let Some(k) = key_opt {
if k == admin_secret {
return Ok(Some(Permissions::ReadWrite));
}
}
return Ok(None);
}
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
if !admin.hexists(k_admin_dbs(), &id.to_string())? {
return Ok(None);
}
// Public?
if load_public(&admin, id)? {
return Ok(Some(Permissions::ReadWrite));
}
// Private: require key and verify
if let Some(k) = key_opt {
let hash = crate::rpc::hash_key(k);
if let Some(v) = admin.hget(&k_meta_db_keys(id), &hash)? {
let (perm, _ts) = parse_perm_value(&v);
return Ok(Some(perm));
}
}
Ok(None)
}
// Enumerate all db ids
pub fn list_dbs(
base_dir: &str,
backend: options::BackendType,
admin_secret: &str,
) -> Result<Vec<u64>, DBError> {
let admin = open_admin_storage(base_dir, backend, admin_secret)?;
let ids = admin.hkeys(k_admin_dbs())?;
let mut out = Vec::new();
for s in ids {
if let Ok(v) = s.parse() {
out.push(v);
}
}
Ok(out)
}
// Helper: parse permission value "Read:ts" or "ReadWrite:ts"
fn parse_perm_value(v: &str) -> (Permissions, u64) {
let mut parts = v.split(':');
let p = parts.next().unwrap_or("Read");
let ts = parts
.next()
.and_then(|s| s.parse().ok())
.unwrap_or(0u64);
let perm = match p {
"ReadWrite" => Permissions::ReadWrite,
_ => Permissions::Read,
};
(perm, ts)
}
fn now_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}

View File

@@ -12,17 +12,19 @@
use std::str::FromStr;
use age::x25519;
use age::{Decryptor, Encryptor};
use secrecy::ExposeSecret;
use age::{Decryptor, Encryptor};
use age::x25519;
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use ed25519_dalek::{Signature, Signer, Verifier, SigningKey, VerifyingKey};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use std::collections::HashSet;
use std::convert::TryInto;
use crate::error::DBError;
use crate::protocol::Protocol;
use crate::server::Server;
use crate::error::DBError;
// ---------- Internal helpers ----------
@@ -32,7 +34,7 @@ pub enum AgeWireError {
Crypto(String),
Utf8,
SignatureLen,
NotFound(&'static str), // which kind of key was missing
NotFound(&'static str), // which kind of key was missing
Storage(String),
}
@@ -74,6 +76,125 @@ fn parse_ed25519_verifying_key(s: &str) -> Result<VerifyingKey, AgeWireError> {
VerifyingKey::from_bytes(&key_bytes).map_err(|_| AgeWireError::ParseKey)
}
// ---------- Derivation + Raw X25519 (Ed25519 -> X25519) ----------
//
// We deterministically derive an X25519 keypair from an Ed25519 SigningKey.
// We persist the X25519 public/secret as base64-encoded 32-byte raw values
// (no "age1..."/"AGE-SECRET-KEY-1..." formatting). Name-based encrypt/decrypt
// uses these raw values directly via x25519-dalek + ChaCha20Poly1305.
use chacha20poly1305::{aead::{Aead, KeyInit}, ChaCha20Poly1305, Key, Nonce};
use sha2::{Digest, Sha256};
use x25519_dalek::{PublicKey as XPublicKey, StaticSecret as XStaticSecret};
fn derive_x25519_raw_from_ed25519(sk: &SigningKey) -> ([u8; 32], [u8; 32]) {
// X25519 secret scalar (clamped) from Ed25519 secret
let scalar: [u8; 32] = sk.to_scalar_bytes();
// Build X25519 secret/public using dalek
let xsec = XStaticSecret::from(scalar);
let xpub = XPublicKey::from(&xsec);
(xpub.to_bytes(), xsec.to_bytes())
}
fn derive_x25519_raw_b64_from_ed25519(sk: &SigningKey) -> (String, String) {
let (xpub, xsec) = derive_x25519_raw_from_ed25519(sk);
(B64.encode(xpub), B64.encode(xsec))
}
// Helper: detect whether a stored key looks like an age-formatted string
fn looks_like_age_format(s: &str) -> bool {
s.starts_with("age1") || s.starts_with("AGE-SECRET-KEY-1")
}
// Our container format for name-based raw X25519 encryption:
// bytes = "HDBX1" (5) || eph_pub(32) || nonce(12) || ciphertext(..)
// Entire blob is base64-encoded for transport.
const HDBX1_MAGIC: &[u8; 5] = b"HDBX1";
fn encrypt_b64_with_x25519_raw(recip_pub_b64: &str, msg: &str) -> Result<String, AgeWireError> {
use rand::RngCore;
use rand::rngs::OsRng;
// Parse recipient public key (raw 32 bytes, base64)
let recip_pub_bytes = B64.decode(recip_pub_b64).map_err(|_| AgeWireError::ParseKey)?;
if recip_pub_bytes.len() != 32 { return Err(AgeWireError::ParseKey); }
let recip_pub_arr: [u8; 32] = recip_pub_bytes.as_slice().try_into().map_err(|_| AgeWireError::ParseKey)?;
let recip_pub: XPublicKey = XPublicKey::from(recip_pub_arr);
// Generate ephemeral X25519 keypair
let mut eph_sec_bytes = [0u8; 32];
OsRng.fill_bytes(&mut eph_sec_bytes);
let eph_sec = XStaticSecret::from(eph_sec_bytes);
let eph_pub = XPublicKey::from(&eph_sec);
// ECDH
let shared = eph_sec.diffie_hellman(&recip_pub);
// Derive symmetric key via SHA-256 over context + shared + parties
let mut hasher = Sha256::default();
hasher.update(b"herodb-x25519-v1");
hasher.update(shared.as_bytes());
hasher.update(eph_pub.as_bytes());
hasher.update(recip_pub.as_bytes());
let key_bytes = hasher.finalize();
let key = Key::from_slice(&key_bytes[..32]);
// Nonce (12 bytes)
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
// Encrypt
let cipher = ChaCha20Poly1305::new(key);
let ct = cipher.encrypt(nonce, msg.as_bytes())
.map_err(|e| AgeWireError::Crypto(format!("encrypt: {e}")))?;
// Assemble container
let mut out = Vec::with_capacity(5 + 32 + 12 + ct.len());
out.extend_from_slice(HDBX1_MAGIC);
out.extend_from_slice(eph_pub.as_bytes());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ct);
Ok(B64.encode(out))
}
fn decrypt_b64_with_x25519_raw(identity_sec_b64: &str, ct_b64: &str) -> Result<String, AgeWireError> {
// Parse X25519 secret (raw 32 bytes, base64)
let sec_bytes = B64.decode(identity_sec_b64).map_err(|_| AgeWireError::ParseKey)?;
if sec_bytes.len() != 32 { return Err(AgeWireError::ParseKey); }
let sec_arr: [u8; 32] = sec_bytes.as_slice().try_into().map_err(|_| AgeWireError::ParseKey)?;
let xsec = XStaticSecret::from(sec_arr);
let xpub = XPublicKey::from(&xsec); // self public
// Decode container
let blob = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
if blob.len() < 5 + 32 + 12 { return Err(AgeWireError::Crypto("ciphertext too short".to_string())); }
if &blob[..5] != HDBX1_MAGIC { return Err(AgeWireError::Crypto("bad header".to_string())); }
let eph_pub_arr: [u8; 32] = blob[5..5+32].try_into().map_err(|_| AgeWireError::Crypto("bad eph pub".to_string()))?;
let eph_pub = XPublicKey::from(eph_pub_arr);
let nonce_bytes: [u8; 12] = blob[5+32..5+32+12].try_into().unwrap();
let ct = &blob[5+32+12..];
// Recompute shared + key
let shared = xsec.diffie_hellman(&eph_pub);
let mut hasher = Sha256::default();
hasher.update(b"herodb-x25519-v1");
hasher.update(shared.as_bytes());
hasher.update(eph_pub.as_bytes());
hasher.update(xpub.as_bytes());
let key_bytes = hasher.finalize();
let key = Key::from_slice(&key_bytes[..32]);
// Decrypt
let cipher = ChaCha20Poly1305::new(key);
let nonce = Nonce::from_slice(&nonce_bytes);
let pt = cipher.decrypt(nonce, ct)
.map_err(|e| AgeWireError::Crypto(format!("decrypt: {e}")))?;
String::from_utf8(pt).map_err(|_| AgeWireError::Utf8)
}
// ---------- Stateless crypto helpers (string in/out) ----------
pub fn gen_enc_keypair() -> (String, String) {
@@ -83,38 +204,34 @@ pub fn gen_enc_keypair() -> (String, String) {
}
pub fn gen_sign_keypair() -> (String, String) {
use rand::rngs::OsRng;
use rand::RngCore;
use rand::rngs::OsRng;
// Generate random 32 bytes for the signing key
let mut secret_bytes = [0u8; 32];
OsRng.fill_bytes(&mut secret_bytes);
let signing_key = SigningKey::from_bytes(&secret_bytes);
let verifying_key = signing_key.verifying_key();
// Encode as base64 for storage
let signing_key_b64 = B64.encode(signing_key.to_bytes());
let verifying_key_b64 = B64.encode(verifying_key.to_bytes());
(verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret)
}
/// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext).
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> {
let recipient = parse_recipient(recipient_str)?;
let enc =
Encryptor::with_recipients(vec![Box::new(recipient)]).expect("failed to create encryptor"); // Handle Option<Encryptor>
let enc = Encryptor::with_recipients(vec![Box::new(recipient)])
.expect("failed to create encryptor"); // Handle Option<Encryptor>
let mut out = Vec::new();
{
use std::io::Write;
let mut w = enc
.wrap_output(&mut out)
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.write_all(msg.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.finish()
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let mut w = enc.wrap_output(&mut out).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.write_all(msg.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.finish().map_err(|e| AgeWireError::Crypto(e.to_string()))?;
}
Ok(B64.encode(out))
}
@@ -122,27 +239,19 @@ pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireErro
/// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String.
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> {
let id = parse_identity(identity_str)?;
let ct = B64
.decode(ct_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let ct = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
// The decrypt method returns a Result<StreamReader, DecryptError>
let mut r = match dec {
Decryptor::Recipients(d) => d
.decrypt(std::iter::once(&id as &dyn age::Identity))
Decryptor::Recipients(d) => d.decrypt(std::iter::once(&id as &dyn age::Identity))
.map_err(|e| AgeWireError::Crypto(e.to_string()))?,
Decryptor::Passphrase(_) => {
return Err(AgeWireError::Crypto(
"Expected recipients, got passphrase".to_string(),
))
}
Decryptor::Passphrase(_) => return Err(AgeWireError::Crypto("Expected recipients, got passphrase".to_string())),
};
let mut pt = Vec::new();
use std::io::Read;
r.read_to_end(&mut pt)
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
r.read_to_end(&mut pt).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
String::from_utf8(pt).map_err(|_| AgeWireError::Utf8)
}
@@ -156,9 +265,7 @@ pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AgeWireEr
/// Verify detached signature (base64) for `msg` with pubkey.
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> {
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
let sig_bytes = B64
.decode(sig_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let sig_bytes = B64.decode(sig_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
if sig_bytes.len() != 64 {
return Err(AgeWireError::SignatureLen);
}
@@ -169,49 +276,30 @@ pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool
// ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
let st = server
.current_storage()
.map_err(|e| AgeWireError::Storage(e.0))?;
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.get(key).map_err(|e| AgeWireError::Storage(e.0))
}
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
let st = server
.current_storage()
.map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string())
.map_err(|e| AgeWireError::Storage(e.0))
let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0))
}
fn enc_pub_key_key(name: &str) -> String {
format!("age:key:{name}")
}
fn enc_priv_key_key(name: &str) -> String {
format!("age:privkey:{name}")
}
fn sign_pub_key_key(name: &str) -> String {
format!("age:signpub:{name}")
}
fn sign_priv_key_key(name: &str) -> String {
format!("age:signpriv:{name}")
}
fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
// ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = gen_enc_keypair();
Protocol::Array(vec![
Protocol::BulkString(recip),
Protocol::BulkString(ident),
])
Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
}
pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = gen_sign_keypair();
Protocol::Array(vec![
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
@@ -243,66 +331,159 @@ pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> P
}
}
// ---------- NEW: unified stateless generator (Ed25519 + derived X25519 raw) ----------
//
// Returns 4-tuple:
// [ verify_pub_b64 (32B), signpriv_b64 (32B), x25519_pub_b64 (32B), x25519_sec_b64 (32B) ]
// No persistence (stateless).
pub async fn cmd_age_genkey() -> Protocol {
use rand::RngCore;
use rand::rngs::OsRng;
let mut secret_bytes = [0u8; 32];
OsRng.fill_bytes(&mut secret_bytes);
let signing_key = SigningKey::from_bytes(&secret_bytes);
let verifying_key = signing_key.verifying_key();
let verify_b64 = B64.encode(verifying_key.to_bytes());
let sign_b64 = B64.encode(signing_key.to_bytes());
let (xpub_b64, xsec_b64) = derive_x25519_raw_b64_from_ed25519(&signing_key);
Protocol::Array(vec![
Protocol::BulkString(verify_b64),
Protocol::BulkString(sign_b64),
Protocol::BulkString(xpub_b64),
Protocol::BulkString(xsec_b64),
])
}
// ---------- NEW: Persistent, named-key commands ----------
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) {
return e.to_protocol();
}
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) {
return e.to_protocol();
}
use rand::RngCore;
use rand::rngs::OsRng;
// Generate Ed25519 keypair
let mut secret_bytes = [0u8; 32];
OsRng.fill_bytes(&mut secret_bytes);
let signing_key = SigningKey::from_bytes(&secret_bytes);
let verifying_key = signing_key.verifying_key();
// Encode Ed25519 as base64 (32 bytes)
let verify_b64 = B64.encode(verifying_key.to_bytes());
let sign_b64 = B64.encode(signing_key.to_bytes());
// Derive X25519 raw (32-byte) keys and encode as base64
let (xpub_b64, xsec_b64) = derive_x25519_raw_b64_from_ed25519(&signing_key);
// Decode to create age-formatted strings
let xpub_bytes = B64.decode(&xpub_b64).unwrap();
let xsec_bytes = B64.decode(&xsec_b64).unwrap();
let xpub_arr: [u8; 32] = xpub_bytes.as_slice().try_into().unwrap();
let xsec_arr: [u8; 32] = xsec_bytes.as_slice().try_into().unwrap();
let recip_str = format!("age1{}", B64.encode(xpub_arr));
let ident_str = format!("AGE-SECRET-KEY-1{}", B64.encode(xsec_arr));
// Persist Ed25519 and derived X25519 (key-managed mode)
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify_b64) { return e.to_protocol(); }
if let Err(e) = sset(server, &sign_priv_key_key(name), &sign_b64) { return e.to_protocol(); }
if let Err(e) = sset(server, &enc_pub_key_key(name), &xpub_b64) { return e.to_protocol(); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &xsec_b64) { return e.to_protocol(); }
// Return [recipient, identity] in age format
Protocol::Array(vec![
Protocol::BulkString(recip),
Protocol::BulkString(ident),
Protocol::BulkString(recip_str),
Protocol::BulkString(ident_str),
])
}
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) {
return e.to_protocol();
}
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) {
return e.to_protocol();
}
Protocol::Array(vec![
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); }
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); }
Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
}
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
let recip = match sget(server, &enc_pub_key_key(name)) {
// Load stored recipient (could be raw b64 32-byte or "age1..." from legacy)
let recip_or_b64 = match sget(server, &enc_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("recipient (age:key:{name})").to_protocol(),
Ok(None) => {
// Derive from stored Ed25519 if present, then persist
match sget(server, &sign_priv_key_key(name)) {
Ok(Some(sign_b64)) => {
let sk = match parse_ed25519_signing_key(&sign_b64) {
Ok(k) => k,
Err(e) => return e.to_protocol(),
};
let (xpub_b64, xsec_b64) = derive_x25519_raw_b64_from_ed25519(&sk);
if let Err(e) = sset(server, &enc_pub_key_key(name), &xpub_b64) { return e.to_protocol(); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &xsec_b64) { return e.to_protocol(); }
xpub_b64
}
Ok(None) => return AgeWireError::NotFound("recipient (age:key:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
}
}
Err(e) => return e.to_protocol(),
};
match encrypt_b64(&recip, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => e.to_protocol(),
if looks_like_age_format(&recip_or_b64) {
match encrypt_b64(&recip_or_b64, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => e.to_protocol(),
}
} else {
match encrypt_b64_with_x25519_raw(&recip_or_b64, message) {
Ok(ct) => Protocol::BulkString(ct),
Err(e) => e.to_protocol(),
}
}
}
pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) -> Protocol {
let ident = match sget(server, &enc_priv_key_key(name)) {
// Load stored identity (could be raw b64 32-byte or "AGE-SECRET-KEY-1..." from legacy)
let ident_or_b64 = match sget(server, &enc_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => return AgeWireError::NotFound("identity (age:privkey:{name})").to_protocol(),
Ok(None) => {
// Derive from stored Ed25519 if present, then persist
match sget(server, &sign_priv_key_key(name)) {
Ok(Some(sign_b64)) => {
let sk = match parse_ed25519_signing_key(&sign_b64) {
Ok(k) => k,
Err(e) => return e.to_protocol(),
};
let (xpub_b64, xsec_b64) = derive_x25519_raw_b64_from_ed25519(&sk);
if let Err(e) = sset(server, &enc_pub_key_key(name), &xpub_b64) { return e.to_protocol(); }
if let Err(e) = sset(server, &enc_priv_key_key(name), &xsec_b64) { return e.to_protocol(); }
xsec_b64
}
Ok(None) => return AgeWireError::NotFound("identity (age:privkey:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
}
}
Err(e) => return e.to_protocol(),
};
match decrypt_b64(&ident, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
if looks_like_age_format(&ident_or_b64) {
match decrypt_b64(&ident_or_b64, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
} else {
match decrypt_b64_with_x25519_raw(&ident_or_b64, ct_b64) {
Ok(pt) => Protocol::BulkString(pt),
Err(e) => e.to_protocol(),
}
}
}
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => {
return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol()
}
Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match sign_b64(&sec, message) {
@@ -311,17 +492,10 @@ pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Pr
}
}
pub async fn cmd_age_verify_name(
server: &Server,
name: &str,
message: &str,
sig_b64: &str,
) -> Protocol {
pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v,
Ok(None) => {
return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol()
}
Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(),
Err(e) => return e.to_protocol(),
};
match verify_b64(&pubk, message, sig_b64) {
@@ -332,11 +506,8 @@ pub async fn cmd_age_verify_name(
}
pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() {
Ok(s) => s,
Err(e) => return Protocol::err(&e.0),
};
// Return a flat, deduplicated, sorted list of managed key names (no labels)
let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?;
@@ -348,35 +519,18 @@ pub async fn cmd_age_list(server: &Server) -> Protocol {
Ok(names)
};
let encpub = match pull("age:key:*", "age:key:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let encpriv = match pull("age:privkey:*", "age:privkey:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpub = match pull("age:signpub:*", "age:signpub:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpriv = match pull("age:signpriv:*", "age:signpriv:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let signpriv = match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array(
v.into_iter().map(Protocol::BulkString).collect(),
));
Protocol::Array(out)
};
let mut set: HashSet<String> = HashSet::new();
for n in encpub.into_iter().chain(encpriv).chain(signpub).chain(signpriv) {
set.insert(n);
}
Protocol::Array(vec![
to_arr("encpub", encpub),
to_arr("encpriv", encpriv),
to_arr("signpub", signpub),
to_arr("signpriv", signpriv),
])
}
let mut names: Vec<String> = set.into_iter().collect();
names.sort();
Protocol::Array(names.into_iter().map(Protocol::BulkString).collect())
}

1138
src/cmd.rs

File diff suppressed because it is too large Load Diff

View File

@@ -11,9 +11,9 @@ const TAG_LEN: usize = 16;
#[derive(Debug)]
pub enum CryptoError {
Format, // wrong length / header
Version(u8), // unknown version
Decrypt, // wrong key or corrupted data
Format, // wrong length / header
Version(u8), // unknown version
Decrypt, // wrong key or corrupted data
}
impl From<CryptoError> for crate::error::DBError {
@@ -71,4 +71,4 @@ impl CryptoFactory {
let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
}
}
}

View File

@@ -1,8 +1,9 @@
use std::num::ParseIntError;
use bincode;
use redb;
use tokio::sync::mpsc;
use redb;
use bincode;
// todo: more error types
#[derive(Debug)]

View File

@@ -1,12 +1,14 @@
pub mod age; // NEW
pub mod age;
pub mod sym;
pub mod cmd;
pub mod crypto;
pub mod error;
pub mod options;
pub mod protocol;
pub mod search_cmd; // Add this
pub mod rpc;
pub mod rpc_server;
pub mod server;
pub mod storage;
pub mod storage_sled; // Add this
pub mod storage_trait; // Add this
pub mod tantivy_search;
pub mod storage_trait;
pub mod storage_sled;
pub mod admin_meta;

View File

@@ -3,6 +3,7 @@
use tokio::net::TcpListener;
use herodb::server;
use herodb::rpc_server;
use clap::Parser;
@@ -22,17 +23,29 @@ struct Args {
#[arg(long)]
debug: bool,
/// Master encryption key for encrypted databases
/// Master encryption key for encrypted databases (deprecated; ignored for data DBs)
#[arg(long)]
encryption_key: Option<String>,
/// Encrypt the database
/// Encrypt the database (deprecated; ignored for data DBs)
#[arg(long)]
encrypt: bool,
/// Enable RPC management server
#[arg(long)]
enable_rpc: bool,
/// RPC server port (default: 8080)
#[arg(long, default_value = "8080")]
rpc_port: u16,
/// Use the sled backend
#[arg(long)]
sled: bool,
/// Admin secret used to encrypt DB 0 and authorize admin access (required)
#[arg(long)]
admin_secret: String,
}
#[tokio::main]
@@ -47,9 +60,19 @@ async fn main() {
.await
.unwrap();
// deprecation warnings for legacy flags
if args.encrypt || args.encryption_key.is_some() {
eprintln!("warning: --encrypt and --encryption-key are deprecated and ignored for data DBs. Admin DB 0 is always encrypted with --admin-secret.");
}
// basic validation for admin secret
if args.admin_secret.trim().is_empty() {
eprintln!("error: --admin-secret must not be empty");
std::process::exit(2);
}
// new DB option
let option = herodb::options::DBOption {
dir: args.dir,
dir: args.dir.clone(),
port,
debug: args.debug,
encryption_key: args.encryption_key,
@@ -59,14 +82,42 @@ async fn main() {
} else {
herodb::options::BackendType::Redb
},
admin_secret: args.admin_secret.clone(),
};
let backend = option.backend.clone();
// Bootstrap admin DB 0 before opening any server storage
if let Err(e) = herodb::admin_meta::ensure_bootstrap(&args.dir, backend.clone(), &args.admin_secret) {
eprintln!("Failed to bootstrap admin DB 0: {}", e.0);
std::process::exit(2);
}
// new server
let server = server::Server::new(option).await;
// Add a small delay to ensure the port is ready
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Start RPC server if enabled
let _rpc_handle = if args.enable_rpc {
let rpc_addr = format!("127.0.0.1:{}", args.rpc_port).parse().unwrap();
let base_dir = args.dir.clone();
match rpc_server::start_rpc_server(rpc_addr, base_dir, backend, args.admin_secret.clone()).await {
Ok(handle) => {
println!("RPC management server started on port {}", args.rpc_port);
Some(handle)
}
Err(e) => {
eprintln!("Failed to start RPC server: {}", e);
None
}
}
} else {
None
};
// accept new connections
loop {
let stream = listener.accept().await;

View File

@@ -1,4 +1,4 @@
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BackendType {
Redb,
Sled,
@@ -9,7 +9,11 @@ pub struct DBOption {
pub dir: String,
pub port: u16,
pub debug: bool,
// Deprecated for data DBs; retained for backward-compat on CLI parsing
pub encrypt: bool,
// Deprecated for data DBs; retained for backward-compat on CLI parsing
pub encryption_key: Option<String>,
pub backend: BackendType,
// New: required admin secret, used to encrypt DB 0 and authorize admin operations
pub admin_secret: String,
}

View File

@@ -81,21 +81,18 @@ impl Protocol {
pub fn encode(&self) -> String {
match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => {
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
}
Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
}
}
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") {
Some(x) => Ok((
Self::SimpleString(protocol[..x].to_string()),
&protocol[x + 2..],
)),
Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])),
_ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}",
protocol

472
src/rpc.rs Normal file
View File

@@ -0,0 +1,472 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use jsonrpsee::{core::RpcResult, proc_macros::rpc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::server::Server;
use crate::options::DBOption;
use crate::admin_meta;
/// Database backend types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BackendType {
Redb,
Sled,
// Future: InMemory, Custom(String)
}
/// Database configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub name: Option<String>,
pub storage_path: Option<String>,
pub max_size: Option<u64>,
pub redis_version: Option<String>,
}
/// Database information returned by metadata queries
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseInfo {
pub id: u64,
pub name: Option<String>,
pub backend: BackendType,
pub encrypted: bool,
pub redis_version: Option<String>,
pub storage_path: Option<String>,
pub size_on_disk: Option<u64>,
pub key_count: Option<u64>,
pub created_at: u64,
pub last_access: Option<u64>,
}
/// Access permissions for database keys
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Permissions {
Read,
ReadWrite,
}
/// Access key information returned by RPC
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessKeyInfo {
pub hash: String,
pub permissions: Permissions,
pub created_at: u64,
}
/// Hash a plaintext key using SHA-256
pub fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// RPC trait for HeroDB management
#[rpc(server, client, namespace = "herodb")]
pub trait Rpc {
/// Create a new database with specified configuration
#[method(name = "createDatabase")]
async fn create_database(
&self,
backend: BackendType,
config: DatabaseConfig,
encryption_key: Option<String>,
) -> RpcResult<u64>;
/// Set encryption for an existing database (write-only key)
#[method(name = "setEncryption")]
async fn set_encryption(&self, db_id: u64, encryption_key: String) -> RpcResult<bool>;
/// List all managed databases
#[method(name = "listDatabases")]
async fn list_databases(&self) -> RpcResult<Vec<DatabaseInfo>>;
/// Get detailed information about a specific database
#[method(name = "getDatabaseInfo")]
async fn get_database_info(&self, db_id: u64) -> RpcResult<DatabaseInfo>;
/// Delete a database
#[method(name = "deleteDatabase")]
async fn delete_database(&self, db_id: u64) -> RpcResult<bool>;
/// Get server statistics
#[method(name = "getServerStats")]
async fn get_server_stats(&self) -> RpcResult<HashMap<String, serde_json::Value>>;
/// Add an access key to a database
#[method(name = "addAccessKey")]
async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult<bool>;
/// Delete an access key from a database
#[method(name = "deleteAccessKey")]
async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult<bool>;
/// List all access keys for a database
#[method(name = "listAccessKeys")]
async fn list_access_keys(&self, db_id: u64) -> RpcResult<Vec<AccessKeyInfo>>;
/// Set database public/private status
#[method(name = "setDatabasePublic")]
async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult<bool>;
}
/// RPC Server implementation
pub struct RpcServerImpl {
/// Base directory for database files
base_dir: String,
/// Managed database servers
servers: Arc<RwLock<HashMap<u64, Arc<Server>>>>,
/// Default backend type
backend: crate::options::BackendType,
/// Admin secret used to encrypt DB 0 and authorize admin access
admin_secret: String,
}
impl RpcServerImpl {
/// Create a new RPC server instance
pub fn new(base_dir: String, backend: crate::options::BackendType, admin_secret: String) -> Self {
Self {
base_dir,
servers: Arc::new(RwLock::new(HashMap::new())),
backend,
admin_secret,
}
}
/// Get or create a server instance for the given database ID
async fn get_or_create_server(&self, db_id: u64) -> Result<Arc<Server>, jsonrpsee::types::ErrorObjectOwned> {
// Check if server already exists
{
let servers = self.servers.read().await;
if let Some(server) = servers.get(&db_id) {
return Ok(server.clone());
}
}
// Validate existence via admin DB 0 (metadata), not filesystem presence
let exists = admin_meta::db_exists(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
if !exists {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
format!("Database {} not found", db_id),
None::<()>
));
}
// Resolve effective backend for this db from admin meta or filesystem; fallback to default
let meta_backend = admin_meta::get_database_backend(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id)
.ok()
.flatten();
let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id));
let sniffed_backend = if db_path.exists() {
if db_path.is_file() {
Some(crate::options::BackendType::Redb)
} else if db_path.is_dir() {
Some(crate::options::BackendType::Sled)
} else {
None
}
} else {
None
};
let effective_backend = meta_backend.clone().or(sniffed_backend).unwrap_or(self.backend.clone());
if effective_backend != self.backend {
eprintln!(
"notice: get_or_create_server: db {} backend resolved to {:?} (server default {:?})",
db_id, effective_backend, self.backend
);
}
// If we had to sniff (no meta), persist the resolved backend
if meta_backend.is_none() {
let _ = admin_meta::set_database_backend(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, effective_backend.clone());
}
// Create server instance with resolved backend
let db_option = DBOption {
dir: self.base_dir.clone(),
port: 0, // Not used for RPC-managed databases
debug: false,
encryption_key: None,
encrypt: false,
backend: effective_backend,
admin_secret: self.admin_secret.clone(),
};
let mut server = Server::new(db_option).await;
// Set the selected database to the db_id
server.selected_db = db_id;
// Lazily open/create physical storage according to admin meta (per-db encryption)
let _ = server.current_storage();
// Store the server
let mut servers = self.servers.write().await;
servers.insert(db_id, Arc::new(server.clone()));
Ok(Arc::new(server))
}
/// Discover existing database IDs from admin DB 0
async fn discover_databases(&self) -> Vec<u64> {
admin_meta::list_dbs(&self.base_dir, self.backend.clone(), &self.admin_secret)
.unwrap_or_default()
}
/// Build database file path for given server/db_id
fn db_file_path(&self, server: &Server, db_id: u64) -> std::path::PathBuf {
std::path::PathBuf::from(&server.option.dir).join(format!("{}.db", db_id))
}
/// Recursively compute size on disk for the database path
fn compute_size_on_disk(&self, path: &std::path::Path) -> Option<u64> {
fn dir_size(p: &std::path::Path) -> u64 {
if p.is_file() {
std::fs::metadata(p).map(|m| m.len()).unwrap_or(0)
} else if p.is_dir() {
let mut total = 0u64;
if let Ok(read) = std::fs::read_dir(p) {
for entry in read.flatten() {
total += dir_size(&entry.path());
}
}
total
} else {
0
}
}
Some(dir_size(path))
}
/// Extract created and last access times (secs) from a path, with fallbacks
fn get_file_times_secs(path: &std::path::Path) -> (u64, Option<u64>) {
let now = std::time::SystemTime::now();
let created = std::fs::metadata(path)
.and_then(|m| m.created().or_else(|_| m.modified()))
.unwrap_or(now)
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let last_access = std::fs::metadata(path)
.and_then(|m| m.accessed())
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok().map(|d| d.as_secs()));
(created, last_access)
}
/// Compose a DatabaseInfo by probing storage and filesystem, with admin meta for access key count
async fn build_database_info(&self, db_id: u64, server: &Server) -> DatabaseInfo {
// Probe storage to determine encryption state
let storage = server.current_storage().ok();
let encrypted = storage.as_ref().map(|s| s.is_encrypted()).unwrap_or(server.option.encrypt);
// Get actual key count from storage
let key_count = storage.as_ref()
.and_then(|s| s.dbsize().ok())
.map(|count| count as u64);
// Get database name from admin meta
let name = admin_meta::get_database_name(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id)
.ok()
.flatten();
// Compute size on disk and timestamps from the DB file path
let db_path = self.db_file_path(server, db_id);
let size_on_disk = self.compute_size_on_disk(&db_path);
let (created_at, last_access) = Self::get_file_times_secs(&db_path);
let backend = match server.option.backend {
crate::options::BackendType::Redb => BackendType::Redb,
crate::options::BackendType::Sled => BackendType::Sled,
};
DatabaseInfo {
id: db_id,
name,
backend,
encrypted,
redis_version: Some("7.0".to_string()),
storage_path: Some(server.option.dir.clone()),
size_on_disk,
key_count,
created_at,
last_access,
}
}
}
#[jsonrpsee::core::async_trait]
impl RpcServer for RpcServerImpl {
async fn create_database(
&self,
backend: BackendType,
config: DatabaseConfig,
encryption_key: Option<String>,
) -> RpcResult<u64> {
// Allocate new ID via admin DB 0
let db_id = admin_meta::allocate_next_id(&self.base_dir, self.backend.clone(), &self.admin_secret)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
// Persist per-db encryption key in admin DB 0 if provided
if let Some(ref key) = encryption_key {
admin_meta::set_enc_key(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, key)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
}
// Persist database name if provided
if let Some(ref name) = config.name {
admin_meta::set_database_name(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, name)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
}
// Ensure base dir exists
if let Err(e) = std::fs::create_dir_all(&self.base_dir) {
return Err(jsonrpsee::types::ErrorObjectOwned::owned(-32000, format!("Failed to ensure base dir: {}", e), None::<()>));
}
// Map RPC backend to options backend and persist it in admin meta for this db id
let opt_backend = match backend {
BackendType::Redb => crate::options::BackendType::Redb,
BackendType::Sled => crate::options::BackendType::Sled,
};
admin_meta::set_database_backend(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, opt_backend.clone())
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
// Create server instance using base_dir, chosen backend and admin secret
let option = DBOption {
dir: self.base_dir.clone(),
port: 0, // Not used for RPC-managed databases
debug: false,
encryption_key: None, // per-db key is stored in admin DB 0
encrypt: false, // encryption decided per-db at open time
backend: opt_backend,
admin_secret: self.admin_secret.clone(),
};
let mut server = Server::new(option).await;
server.selected_db = db_id;
// Initialize storage to create physical <id>.db with proper encryption from admin meta
let _ = server.current_storage();
// Store the server in cache
let mut servers = self.servers.write().await;
servers.insert(db_id, Arc::new(server));
Ok(db_id)
}
async fn set_encryption(&self, _db_id: u64, _encryption_key: String) -> RpcResult<bool> {
// For now, return false as encryption can only be set during creation
let _servers = self.servers.read().await;
// TODO: Implement encryption setting for existing databases
Ok(false)
}
async fn list_databases(&self) -> RpcResult<Vec<DatabaseInfo>> {
let db_ids = self.discover_databases().await;
let mut result = Vec::new();
for db_id in db_ids {
if let Ok(server) = self.get_or_create_server(db_id).await {
// Build accurate info from storage/meta/fs
let info = self.build_database_info(db_id, &server).await;
result.push(info);
}
}
Ok(result)
}
async fn get_database_info(&self, db_id: u64) -> RpcResult<DatabaseInfo> {
let server = self.get_or_create_server(db_id).await?;
// Build accurate info from storage/meta/fs
let info = self.build_database_info(db_id, &server).await;
Ok(info)
}
async fn delete_database(&self, db_id: u64) -> RpcResult<bool> {
let mut servers = self.servers.write().await;
if let Some(_server) = servers.remove(&db_id) {
// Clean up database files
let db_path = std::path::PathBuf::from(&self.base_dir).join(format!("{}.db", db_id));
if db_path.exists() {
if db_path.is_dir() {
std::fs::remove_dir_all(&db_path).ok();
} else {
std::fs::remove_file(&db_path).ok();
}
}
Ok(true)
} else {
Ok(false)
}
}
async fn get_server_stats(&self) -> RpcResult<HashMap<String, serde_json::Value>> {
let db_ids = self.discover_databases().await;
let mut stats = HashMap::new();
stats.insert("total_databases".to_string(), serde_json::json!(db_ids.len()));
stats.insert("uptime".to_string(), serde_json::json!(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
));
Ok(stats)
}
async fn add_access_key(&self, db_id: u64, key: String, permissions: String) -> RpcResult<bool> {
let perms = match permissions.to_lowercase().as_str() {
"read" => Permissions::Read,
"readwrite" => Permissions::ReadWrite,
_ => return Err(jsonrpsee::types::ErrorObjectOwned::owned(
-32000,
"Invalid permissions: use 'read' or 'readwrite'",
None::<()>
)),
};
admin_meta::add_access_key(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, &key, perms)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(true)
}
async fn delete_access_key(&self, db_id: u64, key_hash: String) -> RpcResult<bool> {
let ok = admin_meta::delete_access_key(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, &key_hash)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(ok)
}
async fn list_access_keys(&self, db_id: u64) -> RpcResult<Vec<AccessKeyInfo>> {
let pairs = admin_meta::list_access_keys(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
let keys: Vec<AccessKeyInfo> = pairs.into_iter().map(|(hash, perm, ts)| AccessKeyInfo {
hash,
permissions: perm,
created_at: ts,
}).collect();
Ok(keys)
}
async fn set_database_public(&self, db_id: u64, public: bool) -> RpcResult<bool> {
admin_meta::set_database_public(&self.base_dir, self.backend.clone(), &self.admin_secret, db_id, public)
.map_err(|e| jsonrpsee::types::ErrorObjectOwned::owned(-32000, e.0, None::<()>))?;
Ok(true)
}
}

49
src/rpc_server.rs Normal file
View File

@@ -0,0 +1,49 @@
use std::net::SocketAddr;
use jsonrpsee::server::{ServerBuilder, ServerHandle};
use jsonrpsee::RpcModule;
use crate::rpc::{RpcServer, RpcServerImpl};
/// Start the RPC server on the specified address
pub async fn start_rpc_server(addr: SocketAddr, base_dir: String, backend: crate::options::BackendType, admin_secret: String) -> Result<ServerHandle, Box<dyn std::error::Error + Send + Sync>> {
// Create the RPC server implementation
let rpc_impl = RpcServerImpl::new(base_dir, backend, admin_secret);
// Create the RPC module
let mut module = RpcModule::new(());
module.merge(RpcServer::into_rpc(rpc_impl))?;
// Build the server with both HTTP and WebSocket support
let server = ServerBuilder::default()
.build(addr)
.await?;
// Start the server
let handle = server.start(module);
println!("RPC server started on {}", addr);
Ok(handle)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_rpc_server_startup() {
let addr = "127.0.0.1:0".parse().unwrap(); // Use port 0 for auto-assignment
let base_dir = "/tmp/test_rpc".to_string();
let backend = crate::options::BackendType::Redb; // Default for test
let handle = start_rpc_server(addr, base_dir, backend, "test-admin".to_string()).await.unwrap();
// Give the server a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Stop the server
handle.stop().unwrap();
handle.stopped().await;
}
}

View File

@@ -1,273 +0,0 @@
use crate::{
error::DBError,
protocol::Protocol,
server::Server,
tantivy_search::{
FieldDef, Filter, FilterType, IndexConfig, NumericType, SearchOptions, TantivySearch,
},
};
use std::collections::HashMap;
use std::sync::Arc;
pub async fn ft_create_cmd(
server: &Server,
index_name: String,
schema: Vec<(String, String, Vec<String>)>,
) -> Result<Protocol, DBError> {
// Parse schema into field definitions
let mut field_definitions = Vec::new();
for (field_name, field_type, options) in schema {
let field_def = match field_type.to_uppercase().as_str() {
"TEXT" => {
let mut weight = 1.0;
let mut sortable = false;
let mut no_index = false;
for opt in &options {
match opt.to_uppercase().as_str() {
"WEIGHT" => {
// Next option should be the weight value
if let Some(idx) = options.iter().position(|x| x == opt) {
if idx + 1 < options.len() {
weight = options[idx + 1].parse().unwrap_or(1.0);
}
}
}
"SORTABLE" => sortable = true,
"NOINDEX" => no_index = true,
_ => {}
}
}
FieldDef::Text {
stored: true,
indexed: !no_index,
tokenized: true,
fast: sortable,
}
}
"NUMERIC" => {
let mut sortable = false;
for opt in &options {
if opt.to_uppercase() == "SORTABLE" {
sortable = true;
}
}
FieldDef::Numeric {
stored: true,
indexed: true,
fast: sortable,
precision: NumericType::F64,
}
}
"TAG" => {
let mut separator = ",".to_string();
let mut case_sensitive = false;
for i in 0..options.len() {
match options[i].to_uppercase().as_str() {
"SEPARATOR" => {
if i + 1 < options.len() {
separator = options[i + 1].clone();
}
}
"CASESENSITIVE" => case_sensitive = true,
_ => {}
}
}
FieldDef::Tag {
stored: true,
separator,
case_sensitive,
}
}
"GEO" => FieldDef::Geo { stored: true },
_ => {
return Err(DBError(format!("Unknown field type: {}", field_type)));
}
};
field_definitions.push((field_name, field_def));
}
// Create the search index
let search_path = server.search_index_path();
let config = IndexConfig::default();
println!(
"Creating search index '{}' at path: {:?}",
index_name, search_path
);
println!("Field definitions: {:?}", field_definitions);
let search_index = TantivySearch::new_with_schema(
search_path,
index_name.clone(),
field_definitions,
Some(config),
)?;
println!("Search index '{}' created successfully", index_name);
// Store in registry
let mut indexes = server.search_indexes.write().unwrap();
indexes.insert(index_name, Arc::new(search_index));
Ok(Protocol::SimpleString("OK".to_string()))
}
pub async fn ft_add_cmd(
server: &Server,
index_name: String,
doc_id: String,
_score: f64,
fields: HashMap<String, String>,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
search_index.add_document_with_fields(&doc_id, fields)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
pub async fn ft_search_cmd(
server: &Server,
index_name: String,
query: String,
filters: Vec<(String, String)>,
limit: Option<usize>,
offset: Option<usize>,
return_fields: Option<Vec<String>>,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// Convert filters to search filters
let search_filters = filters
.into_iter()
.map(|(field, value)| Filter {
field,
filter_type: FilterType::Equals(value),
})
.collect();
let options = SearchOptions {
limit: limit.unwrap_or(10),
offset: offset.unwrap_or(0),
filters: search_filters,
sort_by: None,
return_fields,
highlight: false,
};
let results = search_index.search_with_options(&query, options)?;
// Format results as Redis protocol
let mut response = Vec::new();
// First element is the total count
response.push(Protocol::SimpleString(results.total.to_string()));
// Then each document
for doc in results.documents {
let mut doc_array = Vec::new();
// Add document ID if it exists
if let Some(id) = doc.fields.get("_id") {
doc_array.push(Protocol::BulkString(id.clone()));
}
// Add score
doc_array.push(Protocol::BulkString(doc.score.to_string()));
// Add fields as key-value pairs
for (field_name, field_value) in doc.fields {
if field_name != "_id" {
doc_array.push(Protocol::BulkString(field_name));
doc_array.push(Protocol::BulkString(field_value));
}
}
response.push(Protocol::Array(doc_array));
}
Ok(Protocol::Array(response))
}
pub async fn ft_del_cmd(
server: &Server,
index_name: String,
doc_id: String,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let _search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// For now, return success
// In a full implementation, we'd need to add a delete method to TantivySearch
println!("Deleting document '{}' from index '{}'", doc_id, index_name);
Ok(Protocol::SimpleString("1".to_string()))
}
pub async fn ft_info_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
let info = search_index.get_info()?;
// Format info as Redis protocol
let mut response = Vec::new();
response.push(Protocol::BulkString("index_name".to_string()));
response.push(Protocol::BulkString(info.name));
response.push(Protocol::BulkString("num_docs".to_string()));
response.push(Protocol::BulkString(info.num_docs.to_string()));
response.push(Protocol::BulkString("num_fields".to_string()));
response.push(Protocol::BulkString(info.fields.len().to_string()));
response.push(Protocol::BulkString("fields".to_string()));
let fields_str = info
.fields
.iter()
.map(|f| format!("{}:{}", f.name, f.field_type))
.collect::<Vec<_>>()
.join(", ");
response.push(Protocol::BulkString(fields_str));
Ok(Protocol::Array(response))
}
pub async fn ft_drop_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
let mut indexes = server.search_indexes.write().unwrap();
if indexes.remove(&index_name).is_some() {
// Also remove the index files from disk
let index_path = server.search_index_path().join(&index_name);
if index_path.exists() {
std::fs::remove_dir_all(index_path)
.map_err(|e| DBError(format!("Failed to remove index files: {}", e)))?;
}
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Err(DBError(format!("Index '{}' not found", index_name)))
}
}

View File

@@ -1,10 +1,9 @@
use core::str;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::sync::{oneshot, Mutex};
use tokio::sync::{Mutex, oneshot};
use std::sync::atomic::{AtomicU64, Ordering};
@@ -12,19 +11,17 @@ use crate::cmd::Cmd;
use crate::error::DBError;
use crate::options;
use crate::protocol::Protocol;
use crate::storage::Storage;
use crate::storage_sled::SledStorage;
use crate::storage_trait::StorageBackend;
use crate::tantivy_search::TantivySearch;
use crate::admin_meta;
#[derive(Clone)]
pub struct Server {
pub db_cache: Arc<RwLock<HashMap<u64, Arc<dyn StorageBackend>>>>,
pub search_indexes: Arc<RwLock<HashMap<String, Arc<TantivySearch>>>>,
pub db_cache: std::sync::Arc<std::sync::RwLock<HashMap<u64, Arc<dyn StorageBackend>>>>,
pub option: options::DBOption,
pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64
pub queued_cmd: Option<Vec<(Cmd, Protocol)>>,
pub current_permissions: Option<crate::rpc::Permissions>,
// BLPOP waiter registry: per (db_index, key) FIFO of waiters
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
@@ -46,12 +43,12 @@ pub enum PopSide {
impl Server {
pub async fn new(option: options::DBOption) -> Self {
Server {
db_cache: Arc::new(RwLock::new(HashMap::new())),
search_indexes: Arc::new(RwLock::new(HashMap::new())),
db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())),
option,
client_name: None,
selected_db: 0,
queued_cmd: None,
current_permissions: None,
list_waiters: Arc::new(Mutex::new(HashMap::new())),
waiter_seq: Arc::new(AtomicU64::new(1)),
@@ -65,58 +62,42 @@ impl Server {
return Ok(storage.clone());
}
// Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!(
"Failed to create directory {}: {}",
parent_dir.display(),
e
))
})?;
}
println!("Creating new db file: {}", db_file_path.display());
let storage: Arc<dyn StorageBackend> = match self.option.backend {
options::BackendType::Redb => Arc::new(Storage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref(),
)?),
options::BackendType::Sled => Arc::new(SledStorage::new(
db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref(),
)?),
// Use process-wide shared handles to avoid sled/reDB double-open lock contention.
let storage = if self.selected_db == 0 {
// Admin DB 0: always via singleton
admin_meta::open_admin_storage(
&self.option.dir,
self.option.backend.clone(),
&self.option.admin_secret,
)?
} else {
// Data DBs: via global registry keyed by id
admin_meta::open_data_storage(
&self.option.dir,
self.option.backend.clone(),
&self.option.admin_secret,
self.selected_db,
)?
};
cache.insert(self.selected_db, storage.clone());
Ok(storage)
}
fn should_encrypt_db(&self, db_index: u64) -> bool {
// DB 0-9 are non-encrypted, DB 10+ are encrypted
self.option.encrypt && db_index >= 10
/// Check if current permissions allow read operations
pub fn has_read_permission(&self) -> bool {
matches!(self.current_permissions, Some(crate::rpc::Permissions::Read) | Some(crate::rpc::Permissions::ReadWrite))
}
// Add method to get search index path
pub fn search_index_path(&self) -> std::path::PathBuf {
std::path::PathBuf::from(&self.option.dir).join("search_indexes")
/// Check if current permissions allow write operations
pub fn has_write_permission(&self) -> bool {
matches!(self.current_permissions, Some(crate::rpc::Permissions::ReadWrite))
}
// ----- BLPOP waiter helpers -----
pub async fn register_waiter(
&self,
db_index: u64,
key: &str,
side: PopSide,
) -> (u64, oneshot::Receiver<(String, String)>) {
pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) {
let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel::<(String, String)>();
@@ -192,7 +173,10 @@ impl Server {
Ok(())
}
pub async fn handle(&mut self, mut stream: tokio::net::TcpStream) -> Result<(), DBError> {
pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
// Accumulate incoming bytes to handle partial RESP frames
let mut acc = String::new();
let mut buf = vec![0u8; 8192];
@@ -229,10 +213,7 @@ impl Server {
acc = remaining.to_string();
if self.option.debug {
println!(
"\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m",
cmd, protocol
);
println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
} else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol);
}

View File

@@ -12,9 +12,9 @@ use crate::error::DBError;
// Re-export modules
mod storage_basic;
mod storage_extra;
mod storage_hset;
mod storage_lists;
mod storage_extra;
// Re-export implementations
// Note: These imports are used by the impl blocks in the submodules
@@ -28,8 +28,7 @@ const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("string
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> =
TableDefinition::new("streams_data");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
@@ -56,13 +55,9 @@ pub struct Storage {
}
impl Storage {
pub fn new(
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
let db = Database::create(path)?;
// Create tables if they don't exist
let write_txn = db.begin_write()?;
{
@@ -76,28 +71,23 @@ impl Storage {
let _ = write_txn.open_table(EXPIRATION_TABLE)?;
}
write_txn.commit()?;
// Check if database was previously encrypted
let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table
.get("encrypted")?
.map(|v| v.value() == 1)
.unwrap_or(false);
let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
drop(read_txn);
let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes()))
} else {
return Err(DBError(
"Encryption requested but no master key provided".to_string(),
));
return Err(DBError("Encryption requested but no master key provided".to_string()));
}
} else {
None
};
// If we're enabling encryption for the first time, mark it
if should_encrypt && !was_encrypted {
let write_txn = db.begin_write()?;
@@ -107,10 +97,13 @@ impl Storage {
}
write_txn.commit()?;
}
Ok(Storage { db, crypto })
Ok(Storage {
db,
crypto,
})
}
pub fn is_encrypted(&self) -> bool {
self.crypto.is_some()
}
@@ -123,7 +116,7 @@ impl Storage {
Ok(data.to_vec())
}
}
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?)
@@ -172,22 +165,11 @@ impl StorageBackend for Storage {
self.get_key_type(key)
}
fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
self.scan(cursor, pattern, count)
}
fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
self.hscan(key, cursor, pattern, count)
}
@@ -294,7 +276,7 @@ impl StorageBackend for Storage {
fn is_encrypted(&self) -> bool {
self.is_encrypted()
}
fn info(&self) -> Result<Vec<(String, String)>, DBError> {
self.info()
}
@@ -302,4 +284,4 @@ impl StorageBackend for Storage {
fn clone_arc(&self) -> Arc<dyn StorageBackend> {
unimplemented!("Storage cloning not yet implemented for redb backend")
}
}
}

View File

@@ -1,6 +1,6 @@
use super::*;
use redb::{ReadableTable};
use crate::error::DBError;
use redb::ReadableTable;
use super::*;
impl Storage {
pub fn flushdb(&self) -> Result<(), DBError> {
@@ -15,17 +15,11 @@ impl Storage {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way
let keys: Vec<String> = types_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
types_table.remove(key.as_str())?;
}
let keys: Vec<String> = strings_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
strings_table.remove(key.as_str())?;
}
@@ -40,35 +34,23 @@ impl Storage {
for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = lists_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
lists_table.remove(key.as_str())?;
}
let keys: Vec<String> = streams_meta_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
streams_meta_table.remove(key.as_str())?;
}
let keys: Vec<(String, String)> = streams_data_table
.iter()?
.map(|item| {
let binding = item.unwrap();
let (key, field) = binding.0.value();
(key.to_string(), field.to_string())
})
.collect();
let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
let binding = item.unwrap();
let (key, field) = binding.0.value();
(key.to_string(), field.to_string())
}).collect();
for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?;
}
let keys: Vec<String> = expiration_table
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
for key in keys {
expiration_table.remove(key.as_str())?;
}
@@ -80,7 +62,7 @@ impl Storage {
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
// Before returning type, check for expiration
if let Some(type_val) = table.get(key)? {
if type_val.value() == "string" {
@@ -101,7 +83,7 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
@@ -114,7 +96,7 @@ impl Storage {
return Ok(None);
}
}
// Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? {
@@ -133,21 +115,21 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
@@ -155,42 +137,41 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?;
}
write_txn.commit()?;
Ok(())
}
pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> =
write_txn.open_table(HASHES_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table
types_table.remove(key.as_str())?;
// Remove from strings table
strings_table.remove(key.as_str())?;
// Remove all hash fields for this key
let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?;
@@ -202,19 +183,19 @@ impl Storage {
}
}
drop(iter);
for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?;
}
// Remove from lists table
lists_table.remove(key.as_str())?;
// Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?;
}
write_txn.commit()?;
Ok(())
}
@@ -222,7 +203,7 @@ impl Storage {
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new();
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
@@ -231,7 +212,7 @@ impl Storage {
keys.push(key);
}
}
Ok(keys)
}
}
@@ -261,4 +242,4 @@ impl Storage {
}
Ok(count)
}
}
}

View File

@@ -1,29 +1,24 @@
use super::*;
use redb::{ReadableTable};
use crate::error::DBError;
use redb::ReadableTable;
use super::*;
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let strings_table = read_txn.open_table(STRINGS_TABLE)?;
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let key = entry.0.value().to_string();
let key_type = entry.1.value().to_string();
if current_cursor >= cursor {
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
@@ -31,7 +26,7 @@ impl Storage {
} else {
true
};
if matches {
// For scan, we return key-value pairs for string types
if key_type == "string" {
@@ -46,7 +41,7 @@ impl Storage {
// For non-string types, just return the key with type as value
result.push((key, key_type));
}
if result.len() >= limit {
break;
}
@@ -54,19 +49,15 @@ impl Storage {
}
current_cursor += 1;
}
let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
@@ -84,14 +75,14 @@ impl Storage {
}
}
Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types)
None => Ok(-2), // Key does not exist
None => Ok(-2), // Key does not exist
}
}
pub fn exists(&self, key: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => {
// Check if string key has expired
@@ -104,7 +95,7 @@ impl Storage {
Ok(true)
}
Some(_) => Ok(true), // Key exists and is not a string
None => Ok(false), // Key does not exist
None => Ok(false), // Key does not exist
}
}
@@ -187,12 +178,8 @@ impl Storage {
.unwrap_or(false);
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_secs <= 0 {
0
} else {
(ts_secs as u128) * 1000
};
expiration_table.insert(key, &(expires_at_ms as u64))?;
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
expiration_table.insert(key, &((expires_at_ms as u64)))?;
applied = true;
}
}
@@ -214,7 +201,7 @@ impl Storage {
if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
expiration_table.insert(key, &(expires_at_ms as u64))?;
expiration_table.insert(key, &((expires_at_ms as u64)))?;
applied = true;
}
}
@@ -236,21 +223,21 @@ pub fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" {
return true;
}
// Simple glob matching - supports * and ? wildcards
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() {
return ti >= text.len();
}
if ti >= text.len() {
// Check if remaining pattern is all '*'
return pattern[pi..].iter().all(|&c| c == '*');
}
match pattern[pi] {
'*' => {
// Try matching zero or more characters
@@ -275,7 +262,7 @@ pub fn glob_match(pattern: &str, text: &str) -> bool {
}
}
}
match_recursive(&pattern_chars, &text_chars, 0, 0)
}
@@ -296,4 +283,4 @@ mod tests {
assert!(glob_match("*test*", "this_is_a_test_string"));
assert!(!glob_match("*test*", "this_is_a_string"));
}
}
}

View File

@@ -1,50 +1,44 @@
use super::*;
use redb::{ReadableTable};
use crate::error::DBError;
use redb::ReadableTable;
use super::*;
impl Storage {
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage
pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut new_fields = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") | None => {
// Proceed if hash or new key
Some("hash") | None => { // Proceed if hash or new key
// Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?;
for (field, value) in pairs {
// Check if field already exists
let exists = hashes_table.get((key, field.as_str()))?.is_some();
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !exists {
new_fields += 1;
}
}
}
Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
}
}
write_txn.commit()?;
Ok(new_fields)
}
@@ -53,7 +47,7 @@ impl Storage {
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = types_table.get(key)?.map(|v| v.value().to_string());
match key_type.as_deref() {
@@ -68,9 +62,7 @@ impl Storage {
None => Ok(None),
}
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(None),
}
}
@@ -88,7 +80,7 @@ impl Storage {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
@@ -99,12 +91,10 @@ impl Storage {
result.push((field.to_string(), value));
}
}
Ok(result)
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
@@ -112,24 +102,24 @@ impl Storage {
pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut deleted = 0i64;
// First check if key exists and is a hash
let key_type = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") => {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1;
}
}
// Check if hash is now empty and remove type if so
let mut has_fields = false;
let mut iter = hashes_table.iter()?;
@@ -142,20 +132,16 @@ impl Storage {
}
}
drop(iter);
if !has_fields {
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?;
}
}
Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => {} // Key does not exist, nothing to delete, return 0 deleted
}
write_txn.commit()?;
Ok(deleted)
}
@@ -173,9 +159,7 @@ impl Storage {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some())
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(false),
}
}
@@ -192,7 +176,7 @@ impl Storage {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
@@ -201,12 +185,10 @@ impl Storage {
result.push(field.to_string());
}
}
Ok(result)
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
@@ -224,7 +206,7 @@ impl Storage {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
@@ -235,12 +217,10 @@ impl Storage {
result.push(value);
}
}
Ok(result)
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(Vec::new()),
}
}
@@ -257,7 +237,7 @@ impl Storage {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0i64;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
@@ -266,12 +246,10 @@ impl Storage {
count += 1;
}
}
Ok(count)
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(0),
}
}
@@ -289,7 +267,7 @@ impl Storage {
Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new();
for field in fields {
match hashes_table.get((key, field.as_str()))? {
Some(data) => {
@@ -300,12 +278,10 @@ impl Storage {
None => result.push(None),
}
}
Ok(result)
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok(fields.into_iter().map(|_| None).collect()),
}
}
@@ -314,51 +290,39 @@ impl Storage {
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = false;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = {
let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string())
};
match key_type.as_deref() {
Some("hash") | None => {
// Proceed if hash or new key
Some("hash") | None => { // Proceed if hash or new key
// Check if field already exists
if hashes_table.get((key, field))?.is_none() {
// Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?;
// Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted.as_slice())?;
result = true;
}
}
Some(_) => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
}
}
write_txn.commit()?;
Ok(result)
}
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = {
@@ -372,28 +336,28 @@ impl Storage {
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() {
let entry = entry?;
let (hash_key, field) = entry.0.value();
if hash_key == key {
if current_cursor >= cursor {
let field_str = field.to_string();
// Apply pattern matching if specified
let matches = if let Some(pat) = pattern {
super::storage_extra::glob_match(pat, &field_str)
} else {
true
};
if matches {
let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?;
result.push((field_str, value));
if result.len() >= limit {
break;
}
@@ -402,18 +366,12 @@ impl Storage {
current_cursor += 1;
}
}
let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
Some(_) => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
None => Ok((0, Vec::new())),
}
}
}
}

View File

@@ -1,20 +1,20 @@
use super::*;
use redb::{ReadableTable};
use crate::error::DBError;
use redb::ReadableTable;
use super::*;
impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
@@ -23,20 +23,20 @@ impl Storage {
}
None => Vec::new(),
};
// Add elements to the front (left)
for element in elements.into_iter() {
list.insert(0, element);
}
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
@@ -45,14 +45,14 @@ impl Storage {
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut _length = 0i64;
{
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list
types_table.insert(key, "list")?;
// Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => {
@@ -61,17 +61,17 @@ impl Storage {
}
None => Vec::new(),
};
// Add elements to the end (right)
list.extend(elements);
_length = list.len() as i64;
// Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?;
}
write_txn.commit()?;
Ok(_length)
}
@@ -80,12 +80,12 @@ impl Storage {
pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
@@ -100,7 +100,7 @@ impl Storage {
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
@@ -108,7 +108,7 @@ impl Storage {
result.push(list.remove(0));
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
@@ -122,7 +122,7 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
@@ -131,12 +131,12 @@ impl Storage {
pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?;
let mut result = Vec::new();
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
@@ -151,7 +151,7 @@ impl Storage {
};
result
};
if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count {
@@ -159,7 +159,7 @@ impl Storage {
result.push(list.pop().unwrap());
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
// Remove the key if list is empty
@@ -173,7 +173,7 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(result)
}
@@ -181,7 +181,7 @@ impl Storage {
pub fn llen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -202,7 +202,7 @@ impl Storage {
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -210,13 +210,13 @@ impl Storage {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
let actual_index = if index < 0 {
list.len() as i64 + index
} else {
index
};
if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone()))
} else {
@@ -234,7 +234,7 @@ impl Storage {
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -242,30 +242,22 @@ impl Storage {
Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?;
if list.is_empty() {
return Ok(Vec::new());
}
let len = list.len() as i64;
let start_idx = if start < 0 {
std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new());
}
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
}
None => Ok(Vec::new()),
@@ -278,12 +270,12 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
@@ -298,25 +290,17 @@ impl Storage {
};
result
};
if let Some(list) = list_data {
if list.is_empty() {
write_txn.commit()?;
return Ok(());
}
let len = list.len() as i64;
let start_idx = if start < 0 {
std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if start_idx > stop_idx || start_idx >= len {
// Remove the entire list
@@ -327,7 +311,7 @@ impl Storage {
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if trimmed.is_empty() {
lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?;
@@ -340,7 +324,7 @@ impl Storage {
}
}
}
write_txn.commit()?;
Ok(())
}
@@ -349,12 +333,12 @@ impl Storage {
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?;
let mut removed = 0i64;
// First check if key exists and is a list, and get the data
let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? {
@@ -369,7 +353,7 @@ impl Storage {
};
result
};
if let Some(mut list) = list_data {
if count == 0 {
// Remove all occurrences
@@ -399,7 +383,7 @@ impl Storage {
}
}
}
let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() {
lists_table.remove(key)?;
@@ -412,8 +396,8 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?;
}
}
write_txn.commit()?;
Ok(removed)
}
}
}

View File

@@ -1,12 +1,12 @@
// src/storage_sled/mod.rs
use crate::crypto::CryptoFactory;
use crate::error::DBError;
use crate::storage_trait::StorageBackend;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::error::DBError;
use crate::storage_trait::StorageBackend;
use crate::crypto::CryptoFactory;
#[derive(Serialize, Deserialize, Debug, Clone)]
enum ValueType {
@@ -28,56 +28,44 @@ pub struct SledStorage {
}
impl SledStorage {
pub fn new(
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?;
let types = db
.open_tree("types")
.map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?;
let types = db.open_tree("types").map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?;
// Check if database was previously encrypted
let encrypted_tree = db
.open_tree("encrypted")
.map_err(|e| DBError(e.to_string()))?;
let was_encrypted = encrypted_tree
.get("encrypted")
let encrypted_tree = db.open_tree("encrypted").map_err(|e| DBError(e.to_string()))?;
let was_encrypted = encrypted_tree.get("encrypted")
.map_err(|e| DBError(e.to_string()))?
.map(|v| v[0] == 1)
.unwrap_or(false);
let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes()))
} else {
return Err(DBError(
"Encryption requested but no master key provided".to_string(),
));
return Err(DBError("Encryption requested but no master key provided".to_string()));
}
} else {
None
};
// Mark database as encrypted if enabling encryption
if should_encrypt && !was_encrypted {
encrypted_tree
.insert("encrypted", &[1u8])
encrypted_tree.insert("encrypted", &[1u8])
.map_err(|e| DBError(e.to_string()))?;
encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?;
}
Ok(SledStorage { db, types, crypto })
}
fn now_millis() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis()
}
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.encrypt(data))
@@ -85,7 +73,7 @@ impl SledStorage {
Ok(data.to_vec())
}
}
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?)
@@ -93,14 +81,14 @@ impl SledStorage {
Ok(data.to_vec())
}
}
fn get_storage_value(&self, key: &str) -> Result<Option<StorageValue>, DBError> {
match self.db.get(key).map_err(|e| DBError(e.to_string()))? {
Some(encrypted_data) => {
let decrypted = self.decrypt_if_needed(&encrypted_data)?;
let storage_val: StorageValue = bincode::deserialize(&decrypted)
.map_err(|e| DBError(format!("Deserialization error: {}", e)))?;
// Check expiration
if let Some(expires_at) = storage_val.expires_at {
if Self::now_millis() > expires_at {
@@ -110,51 +98,47 @@ impl SledStorage {
return Ok(None);
}
}
Ok(Some(storage_val))
}
None => Ok(None),
None => Ok(None)
}
}
fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> {
let data = bincode::serialize(&storage_val)
.map_err(|e| DBError(format!("Serialization error: {}", e)))?;
let encrypted = self.encrypt_if_needed(&data)?;
self.db
.insert(key, encrypted)
.map_err(|e| DBError(e.to_string()))?;
self.db.insert(key, encrypted).map_err(|e| DBError(e.to_string()))?;
// Store type info (unencrypted for efficiency)
let type_str = match &storage_val.value {
ValueType::String(_) => "string",
ValueType::Hash(_) => "hash",
ValueType::List(_) => "list",
};
self.types
.insert(key, type_str.as_bytes())
.map_err(|e| DBError(e.to_string()))?;
self.types.insert(key, type_str.as_bytes()).map_err(|e| DBError(e.to_string()))?;
Ok(())
}
fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" {
return true;
}
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() {
return ti >= text.len();
}
if ti >= text.len() {
return pattern[pi..].iter().all(|&c| c == '*');
}
match pattern[pi] {
'*' => {
for i in ti..=text.len() {
@@ -174,7 +158,7 @@ impl SledStorage {
}
}
}
match_recursive(&pattern_chars, &text_chars, 0, 0)
}
}
@@ -184,12 +168,12 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::String(s) => Ok(Some(s)),
_ => Ok(None),
},
None => Ok(None),
_ => Ok(None)
}
None => Ok(None)
}
}
fn set(&self, key: String, value: String) -> Result<(), DBError> {
let storage_val = StorageValue {
value: ValueType::String(value),
@@ -199,7 +183,7 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(())
}
fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let storage_val = StorageValue {
value: ValueType::String(value),
@@ -209,27 +193,25 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(())
}
fn del(&self, key: String) -> Result<(), DBError> {
self.db.remove(&key).map_err(|e| DBError(e.to_string()))?;
self.types
.remove(&key)
.map_err(|e| DBError(e.to_string()))?;
self.types.remove(&key).map_err(|e| DBError(e.to_string()))?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(())
}
fn exists(&self, key: &str) -> Result<bool, DBError> {
// Check with expiration
Ok(self.get_storage_value(key)?.is_some())
}
fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let mut keys = Vec::new();
for item in self.types.iter() {
let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?;
let key = String::from_utf8_lossy(&key_bytes).to_string();
// Check if key is expired
if self.get_storage_value(&key)?.is_some() {
if Self::glob_match(pattern, &key) {
@@ -239,29 +221,24 @@ impl StorageBackend for SledStorage {
}
Ok(keys)
}
fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
for item in self.types.iter() {
if current_cursor >= cursor {
let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?;
let key = String::from_utf8_lossy(&key_bytes).to_string();
// Check pattern match
let matches = if let Some(pat) = pattern {
Self::glob_match(pat, &key)
} else {
true
};
if matches {
// Check if key is expired and get value
if let Some(storage_val) = self.get_storage_value(&key)? {
@@ -270,7 +247,7 @@ impl StorageBackend for SledStorage {
_ => String::from_utf8_lossy(&type_bytes).to_string(),
};
result.push((key, value));
if result.len() >= limit {
current_cursor += 1;
break;
@@ -280,15 +257,11 @@ impl StorageBackend for SledStorage {
}
current_cursor += 1;
}
let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
fn dbsize(&self) -> Result<i64, DBError> {
let mut count = 0i64;
for item in self.types.iter() {
@@ -300,42 +273,38 @@ impl StorageBackend for SledStorage {
}
Ok(count)
}
fn flushdb(&self) -> Result<(), DBError> {
self.db.clear().map_err(|e| DBError(e.to_string()))?;
self.types.clear().map_err(|e| DBError(e.to_string()))?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(())
}
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
// First check if key exists (handles expiration)
if self.get_storage_value(key)?.is_some() {
match self.types.get(key).map_err(|e| DBError(e.to_string()))? {
Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())),
None => Ok(None),
None => Ok(None)
}
} else {
Ok(None)
}
}
// Hash operations
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::Hash(HashMap::new()),
expires_at: None,
});
let hash = match &mut storage_val.value {
ValueType::Hash(h) => h,
_ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
};
let mut new_fields = 0i64;
for (field, value) in pairs {
if !hash.contains_key(&field) {
@@ -343,46 +312,40 @@ impl StorageBackend for SledStorage {
}
hash.insert(field, value);
}
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(new_fields)
}
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.get(field).cloned()),
_ => Ok(None),
},
None => Ok(None),
_ => Ok(None)
}
None => Ok(None)
}
}
fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.into_iter().collect()),
_ => Ok(Vec::new()),
},
None => Ok(Vec::new()),
_ => Ok(Vec::new())
}
None => Ok(Vec::new())
}
}
fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => {
let mut result = Vec::new();
let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize;
for (field, value) in h.iter() {
if current_cursor >= cursor {
let matches = if let Some(pat) = pattern {
@@ -390,7 +353,7 @@ impl StorageBackend for SledStorage {
} else {
true
};
if matches {
result.push((field.clone(), value.clone()));
if result.len() >= limit {
@@ -401,115 +364,107 @@ impl StorageBackend for SledStorage {
}
current_cursor += 1;
}
let next_cursor = if result.len() < limit {
0
} else {
current_cursor
};
let next_cursor = if result.len() < limit { 0 } else { current_cursor };
Ok((next_cursor, result))
}
_ => Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
},
None => Ok((0, Vec::new())),
_ => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()))
}
None => Ok((0, Vec::new()))
}
}
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(0),
None => return Ok(0)
};
let hash = match &mut storage_val.value {
ValueType::Hash(h) => h,
_ => return Ok(0),
_ => return Ok(0)
};
let mut deleted = 0i64;
for field in fields {
if hash.remove(&field).is_some() {
deleted += 1;
}
}
if hash.is_empty() {
self.del(key.to_string())?;
} else {
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
}
Ok(deleted)
}
fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.contains_key(field)),
_ => Ok(false),
},
None => Ok(false),
_ => Ok(false)
}
None => Ok(false)
}
}
fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.keys().cloned().collect()),
_ => Ok(Vec::new()),
},
None => Ok(Vec::new()),
_ => Ok(Vec::new())
}
None => Ok(Vec::new())
}
}
fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.values().cloned().collect()),
_ => Ok(Vec::new()),
},
None => Ok(Vec::new()),
_ => Ok(Vec::new())
}
None => Ok(Vec::new())
}
}
fn hlen(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.len() as i64),
_ => Ok(0),
},
None => Ok(0),
_ => Ok(0)
}
None => Ok(0)
}
}
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()),
_ => Ok(fields.into_iter().map(|_| None).collect()),
},
None => Ok(fields.into_iter().map(|_| None).collect()),
ValueType::Hash(h) => {
Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect())
}
_ => Ok(fields.into_iter().map(|_| None).collect())
}
None => Ok(fields.into_iter().map(|_| None).collect())
}
}
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::Hash(HashMap::new()),
expires_at: None,
});
let hash = match &mut storage_val.value {
ValueType::Hash(h) => h,
_ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
};
if hash.contains_key(field) {
Ok(false)
} else {
@@ -519,66 +474,58 @@ impl StorageBackend for SledStorage {
Ok(true)
}
}
// List operations
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::List(Vec::new()),
expires_at: None,
});
let list = match &mut storage_val.value {
ValueType::List(l) => l,
_ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
};
for element in elements.into_iter().rev() {
list.insert(0, element);
}
let len = list.len() as i64;
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(len)
}
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::List(Vec::new()),
expires_at: None,
});
let list = match &mut storage_val.value {
ValueType::List(l) => l,
_ => {
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
_ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
};
list.extend(elements);
let len = list.len() as i64;
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(len)
}
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(Vec::new()),
None => return Ok(Vec::new())
};
let list = match &mut storage_val.value {
ValueType::List(l) => l,
_ => return Ok(Vec::new()),
_ => return Ok(Vec::new())
};
let mut result = Vec::new();
for _ in 0..count.min(list.len() as u64) {
if let Some(elem) = list.first() {
@@ -586,55 +533,55 @@ impl StorageBackend for SledStorage {
list.remove(0);
}
}
if list.is_empty() {
self.del(key.to_string())?;
} else {
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
}
Ok(result)
}
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(Vec::new()),
None => return Ok(Vec::new())
};
let list = match &mut storage_val.value {
ValueType::List(l) => l,
_ => return Ok(Vec::new()),
_ => return Ok(Vec::new())
};
let mut result = Vec::new();
for _ in 0..count.min(list.len() as u64) {
if let Some(elem) = list.pop() {
result.push(elem);
}
}
if list.is_empty() {
self.del(key.to_string())?;
} else {
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
}
Ok(result)
}
fn llen(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
ValueType::List(l) => Ok(l.len() as i64),
_ => Ok(0),
},
None => Ok(0),
_ => Ok(0)
}
None => Ok(0)
}
}
fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
@@ -644,19 +591,19 @@ impl StorageBackend for SledStorage {
} else {
index
};
if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone()))
} else {
Ok(None)
}
}
_ => Ok(None),
},
None => Ok(None),
_ => Ok(None)
}
None => Ok(None)
}
}
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value {
@@ -664,68 +611,68 @@ impl StorageBackend for SledStorage {
if list.is_empty() {
return Ok(Vec::new());
}
let len = list.len() as i64;
let start_idx = if start < 0 {
std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
let start_idx = if start < 0 {
std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new());
}
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
}
_ => Ok(Vec::new()),
},
None => Ok(Vec::new()),
_ => Ok(Vec::new())
}
None => Ok(Vec::new())
}
}
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(()),
None => return Ok(())
};
let list = match &mut storage_val.value {
ValueType::List(l) => l,
_ => return Ok(()),
_ => return Ok(())
};
if list.is_empty() {
return Ok(());
}
let len = list.len() as i64;
let start_idx = if start < 0 {
std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
let start_idx = if start < 0 {
std::cmp::max(0, len + start)
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
if start_idx > stop_idx || start_idx >= len {
self.del(key.to_string())?;
} else {
let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize;
*list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if list.is_empty() {
self.del(key.to_string())?;
} else {
@@ -733,23 +680,23 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?;
}
}
Ok(())
}
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(0),
None => return Ok(0)
};
let list = match &mut storage_val.value {
ValueType::List(l) => l,
_ => return Ok(0),
_ => return Ok(0)
};
let mut removed = 0i64;
if count == 0 {
// Remove all occurrences
let original_len = list.len();
@@ -778,17 +725,17 @@ impl StorageBackend for SledStorage {
}
}
}
if list.is_empty() {
self.del(key.to_string())?;
} else {
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
}
Ok(removed)
}
// Expiration
fn ttl(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? {
@@ -804,40 +751,40 @@ impl StorageBackend for SledStorage {
Ok(-1) // Key exists but has no expiration
}
}
None => Ok(-2), // Key does not exist
None => Ok(-2) // Key does not exist
}
}
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(false),
None => return Ok(false)
};
storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000);
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true)
}
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(false),
None => return Ok(false)
};
storage_val.expires_at = Some(Self::now_millis() + ms);
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true)
}
fn persist(&self, key: &str) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(false),
None => return Ok(false)
};
if storage_val.expires_at.is_some() {
storage_val.expires_at = None;
self.set_storage_value(key, storage_val)?;
@@ -847,41 +794,37 @@ impl StorageBackend for SledStorage {
Ok(false)
}
}
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(false),
};
let expires_at_ms: u128 = if ts_secs <= 0 {
0
} else {
(ts_secs as u128) * 1000
None => return Ok(false)
};
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
storage_val.expires_at = Some(expires_at_ms);
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true)
}
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv,
None => return Ok(false),
None => return Ok(false)
};
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
storage_val.expires_at = Some(expires_at_ms);
self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true)
}
fn is_encrypted(&self) -> bool {
self.crypto.is_some()
}
fn info(&self) -> Result<Vec<(String, String)>, DBError> {
let dbsize = self.dbsize()?;
Ok(vec![
@@ -899,4 +842,4 @@ impl StorageBackend for SledStorage {
crypto: self.crypto.clone(),
})
}
}
}

View File

@@ -13,22 +13,11 @@ pub trait StorageBackend: Send + Sync {
fn dbsize(&self) -> Result<i64, DBError>;
fn flushdb(&self) -> Result<(), DBError>;
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>;
// Scanning
fn scan(
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>;
fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>;
// Hash operations
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>;
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError>;
@@ -40,7 +29,7 @@ pub trait StorageBackend: Send + Sync {
fn hlen(&self, key: &str) -> Result<i64, DBError>;
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError>;
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError>;
// List operations
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
@@ -51,7 +40,7 @@ pub trait StorageBackend: Send + Sync {
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError>;
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError>;
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError>;
// Expiration
fn ttl(&self, key: &str) -> Result<i64, DBError>;
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError>;
@@ -59,11 +48,11 @@ pub trait StorageBackend: Send + Sync {
fn persist(&self, key: &str) -> Result<bool, DBError>;
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError>;
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError>;
// Metadata
fn is_encrypted(&self) -> bool;
fn info(&self) -> Result<Vec<(String, String)>, DBError>;
// Clone to Arc for sharing
fn clone_arc(&self) -> Arc<dyn StorageBackend>;
}
}

123
src/sym.rs Normal file
View File

@@ -0,0 +1,123 @@
//! sym.rs — Stateless symmetric encryption (Phase 1)
//!
//! Commands implemented (RESP):
//! - SYM KEYGEN
//! - SYM ENCRYPT <key_b64> <message>
//! - SYM DECRYPT <key_b64> <ciphertext_b64>
//!
//! Notes:
//! - Raw key: exactly 32 bytes, provided as Base64 in commands.
//! - Cipher: XChaCha20-Poly1305 (AEAD) without AAD in Phase 1
//! - Ciphertext binary layout: [version:1][nonce:24][ciphertext||tag]
//! - Encoding for wire I/O: Base64
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
XChaCha20Poly1305, XNonce,
};
use rand::RngCore;
use crate::protocol::Protocol;
const VERSION: u8 = 1;
const NONCE_LEN: usize = 24;
const TAG_LEN: usize = 16;
#[derive(Debug)]
pub enum SymWireError {
InvalidKey,
BadEncoding,
BadFormat,
BadVersion(u8),
Crypto,
}
impl SymWireError {
fn to_protocol(self) -> Protocol {
match self {
SymWireError::InvalidKey => Protocol::err("ERR sym: invalid key"),
SymWireError::BadEncoding => Protocol::err("ERR sym: bad encoding"),
SymWireError::BadFormat => Protocol::err("ERR sym: bad format"),
SymWireError::BadVersion(v) => Protocol::err(&format!("ERR sym: unsupported version {}", v)),
SymWireError::Crypto => Protocol::err("ERR sym: auth failed"),
}
}
}
fn decode_key_b64(s: &str) -> Result<chacha20poly1305::Key, SymWireError> {
let bytes = B64.decode(s.as_bytes()).map_err(|_| SymWireError::BadEncoding)?;
if bytes.len() != 32 {
return Err(SymWireError::InvalidKey);
}
Ok(chacha20poly1305::Key::from_slice(&bytes).to_owned())
}
fn encrypt_blob(key: &chacha20poly1305::Key, plaintext: &[u8]) -> Result<Vec<u8>, SymWireError> {
let cipher = XChaCha20Poly1305::new(key);
let mut nonce_bytes = [0u8; NONCE_LEN];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
out.push(VERSION);
out.extend_from_slice(&nonce_bytes);
let ct = cipher.encrypt(nonce, plaintext).map_err(|_| SymWireError::Crypto)?;
out.extend_from_slice(&ct);
Ok(out)
}
fn decrypt_blob(key: &chacha20poly1305::Key, blob: &[u8]) -> Result<Vec<u8>, SymWireError> {
if blob.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(SymWireError::BadFormat);
}
let ver = blob[0];
if ver != VERSION {
return Err(SymWireError::BadVersion(ver));
}
let nonce = XNonce::from_slice(&blob[1..1 + NONCE_LEN]);
let ct = &blob[1 + NONCE_LEN..];
let cipher = XChaCha20Poly1305::new(key);
cipher.decrypt(nonce, ct).map_err(|_| SymWireError::Crypto)
}
// ---------- Command handlers (RESP) ----------
pub async fn cmd_sym_keygen() -> Protocol {
let mut key_bytes = [0u8; 32];
OsRng.fill_bytes(&mut key_bytes);
let key_b64 = B64.encode(key_bytes);
Protocol::BulkString(key_b64)
}
pub async fn cmd_sym_encrypt(key_b64: &str, message: &str) -> Protocol {
let key = match decode_key_b64(key_b64) {
Ok(k) => k,
Err(e) => return e.to_protocol(),
};
match encrypt_blob(&key, message.as_bytes()) {
Ok(blob) => Protocol::BulkString(B64.encode(blob)),
Err(e) => e.to_protocol(),
}
}
pub async fn cmd_sym_decrypt(key_b64: &str, ct_b64: &str) -> Protocol {
let key = match decode_key_b64(key_b64) {
Ok(k) => k,
Err(e) => return e.to_protocol(),
};
let blob = match B64.decode(ct_b64.as_bytes()) {
Ok(b) => b,
Err(_) => return SymWireError::BadEncoding.to_protocol(),
};
match decrypt_blob(&key, &blob) {
Ok(pt) => match String::from_utf8(pt) {
Ok(s) => Protocol::BulkString(s),
Err(_) => Protocol::err("ERR sym: invalid UTF-8 plaintext"),
},
Err(e) => e.to_protocol(),
}
}

View File

@@ -1,657 +0,0 @@
use crate::error::DBError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use tantivy::{
collector::TopDocs,
directory::MmapDirectory,
query::{BooleanQuery, Occur, Query, QueryParser, TermQuery},
schema::{Field, Schema, TextFieldIndexing, TextOptions, Value, STORED, STRING},
tokenizer::TokenizerManager,
DateTime, Index, IndexReader, IndexWriter, ReloadPolicy, TantivyDocument, Term,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FieldDef {
Text {
stored: bool,
indexed: bool,
tokenized: bool,
fast: bool,
},
Numeric {
stored: bool,
indexed: bool,
fast: bool,
precision: NumericType,
},
Tag {
stored: bool,
separator: String,
case_sensitive: bool,
},
Geo {
stored: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NumericType {
I64,
U64,
F64,
Date,
}
pub struct IndexSchema {
schema: Schema,
fields: HashMap<String, (Field, FieldDef)>,
default_search_fields: Vec<Field>,
}
pub struct TantivySearch {
index: Index,
writer: Arc<RwLock<IndexWriter>>,
reader: IndexReader,
index_schema: IndexSchema,
name: String,
config: IndexConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexConfig {
pub language: String,
pub stopwords: Vec<String>,
pub stemming: bool,
pub max_doc_count: Option<usize>,
pub default_score: f64,
}
impl Default for IndexConfig {
fn default() -> Self {
IndexConfig {
language: "english".to_string(),
stopwords: vec![],
stemming: true,
max_doc_count: None,
default_score: 1.0,
}
}
}
impl TantivySearch {
pub fn new_with_schema(
base_path: PathBuf,
name: String,
field_definitions: Vec<(String, FieldDef)>,
config: Option<IndexConfig>,
) -> Result<Self, DBError> {
let index_path = base_path.join(&name);
std::fs::create_dir_all(&index_path)
.map_err(|e| DBError(format!("Failed to create index dir: {}", e)))?;
// Build schema from field definitions
let mut schema_builder = Schema::builder();
let mut fields = HashMap::new();
let mut default_search_fields = Vec::new();
// Always add a document ID field
let id_field = schema_builder.add_text_field("_id", STRING | STORED);
fields.insert(
"_id".to_string(),
(
id_field,
FieldDef::Text {
stored: true,
indexed: true,
tokenized: false,
fast: false,
},
),
);
// Add user-defined fields
for (field_name, field_def) in field_definitions {
let field = match &field_def {
FieldDef::Text {
stored,
indexed,
tokenized,
fast: _fast,
} => {
let mut text_options = TextOptions::default();
if *stored {
text_options = text_options.set_stored();
}
if *indexed {
let indexing_options = if *tokenized {
TextFieldIndexing::default()
.set_tokenizer("default")
.set_index_option(
tantivy::schema::IndexRecordOption::WithFreqsAndPositions,
)
} else {
TextFieldIndexing::default()
.set_tokenizer("raw")
.set_index_option(tantivy::schema::IndexRecordOption::Basic)
};
text_options = text_options.set_indexing_options(indexing_options);
let f = schema_builder.add_text_field(&field_name, text_options);
if *tokenized {
default_search_fields.push(f);
}
f
} else {
schema_builder.add_text_field(&field_name, text_options)
}
}
FieldDef::Numeric {
stored,
indexed,
fast,
precision,
} => match precision {
NumericType::I64 => {
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_i64_field(&field_name, opts)
}
NumericType::U64 => {
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_u64_field(&field_name, opts)
}
NumericType::F64 => {
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_f64_field(&field_name, opts)
}
NumericType::Date => {
let mut opts = tantivy::schema::DateOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_date_field(&field_name, opts)
}
},
FieldDef::Tag {
stored,
separator: _,
case_sensitive: _,
} => {
let mut text_options = TextOptions::default();
if *stored {
text_options = text_options.set_stored();
}
text_options = text_options.set_indexing_options(
TextFieldIndexing::default()
.set_tokenizer("raw")
.set_index_option(tantivy::schema::IndexRecordOption::Basic),
);
schema_builder.add_text_field(&field_name, text_options)
}
FieldDef::Geo { stored } => {
// For now, store as two f64 fields for lat/lon
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
opts = opts.set_indexed().set_fast();
let lat_field =
schema_builder.add_f64_field(&format!("{}_lat", field_name), opts.clone());
let lon_field =
schema_builder.add_f64_field(&format!("{}_lon", field_name), opts);
fields.insert(
format!("{}_lat", field_name),
(
lat_field,
FieldDef::Numeric {
stored: *stored,
indexed: true,
fast: true,
precision: NumericType::F64,
},
),
);
fields.insert(
format!("{}_lon", field_name),
(
lon_field,
FieldDef::Numeric {
stored: *stored,
indexed: true,
fast: true,
precision: NumericType::F64,
},
),
);
continue; // Skip adding the geo field itself
}
};
fields.insert(field_name.clone(), (field, field_def));
}
let schema = schema_builder.build();
let index_schema = IndexSchema {
schema: schema.clone(),
fields,
default_search_fields,
};
// Create or open index
let dir = MmapDirectory::open(&index_path)
.map_err(|e| DBError(format!("Failed to open index directory: {}", e)))?;
let mut index = Index::open_or_create(dir, schema)
.map_err(|e| DBError(format!("Failed to create index: {}", e)))?;
// Configure tokenizers
let tokenizer_manager = TokenizerManager::default();
index.set_tokenizers(tokenizer_manager);
let writer = index
.writer(15_000_000)
.map_err(|e| DBError(format!("Failed to create index writer: {}", e)))?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommitWithDelay)
.try_into()
.map_err(|e| DBError(format!("Failed to create reader: {}", e)))?;
let config = config.unwrap_or_default();
Ok(TantivySearch {
index,
writer: Arc::new(RwLock::new(writer)),
reader,
index_schema,
name,
config,
})
}
pub fn add_document_with_fields(
&self,
doc_id: &str,
fields: HashMap<String, String>,
) -> Result<(), DBError> {
let mut writer = self
.writer
.write()
.map_err(|e| DBError(format!("Failed to acquire writer lock: {}", e)))?;
// Delete existing document with same ID
if let Some((id_field, _)) = self.index_schema.fields.get("_id") {
writer.delete_term(Term::from_field_text(*id_field, doc_id));
}
// Create new document
let mut doc = tantivy::doc!();
// Add document ID
if let Some((id_field, _)) = self.index_schema.fields.get("_id") {
doc.add_text(*id_field, doc_id);
}
// Add other fields based on schema
for (field_name, field_value) in fields {
if let Some((field, field_def)) = self.index_schema.fields.get(&field_name) {
match field_def {
FieldDef::Text { .. } => {
doc.add_text(*field, &field_value);
}
FieldDef::Numeric { precision, .. } => match precision {
NumericType::I64 => {
if let Ok(v) = field_value.parse::<i64>() {
doc.add_i64(*field, v);
}
}
NumericType::U64 => {
if let Ok(v) = field_value.parse::<u64>() {
doc.add_u64(*field, v);
}
}
NumericType::F64 => {
if let Ok(v) = field_value.parse::<f64>() {
doc.add_f64(*field, v);
}
}
NumericType::Date => {
if let Ok(v) = field_value.parse::<i64>() {
doc.add_date(*field, DateTime::from_timestamp_millis(v));
}
}
},
FieldDef::Tag {
separator,
case_sensitive,
..
} => {
let tags = if !case_sensitive {
field_value.to_lowercase()
} else {
field_value.clone()
};
// Store tags as separate terms for efficient filtering
for tag in tags.split(separator.as_str()) {
doc.add_text(*field, tag.trim());
}
}
FieldDef::Geo { .. } => {
// Parse "lat,lon" format
let parts: Vec<&str> = field_value.split(',').collect();
if parts.len() == 2 {
if let (Ok(lat), Ok(lon)) =
(parts[0].parse::<f64>(), parts[1].parse::<f64>())
{
if let Some((lat_field, _)) =
self.index_schema.fields.get(&format!("{}_lat", field_name))
{
doc.add_f64(*lat_field, lat);
}
if let Some((lon_field, _)) =
self.index_schema.fields.get(&format!("{}_lon", field_name))
{
doc.add_f64(*lon_field, lon);
}
}
}
}
}
}
}
writer
.add_document(doc)
.map_err(|e| DBError(format!("Failed to add document: {}", e)))?;
writer
.commit()
.map_err(|e| DBError(format!("Failed to commit: {}", e)))?;
Ok(())
}
pub fn search_with_options(
&self,
query_str: &str,
options: SearchOptions,
) -> Result<SearchResults, DBError> {
let searcher = self.reader.searcher();
// Parse query based on search fields
let query: Box<dyn Query> = if self.index_schema.default_search_fields.is_empty() {
return Err(DBError(
"No searchable fields defined in schema".to_string(),
));
} else {
let query_parser = QueryParser::for_index(
&self.index,
self.index_schema.default_search_fields.clone(),
);
Box::new(
query_parser
.parse_query(query_str)
.map_err(|e| DBError(format!("Failed to parse query: {}", e)))?,
)
};
// Apply filters if any
let final_query = if !options.filters.is_empty() {
let mut clauses: Vec<(Occur, Box<dyn Query>)> = vec![(Occur::Must, query)];
// Add filters
for filter in options.filters {
if let Some((field, _)) = self.index_schema.fields.get(&filter.field) {
match filter.filter_type {
FilterType::Equals(value) => {
let term_query = TermQuery::new(
Term::from_field_text(*field, &value),
tantivy::schema::IndexRecordOption::Basic,
);
clauses.push((Occur::Must, Box::new(term_query)));
}
FilterType::Range { min: _, max: _ } => {
// Would need numeric field handling here
// Simplified for now
}
FilterType::InSet(values) => {
let mut sub_clauses: Vec<(Occur, Box<dyn Query>)> = vec![];
for value in values {
let term_query = TermQuery::new(
Term::from_field_text(*field, &value),
tantivy::schema::IndexRecordOption::Basic,
);
sub_clauses.push((Occur::Should, Box::new(term_query)));
}
clauses.push((Occur::Must, Box::new(BooleanQuery::new(sub_clauses))));
}
}
}
}
Box::new(BooleanQuery::new(clauses))
} else {
query
};
// Execute search
let top_docs = searcher
.search(
&*final_query,
&TopDocs::with_limit(options.limit + options.offset),
)
.map_err(|e| DBError(format!("Search failed: {}", e)))?;
let total_hits = top_docs.len();
let mut documents = Vec::new();
for (score, doc_address) in top_docs.iter().skip(options.offset).take(options.limit) {
let retrieved_doc: TantivyDocument = searcher
.doc(*doc_address)
.map_err(|e| DBError(format!("Failed to retrieve doc: {}", e)))?;
let mut doc_fields = HashMap::new();
// Extract all stored fields
for (field_name, (field, field_def)) in &self.index_schema.fields {
match field_def {
FieldDef::Text { stored, .. } | FieldDef::Tag { stored, .. } => {
if *stored {
if let Some(value) = retrieved_doc.get_first(*field) {
if let Some(text) = value.as_str() {
doc_fields.insert(field_name.clone(), text.to_string());
}
}
}
}
FieldDef::Numeric {
stored, precision, ..
} => {
if *stored {
let value_str = match precision {
NumericType::I64 => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_i64())
.map(|v| v.to_string()),
NumericType::U64 => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_u64())
.map(|v| v.to_string()),
NumericType::F64 => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_f64())
.map(|v| v.to_string()),
NumericType::Date => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_datetime())
.map(|v| v.into_timestamp_millis().to_string()),
};
if let Some(v) = value_str {
doc_fields.insert(field_name.clone(), v);
}
}
}
FieldDef::Geo { stored } => {
if *stored {
let lat_field = self
.index_schema
.fields
.get(&format!("{}_lat", field_name))
.unwrap()
.0;
let lon_field = self
.index_schema
.fields
.get(&format!("{}_lon", field_name))
.unwrap()
.0;
let lat = retrieved_doc.get_first(lat_field).and_then(|v| v.as_f64());
let lon = retrieved_doc.get_first(lon_field).and_then(|v| v.as_f64());
if let (Some(lat), Some(lon)) = (lat, lon) {
doc_fields.insert(field_name.clone(), format!("{},{}", lat, lon));
}
}
}
}
}
documents.push(SearchDocument {
fields: doc_fields,
score: *score,
});
}
Ok(SearchResults {
total: total_hits,
documents,
})
}
pub fn get_info(&self) -> Result<IndexInfo, DBError> {
let searcher = self.reader.searcher();
let num_docs = searcher.num_docs();
let fields_info: Vec<FieldInfo> = self
.index_schema
.fields
.iter()
.map(|(name, (_, def))| FieldInfo {
name: name.clone(),
field_type: format!("{:?}", def),
})
.collect();
Ok(IndexInfo {
name: self.name.clone(),
num_docs,
fields: fields_info,
config: self.config.clone(),
})
}
}
#[derive(Debug, Clone)]
pub struct SearchOptions {
pub limit: usize,
pub offset: usize,
pub filters: Vec<Filter>,
pub sort_by: Option<String>,
pub return_fields: Option<Vec<String>>,
pub highlight: bool,
}
impl Default for SearchOptions {
fn default() -> Self {
SearchOptions {
limit: 10,
offset: 0,
filters: vec![],
sort_by: None,
return_fields: None,
highlight: false,
}
}
}
#[derive(Debug, Clone)]
pub struct Filter {
pub field: String,
pub filter_type: FilterType,
}
#[derive(Debug, Clone)]
pub enum FilterType {
Equals(String),
Range { min: String, max: String },
InSet(Vec<String>),
}
#[derive(Debug)]
pub struct SearchResults {
pub total: usize,
pub documents: Vec<SearchDocument>,
}
#[derive(Debug)]
pub struct SearchDocument {
pub fields: HashMap<String, String>,
pub score: f32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct IndexInfo {
pub name: String,
pub num_docs: u64,
pub fields: Vec<FieldInfo>,
pub config: IndexConfig,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FieldInfo {
pub name: String,
pub field_type: String,
}

View File

@@ -1,10 +1,11 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
# Test script for HeroDB - Redis-compatible database with redb backend
# This script starts the server and runs comprehensive tests
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server};
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
@@ -7,7 +7,7 @@ use tokio::time::sleep;
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -19,7 +19,7 @@ async fn debug_hset_simple() {
let test_dir = "/tmp/herodb_debug_hset";
let _ = std::fs::remove_dir_all(test_dir);
std::fs::create_dir_all(test_dir).unwrap();
let port = 16500;
let option = DBOption {
dir: test_dir.to_string(),
@@ -28,50 +28,43 @@ async fn debug_hset_simple() {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Acquire ReadWrite permissions on this connection
let resp = send_command(
&mut stream,
"*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n",
).await;
assert!(resp.contains("OK"), "Failed SELECT handshake: {}", resp);
// Test simple HSET
println!("Testing HSET...");
let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected '1' but got: {}", response);
// Test HGET
println!("Testing HGET...");
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HGET response: {}", response);
assert!(
response.contains("value1"),
"Expected 'value1' but got: {}",
response
);
}
assert!(response.contains("value1"), "Expected 'value1' but got: {}", response);
}

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server};
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
@@ -7,11 +7,11 @@ use tokio::time::sleep;
#[tokio::test]
async fn debug_hset_return_value() {
let test_dir = "/tmp/herodb_debug_hset_return";
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir.to_string(),
port: 16390,
@@ -19,43 +19,47 @@ async fn debug_hset_return_value() {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let mut server = Server::new(option).await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:16390")
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
// Connect and test HSET
let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap();
// Acquire ReadWrite permissions for this new connection
let handshake = "*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n";
stream.write_all(handshake.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let resp = String::from_utf8_lossy(&buffer[..n]);
assert!(resp.contains("OK"), "Failed SELECT handshake: {}", resp);
// Send HSET command
let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n";
stream.write_all(cmd.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
println!("HSET response: {}", response);
println!("Response bytes: {:?}", &buffer[..n]);
// Check if response contains "1"
assert!(
response.contains("1"),
"Expected response to contain '1', got: {}",
response
);
}
assert!(response.contains("1"), "Expected response to contain '1', got: {}", response);
}

View File

@@ -1,15 +1,12 @@
use herodb::cmd::Cmd;
use herodb::protocol::Protocol;
use herodb::cmd::Cmd;
#[test]
fn test_protocol_parsing() {
// Test TYPE command parsing
let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n";
println!(
"Parsing TYPE command: {}",
type_cmd.replace("\r\n", "\\r\\n")
);
println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(type_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
@@ -20,14 +17,11 @@ fn test_protocol_parsing() {
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
// Test HEXISTS command parsing
let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n";
println!(
"\nParsing HEXISTS command: {}",
hexists_cmd.replace("\r\n", "\\r\\n")
);
println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n"));
match Protocol::from(hexists_cmd) {
Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol);
@@ -38,4 +32,4 @@ fn test_protocol_parsing() {
}
Err(e) => println!("Protocol parsing failed: {:?}", e),
}
}
}

View File

@@ -12,7 +12,15 @@ fn get_redis_connection(port: u16) -> Connection {
match client.get_connection() {
Ok(mut conn) => {
if redis::cmd("PING").query::<String>(&mut conn).is_ok() {
return conn;
// Acquire ReadWrite permissions on this connection
let sel: RedisResult<String> = redis::cmd("SELECT")
.arg(0)
.arg("KEY")
.arg("test-admin")
.query(&mut conn);
if sel.is_ok() {
return conn;
}
}
}
Err(e) => {
@@ -78,16 +86,18 @@ fn setup_server() -> (ServerProcessGuard, u16) {
"--port",
&port.to_string(),
"--debug",
"--admin-secret",
"test-admin",
])
.spawn()
.expect("Failed to start server process");
// Create a new guard that also owns the test directory path
let guard = ServerProcessGuard {
process: child,
test_dir,
};
// Give the server time to build and start (cargo run may compile first)
std::thread::sleep(Duration::from_millis(2500));
@@ -206,9 +216,7 @@ async fn test_expiration(conn: &mut Connection) {
async fn test_scan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..5 {
let _: () = conn
.set(format!("key{}", i), format!("value{}", i))
.unwrap();
let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(0)
@@ -255,9 +263,7 @@ async fn test_scan_with_count(conn: &mut Connection) {
async fn test_hscan_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
for i in 0..3 {
let _: () = conn
.hset("testhash", format!("field{}", i), format!("value{}", i))
.unwrap();
let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap();
}
let result: (u64, Vec<String>) = redis::cmd("HSCAN")
.arg("testhash")
@@ -277,16 +283,8 @@ async fn test_hscan_operations(conn: &mut Connection) {
async fn test_transaction_operations(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET")
.arg("key1")
.arg("value1")
.query(conn)
.unwrap();
let _: () = redis::cmd("SET")
.arg("key2")
.arg("value2")
.query(conn)
.unwrap();
let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap();
let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap();
let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap();
let result: String = conn.get("key1").unwrap();
assert_eq!(result, "value1");
@@ -298,11 +296,7 @@ async fn test_transaction_operations(conn: &mut Connection) {
async fn test_discard_transaction(conn: &mut Connection) {
cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET")
.arg("discard")
.arg("value")
.query(conn)
.unwrap();
let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap();
let _: () = redis::cmd("DISCARD").query(conn).unwrap();
let result: Option<String> = conn.get("discard").unwrap();
assert_eq!(result, None);
@@ -322,6 +316,7 @@ async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await;
}
async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await;
let result: String = redis::cmd("INFO").query(conn).unwrap();

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server};
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
@@ -8,14 +8,14 @@ use tokio::time::sleep;
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
@@ -23,18 +23,29 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to connect to the test server
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Ok(mut stream) => {
// Obtain ReadWrite permissions for this connection by selecting DB 0 with admin key
let resp = send_command(
&mut stream,
"*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n",
).await;
if !resp.contains("OK") {
panic!("Failed to acquire write permissions via SELECT 0 KEY test-admin: {}", resp);
}
return stream;
}
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
@@ -47,7 +58,7 @@ async fn connect_to_server(port: u16) -> TcpStream {
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -56,22 +67,22 @@ async fn send_command(stream: &mut TcpStream, command: &str) -> String {
#[tokio::test]
async fn test_basic_ping() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
@@ -80,44 +91,40 @@ async fn test_basic_ping() {
#[tokio::test]
async fn test_string_operations() {
let (mut server, port) = start_test_server("string").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SET
let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test GET non-existent key
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("$-1")); // NULL response
// Test DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test GET after DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("$-1")); // NULL response
@@ -126,37 +133,33 @@ async fn test_string_operations() {
#[tokio::test]
async fn test_incr_operations() {
let (mut server, port) = start_test_server("incr").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INCR on non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("1"));
// Test INCR on existing key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("2"));
// Test INCR on string value (should fail)
send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n",
)
.await;
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await;
assert!(response.contains("ERR"));
}
@@ -164,83 +167,63 @@ async fn test_incr_operations() {
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test HSET
let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("1")); // 1 new field
// Test HGET
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("value1"));
// Test HSET multiple fields
let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("3"));
// Test HEXISTS
let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("0"));
// Test HDEL
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HKEYS
let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field2"));
assert!(response.contains("field3"));
assert!(!response.contains("field1")); // Should be deleted
// Test HVALS
let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("value2"));
@@ -250,50 +233,46 @@ async fn test_hash_operations() {
#[tokio::test]
async fn test_expiration() {
let (mut server, port) = start_test_server("expiration").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test SETEX (expire in 1 second)
let response = send_command(
&mut stream,
"*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n",
)
.await;
let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await;
assert!(response.contains("OK"));
// Test TTL
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds
// Test EXISTS
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1"));
// Wait for expiration
sleep(Duration::from_millis(1100)).await;
// Test GET after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
// Test TTL after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("-2")); // Key doesn't exist
// Test EXISTS after expiration
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("0"));
@@ -302,37 +281,33 @@ async fn test_expiration() {
#[tokio::test]
async fn test_scan_operations() {
let (mut server, port) = start_test_server("scan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up test data
for i in 0..5 {
let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test SCAN
let response = send_command(
&mut stream,
"*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n",
)
.await;
let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("key"));
// Test KEYS
let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await;
assert!(response.contains("key0"));
@@ -342,32 +317,29 @@ async fn test_scan_operations() {
#[tokio::test]
async fn test_hscan_operations() {
let (mut server, port) = start_test_server("hscan").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Set up hash data
for i in 0..3 {
let cmd = format!(
"*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n",
i, i
);
let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await;
}
// Test HSCAN
let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field"));
@@ -377,50 +349,42 @@ async fn test_hscan_operations() {
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transaction").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued commands
let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await;
assert!(response.contains("QUEUED"));
let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("QUEUED"));
// Test EXEC
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("OK")); // Should contain results of executed commands
// Verify commands were executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1"));
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await;
assert!(response.contains("value2"));
}
@@ -428,39 +392,35 @@ async fn test_transaction_operations() {
#[tokio::test]
async fn test_discard_transaction() {
let (mut server, port) = start_test_server("discard").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK"));
// Test queued command
let response = send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("QUEUED"));
// Test DISCARD
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("OK"));
// Verify command was not executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await;
assert!(response.contains("$-1")); // Should be NULL
@@ -469,41 +429,33 @@ async fn test_discard_transaction() {
#[tokio::test]
async fn test_type_command() {
let (mut server, port) = start_test_server("type").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
assert!(response.contains("string"));
// Test hash type
send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("none"));
@@ -512,38 +464,30 @@ async fn test_type_command() {
#[tokio::test]
async fn test_config_commands() {
let (mut server, port) = start_test_server("config").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test CONFIG GET databases
let response = send_command(
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await;
assert!(response.contains("databases"));
assert!(response.contains("16"));
// Test CONFIG GET dir
let response = send_command(
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await;
assert!(response.contains("dir"));
assert!(response.contains("/tmp/herodb_test_config"));
}
@@ -551,27 +495,27 @@ async fn test_config_commands() {
#[tokio::test]
async fn test_info_command() {
let (mut server, port) = start_test_server("info").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test INFO
let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await;
assert!(response.contains("redis_version"));
// Test INFO replication
let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await;
assert!(response.contains("role:master"));
@@ -580,44 +524,36 @@ async fn test_info_command() {
#[tokio::test]
async fn test_error_handling() {
let (mut server, port) = start_test_server("error").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test WRONGTYPE error - try to use hash command on string
send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n",
)
.await;
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await;
assert!(response.contains("WRONGTYPE"));
// Test unknown command
let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await;
assert!(response.contains("unknown cmd") || response.contains("ERR"));
// Test EXEC without MULTI
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("ERR"));
// Test DISCARD without MULTI
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("ERR"));
@@ -626,37 +562,29 @@ async fn test_error_handling() {
#[tokio::test]
async fn test_list_operations() {
let (mut server, port) = start_test_server("list").await;
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await;
// Test LPUSH
let response = send_command(
&mut stream,
"*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n",
)
.await;
let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await;
assert!(response.contains("2")); // 2 elements
// Test RPUSH
let response = send_command(
&mut stream,
"*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n",
)
.await;
let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await;
assert!(response.contains("4")); // 4 elements
// Test LLEN
@@ -664,52 +592,29 @@ async fn test_list_operations() {
assert!(response.contains("4"));
// Test LRANGE
let response = send_command(
&mut stream,
"*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n",
)
.await;
assert_eq!(
response,
"*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"
);
let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
// Test LINDEX
let response = send_command(
&mut stream,
"*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nd\r\n");
// Test LREM
send_command(
&mut stream,
"*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n",
)
.await; // list is now a, c, a
let response = send_command(
&mut stream,
"*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n",
)
.await;
send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a
let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await;
assert!(response.contains("1"));
// Test LTRIM
let response = send_command(
&mut stream,
"*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n",
)
.await;
let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await;
assert!(response.contains("OK"));
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("1"));
}
}

85
tests/rpc_tests.rs Normal file
View File

@@ -0,0 +1,85 @@
use herodb::rpc::{BackendType, DatabaseConfig};
use herodb::admin_meta;
use herodb::options::BackendType as OptionsBackendType;
#[tokio::test]
async fn test_rpc_server_basic() {
// This test would require starting the RPC server in a separate thread
// For now, we'll just test that the types compile correctly
// Test serialization of types
let backend = BackendType::Redb;
let config = DatabaseConfig {
name: Some("test_db".to_string()),
storage_path: Some("/tmp/test".to_string()),
max_size: Some(1024 * 1024),
redis_version: Some("7.0".to_string()),
};
let backend_json = serde_json::to_string(&backend).unwrap();
let config_json = serde_json::to_string(&config).unwrap();
assert_eq!(backend_json, "\"Redb\"");
assert!(config_json.contains("test_db"));
}
#[tokio::test]
async fn test_database_config_serialization() {
let config = DatabaseConfig {
name: Some("my_db".to_string()),
storage_path: None,
max_size: Some(1000000),
redis_version: Some("7.0".to_string()),
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["name"], "my_db");
assert_eq!(json["max_size"], 1000000);
assert_eq!(json["redis_version"], "7.0");
}
#[tokio::test]
async fn test_backend_type_serialization() {
// Test that both Redb and Sled backends serialize correctly
let redb_backend = BackendType::Redb;
let sled_backend = BackendType::Sled;
let redb_json = serde_json::to_string(&redb_backend).unwrap();
let sled_json = serde_json::to_string(&sled_backend).unwrap();
assert_eq!(redb_json, "\"Redb\"");
assert_eq!(sled_json, "\"Sled\"");
// Test deserialization
let redb_deserialized: BackendType = serde_json::from_str(&redb_json).unwrap();
let sled_deserialized: BackendType = serde_json::from_str(&sled_json).unwrap();
assert!(matches!(redb_deserialized, BackendType::Redb));
assert!(matches!(sled_deserialized, BackendType::Sled));
}
#[tokio::test]
async fn test_database_name_persistence() {
let base_dir = "/tmp/test_db_name_persistence";
let admin_secret = "test-admin-secret";
let backend = OptionsBackendType::Redb;
let db_id = 1;
let test_name = "test-database-name";
// Clean up any existing test data
let _ = std::fs::remove_dir_all(base_dir);
// Set the database name
admin_meta::set_database_name(base_dir, backend.clone(), admin_secret, db_id, test_name)
.expect("Failed to set database name");
// Retrieve the database name
let retrieved_name = admin_meta::get_database_name(base_dir, backend, admin_secret, db_id)
.expect("Failed to get database name");
// Verify the name matches
assert_eq!(retrieved_name, Some(test_name.to_string()));
// Clean up
let _ = std::fs::remove_dir_all(base_dir);
}

View File

@@ -1,23 +1,23 @@
use herodb::{options::DBOption, server::Server};
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000);
// Get a unique port for this test
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
@@ -25,20 +25,26 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
(server, port)
}
// Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
stream.write_all(command.as_bytes()).await.unwrap();
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Acquire ReadWrite permissions on this new connection
let handshake = "*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n";
stream.write_all(handshake.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let _ = stream.read(&mut buffer).await.unwrap(); // Read and ignore the OK for handshake
// Now send the intended command
stream.write_all(command.as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
@@ -46,13 +52,13 @@ async fn send_redis_command(port: u16, command: &str) -> String {
#[tokio::test]
async fn test_basic_redis_functionality() {
let (mut server, port) = start_test_server("basic").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..10 {
if let Ok((stream, _)) = listener.accept().await {
@@ -60,79 +66,68 @@ async fn test_basic_redis_functionality() {
}
}
});
sleep(Duration::from_millis(100)).await;
// Test PING
let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
// Test SET
let response =
send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK"));
// Test GET
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value"));
// Test HSET
let response = send_redis_command(
port,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("1"));
// Test HGET
let response =
send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
assert!(response.contains("value"));
// Test EXISTS
let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1"));
// Test TTL
let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("-1")); // No expiration
// Test TYPE
let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await;
assert!(response.contains("string"));
// Test QUIT to close connection gracefully
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
stream
.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes())
.await
.unwrap();
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
stream.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK"));
// Ensure the stream is closed
stream.shutdown().await.unwrap();
// Stop the server
server_handle.abort();
println!("✅ All basic Redis functionality tests passed!");
}
#[tokio::test]
async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash_ops").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
@@ -140,57 +135,53 @@ async fn test_hash_operations() {
}
}
});
sleep(Duration::from_millis(100)).await;
// Test HSET multiple fields
let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("2")); // 2 new fields
// Test HGETALL
let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Test HEXISTS
let response = send_redis_command(
port,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
assert!(response.contains("1"));
// Test HLEN
let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("2"));
// Test HSCAN
let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field1"));
assert!(response.contains("value1"));
assert!(response.contains("field2"));
assert!(response.contains("value2"));
// Stop the server
// For hash operations, we don't have a persistent stream, so we'll just abort the server.
// The server should handle closing its connections.
server_handle.abort();
println!("✅ All hash operations tests passed!");
}
#[tokio::test]
async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transactions").await;
// Start server in background with timeout
let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Accept only a few connections for testing
for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await {
@@ -198,69 +189,56 @@ async fn test_transaction_operations() {
}
}
});
sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.unwrap();
// Test MULTI
stream
.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes())
.await
.unwrap();
// Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
// Acquire write permissions for this connection
let handshake = "*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n";
stream.write_all(handshake.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
let resp = String::from_utf8_lossy(&buffer[..n]);
assert!(resp.contains("OK"));
// Test MULTI
stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK"));
// Test queued commands
stream
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes())
.await
.unwrap();
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
stream
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes())
.await
.unwrap();
stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED"));
// Test EXEC
stream
.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes())
.await
.unwrap();
stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); // Should contain array of OK responses
// Verify commands were executed
stream
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes())
.await
.unwrap();
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value1"));
stream
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes())
.await
.unwrap();
stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes()).await.unwrap();
let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value2"));
// Stop the server
server_handle.abort();
println!("✅ All transaction operations tests passed!");
}
}

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server};
use herodb::{server::Server, options::DBOption};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
@@ -8,14 +8,14 @@ use tokio::time::sleep;
async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_simple_test_{}", test_name);
// Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption {
dir: test_dir,
port,
@@ -23,8 +23,9 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
(server, port)
}
@@ -32,18 +33,28 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
// Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string()
}
// Helper function to connect to the test server
// Helper function to connect to the test server
async fn connect_to_server(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(stream) => return stream,
Ok(mut stream) => {
// Acquire ReadWrite permissions for this connection
let resp = send_command(
&mut stream,
"*4\r\n$6\r\nSELECT\r\n$1\r\n0\r\n$3\r\nKEY\r\n$10\r\ntest-admin\r\n",
).await;
if !resp.contains("OK") {
panic!("Failed to acquire write permissions via SELECT 0 KEY test-admin: {}", resp);
}
return stream;
}
Err(_) if attempts < 10 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
@@ -56,22 +67,22 @@ async fn connect_to_server(port: u16) -> TcpStream {
#[tokio::test]
async fn test_basic_ping_simple() {
let (mut server, port) = start_test_server("ping").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG"));
@@ -80,43 +91,38 @@ async fn test_basic_ping_simple() {
#[tokio::test]
async fn test_hset_clean_db() {
let (mut server, port) = start_test_server("hset_clean").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field
let response = send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
println!("HSET response: {}", response);
assert!(
response.contains("1"),
"Expected HSET to return 1, got: {}",
response
);
// Ensure clean DB state (admin DB 0 may be shared due to global singleton)
let flush = send_command(&mut stream, "*1\r\n$7\r\nFLUSHDB\r\n").await;
assert!(flush.contains("OK"), "Failed to FLUSHDB: {}", flush);
// Test HSET - should return 1 for new field (use a unique key name to avoid collisions)
let key = "hash_clean";
let hset_cmd = format!("*4\r\n$4\r\nHSET\r\n${}\r\n{}\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n", key.len(), key);
let response = send_command(&mut stream, &hset_cmd).await;
println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response);
// Test HGET
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
let hget_cmd = format!("*3\r\n$4\r\nHGET\r\n${}\r\n{}\r\n$6\r\nfield1\r\n", key.len(), key);
let response = send_command(&mut stream, &hget_cmd).await;
println!("HGET response: {}", response);
assert!(response.contains("value1"));
}
@@ -124,101 +130,73 @@ async fn test_hset_clean_db() {
#[tokio::test]
async fn test_type_command_simple() {
let (mut server, port) = start_test_server("type").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Test string type
send_command(
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
println!("TYPE string response: {}", response);
assert!(response.contains("string"));
// Test hash type
send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
println!("TYPE hash response: {}", response);
assert!(response.contains("hash"));
// Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
println!("TYPE noexist response: {}", response);
assert!(
response.contains("none"),
"Expected 'none' for non-existent key, got: {}",
response
);
assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response);
}
#[tokio::test]
async fn test_hexists_simple() {
let (mut server, port) = start_test_server("hexists").await;
// Start server in background
tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await;
}
}
});
sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await;
// Set up hash
send_command(
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
// Test HEXISTS for existing field
let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
println!("HEXISTS existing field response: {}", response);
assert!(response.contains("1"));
// Test HEXISTS for non-existent field
let response = send_command(
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
println!("HEXISTS non-existent field response: {}", response);
assert!(
response.contains("0"),
"Expected HEXISTS to return 0 for non-existent field, got: {}",
response
);
}
assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response);
}

View File

@@ -23,6 +23,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encrypt: false,
encryption_key: None,
backend: herodb::options::BackendType::Redb,
admin_secret: "test-admin".to_string(),
};
let server = Server::new(option).await;
@@ -61,7 +62,17 @@ async fn connect(port: u16) -> TcpStream {
let mut attempts = 0;
loop {
match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(s) => return s,
Ok(mut s) => {
// Acquire ReadWrite permissions for this connection using admin DB 0
let resp = send_cmd(&mut s, &["SELECT", "0", "KEY", "test-admin"]).await;
assert_contains(&resp, "OK", "SELECT 0 KEY test-admin handshake");
// Ensure clean slate per test on DB 0
let fl = send_cmd(&mut s, &["FLUSHDB"]).await;
assert_contains(&fl, "OK", "FLUSHDB after handshake");
return s;
}
Err(_) if attempts < 30 => {
attempts += 1;
sleep(Duration::from_millis(100)).await;
@@ -246,9 +257,9 @@ async fn test_01_connection_and_info() {
let getname = send_cmd(&mut s, &["CLIENT", "GETNAME"]).await;
assert_contains(&getname, "myapp", "CLIENT GETNAME");
// SELECT db
let sel = send_cmd(&mut s, &["SELECT", "0"]).await;
assert_contains(&sel, "OK", "SELECT 0");
// SELECT db (requires key on DB 0)
let sel = send_cmd(&mut s, &["SELECT", "0", "KEY", "test-admin"]).await;
assert_contains(&sel, "OK", "SELECT 0 with key");
// QUIT should close connection after sending OK
let quit = send_cmd(&mut s, &["QUIT"]).await;
@@ -279,7 +290,11 @@ async fn test_02_strings_and_expiry() {
let ex0 = send_cmd(&mut s, &["EXISTS", "user:1"]).await;
assert_contains(&ex0, "0", "EXISTS after DEL");
// DEL non-existent should return 0
let del0 = send_cmd(&mut s, &["DEL", "user:1"]).await;
assert_contains(&del0, "0", "DEL user:1 when not exists -> 0");
// INCR behavior
let i1 = send_cmd(&mut s, &["INCR", "count"]).await;
assert_contains(&i1, "1", "INCR new key -> 1");
@@ -325,11 +340,7 @@ async fn test_03_scan_and_keys() {
let mut s = connect(port).await;
for i in 0..5 {
let _ = send_cmd(
&mut s,
&["SET", &format!("key{}", i), &format!("value{}", i)],
)
.await;
let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await;
}
let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await;
@@ -362,11 +373,7 @@ async fn test_04_hashes_suite() {
assert_contains(&h2, "2", "HSET added 2 new fields");
// HMGET
let hmg = send_cmd(
&mut s,
&["HMGET", "profile:1", "name", "age", "city", "nope"],
)
.await;
let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await;
assert_contains(&hmg, "alice", "HMGET name");
assert_contains(&hmg, "30", "HMGET age");
assert_contains(&hmg, "paris", "HMGET city");
@@ -400,11 +407,7 @@ async fn test_04_hashes_suite() {
assert_contains(&hnx1, "1", "HSETNX new field -> 1");
// HSCAN
let hscan = send_cmd(
&mut s,
&["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"],
)
.await;
let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await;
assert_contains(&hscan, "name", "HSCAN matches fields starting with n");
assert_contains(&hscan, "nickname", "HSCAN nickname present");
@@ -436,21 +439,13 @@ async fn test_05_lists_suite_including_blpop() {
assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b");
let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(
&lr,
"*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n",
"LRANGE q:jobs 0 -1 should be [b,a,c]",
);
assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]");
// LTRIM
let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await;
assert_contains(&ltrim, "OK", "LTRIM OK");
let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp(
&lr_post,
"*2\r\n$1\r\nb\r\n$1\r\na\r\n",
"After LTRIM, list [b,a]",
);
assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]");
// LREM remove first occurrence of b
let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await;
@@ -464,11 +459,7 @@ async fn test_05_lists_suite_including_blpop() {
// LPOP with count on empty -> []
let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await;
assert_eq_resp(
&lpop0,
"*0\r\n",
"LPOP with count on empty returns empty array",
);
assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array");
// BLPOP: block on one client, push from another
let c1 = connect(port).await;
@@ -525,11 +516,11 @@ async fn test_07_age_stateless_suite() {
let mut s = connect(port).await;
// GENENC -> [recipient, identity]
let gen = send_cmd(&mut s, &["AGE", "GENENC"]).await;
let genenc = send_cmd(&mut s, &["AGE", "GENENC"]).await;
assert!(
gen.starts_with("*2\r\n$"),
genenc.starts_with("*2\r\n$"),
"AGE GENENC should return array [recipient, identity], got:\n{}",
gen
genenc
);
// Parse simple RESP array of two bulk strings to extract keys
@@ -537,14 +528,14 @@ async fn test_07_age_stateless_suite() {
// naive parse for tests
let mut lines = resp.lines();
let _ = lines.next(); // *2
// $len
// $len
let _ = lines.next();
let recip = lines.next().unwrap_or("").to_string();
let _ = lines.next();
let ident = lines.next().unwrap_or("").to_string();
(recip, ident)
}
let (recipient, identity) = parse_two_bulk_array(&gen);
let (recipient, identity) = parse_two_bulk_array(&genenc);
assert!(
recipient.starts_with("age1") && identity.starts_with("AGE-SECRET-KEY-1"),
"Unexpected AGE key formats.\nrecipient: {}\nidentity: {}",
@@ -572,16 +563,8 @@ async fn test_07_age_stateless_suite() {
let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await;
assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature");
let v_bad = send_cmd(
&mut s,
&["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64],
)
.await;
assert_contains(
&v_bad,
"0",
"VERIFY should be 0 for invalid message/signature",
);
let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await;
assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature");
}
#[tokio::test]
@@ -613,7 +596,7 @@ async fn test_08_age_persistent_named_suite() {
skg
);
let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"]).await;
let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await;
let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME");
let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await;
assert_contains(&v1, "1", "VERIFYNAME valid => 1");
@@ -623,75 +606,66 @@ async fn test_08_age_persistent_named_suite() {
// AGE LIST
let lst = send_cmd(&mut s, &["AGE", "LIST"]).await;
assert_contains(&lst, "encpub", "AGE LIST label encpub");
// After flattening, LIST returns a flat array of managed key names
assert_contains(&lst, "app1", "AGE LIST includes app1");
}
#[tokio::test]
async fn test_10_expire_pexpire_persist() {
let (server, port) = start_test_server("expire_suite").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let (server, port) = start_test_server("expire_suite").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
let mut s = connect(port).await;
// EXPIRE: seconds
let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await;
let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await;
assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert!(
ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:s should be 1 or 0, got: {}",
ttl1
);
sleep(Duration::from_millis(1100)).await;
let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await;
assert_contains(&get_after, "$-1", "GET after expiry should be Null");
let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert_contains(&ttl_after, "-2", "TTL after expiry -> -2");
let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await;
assert_contains(&exists_after, "0", "EXISTS after expiry -> 0");
// EXPIRE: seconds
let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await;
let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await;
assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert!(
ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:s should be 1 or 0, got: {}",
ttl1
);
sleep(Duration::from_millis(1100)).await;
let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await;
assert_contains(&get_after, "$-1", "GET after expiry should be Null");
let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert_contains(&ttl_after, "-2", "TTL after expiry -> -2");
let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await;
assert_contains(&exists_after, "0", "EXISTS after expiry -> 0");
// PEXPIRE: milliseconds
let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await;
let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await;
assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)");
let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await;
assert!(
ttl_ms1.contains("1") || ttl_ms1.contains("0"),
"TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}",
ttl_ms1
);
sleep(Duration::from_millis(1600)).await;
let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await;
assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0");
// PEXPIRE: milliseconds
let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await;
let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await;
assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)");
let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await;
assert!(
ttl_ms1.contains("1") || ttl_ms1.contains("0"),
"TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}",
ttl_ms1
);
sleep(Duration::from_millis(1600)).await;
let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await;
assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0");
// PERSIST: remove expiration
let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await;
let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await;
let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert!(
ttl_pre.contains("5")
|| ttl_pre.contains("4")
|| ttl_pre.contains("3")
|| ttl_pre.contains("2")
|| ttl_pre.contains("1")
|| ttl_pre.contains("0"),
"TTL exp:persist should be >=0 before persist, got: {}",
ttl_pre
);
let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist1, "1", "PERSIST should remove expiration");
let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
// Second persist should return 0 (nothing to remove)
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(
&persist2,
"0",
"PERSIST again -> 0 (no expiration to remove)",
);
// PERSIST: remove expiration
let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await;
let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await;
let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert!(
ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"),
"TTL exp:persist should be >=0 before persist, got: {}",
ttl_pre
);
let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist1, "1", "PERSIST should remove expiration");
let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
// Second persist should return 0 (nothing to remove)
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)");
}
#[tokio::test]
@@ -704,11 +678,7 @@ async fn test_11_set_with_options() {
// SET with GET on non-existing key -> returns Null, sets value
let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await;
assert_contains(
&set_get1,
"$-1",
"SET s1 v1 GET returns Null when key didn't exist",
);
assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist");
let g1 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g1, "v1", "GET s1 after first SET");
@@ -752,42 +722,42 @@ async fn test_11_set_with_options() {
#[tokio::test]
async fn test_09_mget_mset_and_variadic_exists_del() {
let (server, port) = start_test_server("mget_mset_variadic").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let (server, port) = start_test_server("mget_mset_variadic").await;
spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await;
let mut s = connect(port).await;
// MSET multiple keys
let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await;
assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK");
// MSET multiple keys
let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await;
assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK");
// MGET should return values and Null for missing
let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await;
// Expect an array with 4 entries; verify payloads
assert_contains(&mget, "v1", "MGET k1");
assert_contains(&mget, "v2", "MGET k2");
assert_contains(&mget, "v3", "MGET k3");
assert_contains(&mget, "$-1", "MGET missing returns Null");
// MGET should return values and Null for missing
let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await;
// Expect an array with 4 entries; verify payloads
assert_contains(&mget, "v1", "MGET k1");
assert_contains(&mget, "v2", "MGET k2");
assert_contains(&mget, "v3", "MGET k3");
assert_contains(&mget, "$-1", "MGET missing returns Null");
// EXISTS variadic: count how many exist
let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await;
// Server returns SimpleString numeric, e.g. +2
assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2");
// EXISTS variadic: count how many exist
let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await;
// Server returns SimpleString numeric, e.g. +2
assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2");
// DEL variadic: delete multiple keys, return count deleted
let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await;
assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2");
// DEL variadic: delete multiple keys, return count deleted
let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await;
assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2");
// Verify deletion
let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await;
assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0");
// Verify deletion
let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await;
assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0");
// MGET after deletion should include Nulls for deleted keys
let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await;
assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null");
assert_contains(&mget_after, "v2", "MGET k2 remains");
assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null");
// MGET after deletion should include Nulls for deleted keys
let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await;
assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null");
assert_contains(&mget_after, "v2", "MGET k2 remains");
assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null");
}
#[tokio::test]
async fn test_12_hash_incr() {
@@ -907,16 +877,9 @@ async fn test_14_expireat_pexpireat() {
let mut s = connect(port).await;
// EXPIREAT: seconds since epoch
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await;
let exat = send_cmd(
&mut s,
&["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)],
)
.await;
let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await;
assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await;
assert!(
@@ -926,23 +889,12 @@ async fn test_14_expireat_pexpireat() {
);
sleep(Duration::from_millis(1200)).await;
let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await;
assert_contains(
&exists_after_exat,
"0",
"EXISTS exp:at:s after EXPIREAT expiry -> 0",
);
assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0");
// PEXPIREAT: milliseconds since epoch
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await;
let pexat = send_cmd(
&mut s,
&["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)],
)
.await;
let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await;
assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)");
let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await;
assert!(
@@ -952,9 +904,5 @@ async fn test_14_expireat_pexpireat() {
);
sleep(Duration::from_millis(600)).await;
let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await;
assert_contains(
&exists_after_pexat,
"0",
"EXISTS exp:at:ms after PEXPIREAT expiry -> 0",
);
}
assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0");
}