2 Commits

Author SHA1 Message Date
9410176684 ... 2025-08-25 06:00:08 +02:00
ab56fad635 ... 2025-08-23 05:46:38 +02:00
34 changed files with 8204 additions and 3560 deletions

4894
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,18 @@ age = "0.10"
secrecy = "0.8" secrecy = "0.8"
ed25519-dalek = "2" ed25519-dalek = "2"
base64 = "0.22" base64 = "0.22"
tantivy = "0.25.0" # Lance vector database dependencies
lance = "0.33"
lance-index = "0.33"
lance-linalg = "0.33"
# Use Arrow version compatible with Lance 0.33
arrow = "55.2"
arrow-array = "55.2"
arrow-schema = "55.2"
parquet = "55.2"
uuid = { version = "1.10", features = ["v4"] }
reqwest = { version = "0.11", features = ["json"] }
image = "0.25"
[dev-dependencies] [dev-dependencies]
redis = { version = "0.24", features = ["aio", "tokio-comp"] } redis = { version = "0.24", features = ["aio", "tokio-comp"] }

454
docs/lance_vector_db.md Normal file
View File

@@ -0,0 +1,454 @@
# Lance Vector Database Operations
HeroDB includes a powerful vector database integration using Lance, enabling high-performance vector storage, search, and multimodal data management. By default, it uses Ollama for local text embeddings, with support for custom external embedding services.
## Overview
The Lance vector database integration provides:
- **High-performance vector storage** using Lance's columnar format
- **Local Ollama integration** for text embeddings (default, no external dependencies)
- **Custom embedding service support** for advanced use cases
- **Text embedding support** (images via custom services)
- **Vector similarity search** with configurable parameters
- **Scalable indexing** with IVF_PQ (Inverted File with Product Quantization)
- **Redis-compatible command interface**
## Architecture
```
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
│ HeroDB │ │ External │ │ Lance │
│ Redis Server │◄──►│ Embedding │ │ Vector Store │
│ │ │ Service │ │ │
└─────────────────┘ └──────────────────┘ └─────────────────┘
│ │ │
│ │ │
Redis Protocol HTTP API Arrow/Parquet
Commands JSON Requests Columnar Storage
```
### Key Components
1. **Lance Store**: High-performance columnar vector storage
2. **Ollama Integration**: Local embedding service (default)
3. **Custom Embedding Service**: Optional HTTP API for advanced use cases
4. **Redis Command Interface**: Familiar Redis-style commands
5. **Arrow Schema**: Flexible schema definition for metadata
## Configuration
### Default Setup (Ollama)
HeroDB uses Ollama by default for text embeddings. No configuration is required if Ollama is running locally:
```bash
# Install Ollama (if not already installed)
# Visit: https://ollama.ai
# Pull the embedding model
ollama pull nomic-embed-text
# Ollama automatically runs on localhost:11434
# HeroDB will use this by default
```
**Default Configuration:**
- **URL**: `http://localhost:11434`
- **Model**: `nomic-embed-text`
- **Dimensions**: 768 (for nomic-embed-text)
### Custom Embedding Service (Optional)
To use a custom embedding service instead of Ollama:
```bash
# Set custom embedding service URL
redis-cli HSET config:core:aiembed url "http://your-embedding-service:8080/embed"
# Optional: Set authentication if required
redis-cli HSET config:core:aiembed token "your-api-token"
```
### Embedding Service API Contracts
#### Ollama API (Default)
HeroDB calls Ollama using this format:
```bash
POST http://localhost:11434/api/embeddings
Content-Type: application/json
{
"model": "nomic-embed-text",
"prompt": "Your text to embed"
}
```
Response:
```json
{
"embedding": [0.1, 0.2, 0.3, ...]
}
```
#### Custom Service API
Your custom embedding service should accept POST requests with this JSON format:
```json
{
"texts": ["text1", "text2"], // Optional: array of texts
"images": ["base64_image1", "base64_image2"], // Optional: base64 encoded images
"model": "your-model-name" // Optional: model specification
}
```
And return responses in this format:
```json
{
"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...]], // Array of embedding vectors
"model": "model-name", // Model used
"usage": { // Optional usage stats
"tokens": 100,
"requests": 2
}
}
```
## Commands Reference
### Dataset Management
#### LANCE CREATE
Create a new vector dataset with specified dimensions and optional schema.
```bash
LANCE CREATE <dataset> DIM <dimension> [SCHEMA field:type ...]
```
**Parameters:**
- `dataset`: Name of the dataset
- `dimension`: Vector dimension (e.g., 384, 768, 1536)
- `field:type`: Optional metadata fields (string, int, float, bool)
**Examples:**
```bash
# Create a simple dataset for 384-dimensional vectors
LANCE CREATE documents DIM 384
# Create dataset with metadata schema
LANCE CREATE products DIM 768 SCHEMA category:string price:float available:bool
```
#### LANCE LIST
List all available datasets.
```bash
LANCE LIST
```
**Returns:** Array of dataset names
#### LANCE INFO
Get information about a specific dataset.
```bash
LANCE INFO <dataset>
```
**Returns:** Dataset metadata including name, version, row count, and schema
#### LANCE DROP
Delete a dataset and all its data.
```bash
LANCE DROP <dataset>
```
### Data Operations
#### LANCE STORE
Store multimodal data (text/images) with automatic embedding generation.
```bash
LANCE STORE <dataset> [TEXT <text>] [IMAGE <base64>] [key value ...]
```
**Parameters:**
- `dataset`: Target dataset name
- `TEXT`: Text content to embed
- `IMAGE`: Base64-encoded image to embed
- `key value`: Metadata key-value pairs
**Examples:**
```bash
# Store text with metadata
LANCE STORE documents TEXT "Machine learning is transforming industries" category "AI" author "John Doe"
# Store image with metadata
LANCE STORE images IMAGE "iVBORw0KGgoAAAANSUhEUgAA..." category "nature" tags "landscape,mountains"
# Store both text and image
LANCE STORE multimodal TEXT "Beautiful sunset" IMAGE "base64data..." location "California"
```
**Returns:** Unique ID of the stored item
### Search Operations
#### LANCE SEARCH
Search using a raw vector.
```bash
LANCE SEARCH <dataset> VECTOR <vector> K <k> [NPROBES <n>] [REFINE <r>]
```
**Parameters:**
- `dataset`: Dataset to search
- `vector`: Comma-separated vector values (e.g., "0.1,0.2,0.3")
- `k`: Number of results to return
- `NPROBES`: Number of partitions to search (optional)
- `REFINE`: Refine factor for better accuracy (optional)
**Example:**
```bash
LANCE SEARCH documents VECTOR "0.1,0.2,0.3,0.4" K 5 NPROBES 10
```
#### LANCE SEARCH.TEXT
Search using text query (automatically embedded).
```bash
LANCE SEARCH.TEXT <dataset> <query_text> K <k> [NPROBES <n>] [REFINE <r>]
```
**Parameters:**
- `dataset`: Dataset to search
- `query_text`: Text query to search for
- `k`: Number of results to return
- `NPROBES`: Number of partitions to search (optional)
- `REFINE`: Refine factor for better accuracy (optional)
**Example:**
```bash
LANCE SEARCH.TEXT documents "artificial intelligence applications" K 10 NPROBES 20
```
**Returns:** Array of results with distance scores and metadata
### Embedding Operations
#### LANCE EMBED.TEXT
Generate embeddings for text without storing.
```bash
LANCE EMBED.TEXT <text1> [text2] [text3] ...
```
**Example:**
```bash
LANCE EMBED.TEXT "Hello world" "Machine learning" "Vector database"
```
**Returns:** Array of embedding vectors
### Index Management
#### LANCE CREATE.INDEX
Create a vector index for faster search performance.
```bash
LANCE CREATE.INDEX <dataset> <index_type> [PARTITIONS <n>] [SUBVECTORS <n>]
```
**Parameters:**
- `dataset`: Dataset to index
- `index_type`: Index type (currently supports "IVF_PQ")
- `PARTITIONS`: Number of partitions (default: 256)
- `SUBVECTORS`: Number of sub-vectors for PQ (default: 16)
**Example:**
```bash
LANCE CREATE.INDEX documents IVF_PQ PARTITIONS 512 SUBVECTORS 32
```
## Usage Patterns
### 1. Document Search System
```bash
# Setup
LANCE CREATE documents DIM 384 SCHEMA title:string content:string category:string
# Store documents
LANCE STORE documents TEXT "Introduction to machine learning algorithms" title "ML Basics" category "education"
LANCE STORE documents TEXT "Deep learning neural networks explained" title "Deep Learning" category "education"
LANCE STORE documents TEXT "Building scalable web applications" title "Web Dev" category "programming"
# Create index for better performance
LANCE CREATE.INDEX documents IVF_PQ PARTITIONS 256
# Search
LANCE SEARCH.TEXT documents "neural networks" K 5
```
### 2. Image Similarity Search
```bash
# Setup
LANCE CREATE images DIM 512 SCHEMA filename:string tags:string
# Store images (base64 encoded)
LANCE STORE images IMAGE "iVBORw0KGgoAAAANSUhEUgAA..." filename "sunset.jpg" tags "nature,landscape"
LANCE STORE images IMAGE "iVBORw0KGgoAAAANSUhEUgBB..." filename "city.jpg" tags "urban,architecture"
# Search by image
LANCE STORE temp_search IMAGE "query_image_base64..."
# Then use the returned ID to get embedding and search
```
### 3. Multimodal Content Management
```bash
# Setup
LANCE CREATE content DIM 768 SCHEMA type:string source:string
# Store mixed content
LANCE STORE content TEXT "Product description for smartphone" type "product" source "catalog"
LANCE STORE content IMAGE "product_image_base64..." type "product_image" source "catalog"
# Search across all content types
LANCE SEARCH.TEXT content "smartphone features" K 10
```
## Performance Considerations
### Vector Dimensions
- **384**: Good for general text (e.g., sentence-transformers)
- **768**: Standard for BERT-like models
- **1536**: OpenAI text-embedding-ada-002
- **Higher dimensions**: Better accuracy but slower search
### Index Configuration
- **More partitions**: Better for larger datasets (>100K vectors)
- **More sub-vectors**: Better compression but slower search
- **NPROBES**: Higher values = better accuracy, slower search
### Best Practices
1. **Create indexes** for datasets with >1000 vectors
2. **Use appropriate dimensions** based on your embedding model
3. **Configure NPROBES** based on accuracy vs speed requirements
4. **Batch operations** when possible for better performance
5. **Monitor embedding service** response times and rate limits
## Error Handling
Common error scenarios and solutions:
### Embedding Service Errors
```bash
# Error: Embedding service not configured
ERR Embedding service URL not configured. Set it with: HSET config:core:aiembed url <YOUR_EMBEDDING_SERVICE_URL>
# Error: Service unavailable
ERR Embedding service returned error 404 Not Found
```
**Solution:** Ensure embedding service is running and URL is correct.
### Dataset Errors
```bash
# Error: Dataset doesn't exist
ERR Dataset 'mydata' does not exist
# Error: Dimension mismatch
ERR Vector dimension mismatch: expected 384, got 768
```
**Solution:** Create dataset first or check vector dimensions.
### Search Errors
```bash
# Error: Invalid vector format
ERR Invalid vector format
# Error: No index available
ERR No index available for fast search
```
**Solution:** Check vector format or create an index.
## Integration Examples
### With Python
```python
import redis
import json
r = redis.Redis(host='localhost', port=6379)
# Create dataset
r.execute_command('LANCE', 'CREATE', 'docs', 'DIM', '384')
# Store document
result = r.execute_command('LANCE', 'STORE', 'docs',
'TEXT', 'Machine learning tutorial',
'category', 'education')
print(f"Stored with ID: {result}")
# Search
results = r.execute_command('LANCE', 'SEARCH.TEXT', 'docs',
'machine learning', 'K', '5')
print(f"Search results: {results}")
```
### With Node.js
```javascript
const redis = require('redis');
const client = redis.createClient();
// Create dataset
await client.sendCommand(['LANCE', 'CREATE', 'docs', 'DIM', '384']);
// Store document
const id = await client.sendCommand(['LANCE', 'STORE', 'docs',
'TEXT', 'Deep learning guide',
'category', 'AI']);
// Search
const results = await client.sendCommand(['LANCE', 'SEARCH.TEXT', 'docs',
'deep learning', 'K', '10']);
```
## Monitoring and Maintenance
### Health Checks
```bash
# Check if Lance store is available
LANCE LIST
# Check dataset health
LANCE INFO mydataset
# Test embedding service
LANCE EMBED.TEXT "test"
```
### Maintenance Operations
```bash
# Backup: Use standard Redis backup procedures
# The Lance data is stored separately in the data directory
# Cleanup: Remove unused datasets
LANCE DROP old_dataset
# Reindex: Drop and recreate indexes if needed
LANCE DROP dataset_name
LANCE CREATE dataset_name DIM 384
# Re-import data
LANCE CREATE.INDEX dataset_name IVF_PQ
```
This integration provides a powerful foundation for building AI-powered applications with vector search capabilities while maintaining the familiar Redis interface.

View File

@@ -1,6 +1,191 @@
# HeroDB Tantivy Search Examples # HeroDB Examples
This directory contains examples demonstrating HeroDB's full-text search capabilities powered by Tantivy. This directory contains examples demonstrating HeroDB's capabilities including full-text search powered by Tantivy and vector database operations using Lance.
## Available Examples
1. **[Tantivy Search Demo](#tantivy-search-demo-bash-script)** - Full-text search capabilities
2. **[Lance Vector Database Demo](#lance-vector-database-demo-bash-script)** - Vector database and AI operations
3. **[AGE Encryption Demo](age_bash_demo.sh)** - Cryptographic operations
4. **[Simple Demo](simple_demo.sh)** - Basic Redis operations
---
## Lance Vector Database Demo (Bash Script)
### Overview
The `lance_vector_demo.sh` script provides a comprehensive demonstration of HeroDB's vector database capabilities using Lance. It showcases vector storage, similarity search, multimodal data handling, and AI-powered operations with external embedding services.
### Prerequisites
1. **HeroDB Server**: The server must be running (default port 6379)
2. **Redis CLI**: The `redis-cli` tool must be installed and available in your PATH
3. **Embedding Service** (optional): For full functionality, set up an external embedding service
### Running the Demo
#### Step 1: Start HeroDB Server
```bash
# From the project root directory
cargo run -- --dir ./test_data --port 6379
```
#### Step 2: Run the Demo (in a new terminal)
```bash
# From the project root directory
./examples/lance_vector_demo.sh
```
### What the Demo Covers
The script demonstrates comprehensive vector database operations:
1. **Dataset Management**
- Creating vector datasets with custom dimensions
- Defining schemas with metadata fields
- Listing and inspecting datasets
- Dataset information and statistics
2. **Embedding Operations**
- Text embedding generation via external services
- Multimodal embedding support (text + images)
- Batch embedding operations
3. **Data Storage**
- Storing text documents with automatic embedding
- Storing images with metadata
- Multimodal content storage
- Rich metadata support
4. **Vector Search**
- Similarity search with raw vectors
- Text-based semantic search
- Configurable search parameters (K, NPROBES, REFINE)
- Cross-modal search capabilities
5. **Index Management**
- Creating IVF_PQ indexes for performance
- Custom index parameters
- Performance optimization
6. **Advanced Features**
- Error handling and recovery
- Performance testing concepts
- Monitoring and maintenance
- Cleanup operations
### Key Lance Commands Demonstrated
#### Dataset Management
```bash
# Create vector dataset
LANCE CREATE documents DIM 384
# Create dataset with schema
LANCE CREATE products DIM 768 SCHEMA category:string price:float available:bool
# List datasets
LANCE LIST
# Get dataset information
LANCE INFO documents
```
#### Data Operations
```bash
# Store text with metadata
LANCE STORE documents TEXT "Machine learning tutorial" category "education" author "John Doe"
# Store image with metadata
LANCE STORE images IMAGE "base64_encoded_image..." filename "photo.jpg" tags "nature,landscape"
# Store multimodal content
LANCE STORE content TEXT "Product description" IMAGE "base64_image..." type "product"
```
#### Search Operations
```bash
# Search with raw vector
LANCE SEARCH documents VECTOR "0.1,0.2,0.3,0.4" K 5
# Semantic text search
LANCE SEARCH.TEXT documents "artificial intelligence" K 10 NPROBES 20
# Generate embeddings
LANCE EMBED.TEXT "Hello world" "Machine learning"
```
#### Index Management
```bash
# Create performance index
LANCE CREATE.INDEX documents IVF_PQ PARTITIONS 256 SUBVECTORS 16
# Drop dataset
LANCE DROP old_dataset
```
### Configuration
#### Setting Up Embedding Service
```bash
# Configure embedding service URL
redis-cli HSET config:core:aiembed url "http://your-embedding-service:8080/embed"
# Optional: Set authentication token
redis-cli HSET config:core:aiembed token "your-api-token"
```
#### Embedding Service API
Your embedding service should accept POST requests:
```json
{
"texts": ["text1", "text2"],
"images": ["base64_image1", "base64_image2"],
"model": "your-model-name"
}
```
And return responses:
```json
{
"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...]],
"model": "model-name",
"usage": {"tokens": 100, "requests": 2}
}
```
### Interactive Features
The demo script includes:
- **Colored output** for better readability
- **Step-by-step execution** with explanations
- **Error handling** demonstrations
- **Automatic cleanup** options
- **Performance testing** concepts
- **Real-world usage** examples
### Use Cases Demonstrated
1. **Document Search System**
- Semantic document retrieval
- Metadata filtering
- Relevance ranking
2. **Image Similarity Search**
- Visual content matching
- Tag-based filtering
- Multimodal queries
3. **Product Recommendations**
- Feature-based similarity
- Category filtering
- Price range queries
4. **Content Management**
- Mixed media storage
- Cross-modal search
- Rich metadata support
---
## Tantivy Search Demo (Bash Script) ## Tantivy Search Demo (Bash Script)

View File

@@ -14,31 +14,25 @@ fn read_reply(s: &mut TcpStream) -> String {
let n = s.read(&mut buf).unwrap(); let n = s.read(&mut buf).unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
} }
fn parse_two_bulk(reply: &str) -> Option<(String, String)> { fn parse_two_bulk(reply: &str) -> Option<(String,String)> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
if lines.next()? != "*2" { if lines.next()? != "*2" { return None; }
return None;
}
let _n = lines.next()?; let _n = lines.next()?;
let a = lines.next()?.to_string(); let a = lines.next()?.to_string();
let _m = lines.next()?; let _m = lines.next()?;
let b = lines.next()?.to_string(); let b = lines.next()?.to_string();
Some((a, b)) Some((a,b))
} }
fn parse_bulk(reply: &str) -> Option<String> { fn parse_bulk(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
let hdr = lines.next()?; let hdr = lines.next()?;
if !hdr.starts_with('$') { if !hdr.starts_with('$') { return None; }
return None;
}
Some(lines.next()?.to_string()) Some(lines.next()?.to_string())
} }
fn parse_simple(reply: &str) -> Option<String> { fn parse_simple(reply: &str) -> Option<String> {
let mut lines = reply.split("\r\n"); let mut lines = reply.split("\r\n");
let hdr = lines.next()?; let hdr = lines.next()?;
if !hdr.starts_with('+') { if !hdr.starts_with('+') { return None; }
return None;
}
Some(hdr[1..].to_string()) Some(hdr[1..].to_string())
} }
@@ -51,45 +45,39 @@ fn main() {
let mut s = TcpStream::connect(addr).expect("connect"); let mut s = TcpStream::connect(addr).expect("connect");
// Generate & persist X25519 enc keys under name "alice" // Generate & persist X25519 enc keys under name "alice"
s.write_all(arr(&["age", "keygen", "alice"]).as_bytes()) s.write_all(arr(&["age","keygen","alice"]).as_bytes()).unwrap();
.unwrap();
let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc"); let (_alice_recip, _alice_ident) = parse_two_bulk(&read_reply(&mut s)).expect("gen enc");
// Generate & persist Ed25519 signing key under name "signer" // Generate & persist Ed25519 signing key under name "signer"
s.write_all(arr(&["age", "signkeygen", "signer"]).as_bytes()) s.write_all(arr(&["age","signkeygen","signer"]).as_bytes()).unwrap();
.unwrap();
let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign"); let (_verify, _secret) = parse_two_bulk(&read_reply(&mut s)).expect("gen sign");
// Encrypt by name // Encrypt by name
let msg = "hello from persistent keys"; let msg = "hello from persistent keys";
s.write_all(arr(&["age", "encryptname", "alice", msg]).as_bytes()) s.write_all(arr(&["age","encryptname","alice", msg]).as_bytes()).unwrap();
.unwrap();
let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64"); let ct_b64 = parse_bulk(&read_reply(&mut s)).expect("ct b64");
println!("ciphertext b64: {}", ct_b64); println!("ciphertext b64: {}", ct_b64);
// Decrypt by name // Decrypt by name
s.write_all(arr(&["age", "decryptname", "alice", &ct_b64]).as_bytes()) s.write_all(arr(&["age","decryptname","alice", &ct_b64]).as_bytes()).unwrap();
.unwrap();
let pt = parse_bulk(&read_reply(&mut s)).expect("pt"); let pt = parse_bulk(&read_reply(&mut s)).expect("pt");
assert_eq!(pt, msg); assert_eq!(pt, msg);
println!("decrypted ok"); println!("decrypted ok");
// Sign by name // Sign by name
s.write_all(arr(&["age", "signname", "signer", msg]).as_bytes()) s.write_all(arr(&["age","signname","signer", msg]).as_bytes()).unwrap();
.unwrap();
let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64"); let sig_b64 = parse_bulk(&read_reply(&mut s)).expect("sig b64");
// Verify by name // Verify by name
s.write_all(arr(&["age", "verifyname", "signer", msg, &sig_b64]).as_bytes()) s.write_all(arr(&["age","verifyname","signer", msg, &sig_b64]).as_bytes()).unwrap();
.unwrap();
let ok = parse_simple(&read_reply(&mut s)).expect("verify"); let ok = parse_simple(&read_reply(&mut s)).expect("verify");
assert_eq!(ok, "1"); assert_eq!(ok, "1");
println!("signature verified"); println!("signature verified");
// List names // List names
s.write_all(arr(&["age", "list"]).as_bytes()).unwrap(); s.write_all(arr(&["age","list"]).as_bytes()).unwrap();
let list = read_reply(&mut s); let list = read_reply(&mut s);
println!("LIST -> {list}"); println!("LIST -> {list}");
println!("✔ persistent AGE workflow complete."); println!("✔ persistent AGE workflow complete.");
} }

426
examples/lance_vector_demo.sh Executable file
View File

@@ -0,0 +1,426 @@
#!/bin/bash
# Lance Vector Database Demo Script
# This script demonstrates all Lance vector database operations in HeroDB
set -e # Exit on any error
# Configuration
REDIS_HOST="localhost"
REDIS_PORT="6379"
REDIS_CLI="redis-cli -h $REDIS_HOST -p $REDIS_PORT"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Helper functions
log_info() {
echo -e "${BLUE}[INFO]${NC} $1"
}
log_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
log_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
execute_command() {
local cmd="$1"
local description="$2"
echo
log_info "Executing: $description"
echo "Command: $cmd"
if result=$($cmd 2>&1); then
log_success "Result: $result"
else
log_error "Failed: $result"
return 1
fi
}
# Check if HeroDB is running
check_herodb() {
log_info "Checking if HeroDB is running..."
if ! $REDIS_CLI ping > /dev/null 2>&1; then
log_error "HeroDB is not running. Please start it first:"
echo " cargo run -- --dir ./test_data --port $REDIS_PORT"
exit 1
fi
log_success "HeroDB is running"
}
# Setup embedding service configuration
setup_embedding_service() {
log_info "Setting up embedding service configuration..."
# Note: This is a mock URL for demonstration
# In production, replace with your actual embedding service
execute_command \
"$REDIS_CLI HSET config:core:aiembed url 'http://localhost:8080/embed'" \
"Configure embedding service URL"
# Optional: Set authentication token
# execute_command \
# "$REDIS_CLI HSET config:core:aiembed token 'your-api-token'" \
# "Configure embedding service token"
log_warning "Note: Embedding service at http://localhost:8080/embed is not running."
log_warning "Some operations will fail, but this demonstrates the command structure."
}
# Dataset Management Operations
demo_dataset_management() {
echo
echo "=========================================="
echo " DATASET MANAGEMENT DEMO"
echo "=========================================="
# List datasets (should be empty initially)
execute_command \
"$REDIS_CLI LANCE LIST" \
"List all datasets (initially empty)"
# Create a simple dataset
execute_command \
"$REDIS_CLI LANCE CREATE documents DIM 384" \
"Create a simple document dataset with 384 dimensions"
# Create a dataset with schema
execute_command \
"$REDIS_CLI LANCE CREATE products DIM 768 SCHEMA category:string price:float available:bool description:string" \
"Create products dataset with custom schema"
# Create an image dataset
execute_command \
"$REDIS_CLI LANCE CREATE images DIM 512 SCHEMA filename:string tags:string width:int height:int" \
"Create images dataset for multimodal content"
# List datasets again
execute_command \
"$REDIS_CLI LANCE LIST" \
"List all datasets (should show 3 datasets)"
# Get info about datasets
execute_command \
"$REDIS_CLI LANCE INFO documents" \
"Get information about documents dataset"
execute_command \
"$REDIS_CLI LANCE INFO products" \
"Get information about products dataset"
}
# Embedding Operations
demo_embedding_operations() {
echo
echo "=========================================="
echo " EMBEDDING OPERATIONS DEMO"
echo "=========================================="
log_warning "The following operations will fail because no embedding service is running."
log_warning "This demonstrates the command structure and error handling."
# Try to embed text (will fail without embedding service)
execute_command \
"$REDIS_CLI LANCE EMBED.TEXT 'Hello world'" \
"Generate embedding for single text" || true
# Try to embed multiple texts
execute_command \
"$REDIS_CLI LANCE EMBED.TEXT 'Machine learning' 'Artificial intelligence' 'Deep learning'" \
"Generate embeddings for multiple texts" || true
}
# Data Storage Operations
demo_data_storage() {
echo
echo "=========================================="
echo " DATA STORAGE DEMO"
echo "=========================================="
log_warning "Storage operations will fail without embedding service, but show command structure."
# Store text documents
execute_command \
"$REDIS_CLI LANCE STORE documents TEXT 'Introduction to machine learning algorithms and their applications in modern AI systems' category 'education' author 'John Doe' difficulty 'beginner'" \
"Store a document with text and metadata" || true
execute_command \
"$REDIS_CLI LANCE STORE documents TEXT 'Deep learning neural networks for computer vision tasks' category 'research' author 'Jane Smith' difficulty 'advanced'" \
"Store another document" || true
# Store product information
execute_command \
"$REDIS_CLI LANCE STORE products TEXT 'High-performance laptop with 16GB RAM and SSD storage' category 'electronics' price '1299.99' available 'true'" \
"Store product with text description" || true
# Store image with metadata (using placeholder base64)
execute_command \
"$REDIS_CLI LANCE STORE images IMAGE 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==' filename 'sample.png' tags 'test,demo' width '1' height '1'" \
"Store image with metadata (1x1 pixel PNG)" || true
# Store multimodal content
execute_command \
"$REDIS_CLI LANCE STORE images TEXT 'Beautiful sunset over mountains' IMAGE 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==' filename 'sunset.png' tags 'nature,landscape' location 'California'" \
"Store multimodal content (text + image)" || true
}
# Search Operations
demo_search_operations() {
echo
echo "=========================================="
echo " SEARCH OPERATIONS DEMO"
echo "=========================================="
log_warning "Search operations will fail without data, but show command structure."
# Search with raw vector
execute_command \
"$REDIS_CLI LANCE SEARCH documents VECTOR '0.1,0.2,0.3,0.4,0.5' K 5" \
"Search with raw vector (5 results)" || true
# Search with vector and parameters
execute_command \
"$REDIS_CLI LANCE SEARCH documents VECTOR '0.1,0.2,0.3,0.4,0.5' K 10 NPROBES 20 REFINE 2" \
"Search with vector and advanced parameters" || true
# Text-based search
execute_command \
"$REDIS_CLI LANCE SEARCH.TEXT documents 'machine learning algorithms' K 5" \
"Search using text query" || true
# Text search with parameters
execute_command \
"$REDIS_CLI LANCE SEARCH.TEXT products 'laptop computer' K 3 NPROBES 10" \
"Search products using text with parameters" || true
# Search in image dataset
execute_command \
"$REDIS_CLI LANCE SEARCH.TEXT images 'sunset landscape' K 5" \
"Search images using text description" || true
}
# Index Management Operations
demo_index_management() {
echo
echo "=========================================="
echo " INDEX MANAGEMENT DEMO"
echo "=========================================="
# Create indexes for better search performance
execute_command \
"$REDIS_CLI LANCE CREATE.INDEX documents IVF_PQ" \
"Create default IVF_PQ index for documents"
execute_command \
"$REDIS_CLI LANCE CREATE.INDEX products IVF_PQ PARTITIONS 512 SUBVECTORS 32" \
"Create IVF_PQ index with custom parameters for products"
execute_command \
"$REDIS_CLI LANCE CREATE.INDEX images IVF_PQ PARTITIONS 256 SUBVECTORS 16" \
"Create IVF_PQ index for images dataset"
log_success "Indexes created successfully"
}
# Advanced Usage Examples
demo_advanced_usage() {
echo
echo "=========================================="
echo " ADVANCED USAGE EXAMPLES"
echo "=========================================="
# Create a specialized dataset for semantic search
execute_command \
"$REDIS_CLI LANCE CREATE semantic_search DIM 1536 SCHEMA title:string content:string url:string timestamp:string source:string" \
"Create dataset for semantic search with rich metadata"
# Demonstrate batch operations concept
log_info "Batch operations example (would store multiple items):"
echo " for doc in documents:"
echo " LANCE STORE semantic_search TEXT \"\$doc_content\" title \"\$title\" url \"\$url\""
# Show monitoring commands
log_info "Monitoring and maintenance commands:"
execute_command \
"$REDIS_CLI LANCE LIST" \
"List all datasets for monitoring"
# Show dataset statistics
for dataset in documents products images semantic_search; do
execute_command \
"$REDIS_CLI LANCE INFO $dataset" \
"Get statistics for $dataset" || true
done
}
# Cleanup Operations
demo_cleanup() {
echo
echo "=========================================="
echo " CLEANUP OPERATIONS DEMO"
echo "=========================================="
log_info "Demonstrating cleanup operations..."
# Drop individual datasets
execute_command \
"$REDIS_CLI LANCE DROP semantic_search" \
"Drop semantic_search dataset"
# List remaining datasets
execute_command \
"$REDIS_CLI LANCE LIST" \
"List remaining datasets"
# Ask user if they want to clean up all test data
echo
read -p "Do you want to clean up all test datasets? (y/N): " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
execute_command \
"$REDIS_CLI LANCE DROP documents" \
"Drop documents dataset"
execute_command \
"$REDIS_CLI LANCE DROP products" \
"Drop products dataset"
execute_command \
"$REDIS_CLI LANCE DROP images" \
"Drop images dataset"
execute_command \
"$REDIS_CLI LANCE LIST" \
"Verify all datasets are cleaned up"
log_success "All test datasets cleaned up"
else
log_info "Keeping test datasets for further experimentation"
fi
}
# Error Handling Demo
demo_error_handling() {
echo
echo "=========================================="
echo " ERROR HANDLING DEMO"
echo "=========================================="
log_info "Demonstrating various error conditions..."
# Try to access non-existent dataset
execute_command \
"$REDIS_CLI LANCE INFO nonexistent_dataset" \
"Try to get info for non-existent dataset" || true
# Try to search non-existent dataset
execute_command \
"$REDIS_CLI LANCE SEARCH nonexistent_dataset VECTOR '0.1,0.2' K 5" \
"Try to search non-existent dataset" || true
# Try to drop non-existent dataset
execute_command \
"$REDIS_CLI LANCE DROP nonexistent_dataset" \
"Try to drop non-existent dataset" || true
# Try invalid vector format
execute_command \
"$REDIS_CLI LANCE SEARCH documents VECTOR 'invalid,vector,format' K 5" \
"Try search with invalid vector format" || true
log_info "Error handling demonstration complete"
}
# Performance Testing Demo
demo_performance_testing() {
echo
echo "=========================================="
echo " PERFORMANCE TESTING DEMO"
echo "=========================================="
log_info "Creating performance test dataset..."
execute_command \
"$REDIS_CLI LANCE CREATE perf_test DIM 128 SCHEMA batch_id:string item_id:string" \
"Create performance test dataset"
log_info "Performance testing would involve:"
echo " 1. Bulk loading thousands of vectors"
echo " 2. Creating indexes with different parameters"
echo " 3. Measuring search latency with various K values"
echo " 4. Testing different NPROBES settings"
echo " 5. Monitoring memory usage"
log_info "Example performance test commands:"
echo " # Test search speed with different parameters"
echo " time redis-cli LANCE SEARCH.TEXT perf_test 'query' K 10"
echo " time redis-cli LANCE SEARCH.TEXT perf_test 'query' K 10 NPROBES 50"
echo " time redis-cli LANCE SEARCH.TEXT perf_test 'query' K 100 NPROBES 100"
# Clean up performance test dataset
execute_command \
"$REDIS_CLI LANCE DROP perf_test" \
"Clean up performance test dataset"
}
# Main execution
main() {
echo "=========================================="
echo " LANCE VECTOR DATABASE DEMO SCRIPT"
echo "=========================================="
echo
echo "This script demonstrates all Lance vector database operations."
echo "Note: Some operations will fail without a running embedding service."
echo "This is expected and demonstrates error handling."
echo
# Check prerequisites
check_herodb
# Setup
setup_embedding_service
# Run demos
demo_dataset_management
demo_embedding_operations
demo_data_storage
demo_search_operations
demo_index_management
demo_advanced_usage
demo_error_handling
demo_performance_testing
# Cleanup
demo_cleanup
echo
echo "=========================================="
echo " DEMO COMPLETE"
echo "=========================================="
echo
log_success "Lance vector database demo completed successfully!"
echo
echo "Next steps:"
echo "1. Set up a real embedding service (OpenAI, Hugging Face, etc.)"
echo "2. Update the embedding service URL configuration"
echo "3. Try storing and searching real data"
echo "4. Experiment with different vector dimensions and index parameters"
echo "5. Build your AI-powered application!"
echo
echo "For more information, see docs/lance_vector_db.md"
}
# Run the demo
main "$@"

View File

@@ -1,239 +0,0 @@
#!/bin/bash
# HeroDB Tantivy Search Demo
# This script demonstrates full-text search capabilities using Redis commands
# HeroDB server should be running on port 6381
set -e # Exit on any error
# Configuration
REDIS_HOST="localhost"
REDIS_PORT="6382"
REDIS_CLI="redis-cli -h $REDIS_HOST -p $REDIS_PORT"
# Start the herodb server in the background
echo "Starting herodb server..."
cargo run -p herodb -- --dir /tmp/herodbtest --port ${REDIS_PORT} --debug &
SERVER_PID=$!
echo
sleep 2 # Give the server a moment to start
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
BLUE='\033[0;34m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Function to print colored output
print_header() {
echo -e "${BLUE}=== $1 ===${NC}"
}
print_success() {
echo -e "${GREEN}$1${NC}"
}
print_info() {
echo -e "${YELLOW} $1${NC}"
}
print_error() {
echo -e "${RED}$1${NC}"
}
# Function to check if HeroDB is running
check_herodb() {
print_info "Checking if HeroDB is running on port $REDIS_PORT..."
if ! $REDIS_CLI ping > /dev/null 2>&1; then
print_error "HeroDB is not running on port $REDIS_PORT"
print_info "Please start HeroDB with: cargo run -- --port $REDIS_PORT"
exit 1
fi
print_success "HeroDB is running and responding"
}
# Function to execute Redis command with error handling
execute_cmd() {
local description="${@: -1}"
set -- "${@:1:$(($#-1))}"
echo -e "${YELLOW}Command:${NC} $(printf '%q ' "$@")"
if result=$($REDIS_CLI "$@" 2>&1); then
echo -e "${GREEN}Result:${NC} $result"
return 0
else
print_error "Failed: $description"
echo "Error: $result"
return 1
fi
}
# Function to pause for readability
pause() {
echo
read -p "Press Enter to continue..."
echo
}
# Main demo function
main() {
clear
print_header "HeroDB Tantivy Search Demonstration"
echo "This demo shows full-text search capabilities using Redis commands"
echo "HeroDB runs on port $REDIS_PORT (instead of Redis default 6379)"
echo
# Check if HeroDB is running
check_herodb
echo
print_header "Step 1: Create Search Index"
print_info "Creating a product catalog search index with various field types"
# Create search index with schema
execute_cmd FT.CREATE product_catalog SCHEMA title TEXT description TEXT category TAG price NUMERIC rating NUMERIC location GEO \
"Creating search index"
print_success "Search index 'product_catalog' created successfully"
pause
print_header "Step 2: Add Sample Products"
print_info "Adding sample products to demonstrate different search scenarios"
# Add sample products using FT.ADD
execute_cmd FT.ADD product_catalog product:1 1.0 title 'Wireless Bluetooth Headphones' description 'Premium noise-canceling headphones with 30-hour battery life' category 'electronics,audio' price 299.99 rating 4.5 location '-122.4194,37.7749' "Adding product 1"
execute_cmd FT.ADD product_catalog product:2 1.0 title 'Organic Coffee Beans' description 'Single-origin Ethiopian coffee beans, medium roast' category 'food,beverages,organic' price 24.99 rating 4.8 location '-74.0060,40.7128' "Adding product 2"
execute_cmd FT.ADD product_catalog product:3 1.0 title 'Yoga Mat Premium' description 'Eco-friendly yoga mat with superior grip and cushioning' category 'fitness,wellness,eco-friendly' price 89.99 rating 4.3 location '-118.2437,34.0522' "Adding product 3"
execute_cmd FT.ADD product_catalog product:4 1.0 title 'Smart Home Speaker' description 'Voice-controlled smart speaker with AI assistant' category 'electronics,smart-home' price 149.99 rating 4.2 location '-87.6298,41.8781' "Adding product 4"
execute_cmd FT.ADD product_catalog product:5 1.0 title 'Organic Green Tea' description 'Premium organic green tea leaves from Japan' category 'food,beverages,organic,tea' price 18.99 rating 4.7 location '139.6503,35.6762' "Adding product 5"
execute_cmd FT.ADD product_catalog product:6 1.0 title 'Wireless Gaming Mouse' description 'High-precision gaming mouse with RGB lighting' category 'electronics,gaming' price 79.99 rating 4.4 location '-122.3321,47.6062' "Adding product 6"
execute_cmd FT.ADD product_catalog product:7 1.0 title 'Comfortable meditation cushion for mindfulness practice' description 'Meditation cushion with premium materials' category 'wellness,meditation' price 45.99 rating 4.6 location '-122.4194,37.7749' "Adding product 7"
execute_cmd FT.ADD product_catalog product:8 1.0 title 'Bluetooth Earbuds' description 'True wireless earbuds with active noise cancellation' category 'electronics,audio' price 199.99 rating 4.1 location '-74.0060,40.7128' "Adding product 8"
print_success "Added 8 products to the index"
pause
print_header "Step 3: Basic Text Search"
print_info "Searching for 'wireless' products"
execute_cmd FT.SEARCH product_catalog wireless "Basic text search"
pause
print_header "Step 4: Search with Filters"
print_info "Searching for 'organic' products"
execute_cmd FT.SEARCH product_catalog organic "Filtered search"
pause
print_header "Step 5: Numeric Range Search"
print_info "Searching for 'premium' products"
execute_cmd FT.SEARCH product_catalog premium "Text search"
pause
print_header "Step 6: Sorting Results"
print_info "Searching for electronics"
execute_cmd FT.SEARCH product_catalog electronics "Category search"
pause
print_header "Step 7: Limiting Results"
print_info "Searching for wireless products with limit"
execute_cmd FT.SEARCH product_catalog wireless LIMIT 0 3 "Limited results"
pause
print_header "Step 8: Complex Query"
print_info "Finding audio products with noise cancellation"
execute_cmd FT.SEARCH product_catalog 'noise cancellation' "Complex query"
pause
print_header "Step 9: Geographic Search"
print_info "Searching for meditation products"
execute_cmd FT.SEARCH product_catalog meditation "Text search"
pause
print_header "Step 10: Aggregation Example"
print_info "Getting index information and statistics"
execute_cmd FT.INFO product_catalog "Index information"
pause
print_header "Step 11: Search Comparison"
print_info "Comparing Tantivy search vs simple key matching"
echo -e "${YELLOW}Tantivy Full-Text Search:${NC}"
execute_cmd FT.SEARCH product_catalog 'battery life' "Full-text search for 'battery life'"
echo
echo -e "${YELLOW}Simple Key Pattern Matching:${NC}"
execute_cmd KEYS *battery* "Simple pattern matching for 'battery'"
print_info "Notice how full-text search finds relevant results even when exact words don't match keys"
pause
print_header "Step 12: Fuzzy Search"
print_info "Searching for headphones"
execute_cmd FT.SEARCH product_catalog headphones "Text search"
pause
print_header "Step 13: Phrase Search"
print_info "Searching for coffee products"
execute_cmd FT.SEARCH product_catalog coffee "Text search"
pause
print_header "Step 14: Boolean Queries"
print_info "Searching for gaming products"
execute_cmd FT.SEARCH product_catalog gaming "Text search"
echo
execute_cmd FT.SEARCH product_catalog tea "Text search"
pause
print_header "Step 15: Cleanup"
print_info "Removing test data"
# Delete the search index
execute_cmd FT.DROP product_catalog "Dropping search index"
# Clean up documents from search index
for i in {1..8}; do
execute_cmd FT.DEL product_catalog product:$i "Deleting product:$i from index"
done
print_success "Cleanup completed"
echo
print_header "Demo Summary"
echo "This demonstration showed:"
echo "• Creating search indexes with different field types"
echo "• Adding documents to the search index"
echo "• Basic and advanced text search queries"
echo "• Filtering by categories and numeric ranges"
echo "• Sorting and limiting results"
echo "• Geographic searches"
echo "• Fuzzy matching and phrase searches"
echo "• Boolean query operators"
echo "• Comparison with simple pattern matching"
echo
print_success "HeroDB Tantivy search demo completed successfully!"
echo
print_info "Key advantages of Tantivy full-text search:"
echo " - Relevance scoring and ranking"
echo " - Fuzzy matching and typo tolerance"
echo " - Complex boolean queries"
echo " - Field-specific searches and filters"
echo " - Geographic and numeric range queries"
echo " - Much faster than pattern matching on large datasets"
echo
print_info "To run HeroDB server: cargo run -- --port 6381"
print_info "To connect with redis-cli: redis-cli -h localhost -p 6381"
}
# Run the demo
main "$@"

View File

@@ -1,101 +0,0 @@
#!/bin/bash
# Simple Tantivy Search Integration Test for HeroDB
# This script tests the full-text search functionality we just integrated
set -e
echo "🔍 Testing Tantivy Search Integration..."
# Build the project first
echo "📦 Building HeroDB..."
cargo build --release
# Start the server in the background
echo "🚀 Starting HeroDB server on port 6379..."
cargo run --release -- --port 6379 --dir ./test_data &
SERVER_PID=$!
# Wait for server to start
sleep 3
# Function to cleanup on exit
cleanup() {
echo "🧹 Cleaning up..."
kill $SERVER_PID 2>/dev/null || true
rm -rf ./test_data
exit
}
# Set trap for cleanup
trap cleanup EXIT INT TERM
# Function to execute Redis command
execute_cmd() {
local cmd="$1"
local description="$2"
echo "📝 $description"
echo " Command: $cmd"
if result=$(redis-cli -p 6379 $cmd 2>&1); then
echo " ✅ Result: $result"
echo
return 0
else
echo " ❌ Failed: $result"
echo
return 1
fi
}
echo "🧪 Running Tantivy Search Tests..."
echo
# Test 1: Create a search index
execute_cmd "ft.create books SCHEMA title TEXT description TEXT author TEXT category TAG price NUMERIC" \
"Creating search index 'books'"
# Test 2: Add documents to the index
execute_cmd "ft.add books book1 1.0 title \"The Great Gatsby\" description \"A classic American novel about the Jazz Age\" author \"F. Scott Fitzgerald\" category \"fiction,classic\" price \"12.99\"" \
"Adding first book"
execute_cmd "ft.add books book2 1.0 title \"To Kill a Mockingbird\" description \"A novel about racial injustice in the American South\" author \"Harper Lee\" category \"fiction,classic\" price \"14.99\"" \
"Adding second book"
execute_cmd "ft.add books book3 1.0 title \"Programming Rust\" description \"A comprehensive guide to Rust programming language\" author \"Jim Blandy\" category \"programming,technical\" price \"49.99\"" \
"Adding third book"
execute_cmd "ft.add books book4 1.0 title \"The Rust Programming Language\" description \"The official book on Rust programming\" author \"Steve Klabnik\" category \"programming,technical\" price \"39.99\"" \
"Adding fourth book"
# Test 3: Basic search
execute_cmd "ft.search books Rust" \
"Searching for 'Rust'"
# Test 4: Search with filters
execute_cmd "ft.search books programming FILTER category programming" \
"Searching for 'programming' with category filter"
# Test 5: Search with limit
execute_cmd "ft.search books \"*\" LIMIT 0 2" \
"Getting first 2 documents"
# Test 6: Get index info
execute_cmd "ft.info books" \
"Getting index information"
# Test 7: Delete a document
execute_cmd "ft.del books book1" \
"Deleting book1"
# Test 8: Search again to verify deletion
execute_cmd "ft.search books Gatsby" \
"Searching for deleted book"
# Test 9: Drop the index
execute_cmd "ft.drop books" \
"Dropping the index"
echo "🎉 All tests completed successfully!"
echo "✅ Tantivy search integration is working correctly"

View File

@@ -12,17 +12,17 @@
use std::str::FromStr; use std::str::FromStr;
use age::x25519;
use age::{Decryptor, Encryptor};
use secrecy::ExposeSecret; use secrecy::ExposeSecret;
use age::{Decryptor, Encryptor};
use age::x25519;
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey}; use ed25519_dalek::{Signature, Signer, Verifier, SigningKey, VerifyingKey};
use base64::{engine::general_purpose::STANDARD as B64, Engine as _}; use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use crate::error::DBError;
use crate::protocol::Protocol; use crate::protocol::Protocol;
use crate::server::Server; use crate::server::Server;
use crate::error::DBError;
// ---------- Internal helpers ---------- // ---------- Internal helpers ----------
@@ -32,7 +32,7 @@ pub enum AgeWireError {
Crypto(String), Crypto(String),
Utf8, Utf8,
SignatureLen, SignatureLen,
NotFound(&'static str), // which kind of key was missing NotFound(&'static str), // which kind of key was missing
Storage(String), Storage(String),
} }
@@ -83,38 +83,34 @@ pub fn gen_enc_keypair() -> (String, String) {
} }
pub fn gen_sign_keypair() -> (String, String) { pub fn gen_sign_keypair() -> (String, String) {
use rand::rngs::OsRng;
use rand::RngCore; use rand::RngCore;
use rand::rngs::OsRng;
// Generate random 32 bytes for the signing key // Generate random 32 bytes for the signing key
let mut secret_bytes = [0u8; 32]; let mut secret_bytes = [0u8; 32];
OsRng.fill_bytes(&mut secret_bytes); OsRng.fill_bytes(&mut secret_bytes);
let signing_key = SigningKey::from_bytes(&secret_bytes); let signing_key = SigningKey::from_bytes(&secret_bytes);
let verifying_key = signing_key.verifying_key(); let verifying_key = signing_key.verifying_key();
// Encode as base64 for storage // Encode as base64 for storage
let signing_key_b64 = B64.encode(signing_key.to_bytes()); let signing_key_b64 = B64.encode(signing_key.to_bytes());
let verifying_key_b64 = B64.encode(verifying_key.to_bytes()); let verifying_key_b64 = B64.encode(verifying_key.to_bytes());
(verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret) (verifying_key_b64, signing_key_b64) // (verify_pub, signing_secret)
} }
/// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext). /// Encrypt `msg` for `recipient_str` (X25519). Returns base64(ciphertext).
pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> { pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireError> {
let recipient = parse_recipient(recipient_str)?; let recipient = parse_recipient(recipient_str)?;
let enc = let enc = Encryptor::with_recipients(vec![Box::new(recipient)])
Encryptor::with_recipients(vec![Box::new(recipient)]).expect("failed to create encryptor"); // Handle Option<Encryptor> .expect("failed to create encryptor"); // Handle Option<Encryptor>
let mut out = Vec::new(); let mut out = Vec::new();
{ {
use std::io::Write; use std::io::Write;
let mut w = enc let mut w = enc.wrap_output(&mut out).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
.wrap_output(&mut out) w.write_all(msg.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
.map_err(|e| AgeWireError::Crypto(e.to_string()))?; w.finish().map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.write_all(msg.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
w.finish()
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
} }
Ok(B64.encode(out)) Ok(B64.encode(out))
} }
@@ -122,27 +118,19 @@ pub fn encrypt_b64(recipient_str: &str, msg: &str) -> Result<String, AgeWireErro
/// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String. /// Decrypt base64(ciphertext) with `identity_str`. Returns plaintext String.
pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> { pub fn decrypt_b64(identity_str: &str, ct_b64: &str) -> Result<String, AgeWireError> {
let id = parse_identity(identity_str)?; let id = parse_identity(identity_str)?;
let ct = B64 let ct = B64.decode(ct_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
.decode(ct_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?; let dec = Decryptor::new(&ct[..]).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
// The decrypt method returns a Result<StreamReader, DecryptError> // The decrypt method returns a Result<StreamReader, DecryptError>
let mut r = match dec { let mut r = match dec {
Decryptor::Recipients(d) => d Decryptor::Recipients(d) => d.decrypt(std::iter::once(&id as &dyn age::Identity))
.decrypt(std::iter::once(&id as &dyn age::Identity))
.map_err(|e| AgeWireError::Crypto(e.to_string()))?, .map_err(|e| AgeWireError::Crypto(e.to_string()))?,
Decryptor::Passphrase(_) => { Decryptor::Passphrase(_) => return Err(AgeWireError::Crypto("Expected recipients, got passphrase".to_string())),
return Err(AgeWireError::Crypto(
"Expected recipients, got passphrase".to_string(),
))
}
}; };
let mut pt = Vec::new(); let mut pt = Vec::new();
use std::io::Read; use std::io::Read;
r.read_to_end(&mut pt) r.read_to_end(&mut pt).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
String::from_utf8(pt).map_err(|_| AgeWireError::Utf8) String::from_utf8(pt).map_err(|_| AgeWireError::Utf8)
} }
@@ -156,9 +144,7 @@ pub fn sign_b64(signing_secret_str: &str, msg: &str) -> Result<String, AgeWireEr
/// Verify detached signature (base64) for `msg` with pubkey. /// Verify detached signature (base64) for `msg` with pubkey.
pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> { pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool, AgeWireError> {
let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?; let verifying_key = parse_ed25519_verifying_key(verify_pub_str)?;
let sig_bytes = B64 let sig_bytes = B64.decode(sig_b64.as_bytes()).map_err(|e| AgeWireError::Crypto(e.to_string()))?;
.decode(sig_b64.as_bytes())
.map_err(|e| AgeWireError::Crypto(e.to_string()))?;
if sig_bytes.len() != 64 { if sig_bytes.len() != 64 {
return Err(AgeWireError::SignatureLen); return Err(AgeWireError::SignatureLen);
} }
@@ -169,49 +155,30 @@ pub fn verify_b64(verify_pub_str: &str, msg: &str, sig_b64: &str) -> Result<bool
// ---------- Storage helpers ---------- // ---------- Storage helpers ----------
fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> { fn sget(server: &Server, key: &str) -> Result<Option<String>, AgeWireError> {
let st = server let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
.current_storage()
.map_err(|e| AgeWireError::Storage(e.0))?;
st.get(key).map_err(|e| AgeWireError::Storage(e.0)) st.get(key).map_err(|e| AgeWireError::Storage(e.0))
} }
fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> { fn sset(server: &Server, key: &str, val: &str) -> Result<(), AgeWireError> {
let st = server let st = server.current_storage().map_err(|e| AgeWireError::Storage(e.0))?;
.current_storage() st.set(key.to_string(), val.to_string()).map_err(|e| AgeWireError::Storage(e.0))
.map_err(|e| AgeWireError::Storage(e.0))?;
st.set(key.to_string(), val.to_string())
.map_err(|e| AgeWireError::Storage(e.0))
} }
fn enc_pub_key_key(name: &str) -> String { fn enc_pub_key_key(name: &str) -> String { format!("age:key:{name}") }
format!("age:key:{name}") fn enc_priv_key_key(name: &str) -> String { format!("age:privkey:{name}") }
} fn sign_pub_key_key(name: &str) -> String { format!("age:signpub:{name}") }
fn enc_priv_key_key(name: &str) -> String { fn sign_priv_key_key(name: &str) -> String { format!("age:signpriv:{name}") }
format!("age:privkey:{name}")
}
fn sign_pub_key_key(name: &str) -> String {
format!("age:signpub:{name}")
}
fn sign_priv_key_key(name: &str) -> String {
format!("age:signpriv:{name}")
}
// ---------- Command handlers (RESP Protocol) ---------- // ---------- Command handlers (RESP Protocol) ----------
// Basic (stateless) ones kept for completeness // Basic (stateless) ones kept for completeness
pub async fn cmd_age_genenc() -> Protocol { pub async fn cmd_age_genenc() -> Protocol {
let (recip, ident) = gen_enc_keypair(); let (recip, ident) = gen_enc_keypair();
Protocol::Array(vec![ Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
Protocol::BulkString(recip),
Protocol::BulkString(ident),
])
} }
pub async fn cmd_age_gensign() -> Protocol { pub async fn cmd_age_gensign() -> Protocol {
let (verify, secret) = gen_sign_keypair(); let (verify, secret) = gen_sign_keypair();
Protocol::Array(vec![ Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
} }
pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol { pub async fn cmd_age_encrypt(recipient: &str, message: &str) -> Protocol {
@@ -247,30 +214,16 @@ pub async fn cmd_age_verify(verify_pub: &str, message: &str, sig_b64: &str) -> P
pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol { pub async fn cmd_age_keygen(server: &Server, name: &str) -> Protocol {
let (recip, ident) = gen_enc_keypair(); let (recip, ident) = gen_enc_keypair();
if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { if let Err(e) = sset(server, &enc_pub_key_key(name), &recip) { return e.to_protocol(); }
return e.to_protocol(); if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) { return e.to_protocol(); }
} Protocol::Array(vec![Protocol::BulkString(recip), Protocol::BulkString(ident)])
if let Err(e) = sset(server, &enc_priv_key_key(name), &ident) {
return e.to_protocol();
}
Protocol::Array(vec![
Protocol::BulkString(recip),
Protocol::BulkString(ident),
])
} }
pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol { pub async fn cmd_age_signkeygen(server: &Server, name: &str) -> Protocol {
let (verify, secret) = gen_sign_keypair(); let (verify, secret) = gen_sign_keypair();
if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { if let Err(e) = sset(server, &sign_pub_key_key(name), &verify) { return e.to_protocol(); }
return e.to_protocol(); if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) { return e.to_protocol(); }
} Protocol::Array(vec![Protocol::BulkString(verify), Protocol::BulkString(secret)])
if let Err(e) = sset(server, &sign_priv_key_key(name), &secret) {
return e.to_protocol();
}
Protocol::Array(vec![
Protocol::BulkString(verify),
Protocol::BulkString(secret),
])
} }
pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol { pub async fn cmd_age_encrypt_name(server: &Server, name: &str, message: &str) -> Protocol {
@@ -300,9 +253,7 @@ pub async fn cmd_age_decrypt_name(server: &Server, name: &str, ct_b64: &str) ->
pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol { pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Protocol {
let sec = match sget(server, &sign_priv_key_key(name)) { let sec = match sget(server, &sign_priv_key_key(name)) {
Ok(Some(v)) => v, Ok(Some(v)) => v,
Ok(None) => { Ok(None) => return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol(),
return AgeWireError::NotFound("signing secret (age:signpriv:{name})").to_protocol()
}
Err(e) => return e.to_protocol(), Err(e) => return e.to_protocol(),
}; };
match sign_b64(&sec, message) { match sign_b64(&sec, message) {
@@ -311,17 +262,10 @@ pub async fn cmd_age_sign_name(server: &Server, name: &str, message: &str) -> Pr
} }
} }
pub async fn cmd_age_verify_name( pub async fn cmd_age_verify_name(server: &Server, name: &str, message: &str, sig_b64: &str) -> Protocol {
server: &Server,
name: &str,
message: &str,
sig_b64: &str,
) -> Protocol {
let pubk = match sget(server, &sign_pub_key_key(name)) { let pubk = match sget(server, &sign_pub_key_key(name)) {
Ok(Some(v)) => v, Ok(Some(v)) => v,
Ok(None) => { Ok(None) => return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol(),
return AgeWireError::NotFound("verify pubkey (age:signpub:{name})").to_protocol()
}
Err(e) => return e.to_protocol(), Err(e) => return e.to_protocol(),
}; };
match verify_b64(&pubk, message, sig_b64) { match verify_b64(&pubk, message, sig_b64) {
@@ -333,43 +277,25 @@ pub async fn cmd_age_verify_name(
pub async fn cmd_age_list(server: &Server) -> Protocol { pub async fn cmd_age_list(server: &Server) -> Protocol {
// Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...] // Returns 4 arrays: ["encpub", <names...>], ["encpriv", ...], ["signpub", ...], ["signpriv", ...]
let st = match server.current_storage() { let st = match server.current_storage() { Ok(s) => s, Err(e) => return Protocol::err(&e.0) };
Ok(s) => s,
Err(e) => return Protocol::err(&e.0),
};
let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> { let pull = |pat: &str, prefix: &str| -> Result<Vec<String>, DBError> {
let keys = st.keys(pat)?; let keys = st.keys(pat)?;
let mut names: Vec<String> = keys let mut names: Vec<String> = keys.into_iter()
.into_iter()
.filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string())) .filter_map(|k| k.strip_prefix(prefix).map(|x| x.to_string()))
.collect(); .collect();
names.sort(); names.sort();
Ok(names) Ok(names)
}; };
let encpub = match pull("age:key:*", "age:key:") { let encpub = match pull("age:key:*", "age:key:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
Ok(v) => v, let encpriv = match pull("age:privkey:*", "age:privkey:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
Err(e) => return Protocol::err(&e.0), let signpub = match pull("age:signpub:*", "age:signpub:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
}; let signpriv= match pull("age:signpriv:*", "age:signpriv:") { Ok(v) => v, Err(e)=> return Protocol::err(&e.0) };
let encpriv = match pull("age:privkey:*", "age:privkey:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpub = match pull("age:signpub:*", "age:signpub:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let signpriv = match pull("age:signpriv:*", "age:signpriv:") {
Ok(v) => v,
Err(e) => return Protocol::err(&e.0),
};
let to_arr = |label: &str, v: Vec<String>| { let to_arr = |label: &str, v: Vec<String>| {
let mut out = vec![Protocol::BulkString(label.to_string())]; let mut out = vec![Protocol::BulkString(label.to_string())];
out.push(Protocol::Array( out.push(Protocol::Array(v.into_iter().map(Protocol::BulkString).collect()));
v.into_iter().map(Protocol::BulkString).collect(),
));
Protocol::Array(out) Protocol::Array(out)
}; };
@@ -379,4 +305,4 @@ pub async fn cmd_age_list(server: &Server) -> Protocol {
to_arr("signpub", signpub), to_arr("signpub", signpub),
to_arr("signpriv", signpriv), to_arr("signpriv", signpriv),
]) ])
} }

1498
src/cmd.rs

File diff suppressed because it is too large Load Diff

View File

@@ -11,9 +11,9 @@ const TAG_LEN: usize = 16;
#[derive(Debug)] #[derive(Debug)]
pub enum CryptoError { pub enum CryptoError {
Format, // wrong length / header Format, // wrong length / header
Version(u8), // unknown version Version(u8), // unknown version
Decrypt, // wrong key or corrupted data Decrypt, // wrong key or corrupted data
} }
impl From<CryptoError> for crate::error::DBError { impl From<CryptoError> for crate::error::DBError {
@@ -71,4 +71,4 @@ impl CryptoFactory {
let cipher = XChaCha20Poly1305::new(&self.key); let cipher = XChaCha20Poly1305::new(&self.key);
cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt) cipher.decrypt(nonce, ct).map_err(|_| CryptoError::Decrypt)
} }
} }

View File

@@ -1,13 +1,20 @@
use std::num::ParseIntError; use std::num::ParseIntError;
use bincode;
use redb;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use redb;
use bincode;
// todo: more error types // todo: more error types
#[derive(Debug)] #[derive(Debug)]
pub struct DBError(pub String); pub struct DBError(pub String);
impl std::fmt::Display for DBError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<std::io::Error> for DBError { impl From<std::io::Error> for DBError {
fn from(item: std::io::Error) -> Self { fn from(item: std::io::Error) -> Self {
DBError(item.to_string().clone()) DBError(item.to_string().clone())
@@ -91,3 +98,40 @@ impl From<chacha20poly1305::Error> for DBError {
DBError(item.to_string()) DBError(item.to_string())
} }
} }
// Lance and related dependencies error handling
impl From<lance::Error> for DBError {
fn from(item: lance::Error) -> Self {
DBError(item.to_string())
}
}
impl From<arrow::error::ArrowError> for DBError {
fn from(item: arrow::error::ArrowError) -> Self {
DBError(item.to_string())
}
}
impl From<reqwest::Error> for DBError {
fn from(item: reqwest::Error) -> Self {
DBError(item.to_string())
}
}
impl From<image::ImageError> for DBError {
fn from(item: image::ImageError) -> Self {
DBError(item.to_string())
}
}
impl From<uuid::Error> for DBError {
fn from(item: uuid::Error) -> Self {
DBError(item.to_string())
}
}
impl From<base64::DecodeError> for DBError {
fn from(item: base64::DecodeError) -> Self {
DBError(item.to_string())
}
}

609
src/lance_store.rs Normal file
View File

@@ -0,0 +1,609 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use arrow::array::{Float32Array, StringArray, ArrayRef, FixedSizeListArray, Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::{RecordBatch, RecordBatchReader};
use arrow::error::ArrowError;
use lance::dataset::{Dataset, WriteParams, WriteMode};
use lance::index::vector::VectorIndexParams;
use lance_index::vector::pq::PQBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::DatasetIndexExt;
use lance_linalg::distance::MetricType;
use futures::TryStreamExt;
use base64::Engine;
use serde::{Deserialize, Serialize};
use crate::error::DBError;
// Simple RecordBatchReader implementation for Vec<RecordBatch>
struct VecRecordBatchReader {
batches: std::vec::IntoIter<Result<RecordBatch, ArrowError>>,
}
impl VecRecordBatchReader {
fn new(batches: Vec<RecordBatch>) -> Self {
let result_batches = batches.into_iter().map(Ok).collect::<Vec<_>>();
Self {
batches: result_batches.into_iter(),
}
}
}
impl Iterator for VecRecordBatchReader {
type Item = Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
self.batches.next()
}
}
impl RecordBatchReader for VecRecordBatchReader {
fn schema(&self) -> SchemaRef {
// This is a simplified implementation - in practice you'd want to store the schema
Arc::new(Schema::empty())
}
}
#[derive(Debug, Serialize, Deserialize)]
struct EmbeddingRequest {
texts: Option<Vec<String>>,
images: Option<Vec<String>>, // base64 encoded
model: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct EmbeddingResponse {
embeddings: Vec<Vec<f32>>,
model: String,
usage: Option<HashMap<String, u32>>,
}
// Ollama-specific request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct OllamaEmbeddingRequest {
model: String,
prompt: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaEmbeddingResponse {
embedding: Vec<f32>,
}
pub struct LanceStore {
datasets: Arc<RwLock<HashMap<String, Arc<Dataset>>>>,
data_dir: PathBuf,
http_client: reqwest::Client,
}
impl LanceStore {
pub async fn new(data_dir: PathBuf) -> Result<Self, DBError> {
// Create data directory if it doesn't exist
std::fs::create_dir_all(&data_dir)
.map_err(|e| DBError(format!("Failed to create Lance data directory: {}", e)))?;
let http_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| DBError(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
datasets: Arc::new(RwLock::new(HashMap::new())),
data_dir,
http_client,
})
}
/// Get embedding service URL from Redis config, default to local Ollama
async fn get_embedding_url(&self, server: &crate::server::Server) -> Result<String, DBError> {
// Get the embedding URL from Redis config directly from storage
let storage = server.current_storage()?;
match storage.hget("config:core:aiembed", "url")? {
Some(url) => Ok(url),
None => Ok("http://localhost:11434".to_string()), // Default to local Ollama
}
}
/// Check if we're using Ollama (default) or custom embedding service
async fn is_ollama_service(&self, server: &crate::server::Server) -> Result<bool, DBError> {
let url = self.get_embedding_url(server).await?;
Ok(url.contains("localhost:11434") || url.contains("127.0.0.1:11434"))
}
/// Call external embedding service (Ollama or custom)
async fn call_embedding_service(
&self,
server: &crate::server::Server,
texts: Option<Vec<String>>,
images: Option<Vec<String>>,
) -> Result<Vec<Vec<f32>>, DBError> {
let base_url = self.get_embedding_url(server).await?;
let is_ollama = self.is_ollama_service(server).await?;
if is_ollama {
// Use Ollama API format
if let Some(texts) = texts {
let mut embeddings = Vec::new();
for text in texts {
let url = format!("{}/api/embeddings", base_url);
let request = OllamaEmbeddingRequest {
model: "nomic-embed-text".to_string(),
prompt: text,
};
let response = self.http_client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| DBError(format!("Failed to call Ollama embedding service: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(DBError(format!(
"Ollama embedding service returned error {}: {}",
status, error_text
)));
}
let ollama_response: OllamaEmbeddingResponse = response
.json()
.await
.map_err(|e| DBError(format!("Failed to parse Ollama embedding response: {}", e)))?;
embeddings.push(ollama_response.embedding);
}
Ok(embeddings)
} else if let Some(_images) = images {
// Ollama doesn't support image embeddings with this API yet
Err(DBError("Image embeddings not supported with Ollama. Please configure a custom embedding service.".to_string()))
} else {
Err(DBError("No text or images provided for embedding".to_string()))
}
} else {
// Use custom embedding service API format
let request = EmbeddingRequest {
texts,
images,
model: None, // Let the service use its default
};
let response = self.http_client
.post(&base_url)
.json(&request)
.send()
.await
.map_err(|e| DBError(format!("Failed to call embedding service: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(DBError(format!(
"Embedding service returned error {}: {}",
status, error_text
)));
}
let embedding_response: EmbeddingResponse = response
.json()
.await
.map_err(|e| DBError(format!("Failed to parse embedding response: {}", e)))?;
Ok(embedding_response.embeddings)
}
}
pub async fn embed_text(
&self,
server: &crate::server::Server,
texts: Vec<String>
) -> Result<Vec<Vec<f32>>, DBError> {
if texts.is_empty() {
return Ok(Vec::new());
}
self.call_embedding_service(server, Some(texts), None).await
}
pub async fn embed_image(
&self,
server: &crate::server::Server,
image_bytes: Vec<u8>
) -> Result<Vec<f32>, DBError> {
// Convert image bytes to base64
let base64_image = base64::engine::general_purpose::STANDARD.encode(&image_bytes);
let embeddings = self.call_embedding_service(
server,
None,
Some(vec![base64_image])
).await?;
embeddings.into_iter()
.next()
.ok_or_else(|| DBError("No embedding returned for image".to_string()))
}
pub async fn create_dataset(
&self,
name: &str,
schema: Schema,
) -> Result<(), DBError> {
let dataset_path = self.data_dir.join(format!("{}.lance", name));
// Create empty dataset with schema
let write_params = WriteParams {
mode: WriteMode::Create,
..Default::default()
};
// Create an empty RecordBatch with the schema
let empty_batch = RecordBatch::new_empty(Arc::new(schema));
// Use RecordBatchReader for Lance 0.33
let reader = VecRecordBatchReader::new(vec![empty_batch]);
let dataset = Dataset::write(
reader,
dataset_path.to_str().unwrap(),
Some(write_params)
).await
.map_err(|e| DBError(format!("Failed to create dataset: {}", e)))?;
let mut datasets = self.datasets.write().await;
datasets.insert(name.to_string(), Arc::new(dataset));
Ok(())
}
pub async fn write_vectors(
&self,
dataset_name: &str,
vectors: Vec<Vec<f32>>,
metadata: Option<HashMap<String, Vec<String>>>,
) -> Result<usize, DBError> {
let dataset_path = self.data_dir.join(format!("{}.lance", dataset_name));
// Open or get cached dataset
let _dataset = self.get_or_open_dataset(dataset_name).await?;
// Build RecordBatch
let num_vectors = vectors.len();
if num_vectors == 0 {
return Ok(0);
}
let dim = vectors.first()
.ok_or_else(|| DBError("Empty vectors".to_string()))?
.len();
// Flatten vectors
let flat_vectors: Vec<f32> = vectors.into_iter().flatten().collect();
let values_array = Float32Array::from(flat_vectors);
let field = Arc::new(Field::new("item", DataType::Float32, true));
let vector_array = FixedSizeListArray::try_new(
field,
dim as i32,
Arc::new(values_array),
None
).map_err(|e| DBError(format!("Failed to create vector array: {}", e)))?;
let mut arrays: Vec<ArrayRef> = vec![Arc::new(vector_array)];
let mut fields = vec![Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32
),
false
)];
// Add metadata columns if provided
if let Some(metadata) = metadata {
for (key, values) in metadata {
if values.len() != num_vectors {
return Err(DBError(format!(
"Metadata field '{}' has {} values but expected {}",
key, values.len(), num_vectors
)));
}
let array = StringArray::from(values);
arrays.push(Arc::new(array));
fields.push(Field::new(&key, DataType::Utf8, true));
}
}
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(schema, arrays)
.map_err(|e| DBError(format!("Failed to create RecordBatch: {}", e)))?;
// Append to dataset
let write_params = WriteParams {
mode: WriteMode::Append,
..Default::default()
};
let reader = VecRecordBatchReader::new(vec![batch]);
Dataset::write(
reader,
dataset_path.to_str().unwrap(),
Some(write_params)
).await
.map_err(|e| DBError(format!("Failed to write to dataset: {}", e)))?;
// Refresh cached dataset
let mut datasets = self.datasets.write().await;
datasets.remove(dataset_name);
Ok(num_vectors)
}
pub async fn search_vectors(
&self,
dataset_name: &str,
query_vector: Vec<f32>,
k: usize,
nprobes: Option<usize>,
_refine_factor: Option<usize>,
) -> Result<Vec<(f32, HashMap<String, String>)>, DBError> {
let dataset = self.get_or_open_dataset(dataset_name).await?;
// Build query
let query_array = Float32Array::from(query_vector.clone());
let mut query = dataset.scan();
query.nearest(
"vector",
&query_array,
k,
).map_err(|e| DBError(format!("Failed to build search query: {}", e)))?;
if let Some(nprobes) = nprobes {
query.nprobs(nprobes);
}
// Note: refine_factor might not be available in this Lance version
// if let Some(refine) = refine_factor {
// query.refine_factor(refine);
// }
// Execute search
let results = query
.try_into_stream()
.await
.map_err(|e| DBError(format!("Failed to execute search: {}", e)))?
.try_collect::<Vec<_>>()
.await
.map_err(|e| DBError(format!("Failed to collect results: {}", e)))?;
// Process results
let mut output = Vec::new();
for batch in results {
// Get distances
let distances = batch
.column_by_name("_distance")
.ok_or_else(|| DBError("No distance column".to_string()))?
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| DBError("Invalid distance type".to_string()))?;
// Get metadata
for i in 0..batch.num_rows() {
let distance = distances.value(i);
let mut metadata = HashMap::new();
for field in batch.schema().fields() {
if field.name() != "vector" && field.name() != "_distance" {
if let Some(col) = batch.column_by_name(field.name()) {
if let Some(str_array) = col.as_any().downcast_ref::<StringArray>() {
if !str_array.is_null(i) {
metadata.insert(
field.name().to_string(),
str_array.value(i).to_string()
);
}
}
}
}
}
output.push((distance, metadata));
}
}
Ok(output)
}
pub async fn store_multimodal(
&self,
server: &crate::server::Server,
dataset_name: &str,
text: Option<String>,
image_bytes: Option<Vec<u8>>,
metadata: HashMap<String, String>,
) -> Result<String, DBError> {
// Generate ID
let id = uuid::Uuid::new_v4().to_string();
// Generate embeddings using external service
let embedding = if let Some(text) = text.as_ref() {
self.embed_text(server, vec![text.clone()]).await?
.into_iter()
.next()
.ok_or_else(|| DBError("No embedding returned".to_string()))?
} else if let Some(img) = image_bytes.as_ref() {
self.embed_image(server, img.clone()).await?
} else {
return Err(DBError("No text or image provided".to_string()));
};
// Prepare metadata
let mut full_metadata = metadata;
full_metadata.insert("id".to_string(), id.clone());
if let Some(text) = text {
full_metadata.insert("text".to_string(), text);
}
if let Some(img) = image_bytes {
full_metadata.insert("image_base64".to_string(), base64::engine::general_purpose::STANDARD.encode(img));
}
// Convert metadata to column vectors
let mut metadata_cols = HashMap::new();
for (key, value) in full_metadata {
metadata_cols.insert(key, vec![value]);
}
// Write to dataset
self.write_vectors(dataset_name, vec![embedding], Some(metadata_cols)).await?;
Ok(id)
}
pub async fn search_with_text(
&self,
server: &crate::server::Server,
dataset_name: &str,
query_text: String,
k: usize,
nprobes: Option<usize>,
refine_factor: Option<usize>,
) -> Result<Vec<(f32, HashMap<String, String>)>, DBError> {
// Embed the query text using external service
let embeddings = self.embed_text(server, vec![query_text]).await?;
let query_vector = embeddings.into_iter()
.next()
.ok_or_else(|| DBError("No embedding returned for query".to_string()))?;
// Search with the embedding
self.search_vectors(dataset_name, query_vector, k, nprobes, refine_factor).await
}
pub async fn create_index(
&self,
dataset_name: &str,
index_type: &str,
num_partitions: Option<usize>,
num_sub_vectors: Option<usize>,
) -> Result<(), DBError> {
let _dataset = self.get_or_open_dataset(dataset_name).await?;
match index_type.to_uppercase().as_str() {
"IVF_PQ" => {
let ivf_params = IvfBuildParams {
num_partitions: num_partitions.unwrap_or(256),
..Default::default()
};
let pq_params = PQBuildParams {
num_sub_vectors: num_sub_vectors.unwrap_or(16),
..Default::default()
};
let params = VectorIndexParams::with_ivf_pq_params(
MetricType::L2,
ivf_params,
pq_params,
);
// Get a mutable reference to the dataset
let mut dataset_mut = Dataset::open(self.data_dir.join(format!("{}.lance", dataset_name)).to_str().unwrap())
.await
.map_err(|e| DBError(format!("Failed to open dataset for indexing: {}", e)))?;
dataset_mut.create_index(
&["vector"],
lance_index::IndexType::Vector,
None,
&params,
true
).await
.map_err(|e| DBError(format!("Failed to create index: {}", e)))?;
}
_ => return Err(DBError(format!("Unsupported index type: {}", index_type))),
}
Ok(())
}
async fn get_or_open_dataset(&self, name: &str) -> Result<Arc<Dataset>, DBError> {
let mut datasets = self.datasets.write().await;
if let Some(dataset) = datasets.get(name) {
return Ok(dataset.clone());
}
let dataset_path = self.data_dir.join(format!("{}.lance", name));
if !dataset_path.exists() {
return Err(DBError(format!("Dataset '{}' does not exist", name)));
}
let dataset = Dataset::open(dataset_path.to_str().unwrap())
.await
.map_err(|e| DBError(format!("Failed to open dataset: {}", e)))?;
let dataset = Arc::new(dataset);
datasets.insert(name.to_string(), dataset.clone());
Ok(dataset)
}
pub async fn list_datasets(&self) -> Result<Vec<String>, DBError> {
let mut datasets = Vec::new();
let entries = std::fs::read_dir(&self.data_dir)
.map_err(|e| DBError(format!("Failed to read data directory: {}", e)))?;
for entry in entries {
let entry = entry.map_err(|e| DBError(format!("Failed to read entry: {}", e)))?;
let path = entry.path();
if path.is_dir() {
if let Some(name) = path.file_name() {
if let Some(name_str) = name.to_str() {
if name_str.ends_with(".lance") {
let dataset_name = name_str.trim_end_matches(".lance");
datasets.push(dataset_name.to_string());
}
}
}
}
}
Ok(datasets)
}
pub async fn drop_dataset(&self, name: &str) -> Result<(), DBError> {
// Remove from cache
let mut datasets = self.datasets.write().await;
datasets.remove(name);
// Delete from disk
let dataset_path = self.data_dir.join(format!("{}.lance", name));
if dataset_path.exists() {
std::fs::remove_dir_all(dataset_path)
.map_err(|e| DBError(format!("Failed to delete dataset: {}", e)))?;
}
Ok(())
}
pub async fn get_dataset_info(&self, name: &str) -> Result<HashMap<String, String>, DBError> {
let dataset = self.get_or_open_dataset(name).await?;
let mut info = HashMap::new();
info.insert("name".to_string(), name.to_string());
info.insert("version".to_string(), dataset.version().version.to_string());
info.insert("num_rows".to_string(), dataset.count_rows(None).await?.to_string());
// Get schema info
let schema = dataset.schema();
let fields: Vec<String> = schema.fields
.iter()
.map(|f| format!("{}:{}", f.name, f.data_type()))
.collect();
info.insert("schema".to_string(), fields.join(", "));
Ok(info)
}
}

View File

@@ -1,12 +1,11 @@
pub mod age; // NEW pub mod age; // NEW
pub mod cmd; pub mod cmd;
pub mod crypto; pub mod crypto;
pub mod error; pub mod error;
pub mod lance_store; // Add Lance store module
pub mod options; pub mod options;
pub mod protocol; pub mod protocol;
pub mod search_cmd; // Add this
pub mod server; pub mod server;
pub mod storage; pub mod storage;
pub mod storage_sled; // Add this pub mod storage_trait; // Add this
pub mod storage_trait; // Add this pub mod storage_sled; // Add this
pub mod tantivy_search;

View File

@@ -22,6 +22,7 @@ struct Args {
#[arg(long)] #[arg(long)]
debug: bool, debug: bool,
/// Master encryption key for encrypted databases /// Master encryption key for encrypted databases
#[arg(long)] #[arg(long)]
encryption_key: Option<String>, encryption_key: Option<String>,

View File

@@ -81,21 +81,18 @@ impl Protocol {
pub fn encode(&self) -> String { pub fn encode(&self) -> String {
match self { match self {
Protocol::SimpleString(s) => format!("+{}\r\n", s), Protocol::SimpleString(s) => format!("+{}\r\n", s),
Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s), Protocol::BulkString(s) => format!("${}\r\n{}\r\n", s.len(), s),
Protocol::Array(ss) => { Protocol::Array(ss) => {
format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>() format!("*{}\r\n", ss.len()) + &ss.iter().map(|x| x.encode()).collect::<String>()
} }
Protocol::Null => "$-1\r\n".to_string(), Protocol::Null => "$-1\r\n".to_string(),
Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error Protocol::Error(s) => format!("-{}\r\n", s), // proper RESP error
} }
} }
fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> { fn parse_simple_string_sfx(protocol: &str) -> Result<(Self, &str), DBError> {
match protocol.find("\r\n") { match protocol.find("\r\n") {
Some(x) => Ok(( Some(x) => Ok((Self::SimpleString(protocol[..x].to_string()), &protocol[x + 2..])),
Self::SimpleString(protocol[..x].to_string()),
&protocol[x + 2..],
)),
_ => Err(DBError(format!( _ => Err(DBError(format!(
"[new simple string] unsupported protocol: {:?}", "[new simple string] unsupported protocol: {:?}",
protocol protocol

View File

@@ -1,273 +0,0 @@
use crate::{
error::DBError,
protocol::Protocol,
server::Server,
tantivy_search::{
FieldDef, Filter, FilterType, IndexConfig, NumericType, SearchOptions, TantivySearch,
},
};
use std::collections::HashMap;
use std::sync::Arc;
pub async fn ft_create_cmd(
server: &Server,
index_name: String,
schema: Vec<(String, String, Vec<String>)>,
) -> Result<Protocol, DBError> {
// Parse schema into field definitions
let mut field_definitions = Vec::new();
for (field_name, field_type, options) in schema {
let field_def = match field_type.to_uppercase().as_str() {
"TEXT" => {
let mut weight = 1.0;
let mut sortable = false;
let mut no_index = false;
for opt in &options {
match opt.to_uppercase().as_str() {
"WEIGHT" => {
// Next option should be the weight value
if let Some(idx) = options.iter().position(|x| x == opt) {
if idx + 1 < options.len() {
weight = options[idx + 1].parse().unwrap_or(1.0);
}
}
}
"SORTABLE" => sortable = true,
"NOINDEX" => no_index = true,
_ => {}
}
}
FieldDef::Text {
stored: true,
indexed: !no_index,
tokenized: true,
fast: sortable,
}
}
"NUMERIC" => {
let mut sortable = false;
for opt in &options {
if opt.to_uppercase() == "SORTABLE" {
sortable = true;
}
}
FieldDef::Numeric {
stored: true,
indexed: true,
fast: sortable,
precision: NumericType::F64,
}
}
"TAG" => {
let mut separator = ",".to_string();
let mut case_sensitive = false;
for i in 0..options.len() {
match options[i].to_uppercase().as_str() {
"SEPARATOR" => {
if i + 1 < options.len() {
separator = options[i + 1].clone();
}
}
"CASESENSITIVE" => case_sensitive = true,
_ => {}
}
}
FieldDef::Tag {
stored: true,
separator,
case_sensitive,
}
}
"GEO" => FieldDef::Geo { stored: true },
_ => {
return Err(DBError(format!("Unknown field type: {}", field_type)));
}
};
field_definitions.push((field_name, field_def));
}
// Create the search index
let search_path = server.search_index_path();
let config = IndexConfig::default();
println!(
"Creating search index '{}' at path: {:?}",
index_name, search_path
);
println!("Field definitions: {:?}", field_definitions);
let search_index = TantivySearch::new_with_schema(
search_path,
index_name.clone(),
field_definitions,
Some(config),
)?;
println!("Search index '{}' created successfully", index_name);
// Store in registry
let mut indexes = server.search_indexes.write().unwrap();
indexes.insert(index_name, Arc::new(search_index));
Ok(Protocol::SimpleString("OK".to_string()))
}
pub async fn ft_add_cmd(
server: &Server,
index_name: String,
doc_id: String,
_score: f64,
fields: HashMap<String, String>,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
search_index.add_document_with_fields(&doc_id, fields)?;
Ok(Protocol::SimpleString("OK".to_string()))
}
pub async fn ft_search_cmd(
server: &Server,
index_name: String,
query: String,
filters: Vec<(String, String)>,
limit: Option<usize>,
offset: Option<usize>,
return_fields: Option<Vec<String>>,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// Convert filters to search filters
let search_filters = filters
.into_iter()
.map(|(field, value)| Filter {
field,
filter_type: FilterType::Equals(value),
})
.collect();
let options = SearchOptions {
limit: limit.unwrap_or(10),
offset: offset.unwrap_or(0),
filters: search_filters,
sort_by: None,
return_fields,
highlight: false,
};
let results = search_index.search_with_options(&query, options)?;
// Format results as Redis protocol
let mut response = Vec::new();
// First element is the total count
response.push(Protocol::SimpleString(results.total.to_string()));
// Then each document
for doc in results.documents {
let mut doc_array = Vec::new();
// Add document ID if it exists
if let Some(id) = doc.fields.get("_id") {
doc_array.push(Protocol::BulkString(id.clone()));
}
// Add score
doc_array.push(Protocol::BulkString(doc.score.to_string()));
// Add fields as key-value pairs
for (field_name, field_value) in doc.fields {
if field_name != "_id" {
doc_array.push(Protocol::BulkString(field_name));
doc_array.push(Protocol::BulkString(field_value));
}
}
response.push(Protocol::Array(doc_array));
}
Ok(Protocol::Array(response))
}
pub async fn ft_del_cmd(
server: &Server,
index_name: String,
doc_id: String,
) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let _search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
// For now, return success
// In a full implementation, we'd need to add a delete method to TantivySearch
println!("Deleting document '{}' from index '{}'", doc_id, index_name);
Ok(Protocol::SimpleString("1".to_string()))
}
pub async fn ft_info_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
let indexes = server.search_indexes.read().unwrap();
let search_index = indexes
.get(&index_name)
.ok_or_else(|| DBError(format!("Index '{}' not found", index_name)))?;
let info = search_index.get_info()?;
// Format info as Redis protocol
let mut response = Vec::new();
response.push(Protocol::BulkString("index_name".to_string()));
response.push(Protocol::BulkString(info.name));
response.push(Protocol::BulkString("num_docs".to_string()));
response.push(Protocol::BulkString(info.num_docs.to_string()));
response.push(Protocol::BulkString("num_fields".to_string()));
response.push(Protocol::BulkString(info.fields.len().to_string()));
response.push(Protocol::BulkString("fields".to_string()));
let fields_str = info
.fields
.iter()
.map(|f| format!("{}:{}", f.name, f.field_type))
.collect::<Vec<_>>()
.join(", ");
response.push(Protocol::BulkString(fields_str));
Ok(Protocol::Array(response))
}
pub async fn ft_drop_cmd(server: &Server, index_name: String) -> Result<Protocol, DBError> {
let mut indexes = server.search_indexes.write().unwrap();
if indexes.remove(&index_name).is_some() {
// Also remove the index files from disk
let index_path = server.search_index_path().join(&index_name);
if index_path.exists() {
std::fs::remove_dir_all(index_path)
.map_err(|e| DBError(format!("Failed to remove index files: {}", e)))?;
}
Ok(Protocol::SimpleString("OK".to_string()))
} else {
Err(DBError(format!("Index '{}' not found", index_name)))
}
}

View File

@@ -1,26 +1,24 @@
use core::str; use core::str;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::RwLock;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::sync::{oneshot, Mutex}; use tokio::sync::{Mutex, oneshot};
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use crate::cmd::Cmd; use crate::cmd::Cmd;
use crate::error::DBError; use crate::error::DBError;
use crate::lance_store::LanceStore;
use crate::options; use crate::options;
use crate::protocol::Protocol; use crate::protocol::Protocol;
use crate::storage::Storage; use crate::storage::Storage;
use crate::storage_sled::SledStorage; use crate::storage_sled::SledStorage;
use crate::storage_trait::StorageBackend; use crate::storage_trait::StorageBackend;
use crate::tantivy_search::TantivySearch;
#[derive(Clone)] #[derive(Clone)]
pub struct Server { pub struct Server {
pub db_cache: Arc<RwLock<HashMap<u64, Arc<dyn StorageBackend>>>>, pub db_cache: std::sync::Arc<std::sync::RwLock<HashMap<u64, Arc<dyn StorageBackend>>>>,
pub search_indexes: Arc<RwLock<HashMap<String, Arc<TantivySearch>>>>,
pub option: options::DBOption, pub option: options::DBOption,
pub client_name: Option<String>, pub client_name: Option<String>,
pub selected_db: u64, // Changed from usize to u64 pub selected_db: u64, // Changed from usize to u64
@@ -29,6 +27,9 @@ pub struct Server {
// BLPOP waiter registry: per (db_index, key) FIFO of waiters // BLPOP waiter registry: per (db_index, key) FIFO of waiters
pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>, pub list_waiters: Arc<Mutex<HashMap<u64, HashMap<String, Vec<Waiter>>>>>,
pub waiter_seq: Arc<AtomicU64>, pub waiter_seq: Arc<AtomicU64>,
// Lance vector store
pub lance_store: Option<Arc<LanceStore>>,
} }
pub struct Waiter { pub struct Waiter {
@@ -45,9 +46,18 @@ pub enum PopSide {
impl Server { impl Server {
pub async fn new(option: options::DBOption) -> Self { pub async fn new(option: options::DBOption) -> Self {
// Initialize Lance store
let lance_data_dir = std::path::PathBuf::from(&option.dir).join("lance");
let lance_store = match LanceStore::new(lance_data_dir).await {
Ok(store) => Some(Arc::new(store)),
Err(e) => {
eprintln!("Warning: Failed to initialize Lance store: {}", e.0);
None
}
};
Server { Server {
db_cache: Arc::new(RwLock::new(HashMap::new())), db_cache: Arc::new(std::sync::RwLock::new(HashMap::new())),
search_indexes: Arc::new(RwLock::new(HashMap::new())),
option, option,
client_name: None, client_name: None,
selected_db: 0, selected_db: 0,
@@ -55,68 +65,67 @@ impl Server {
list_waiters: Arc::new(Mutex::new(HashMap::new())), list_waiters: Arc::new(Mutex::new(HashMap::new())),
waiter_seq: Arc::new(AtomicU64::new(1)), waiter_seq: Arc::new(AtomicU64::new(1)),
lance_store,
} }
} }
pub fn lance_store(&self) -> Result<Arc<LanceStore>, DBError> {
self.lance_store
.as_ref()
.cloned()
.ok_or_else(|| DBError("Lance store not initialized".to_string()))
}
pub fn current_storage(&self) -> Result<Arc<dyn StorageBackend>, DBError> { pub fn current_storage(&self) -> Result<Arc<dyn StorageBackend>, DBError> {
let mut cache = self.db_cache.write().unwrap(); let mut cache = self.db_cache.write().unwrap();
if let Some(storage) = cache.get(&self.selected_db) { if let Some(storage) = cache.get(&self.selected_db) {
return Ok(storage.clone()); return Ok(storage.clone());
} }
// Create new database file // Create new database file
let db_file_path = std::path::PathBuf::from(self.option.dir.clone()) let db_file_path = std::path::PathBuf::from(self.option.dir.clone())
.join(format!("{}.db", self.selected_db)); .join(format!("{}.db", self.selected_db));
// Ensure the directory exists before creating the database file // Ensure the directory exists before creating the database file
if let Some(parent_dir) = db_file_path.parent() { if let Some(parent_dir) = db_file_path.parent() {
std::fs::create_dir_all(parent_dir).map_err(|e| { std::fs::create_dir_all(parent_dir).map_err(|e| {
DBError(format!( DBError(format!("Failed to create directory {}: {}", parent_dir.display(), e))
"Failed to create directory {}: {}",
parent_dir.display(),
e
))
})?; })?;
} }
println!("Creating new db file: {}", db_file_path.display()); println!("Creating new db file: {}", db_file_path.display());
let storage: Arc<dyn StorageBackend> = match self.option.backend { let storage: Arc<dyn StorageBackend> = match self.option.backend {
options::BackendType::Redb => Arc::new(Storage::new( options::BackendType::Redb => {
db_file_path, Arc::new(Storage::new(
self.should_encrypt_db(self.selected_db), db_file_path,
self.option.encryption_key.as_deref(), self.should_encrypt_db(self.selected_db),
)?), self.option.encryption_key.as_deref()
options::BackendType::Sled => Arc::new(SledStorage::new( )?)
db_file_path, }
self.should_encrypt_db(self.selected_db), options::BackendType::Sled => {
self.option.encryption_key.as_deref(), Arc::new(SledStorage::new(
)?), db_file_path,
self.should_encrypt_db(self.selected_db),
self.option.encryption_key.as_deref()
)?)
}
}; };
cache.insert(self.selected_db, storage.clone()); cache.insert(self.selected_db, storage.clone());
Ok(storage) Ok(storage)
} }
fn should_encrypt_db(&self, db_index: u64) -> bool { fn should_encrypt_db(&self, db_index: u64) -> bool {
// DB 0-9 are non-encrypted, DB 10+ are encrypted // DB 0-9 are non-encrypted, DB 10+ are encrypted
self.option.encrypt && db_index >= 10 self.option.encrypt && db_index >= 10
} }
// Add method to get search index path
pub fn search_index_path(&self) -> std::path::PathBuf {
std::path::PathBuf::from(&self.option.dir).join("search_indexes")
}
// ----- BLPOP waiter helpers ----- // ----- BLPOP waiter helpers -----
pub async fn register_waiter( pub async fn register_waiter(&self, db_index: u64, key: &str, side: PopSide) -> (u64, oneshot::Receiver<(String, String)>) {
&self,
db_index: u64,
key: &str,
side: PopSide,
) -> (u64, oneshot::Receiver<(String, String)>) {
let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed); let id = self.waiter_seq.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel::<(String, String)>(); let (tx, rx) = oneshot::channel::<(String, String)>();
@@ -192,7 +201,10 @@ impl Server {
Ok(()) Ok(())
} }
pub async fn handle(&mut self, mut stream: tokio::net::TcpStream) -> Result<(), DBError> { pub async fn handle(
&mut self,
mut stream: tokio::net::TcpStream,
) -> Result<(), DBError> {
// Accumulate incoming bytes to handle partial RESP frames // Accumulate incoming bytes to handle partial RESP frames
let mut acc = String::new(); let mut acc = String::new();
let mut buf = vec![0u8; 8192]; let mut buf = vec![0u8; 8192];
@@ -229,10 +241,7 @@ impl Server {
acc = remaining.to_string(); acc = remaining.to_string();
if self.option.debug { if self.option.debug {
println!( println!("\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m", cmd, protocol);
"\x1b[34;1mgot command: {:?}, protocol: {:?}\x1b[0m",
cmd, protocol
);
} else { } else {
println!("got command: {:?}, protocol: {:?}", cmd, protocol); println!("got command: {:?}, protocol: {:?}", cmd, protocol);
} }

View File

@@ -12,9 +12,9 @@ use crate::error::DBError;
// Re-export modules // Re-export modules
mod storage_basic; mod storage_basic;
mod storage_extra;
mod storage_hset; mod storage_hset;
mod storage_lists; mod storage_lists;
mod storage_extra;
// Re-export implementations // Re-export implementations
// Note: These imports are used by the impl blocks in the submodules // Note: These imports are used by the impl blocks in the submodules
@@ -28,8 +28,7 @@ const STRINGS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("string
const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes"); const HASHES_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("hashes");
const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists"); const LISTS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("lists");
const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta"); const STREAMS_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("streams_meta");
const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = const STREAMS_DATA_TABLE: TableDefinition<(&str, &str), &[u8]> = TableDefinition::new("streams_data");
TableDefinition::new("streams_data");
const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted"); const ENCRYPTED_TABLE: TableDefinition<&str, u8> = TableDefinition::new("encrypted");
const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration"); const EXPIRATION_TABLE: TableDefinition<&str, u64> = TableDefinition::new("expiration");
@@ -56,13 +55,9 @@ pub struct Storage {
} }
impl Storage { impl Storage {
pub fn new( pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
let db = Database::create(path)?; let db = Database::create(path)?;
// Create tables if they don't exist // Create tables if they don't exist
let write_txn = db.begin_write()?; let write_txn = db.begin_write()?;
{ {
@@ -76,28 +71,23 @@ impl Storage {
let _ = write_txn.open_table(EXPIRATION_TABLE)?; let _ = write_txn.open_table(EXPIRATION_TABLE)?;
} }
write_txn.commit()?; write_txn.commit()?;
// Check if database was previously encrypted // Check if database was previously encrypted
let read_txn = db.begin_read()?; let read_txn = db.begin_read()?;
let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?; let encrypted_table = read_txn.open_table(ENCRYPTED_TABLE)?;
let was_encrypted = encrypted_table let was_encrypted = encrypted_table.get("encrypted")?.map(|v| v.value() == 1).unwrap_or(false);
.get("encrypted")?
.map(|v| v.value() == 1)
.unwrap_or(false);
drop(read_txn); drop(read_txn);
let crypto = if should_encrypt || was_encrypted { let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key { if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes())) Some(CryptoFactory::new(key.as_bytes()))
} else { } else {
return Err(DBError( return Err(DBError("Encryption requested but no master key provided".to_string()));
"Encryption requested but no master key provided".to_string(),
));
} }
} else { } else {
None None
}; };
// If we're enabling encryption for the first time, mark it // If we're enabling encryption for the first time, mark it
if should_encrypt && !was_encrypted { if should_encrypt && !was_encrypted {
let write_txn = db.begin_write()?; let write_txn = db.begin_write()?;
@@ -107,10 +97,13 @@ impl Storage {
} }
write_txn.commit()?; write_txn.commit()?;
} }
Ok(Storage { db, crypto }) Ok(Storage {
db,
crypto,
})
} }
pub fn is_encrypted(&self) -> bool { pub fn is_encrypted(&self) -> bool {
self.crypto.is_some() self.crypto.is_some()
} }
@@ -123,7 +116,7 @@ impl Storage {
Ok(data.to_vec()) Ok(data.to_vec())
} }
} }
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> { fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto { if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?) Ok(crypto.decrypt(data)?)
@@ -172,22 +165,11 @@ impl StorageBackend for Storage {
self.get_key_type(key) self.get_key_type(key)
} }
fn scan( fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
self.scan(cursor, pattern, count) self.scan(cursor, pattern, count)
} }
fn hscan( fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
self.hscan(key, cursor, pattern, count) self.hscan(key, cursor, pattern, count)
} }
@@ -294,7 +276,7 @@ impl StorageBackend for Storage {
fn is_encrypted(&self) -> bool { fn is_encrypted(&self) -> bool {
self.is_encrypted() self.is_encrypted()
} }
fn info(&self) -> Result<Vec<(String, String)>, DBError> { fn info(&self) -> Result<Vec<(String, String)>, DBError> {
self.info() self.info()
} }
@@ -302,4 +284,4 @@ impl StorageBackend for Storage {
fn clone_arc(&self) -> Arc<dyn StorageBackend> { fn clone_arc(&self) -> Arc<dyn StorageBackend> {
unimplemented!("Storage cloning not yet implemented for redb backend") unimplemented!("Storage cloning not yet implemented for redb backend")
} }
} }

View File

@@ -1,6 +1,6 @@
use super::*; use redb::{ReadableTable};
use crate::error::DBError; use crate::error::DBError;
use redb::ReadableTable; use super::*;
impl Storage { impl Storage {
pub fn flushdb(&self) -> Result<(), DBError> { pub fn flushdb(&self) -> Result<(), DBError> {
@@ -15,17 +15,11 @@ impl Storage {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
// inefficient, but there is no other way // inefficient, but there is no other way
let keys: Vec<String> = types_table let keys: Vec<String> = types_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
types_table.remove(key.as_str())?; types_table.remove(key.as_str())?;
} }
let keys: Vec<String> = strings_table let keys: Vec<String> = strings_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
strings_table.remove(key.as_str())?; strings_table.remove(key.as_str())?;
} }
@@ -40,35 +34,23 @@ impl Storage {
for (key, field) in keys { for (key, field) in keys {
hashes_table.remove((key.as_str(), field.as_str()))?; hashes_table.remove((key.as_str(), field.as_str()))?;
} }
let keys: Vec<String> = lists_table let keys: Vec<String> = lists_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
lists_table.remove(key.as_str())?; lists_table.remove(key.as_str())?;
} }
let keys: Vec<String> = streams_meta_table let keys: Vec<String> = streams_meta_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
streams_meta_table.remove(key.as_str())?; streams_meta_table.remove(key.as_str())?;
} }
let keys: Vec<(String, String)> = streams_data_table let keys: Vec<(String,String)> = streams_data_table.iter()?.map(|item| {
.iter()? let binding = item.unwrap();
.map(|item| { let (key, field) = binding.0.value();
let binding = item.unwrap(); (key.to_string(), field.to_string())
let (key, field) = binding.0.value(); }).collect();
(key.to_string(), field.to_string())
})
.collect();
for (key, field) in keys { for (key, field) in keys {
streams_data_table.remove((key.as_str(), field.as_str()))?; streams_data_table.remove((key.as_str(), field.as_str()))?;
} }
let keys: Vec<String> = expiration_table let keys: Vec<String> = expiration_table.iter()?.map(|item| item.unwrap().0.value().to_string()).collect();
.iter()?
.map(|item| item.unwrap().0.value().to_string())
.collect();
for key in keys { for key in keys {
expiration_table.remove(key.as_str())?; expiration_table.remove(key.as_str())?;
} }
@@ -80,7 +62,7 @@ impl Storage {
pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> { pub fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?; let table = read_txn.open_table(TYPES_TABLE)?;
// Before returning type, check for expiration // Before returning type, check for expiration
if let Some(type_val) = table.get(key)? { if let Some(type_val) = table.get(key)? {
if type_val.value() == "string" { if type_val.value() == "string" {
@@ -101,7 +83,7 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted // ✅ ENCRYPTION APPLIED: Value is encrypted/decrypted
pub fn get(&self, key: &str) -> Result<Option<String>, DBError> { pub fn get(&self, key: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => { Some(type_val) if type_val.value() == "string" => {
@@ -114,7 +96,7 @@ impl Storage {
return Ok(None); return Ok(None);
} }
} }
// Get and decrypt value // Get and decrypt value
let strings_table = read_txn.open_table(STRINGS_TABLE)?; let strings_table = read_txn.open_table(STRINGS_TABLE)?;
match strings_table.get(key)? { match strings_table.get(key)? {
@@ -133,21 +115,21 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage // ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn set(&self, key: String, value: String) -> Result<(), DBError> { pub fn set(&self, key: String, value: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?; types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value, not expiration // Only encrypt the value, not expiration
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?; strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Remove any existing expiration since this is a regular SET // Remove any existing expiration since this is a regular SET
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?; expiration_table.remove(key.as_str())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
@@ -155,42 +137,41 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Value is encrypted before storage // ✅ ENCRYPTION APPLIED: Value is encrypted before storage
pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { pub fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.insert(key.as_str(), "string")?; types_table.insert(key.as_str(), "string")?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
// Only encrypt the value // Only encrypt the value
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
strings_table.insert(key.as_str(), encrypted.as_slice())?; strings_table.insert(key.as_str(), encrypted.as_slice())?;
// Store expiration separately (unencrypted) // Store expiration separately (unencrypted)
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at = expire_ms + now_in_millis(); let expires_at = expire_ms + now_in_millis();
expiration_table.insert(key.as_str(), &(expires_at as u64))?; expiration_table.insert(key.as_str(), &(expires_at as u64))?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
pub fn del(&self, key: String) -> Result<(), DBError> { pub fn del(&self, key: String) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut strings_table = write_txn.open_table(STRINGS_TABLE)?; let mut strings_table = write_txn.open_table(STRINGS_TABLE)?;
let mut hashes_table: redb::Table<(&str, &str), &[u8]> = let mut hashes_table: redb::Table<(&str, &str), &[u8]> = write_txn.open_table(HASHES_TABLE)?;
write_txn.open_table(HASHES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Remove from type table // Remove from type table
types_table.remove(key.as_str())?; types_table.remove(key.as_str())?;
// Remove from strings table // Remove from strings table
strings_table.remove(key.as_str())?; strings_table.remove(key.as_str())?;
// Remove all hash fields for this key // Remove all hash fields for this key
let mut to_remove = Vec::new(); let mut to_remove = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
@@ -202,19 +183,19 @@ impl Storage {
} }
} }
drop(iter); drop(iter);
for (hash_key, field) in to_remove { for (hash_key, field) in to_remove {
hashes_table.remove((hash_key.as_str(), field.as_str()))?; hashes_table.remove((hash_key.as_str(), field.as_str()))?;
} }
// Remove from lists table // Remove from lists table
lists_table.remove(key.as_str())?; lists_table.remove(key.as_str())?;
// Also remove expiration // Also remove expiration
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
expiration_table.remove(key.as_str())?; expiration_table.remove(key.as_str())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
@@ -222,7 +203,7 @@ impl Storage {
pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> { pub fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?; let table = read_txn.open_table(TYPES_TABLE)?;
let mut keys = Vec::new(); let mut keys = Vec::new();
let mut iter = table.iter()?; let mut iter = table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
@@ -231,7 +212,7 @@ impl Storage {
keys.push(key); keys.push(key);
} }
} }
Ok(keys) Ok(keys)
} }
} }
@@ -261,4 +242,4 @@ impl Storage {
} }
Ok(count) Ok(count)
} }
} }

View File

@@ -1,29 +1,24 @@
use super::*; use redb::{ReadableTable};
use crate::error::DBError; use crate::error::DBError;
use redb::ReadableTable; use super::*;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn scan( pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let strings_table = read_txn.open_table(STRINGS_TABLE)?; let strings_table = read_txn.open_table(STRINGS_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
let mut iter = types_table.iter()?; let mut iter = types_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
let key = entry.0.value().to_string(); let key = entry.0.value().to_string();
let key_type = entry.1.value().to_string(); let key_type = entry.1.value().to_string();
if current_cursor >= cursor { if current_cursor >= cursor {
// Apply pattern matching if specified // Apply pattern matching if specified
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
@@ -31,7 +26,7 @@ impl Storage {
} else { } else {
true true
}; };
if matches { if matches {
// For scan, we return key-value pairs for string types // For scan, we return key-value pairs for string types
if key_type == "string" { if key_type == "string" {
@@ -46,7 +41,7 @@ impl Storage {
// For non-string types, just return the key with type as value // For non-string types, just return the key with type as value
result.push((key, key_type)); result.push((key, key_type));
} }
if result.len() >= limit { if result.len() >= limit {
break; break;
} }
@@ -54,19 +49,15 @@ impl Storage {
} }
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { let next_cursor = if result.len() < limit { 0 } else { current_cursor };
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
pub fn ttl(&self, key: &str) -> Result<i64, DBError> { pub fn ttl(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => { Some(type_val) if type_val.value() == "string" => {
let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?; let expiration_table = read_txn.open_table(EXPIRATION_TABLE)?;
@@ -84,14 +75,14 @@ impl Storage {
} }
} }
Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types) Some(_) => Ok(-1), // Key exists but is not a string (no expiration support for other types)
None => Ok(-2), // Key does not exist None => Ok(-2), // Key does not exist
} }
} }
pub fn exists(&self, key: &str) -> Result<bool, DBError> { pub fn exists(&self, key: &str) -> Result<bool, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "string" => { Some(type_val) if type_val.value() == "string" => {
// Check if string key has expired // Check if string key has expired
@@ -104,7 +95,7 @@ impl Storage {
Ok(true) Ok(true)
} }
Some(_) => Ok(true), // Key exists and is not a string Some(_) => Ok(true), // Key exists and is not a string
None => Ok(false), // Key does not exist None => Ok(false), // Key does not exist
} }
} }
@@ -187,12 +178,8 @@ impl Storage {
.unwrap_or(false); .unwrap_or(false);
if is_string { if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_secs <= 0 { let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
0 expiration_table.insert(key, &((expires_at_ms as u64)))?;
} else {
(ts_secs as u128) * 1000
};
expiration_table.insert(key, &(expires_at_ms as u64))?;
applied = true; applied = true;
} }
} }
@@ -214,7 +201,7 @@ impl Storage {
if is_string { if is_string {
let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?; let mut expiration_table = write_txn.open_table(EXPIRATION_TABLE)?;
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
expiration_table.insert(key, &(expires_at_ms as u64))?; expiration_table.insert(key, &((expires_at_ms as u64)))?;
applied = true; applied = true;
} }
} }
@@ -236,21 +223,21 @@ pub fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" { if pattern == "*" {
return true; return true;
} }
// Simple glob matching - supports * and ? wildcards // Simple glob matching - supports * and ? wildcards
let pattern_chars: Vec<char> = pattern.chars().collect(); let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect(); let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() { if pi >= pattern.len() {
return ti >= text.len(); return ti >= text.len();
} }
if ti >= text.len() { if ti >= text.len() {
// Check if remaining pattern is all '*' // Check if remaining pattern is all '*'
return pattern[pi..].iter().all(|&c| c == '*'); return pattern[pi..].iter().all(|&c| c == '*');
} }
match pattern[pi] { match pattern[pi] {
'*' => { '*' => {
// Try matching zero or more characters // Try matching zero or more characters
@@ -275,7 +262,7 @@ pub fn glob_match(pattern: &str, text: &str) -> bool {
} }
} }
} }
match_recursive(&pattern_chars, &text_chars, 0, 0) match_recursive(&pattern_chars, &text_chars, 0, 0)
} }
@@ -296,4 +283,4 @@ mod tests {
assert!(glob_match("*test*", "this_is_a_test_string")); assert!(glob_match("*test*", "this_is_a_test_string"));
assert!(!glob_match("*test*", "this_is_a_string")); assert!(!glob_match("*test*", "this_is_a_string"));
} }
} }

View File

@@ -1,50 +1,44 @@
use super::*; use redb::{ReadableTable};
use crate::error::DBError; use crate::error::DBError;
use redb::ReadableTable; use super::*;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Values are encrypted before storage // ✅ ENCRYPTION APPLIED: Values are encrypted before storage
pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> { pub fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut new_fields = 0i64; let mut new_fields = 0i64;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string()) access_guard.map(|v| v.value().to_string())
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") | None => { Some("hash") | None => { // Proceed if hash or new key
// Proceed if hash or new key
// Set the type to hash (only if new key or existing hash) // Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?; types_table.insert(key, "hash")?;
for (field, value) in pairs { for (field, value) in pairs {
// Check if field already exists // Check if field already exists
let exists = hashes_table.get((key, field.as_str()))?.is_some(); let exists = hashes_table.get((key, field.as_str()))?.is_some();
// Encrypt the value before storing // Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field.as_str()), encrypted.as_slice())?; hashes_table.insert((key, field.as_str()), encrypted.as_slice())?;
if !exists { if !exists {
new_fields += 1; new_fields += 1;
} }
} }
} }
Some(_) => { Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(new_fields) Ok(new_fields)
} }
@@ -53,7 +47,7 @@ impl Storage {
pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> { pub fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = types_table.get(key)?.map(|v| v.value().to_string()); let key_type = types_table.get(key)?.map(|v| v.value().to_string());
match key_type.as_deref() { match key_type.as_deref() {
@@ -68,9 +62,7 @@ impl Storage {
None => Ok(None), None => Ok(None),
} }
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(None), None => Ok(None),
} }
} }
@@ -88,7 +80,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -99,12 +91,10 @@ impl Storage {
result.push((field.to_string(), value)); result.push((field.to_string(), value));
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -112,24 +102,24 @@ impl Storage {
pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> { pub fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut deleted = 0i64; let mut deleted = 0i64;
// First check if key exists and is a hash // First check if key exists and is a hash
let key_type = { let key_type = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string()) access_guard.map(|v| v.value().to_string())
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") => { Some("hash") => {
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
for field in fields { for field in fields {
if hashes_table.remove((key, field.as_str()))?.is_some() { if hashes_table.remove((key, field.as_str()))?.is_some() {
deleted += 1; deleted += 1;
} }
} }
// Check if hash is now empty and remove type if so // Check if hash is now empty and remove type if so
let mut has_fields = false; let mut has_fields = false;
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
@@ -142,20 +132,16 @@ impl Storage {
} }
} }
drop(iter); drop(iter);
if !has_fields { if !has_fields {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
types_table.remove(key)?; types_table.remove(key)?;
} }
} }
Some(_) => { Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
None => {} // Key does not exist, nothing to delete, return 0 deleted None => {} // Key does not exist, nothing to delete, return 0 deleted
} }
write_txn.commit()?; write_txn.commit()?;
Ok(deleted) Ok(deleted)
} }
@@ -173,9 +159,7 @@ impl Storage {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
Ok(hashes_table.get((key, field))?.is_some()) Ok(hashes_table.get((key, field))?.is_some())
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(false), None => Ok(false),
} }
} }
@@ -192,7 +176,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -201,12 +185,10 @@ impl Storage {
result.push(field.to_string()); result.push(field.to_string());
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -224,7 +206,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -235,12 +217,10 @@ impl Storage {
result.push(value); result.push(value);
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
@@ -257,7 +237,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut count = 0i64; let mut count = 0i64;
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
@@ -266,12 +246,10 @@ impl Storage {
count += 1; count += 1;
} }
} }
Ok(count) Ok(count)
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(0), None => Ok(0),
} }
} }
@@ -289,7 +267,7 @@ impl Storage {
Some("hash") => { Some("hash") => {
let hashes_table = read_txn.open_table(HASHES_TABLE)?; let hashes_table = read_txn.open_table(HASHES_TABLE)?;
let mut result = Vec::new(); let mut result = Vec::new();
for field in fields { for field in fields {
match hashes_table.get((key, field.as_str()))? { match hashes_table.get((key, field.as_str()))? {
Some(data) => { Some(data) => {
@@ -300,12 +278,10 @@ impl Storage {
None => result.push(None), None => result.push(None),
} }
} }
Ok(result) Ok(result)
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok(fields.into_iter().map(|_| None).collect()), None => Ok(fields.into_iter().map(|_| None).collect()),
} }
} }
@@ -314,51 +290,39 @@ impl Storage {
pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> { pub fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut result = false; let mut result = false;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut hashes_table = write_txn.open_table(HASHES_TABLE)?; let mut hashes_table = write_txn.open_table(HASHES_TABLE)?;
let key_type = { let key_type = {
let access_guard = types_table.get(key)?; let access_guard = types_table.get(key)?;
access_guard.map(|v| v.value().to_string()) access_guard.map(|v| v.value().to_string())
}; };
match key_type.as_deref() { match key_type.as_deref() {
Some("hash") | None => { Some("hash") | None => { // Proceed if hash or new key
// Proceed if hash or new key
// Check if field already exists // Check if field already exists
if hashes_table.get((key, field))?.is_none() { if hashes_table.get((key, field))?.is_none() {
// Set the type to hash (only if new key or existing hash) // Set the type to hash (only if new key or existing hash)
types_table.insert(key, "hash")?; types_table.insert(key, "hash")?;
// Encrypt the value before storing // Encrypt the value before storing
let encrypted = self.encrypt_if_needed(value.as_bytes())?; let encrypted = self.encrypt_if_needed(value.as_bytes())?;
hashes_table.insert((key, field), encrypted.as_slice())?; hashes_table.insert((key, field), encrypted.as_slice())?;
result = true; result = true;
} }
} }
Some(_) => { Some(_) => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value"
.to_string(),
))
}
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(result) Ok(result)
} }
// ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval // ✅ ENCRYPTION APPLIED: Values are decrypted after retrieval
pub fn hscan( pub fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
let key_type = { let key_type = {
@@ -372,28 +336,28 @@ impl Storage {
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
let mut iter = hashes_table.iter()?; let mut iter = hashes_table.iter()?;
while let Some(entry) = iter.next() { while let Some(entry) = iter.next() {
let entry = entry?; let entry = entry?;
let (hash_key, field) = entry.0.value(); let (hash_key, field) = entry.0.value();
if hash_key == key { if hash_key == key {
if current_cursor >= cursor { if current_cursor >= cursor {
let field_str = field.to_string(); let field_str = field.to_string();
// Apply pattern matching if specified // Apply pattern matching if specified
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
super::storage_extra::glob_match(pat, &field_str) super::storage_extra::glob_match(pat, &field_str)
} else { } else {
true true
}; };
if matches { if matches {
let decrypted = self.decrypt_if_needed(entry.1.value())?; let decrypted = self.decrypt_if_needed(entry.1.value())?;
let value = String::from_utf8(decrypted)?; let value = String::from_utf8(decrypted)?;
result.push((field_str, value)); result.push((field_str, value));
if result.len() >= limit { if result.len() >= limit {
break; break;
} }
@@ -402,18 +366,12 @@ impl Storage {
current_cursor += 1; current_cursor += 1;
} }
} }
let next_cursor = if result.len() < limit { let next_cursor = if result.len() < limit { 0 } else { current_cursor };
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
Some(_) => Err(DBError( Some(_) => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
)),
None => Ok((0, Vec::new())), None => Ok((0, Vec::new())),
} }
} }
} }

View File

@@ -1,20 +1,20 @@
use super::*; use redb::{ReadableTable};
use crate::error::DBError; use crate::error::DBError;
use redb::ReadableTable; use super::*;
impl Storage { impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are encrypted before storage // ✅ ENCRYPTION APPLIED: Elements are encrypted before storage
pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { pub fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut _length = 0i64; let mut _length = 0i64;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list // Set the type to list
types_table.insert(key, "list")?; types_table.insert(key, "list")?;
// Get current list or create empty one // Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? { let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => { Some(data) => {
@@ -23,20 +23,20 @@ impl Storage {
} }
None => Vec::new(), None => Vec::new(),
}; };
// Add elements to the front (left) // Add elements to the front (left)
for element in elements.into_iter() { for element in elements.into_iter() {
list.insert(0, element); list.insert(0, element);
} }
_length = list.len() as i64; _length = list.len() as i64;
// Encrypt and store the updated list // Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?; let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?; let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(_length) Ok(_length)
} }
@@ -45,14 +45,14 @@ impl Storage {
pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { pub fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut _length = 0i64; let mut _length = 0i64;
{ {
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
// Set the type to list // Set the type to list
types_table.insert(key, "list")?; types_table.insert(key, "list")?;
// Get current list or create empty one // Get current list or create empty one
let mut list: Vec<String> = match lists_table.get(key)? { let mut list: Vec<String> = match lists_table.get(key)? {
Some(data) => { Some(data) => {
@@ -61,17 +61,17 @@ impl Storage {
} }
None => Vec::new(), None => Vec::new(),
}; };
// Add elements to the end (right) // Add elements to the end (right)
list.extend(elements); list.extend(elements);
_length = list.len() as i64; _length = list.len() as i64;
// Encrypt and store the updated list // Encrypt and store the updated list
let serialized = serde_json::to_vec(&list)?; let serialized = serde_json::to_vec(&list)?;
let encrypted = self.encrypt_if_needed(&serialized)?; let encrypted = self.encrypt_if_needed(&serialized)?;
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
write_txn.commit()?; write_txn.commit()?;
Ok(_length) Ok(_length)
} }
@@ -80,12 +80,12 @@ impl Storage {
pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { pub fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut result = Vec::new(); let mut result = Vec::new();
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -100,7 +100,7 @@ impl Storage {
}; };
result result
}; };
if let Some(mut list) = list_data { if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len()); let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count { for _ in 0..pop_count {
@@ -108,7 +108,7 @@ impl Storage {
result.push(list.remove(0)); result.push(list.remove(0));
} }
} }
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() { if list.is_empty() {
// Remove the key if list is empty // Remove the key if list is empty
@@ -122,7 +122,7 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(result) Ok(result)
} }
@@ -131,12 +131,12 @@ impl Storage {
pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { pub fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut result = Vec::new(); let mut result = Vec::new();
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -151,7 +151,7 @@ impl Storage {
}; };
result result
}; };
if let Some(mut list) = list_data { if let Some(mut list) = list_data {
let pop_count = std::cmp::min(count as usize, list.len()); let pop_count = std::cmp::min(count as usize, list.len());
for _ in 0..pop_count { for _ in 0..pop_count {
@@ -159,7 +159,7 @@ impl Storage {
result.push(list.pop().unwrap()); result.push(list.pop().unwrap());
} }
} }
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() { if list.is_empty() {
// Remove the key if list is empty // Remove the key if list is empty
@@ -173,7 +173,7 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(result) Ok(result)
} }
@@ -181,7 +181,7 @@ impl Storage {
pub fn llen(&self, key: &str) -> Result<i64, DBError> { pub fn llen(&self, key: &str) -> Result<i64, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?; let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -202,7 +202,7 @@ impl Storage {
pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> { pub fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?; let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -210,13 +210,13 @@ impl Storage {
Some(data) => { Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?; let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?; let list: Vec<String> = serde_json::from_slice(&decrypted)?;
let actual_index = if index < 0 { let actual_index = if index < 0 {
list.len() as i64 + index list.len() as i64 + index
} else { } else {
index index
}; };
if actual_index >= 0 && (actual_index as usize) < list.len() { if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone())) Ok(Some(list[actual_index as usize].clone()))
} else { } else {
@@ -234,7 +234,7 @@ impl Storage {
pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> { pub fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
let read_txn = self.db.begin_read()?; let read_txn = self.db.begin_read()?;
let types_table = read_txn.open_table(TYPES_TABLE)?; let types_table = read_txn.open_table(TYPES_TABLE)?;
match types_table.get(key)? { match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
let lists_table = read_txn.open_table(LISTS_TABLE)?; let lists_table = read_txn.open_table(LISTS_TABLE)?;
@@ -242,30 +242,22 @@ impl Storage {
Some(data) => { Some(data) => {
let decrypted = self.decrypt_if_needed(data.value())?; let decrypted = self.decrypt_if_needed(data.value())?;
let list: Vec<String> = serde_json::from_slice(&decrypted)?; let list: Vec<String> = serde_json::from_slice(&decrypted)?;
if list.is_empty() { if list.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
std::cmp::max(0, len + start) let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
} }
None => Ok(Vec::new()), None => Ok(Vec::new()),
@@ -278,12 +270,12 @@ impl Storage {
// ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage // ✅ ENCRYPTION APPLIED: Elements are decrypted after retrieval and encrypted before storage
pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { pub fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -298,25 +290,17 @@ impl Storage {
}; };
result result
}; };
if let Some(list) = list_data { if let Some(list) = list_data {
if list.is_empty() { if list.is_empty() {
write_txn.commit()?; write_txn.commit()?;
return Ok(()); return Ok(());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { let start_idx = if start < 0 { std::cmp::max(0, len + start) } else { std::cmp::min(start, len) };
std::cmp::max(0, len + start) let stop_idx = if stop < 0 { std::cmp::max(-1, len + stop) } else { std::cmp::min(stop, len - 1) };
} else {
std::cmp::min(start, len)
};
let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop)
} else {
std::cmp::min(stop, len - 1)
};
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
// Remove the entire list // Remove the entire list
@@ -327,7 +311,7 @@ impl Storage {
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); let trimmed = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if trimmed.is_empty() { if trimmed.is_empty() {
lists_table.remove(key)?; lists_table.remove(key)?;
let mut types_table = write_txn.open_table(TYPES_TABLE)?; let mut types_table = write_txn.open_table(TYPES_TABLE)?;
@@ -340,7 +324,7 @@ impl Storage {
} }
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(()) Ok(())
} }
@@ -349,12 +333,12 @@ impl Storage {
pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> { pub fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let write_txn = self.db.begin_write()?; let write_txn = self.db.begin_write()?;
let mut removed = 0i64; let mut removed = 0i64;
// First check if key exists and is a list, and get the data // First check if key exists and is a list, and get the data
let list_data = { let list_data = {
let types_table = write_txn.open_table(TYPES_TABLE)?; let types_table = write_txn.open_table(TYPES_TABLE)?;
let lists_table = write_txn.open_table(LISTS_TABLE)?; let lists_table = write_txn.open_table(LISTS_TABLE)?;
let result = match types_table.get(key)? { let result = match types_table.get(key)? {
Some(type_val) if type_val.value() == "list" => { Some(type_val) if type_val.value() == "list" => {
if let Some(data) = lists_table.get(key)? { if let Some(data) = lists_table.get(key)? {
@@ -369,7 +353,7 @@ impl Storage {
}; };
result result
}; };
if let Some(mut list) = list_data { if let Some(mut list) = list_data {
if count == 0 { if count == 0 {
// Remove all occurrences // Remove all occurrences
@@ -399,7 +383,7 @@ impl Storage {
} }
} }
} }
let mut lists_table = write_txn.open_table(LISTS_TABLE)?; let mut lists_table = write_txn.open_table(LISTS_TABLE)?;
if list.is_empty() { if list.is_empty() {
lists_table.remove(key)?; lists_table.remove(key)?;
@@ -412,8 +396,8 @@ impl Storage {
lists_table.insert(key, encrypted.as_slice())?; lists_table.insert(key, encrypted.as_slice())?;
} }
} }
write_txn.commit()?; write_txn.commit()?;
Ok(removed) Ok(removed)
} }
} }

View File

@@ -1,12 +1,12 @@
// src/storage_sled/mod.rs // src/storage_sled/mod.rs
use crate::crypto::CryptoFactory;
use crate::error::DBError;
use crate::storage_trait::StorageBackend;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::error::DBError;
use crate::storage_trait::StorageBackend;
use crate::crypto::CryptoFactory;
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
enum ValueType { enum ValueType {
@@ -28,56 +28,44 @@ pub struct SledStorage {
} }
impl SledStorage { impl SledStorage {
pub fn new( pub fn new(path: impl AsRef<Path>, should_encrypt: bool, master_key: Option<&str>) -> Result<Self, DBError> {
path: impl AsRef<Path>,
should_encrypt: bool,
master_key: Option<&str>,
) -> Result<Self, DBError> {
let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?; let db = sled::open(path).map_err(|e| DBError(format!("Failed to open sled: {}", e)))?;
let types = db let types = db.open_tree("types").map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?;
.open_tree("types")
.map_err(|e| DBError(format!("Failed to open types tree: {}", e)))?;
// Check if database was previously encrypted // Check if database was previously encrypted
let encrypted_tree = db let encrypted_tree = db.open_tree("encrypted").map_err(|e| DBError(e.to_string()))?;
.open_tree("encrypted") let was_encrypted = encrypted_tree.get("encrypted")
.map_err(|e| DBError(e.to_string()))?;
let was_encrypted = encrypted_tree
.get("encrypted")
.map_err(|e| DBError(e.to_string()))? .map_err(|e| DBError(e.to_string()))?
.map(|v| v[0] == 1) .map(|v| v[0] == 1)
.unwrap_or(false); .unwrap_or(false);
let crypto = if should_encrypt || was_encrypted { let crypto = if should_encrypt || was_encrypted {
if let Some(key) = master_key { if let Some(key) = master_key {
Some(CryptoFactory::new(key.as_bytes())) Some(CryptoFactory::new(key.as_bytes()))
} else { } else {
return Err(DBError( return Err(DBError("Encryption requested but no master key provided".to_string()));
"Encryption requested but no master key provided".to_string(),
));
} }
} else { } else {
None None
}; };
// Mark database as encrypted if enabling encryption // Mark database as encrypted if enabling encryption
if should_encrypt && !was_encrypted { if should_encrypt && !was_encrypted {
encrypted_tree encrypted_tree.insert("encrypted", &[1u8])
.insert("encrypted", &[1u8])
.map_err(|e| DBError(e.to_string()))?; .map_err(|e| DBError(e.to_string()))?;
encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?; encrypted_tree.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(SledStorage { db, types, crypto }) Ok(SledStorage { db, types, crypto })
} }
fn now_millis() -> u128 { fn now_millis() -> u128 {
SystemTime::now() SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
.as_millis() .as_millis()
} }
fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> { fn encrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto { if let Some(crypto) = &self.crypto {
Ok(crypto.encrypt(data)) Ok(crypto.encrypt(data))
@@ -85,7 +73,7 @@ impl SledStorage {
Ok(data.to_vec()) Ok(data.to_vec())
} }
} }
fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> { fn decrypt_if_needed(&self, data: &[u8]) -> Result<Vec<u8>, DBError> {
if let Some(crypto) = &self.crypto { if let Some(crypto) = &self.crypto {
Ok(crypto.decrypt(data)?) Ok(crypto.decrypt(data)?)
@@ -93,14 +81,14 @@ impl SledStorage {
Ok(data.to_vec()) Ok(data.to_vec())
} }
} }
fn get_storage_value(&self, key: &str) -> Result<Option<StorageValue>, DBError> { fn get_storage_value(&self, key: &str) -> Result<Option<StorageValue>, DBError> {
match self.db.get(key).map_err(|e| DBError(e.to_string()))? { match self.db.get(key).map_err(|e| DBError(e.to_string()))? {
Some(encrypted_data) => { Some(encrypted_data) => {
let decrypted = self.decrypt_if_needed(&encrypted_data)?; let decrypted = self.decrypt_if_needed(&encrypted_data)?;
let storage_val: StorageValue = bincode::deserialize(&decrypted) let storage_val: StorageValue = bincode::deserialize(&decrypted)
.map_err(|e| DBError(format!("Deserialization error: {}", e)))?; .map_err(|e| DBError(format!("Deserialization error: {}", e)))?;
// Check expiration // Check expiration
if let Some(expires_at) = storage_val.expires_at { if let Some(expires_at) = storage_val.expires_at {
if Self::now_millis() > expires_at { if Self::now_millis() > expires_at {
@@ -110,51 +98,47 @@ impl SledStorage {
return Ok(None); return Ok(None);
} }
} }
Ok(Some(storage_val)) Ok(Some(storage_val))
} }
None => Ok(None), None => Ok(None)
} }
} }
fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> { fn set_storage_value(&self, key: &str, storage_val: StorageValue) -> Result<(), DBError> {
let data = bincode::serialize(&storage_val) let data = bincode::serialize(&storage_val)
.map_err(|e| DBError(format!("Serialization error: {}", e)))?; .map_err(|e| DBError(format!("Serialization error: {}", e)))?;
let encrypted = self.encrypt_if_needed(&data)?; let encrypted = self.encrypt_if_needed(&data)?;
self.db self.db.insert(key, encrypted).map_err(|e| DBError(e.to_string()))?;
.insert(key, encrypted)
.map_err(|e| DBError(e.to_string()))?;
// Store type info (unencrypted for efficiency) // Store type info (unencrypted for efficiency)
let type_str = match &storage_val.value { let type_str = match &storage_val.value {
ValueType::String(_) => "string", ValueType::String(_) => "string",
ValueType::Hash(_) => "hash", ValueType::Hash(_) => "hash",
ValueType::List(_) => "list", ValueType::List(_) => "list",
}; };
self.types self.types.insert(key, type_str.as_bytes()).map_err(|e| DBError(e.to_string()))?;
.insert(key, type_str.as_bytes())
.map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn glob_match(pattern: &str, text: &str) -> bool { fn glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" { if pattern == "*" {
return true; return true;
} }
let pattern_chars: Vec<char> = pattern.chars().collect(); let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect(); let text_chars: Vec<char> = text.chars().collect();
fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool { fn match_recursive(pattern: &[char], text: &[char], pi: usize, ti: usize) -> bool {
if pi >= pattern.len() { if pi >= pattern.len() {
return ti >= text.len(); return ti >= text.len();
} }
if ti >= text.len() { if ti >= text.len() {
return pattern[pi..].iter().all(|&c| c == '*'); return pattern[pi..].iter().all(|&c| c == '*');
} }
match pattern[pi] { match pattern[pi] {
'*' => { '*' => {
for i in ti..=text.len() { for i in ti..=text.len() {
@@ -174,7 +158,7 @@ impl SledStorage {
} }
} }
} }
match_recursive(&pattern_chars, &text_chars, 0, 0) match_recursive(&pattern_chars, &text_chars, 0, 0)
} }
} }
@@ -184,12 +168,12 @@ impl StorageBackend for SledStorage {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::String(s) => Ok(Some(s)), ValueType::String(s) => Ok(Some(s)),
_ => Ok(None), _ => Ok(None)
}, }
None => Ok(None), None => Ok(None)
} }
} }
fn set(&self, key: String, value: String) -> Result<(), DBError> { fn set(&self, key: String, value: String) -> Result<(), DBError> {
let storage_val = StorageValue { let storage_val = StorageValue {
value: ValueType::String(value), value: ValueType::String(value),
@@ -199,7 +183,7 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> { fn setx(&self, key: String, value: String, expire_ms: u128) -> Result<(), DBError> {
let storage_val = StorageValue { let storage_val = StorageValue {
value: ValueType::String(value), value: ValueType::String(value),
@@ -209,27 +193,25 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn del(&self, key: String) -> Result<(), DBError> { fn del(&self, key: String) -> Result<(), DBError> {
self.db.remove(&key).map_err(|e| DBError(e.to_string()))?; self.db.remove(&key).map_err(|e| DBError(e.to_string()))?;
self.types self.types.remove(&key).map_err(|e| DBError(e.to_string()))?;
.remove(&key)
.map_err(|e| DBError(e.to_string()))?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn exists(&self, key: &str) -> Result<bool, DBError> { fn exists(&self, key: &str) -> Result<bool, DBError> {
// Check with expiration // Check with expiration
Ok(self.get_storage_value(key)?.is_some()) Ok(self.get_storage_value(key)?.is_some())
} }
fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> { fn keys(&self, pattern: &str) -> Result<Vec<String>, DBError> {
let mut keys = Vec::new(); let mut keys = Vec::new();
for item in self.types.iter() { for item in self.types.iter() {
let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?; let (key_bytes, _) = item.map_err(|e| DBError(e.to_string()))?;
let key = String::from_utf8_lossy(&key_bytes).to_string(); let key = String::from_utf8_lossy(&key_bytes).to_string();
// Check if key is expired // Check if key is expired
if self.get_storage_value(&key)?.is_some() { if self.get_storage_value(&key)?.is_some() {
if Self::glob_match(pattern, &key) { if Self::glob_match(pattern, &key) {
@@ -239,29 +221,24 @@ impl StorageBackend for SledStorage {
} }
Ok(keys) Ok(keys)
} }
fn scan( fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
&self,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
for item in self.types.iter() { for item in self.types.iter() {
if current_cursor >= cursor { if current_cursor >= cursor {
let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?; let (key_bytes, type_bytes) = item.map_err(|e| DBError(e.to_string()))?;
let key = String::from_utf8_lossy(&key_bytes).to_string(); let key = String::from_utf8_lossy(&key_bytes).to_string();
// Check pattern match // Check pattern match
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
Self::glob_match(pat, &key) Self::glob_match(pat, &key)
} else { } else {
true true
}; };
if matches { if matches {
// Check if key is expired and get value // Check if key is expired and get value
if let Some(storage_val) = self.get_storage_value(&key)? { if let Some(storage_val) = self.get_storage_value(&key)? {
@@ -270,7 +247,7 @@ impl StorageBackend for SledStorage {
_ => String::from_utf8_lossy(&type_bytes).to_string(), _ => String::from_utf8_lossy(&type_bytes).to_string(),
}; };
result.push((key, value)); result.push((key, value));
if result.len() >= limit { if result.len() >= limit {
current_cursor += 1; current_cursor += 1;
break; break;
@@ -280,15 +257,11 @@ impl StorageBackend for SledStorage {
} }
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { let next_cursor = if result.len() < limit { 0 } else { current_cursor };
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
fn dbsize(&self) -> Result<i64, DBError> { fn dbsize(&self) -> Result<i64, DBError> {
let mut count = 0i64; let mut count = 0i64;
for item in self.types.iter() { for item in self.types.iter() {
@@ -300,42 +273,38 @@ impl StorageBackend for SledStorage {
} }
Ok(count) Ok(count)
} }
fn flushdb(&self) -> Result<(), DBError> { fn flushdb(&self) -> Result<(), DBError> {
self.db.clear().map_err(|e| DBError(e.to_string()))?; self.db.clear().map_err(|e| DBError(e.to_string()))?;
self.types.clear().map_err(|e| DBError(e.to_string()))?; self.types.clear().map_err(|e| DBError(e.to_string()))?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(()) Ok(())
} }
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> { fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError> {
// First check if key exists (handles expiration) // First check if key exists (handles expiration)
if self.get_storage_value(key)?.is_some() { if self.get_storage_value(key)?.is_some() {
match self.types.get(key).map_err(|e| DBError(e.to_string()))? { match self.types.get(key).map_err(|e| DBError(e.to_string()))? {
Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())), Some(data) => Ok(Some(String::from_utf8_lossy(&data).to_string())),
None => Ok(None), None => Ok(None)
} }
} else { } else {
Ok(None) Ok(None)
} }
} }
// Hash operations // Hash operations
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> { fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::Hash(HashMap::new()), value: ValueType::Hash(HashMap::new()),
expires_at: None, expires_at: None,
}); });
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => { _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
let mut new_fields = 0i64; let mut new_fields = 0i64;
for (field, value) in pairs { for (field, value) in pairs {
if !hash.contains_key(&field) { if !hash.contains_key(&field) {
@@ -343,46 +312,40 @@ impl StorageBackend for SledStorage {
} }
hash.insert(field, value); hash.insert(field, value);
} }
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(new_fields) Ok(new_fields)
} }
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> { fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.get(field).cloned()), ValueType::Hash(h) => Ok(h.get(field).cloned()),
_ => Ok(None), _ => Ok(None)
}, }
None => Ok(None), None => Ok(None)
} }
} }
fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> { fn hgetall(&self, key: &str) -> Result<Vec<(String, String)>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.into_iter().collect()), ValueType::Hash(h) => Ok(h.into_iter().collect()),
_ => Ok(Vec::new()), _ => Ok(Vec::new())
}, }
None => Ok(Vec::new()), None => Ok(Vec::new())
} }
} }
fn hscan( fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError> {
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => { ValueType::Hash(h) => {
let mut result = Vec::new(); let mut result = Vec::new();
let mut current_cursor = 0u64; let mut current_cursor = 0u64;
let limit = count.unwrap_or(10) as usize; let limit = count.unwrap_or(10) as usize;
for (field, value) in h.iter() { for (field, value) in h.iter() {
if current_cursor >= cursor { if current_cursor >= cursor {
let matches = if let Some(pat) = pattern { let matches = if let Some(pat) = pattern {
@@ -390,7 +353,7 @@ impl StorageBackend for SledStorage {
} else { } else {
true true
}; };
if matches { if matches {
result.push((field.clone(), value.clone())); result.push((field.clone(), value.clone()));
if result.len() >= limit { if result.len() >= limit {
@@ -401,115 +364,107 @@ impl StorageBackend for SledStorage {
} }
current_cursor += 1; current_cursor += 1;
} }
let next_cursor = if result.len() < limit { let next_cursor = if result.len() < limit { 0 } else { current_cursor };
0
} else {
current_cursor
};
Ok((next_cursor, result)) Ok((next_cursor, result))
} }
_ => Err(DBError( _ => Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string()))
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(), }
)), None => Ok((0, Vec::new()))
},
None => Ok((0, Vec::new())),
} }
} }
fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> { fn hdel(&self, key: &str, fields: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(0), None => return Ok(0)
}; };
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => return Ok(0), _ => return Ok(0)
}; };
let mut deleted = 0i64; let mut deleted = 0i64;
for field in fields { for field in fields {
if hash.remove(&field).is_some() { if hash.remove(&field).is_some() {
deleted += 1; deleted += 1;
} }
} }
if hash.is_empty() { if hash.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(deleted) Ok(deleted)
} }
fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> { fn hexists(&self, key: &str, field: &str) -> Result<bool, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.contains_key(field)), ValueType::Hash(h) => Ok(h.contains_key(field)),
_ => Ok(false), _ => Ok(false)
}, }
None => Ok(false), None => Ok(false)
} }
} }
fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> { fn hkeys(&self, key: &str) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.keys().cloned().collect()), ValueType::Hash(h) => Ok(h.keys().cloned().collect()),
_ => Ok(Vec::new()), _ => Ok(Vec::new())
}, }
None => Ok(Vec::new()), None => Ok(Vec::new())
} }
} }
fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> { fn hvals(&self, key: &str) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.values().cloned().collect()), ValueType::Hash(h) => Ok(h.values().cloned().collect()),
_ => Ok(Vec::new()), _ => Ok(Vec::new())
}, }
None => Ok(Vec::new()), None => Ok(Vec::new())
} }
} }
fn hlen(&self, key: &str) -> Result<i64, DBError> { fn hlen(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(h.len() as i64), ValueType::Hash(h) => Ok(h.len() as i64),
_ => Ok(0), _ => Ok(0)
}, }
None => Ok(0), None => Ok(0)
} }
} }
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> { fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::Hash(h) => Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect()), ValueType::Hash(h) => {
_ => Ok(fields.into_iter().map(|_| None).collect()), Ok(fields.into_iter().map(|f| h.get(&f).cloned()).collect())
}, }
None => Ok(fields.into_iter().map(|_| None).collect()), _ => Ok(fields.into_iter().map(|_| None).collect())
}
None => Ok(fields.into_iter().map(|_| None).collect())
} }
} }
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> { fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::Hash(HashMap::new()), value: ValueType::Hash(HashMap::new()),
expires_at: None, expires_at: None,
}); });
let hash = match &mut storage_val.value { let hash = match &mut storage_val.value {
ValueType::Hash(h) => h, ValueType::Hash(h) => h,
_ => { _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
if hash.contains_key(field) { if hash.contains_key(field) {
Ok(false) Ok(false)
} else { } else {
@@ -519,66 +474,58 @@ impl StorageBackend for SledStorage {
Ok(true) Ok(true)
} }
} }
// List operations // List operations
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::List(Vec::new()), value: ValueType::List(Vec::new()),
expires_at: None, expires_at: None,
}); });
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => { _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
for element in elements.into_iter().rev() { for element in elements.into_iter().rev() {
list.insert(0, element); list.insert(0, element);
} }
let len = list.len() as i64; let len = list.len() as i64;
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(len) Ok(len)
} }
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> { fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError> {
let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue { let mut storage_val = self.get_storage_value(key)?.unwrap_or(StorageValue {
value: ValueType::List(Vec::new()), value: ValueType::List(Vec::new()),
expires_at: None, expires_at: None,
}); });
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => { _ => return Err(DBError("WRONGTYPE Operation against a key holding the wrong kind of value".to_string())),
return Err(DBError(
"WRONGTYPE Operation against a key holding the wrong kind of value".to_string(),
))
}
}; };
list.extend(elements); list.extend(elements);
let len = list.len() as i64; let len = list.len() as i64;
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(len) Ok(len)
} }
fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { fn lpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(Vec::new()), None => return Ok(Vec::new())
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(Vec::new()), _ => return Ok(Vec::new())
}; };
let mut result = Vec::new(); let mut result = Vec::new();
for _ in 0..count.min(list.len() as u64) { for _ in 0..count.min(list.len() as u64) {
if let Some(elem) = list.first() { if let Some(elem) = list.first() {
@@ -586,55 +533,55 @@ impl StorageBackend for SledStorage {
list.remove(0); list.remove(0);
} }
} }
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(result) Ok(result)
} }
fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> { fn rpop(&self, key: &str, count: u64) -> Result<Vec<String>, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(Vec::new()), None => return Ok(Vec::new())
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(Vec::new()), _ => return Ok(Vec::new())
}; };
let mut result = Vec::new(); let mut result = Vec::new();
for _ in 0..count.min(list.len() as u64) { for _ in 0..count.min(list.len() as u64) {
if let Some(elem) = list.pop() { if let Some(elem) = list.pop() {
result.push(elem); result.push(elem);
} }
} }
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(result) Ok(result)
} }
fn llen(&self, key: &str) -> Result<i64, DBError> { fn llen(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
ValueType::List(l) => Ok(l.len() as i64), ValueType::List(l) => Ok(l.len() as i64),
_ => Ok(0), _ => Ok(0)
}, }
None => Ok(0), None => Ok(0)
} }
} }
fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> { fn lindex(&self, key: &str, index: i64) -> Result<Option<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
@@ -644,19 +591,19 @@ impl StorageBackend for SledStorage {
} else { } else {
index index
}; };
if actual_index >= 0 && (actual_index as usize) < list.len() { if actual_index >= 0 && (actual_index as usize) < list.len() {
Ok(Some(list[actual_index as usize].clone())) Ok(Some(list[actual_index as usize].clone()))
} else { } else {
Ok(None) Ok(None)
} }
} }
_ => Ok(None), _ => Ok(None)
}, }
None => Ok(None), None => Ok(None)
} }
} }
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> { fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
Some(storage_val) => match storage_val.value { Some(storage_val) => match storage_val.value {
@@ -664,68 +611,68 @@ impl StorageBackend for SledStorage {
if list.is_empty() { if list.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { let start_idx = if start < 0 {
std::cmp::max(0, len + start) std::cmp::max(0, len + start)
} else { } else {
std::cmp::min(start, len) std::cmp::min(start, len)
}; };
let stop_idx = if stop < 0 { let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop) std::cmp::max(-1, len + stop)
} else { } else {
std::cmp::min(stop, len - 1) std::cmp::min(stop, len - 1)
}; };
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec()) Ok(list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec())
} }
_ => Ok(Vec::new()), _ => Ok(Vec::new())
}, }
None => Ok(Vec::new()), None => Ok(Vec::new())
} }
} }
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> { fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(()), None => return Ok(())
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(()), _ => return Ok(())
}; };
if list.is_empty() { if list.is_empty() {
return Ok(()); return Ok(());
} }
let len = list.len() as i64; let len = list.len() as i64;
let start_idx = if start < 0 { let start_idx = if start < 0 {
std::cmp::max(0, len + start) std::cmp::max(0, len + start)
} else { } else {
std::cmp::min(start, len) std::cmp::min(start, len)
}; };
let stop_idx = if stop < 0 { let stop_idx = if stop < 0 {
std::cmp::max(-1, len + stop) std::cmp::max(-1, len + stop)
} else { } else {
std::cmp::min(stop, len - 1) std::cmp::min(stop, len - 1)
}; };
if start_idx > stop_idx || start_idx >= len { if start_idx > stop_idx || start_idx >= len {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
let start_usize = start_idx as usize; let start_usize = start_idx as usize;
let stop_usize = (stop_idx + 1) as usize; let stop_usize = (stop_idx + 1) as usize;
*list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec(); *list = list[start_usize..std::cmp::min(stop_usize, list.len())].to_vec();
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
@@ -733,23 +680,23 @@ impl StorageBackend for SledStorage {
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
} }
Ok(()) Ok(())
} }
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> { fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(0), None => return Ok(0)
}; };
let list = match &mut storage_val.value { let list = match &mut storage_val.value {
ValueType::List(l) => l, ValueType::List(l) => l,
_ => return Ok(0), _ => return Ok(0)
}; };
let mut removed = 0i64; let mut removed = 0i64;
if count == 0 { if count == 0 {
// Remove all occurrences // Remove all occurrences
let original_len = list.len(); let original_len = list.len();
@@ -778,17 +725,17 @@ impl StorageBackend for SledStorage {
} }
} }
} }
if list.is_empty() { if list.is_empty() {
self.del(key.to_string())?; self.del(key.to_string())?;
} else { } else {
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
} }
Ok(removed) Ok(removed)
} }
// Expiration // Expiration
fn ttl(&self, key: &str) -> Result<i64, DBError> { fn ttl(&self, key: &str) -> Result<i64, DBError> {
match self.get_storage_value(key)? { match self.get_storage_value(key)? {
@@ -804,40 +751,40 @@ impl StorageBackend for SledStorage {
Ok(-1) // Key exists but has no expiration Ok(-1) // Key exists but has no expiration
} }
} }
None => Ok(-2), // Key does not exist None => Ok(-2) // Key does not exist
} }
} }
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> { fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false), None => return Ok(false)
}; };
storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000); storage_val.expires_at = Some(Self::now_millis() + (secs as u128) * 1000);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> { fn pexpire_millis(&self, key: &str, ms: u128) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false), None => return Ok(false)
}; };
storage_val.expires_at = Some(Self::now_millis() + ms); storage_val.expires_at = Some(Self::now_millis() + ms);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn persist(&self, key: &str) -> Result<bool, DBError> { fn persist(&self, key: &str) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false), None => return Ok(false)
}; };
if storage_val.expires_at.is_some() { if storage_val.expires_at.is_some() {
storage_val.expires_at = None; storage_val.expires_at = None;
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
@@ -847,41 +794,37 @@ impl StorageBackend for SledStorage {
Ok(false) Ok(false)
} }
} }
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> { fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false), None => return Ok(false)
};
let expires_at_ms: u128 = if ts_secs <= 0 {
0
} else {
(ts_secs as u128) * 1000
}; };
let expires_at_ms: u128 = if ts_secs <= 0 { 0 } else { (ts_secs as u128) * 1000 };
storage_val.expires_at = Some(expires_at_ms); storage_val.expires_at = Some(expires_at_ms);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> { fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError> {
let mut storage_val = match self.get_storage_value(key)? { let mut storage_val = match self.get_storage_value(key)? {
Some(sv) => sv, Some(sv) => sv,
None => return Ok(false), None => return Ok(false)
}; };
let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 }; let expires_at_ms: u128 = if ts_ms <= 0 { 0 } else { ts_ms as u128 };
storage_val.expires_at = Some(expires_at_ms); storage_val.expires_at = Some(expires_at_ms);
self.set_storage_value(key, storage_val)?; self.set_storage_value(key, storage_val)?;
self.db.flush().map_err(|e| DBError(e.to_string()))?; self.db.flush().map_err(|e| DBError(e.to_string()))?;
Ok(true) Ok(true)
} }
fn is_encrypted(&self) -> bool { fn is_encrypted(&self) -> bool {
self.crypto.is_some() self.crypto.is_some()
} }
fn info(&self) -> Result<Vec<(String, String)>, DBError> { fn info(&self) -> Result<Vec<(String, String)>, DBError> {
let dbsize = self.dbsize()?; let dbsize = self.dbsize()?;
Ok(vec![ Ok(vec![
@@ -899,4 +842,4 @@ impl StorageBackend for SledStorage {
crypto: self.crypto.clone(), crypto: self.crypto.clone(),
}) })
} }
} }

View File

@@ -13,22 +13,11 @@ pub trait StorageBackend: Send + Sync {
fn dbsize(&self) -> Result<i64, DBError>; fn dbsize(&self) -> Result<i64, DBError>;
fn flushdb(&self) -> Result<(), DBError>; fn flushdb(&self) -> Result<(), DBError>;
fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>; fn get_key_type(&self, key: &str) -> Result<Option<String>, DBError>;
// Scanning // Scanning
fn scan( fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>;
&self, fn hscan(&self, key: &str, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<(String, String)>), DBError>;
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
fn hscan(
&self,
key: &str,
cursor: u64,
pattern: Option<&str>,
count: Option<u64>,
) -> Result<(u64, Vec<(String, String)>), DBError>;
// Hash operations // Hash operations
fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>; fn hset(&self, key: &str, pairs: Vec<(String, String)>) -> Result<i64, DBError>;
fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError>; fn hget(&self, key: &str, field: &str) -> Result<Option<String>, DBError>;
@@ -40,7 +29,7 @@ pub trait StorageBackend: Send + Sync {
fn hlen(&self, key: &str) -> Result<i64, DBError>; fn hlen(&self, key: &str) -> Result<i64, DBError>;
fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError>; fn hmget(&self, key: &str, fields: Vec<String>) -> Result<Vec<Option<String>>, DBError>;
fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError>; fn hsetnx(&self, key: &str, field: &str, value: &str) -> Result<bool, DBError>;
// List operations // List operations
fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>; fn lpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>; fn rpush(&self, key: &str, elements: Vec<String>) -> Result<i64, DBError>;
@@ -51,7 +40,7 @@ pub trait StorageBackend: Send + Sync {
fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError>; fn lrange(&self, key: &str, start: i64, stop: i64) -> Result<Vec<String>, DBError>;
fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError>; fn ltrim(&self, key: &str, start: i64, stop: i64) -> Result<(), DBError>;
fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError>; fn lrem(&self, key: &str, count: i64, element: &str) -> Result<i64, DBError>;
// Expiration // Expiration
fn ttl(&self, key: &str) -> Result<i64, DBError>; fn ttl(&self, key: &str) -> Result<i64, DBError>;
fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError>; fn expire_seconds(&self, key: &str, secs: u64) -> Result<bool, DBError>;
@@ -59,11 +48,11 @@ pub trait StorageBackend: Send + Sync {
fn persist(&self, key: &str) -> Result<bool, DBError>; fn persist(&self, key: &str) -> Result<bool, DBError>;
fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError>; fn expire_at_seconds(&self, key: &str, ts_secs: i64) -> Result<bool, DBError>;
fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError>; fn pexpire_at_millis(&self, key: &str, ts_ms: i64) -> Result<bool, DBError>;
// Metadata // Metadata
fn is_encrypted(&self) -> bool; fn is_encrypted(&self) -> bool;
fn info(&self) -> Result<Vec<(String, String)>, DBError>; fn info(&self) -> Result<Vec<(String, String)>, DBError>;
// Clone to Arc for sharing // Clone to Arc for sharing
fn clone_arc(&self) -> Arc<dyn StorageBackend>; fn clone_arc(&self) -> Arc<dyn StorageBackend>;
} }

View File

@@ -1,657 +0,0 @@
use crate::error::DBError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use tantivy::{
collector::TopDocs,
directory::MmapDirectory,
query::{BooleanQuery, Occur, Query, QueryParser, TermQuery},
schema::{Field, Schema, TextFieldIndexing, TextOptions, Value, STORED, STRING},
tokenizer::TokenizerManager,
DateTime, Index, IndexReader, IndexWriter, ReloadPolicy, TantivyDocument, Term,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FieldDef {
Text {
stored: bool,
indexed: bool,
tokenized: bool,
fast: bool,
},
Numeric {
stored: bool,
indexed: bool,
fast: bool,
precision: NumericType,
},
Tag {
stored: bool,
separator: String,
case_sensitive: bool,
},
Geo {
stored: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NumericType {
I64,
U64,
F64,
Date,
}
pub struct IndexSchema {
schema: Schema,
fields: HashMap<String, (Field, FieldDef)>,
default_search_fields: Vec<Field>,
}
pub struct TantivySearch {
index: Index,
writer: Arc<RwLock<IndexWriter>>,
reader: IndexReader,
index_schema: IndexSchema,
name: String,
config: IndexConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexConfig {
pub language: String,
pub stopwords: Vec<String>,
pub stemming: bool,
pub max_doc_count: Option<usize>,
pub default_score: f64,
}
impl Default for IndexConfig {
fn default() -> Self {
IndexConfig {
language: "english".to_string(),
stopwords: vec![],
stemming: true,
max_doc_count: None,
default_score: 1.0,
}
}
}
impl TantivySearch {
pub fn new_with_schema(
base_path: PathBuf,
name: String,
field_definitions: Vec<(String, FieldDef)>,
config: Option<IndexConfig>,
) -> Result<Self, DBError> {
let index_path = base_path.join(&name);
std::fs::create_dir_all(&index_path)
.map_err(|e| DBError(format!("Failed to create index dir: {}", e)))?;
// Build schema from field definitions
let mut schema_builder = Schema::builder();
let mut fields = HashMap::new();
let mut default_search_fields = Vec::new();
// Always add a document ID field
let id_field = schema_builder.add_text_field("_id", STRING | STORED);
fields.insert(
"_id".to_string(),
(
id_field,
FieldDef::Text {
stored: true,
indexed: true,
tokenized: false,
fast: false,
},
),
);
// Add user-defined fields
for (field_name, field_def) in field_definitions {
let field = match &field_def {
FieldDef::Text {
stored,
indexed,
tokenized,
fast: _fast,
} => {
let mut text_options = TextOptions::default();
if *stored {
text_options = text_options.set_stored();
}
if *indexed {
let indexing_options = if *tokenized {
TextFieldIndexing::default()
.set_tokenizer("default")
.set_index_option(
tantivy::schema::IndexRecordOption::WithFreqsAndPositions,
)
} else {
TextFieldIndexing::default()
.set_tokenizer("raw")
.set_index_option(tantivy::schema::IndexRecordOption::Basic)
};
text_options = text_options.set_indexing_options(indexing_options);
let f = schema_builder.add_text_field(&field_name, text_options);
if *tokenized {
default_search_fields.push(f);
}
f
} else {
schema_builder.add_text_field(&field_name, text_options)
}
}
FieldDef::Numeric {
stored,
indexed,
fast,
precision,
} => match precision {
NumericType::I64 => {
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_i64_field(&field_name, opts)
}
NumericType::U64 => {
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_u64_field(&field_name, opts)
}
NumericType::F64 => {
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_f64_field(&field_name, opts)
}
NumericType::Date => {
let mut opts = tantivy::schema::DateOptions::default();
if *stored {
opts = opts.set_stored();
}
if *indexed {
opts = opts.set_indexed();
}
if *fast {
opts = opts.set_fast();
}
schema_builder.add_date_field(&field_name, opts)
}
},
FieldDef::Tag {
stored,
separator: _,
case_sensitive: _,
} => {
let mut text_options = TextOptions::default();
if *stored {
text_options = text_options.set_stored();
}
text_options = text_options.set_indexing_options(
TextFieldIndexing::default()
.set_tokenizer("raw")
.set_index_option(tantivy::schema::IndexRecordOption::Basic),
);
schema_builder.add_text_field(&field_name, text_options)
}
FieldDef::Geo { stored } => {
// For now, store as two f64 fields for lat/lon
let mut opts = tantivy::schema::NumericOptions::default();
if *stored {
opts = opts.set_stored();
}
opts = opts.set_indexed().set_fast();
let lat_field =
schema_builder.add_f64_field(&format!("{}_lat", field_name), opts.clone());
let lon_field =
schema_builder.add_f64_field(&format!("{}_lon", field_name), opts);
fields.insert(
format!("{}_lat", field_name),
(
lat_field,
FieldDef::Numeric {
stored: *stored,
indexed: true,
fast: true,
precision: NumericType::F64,
},
),
);
fields.insert(
format!("{}_lon", field_name),
(
lon_field,
FieldDef::Numeric {
stored: *stored,
indexed: true,
fast: true,
precision: NumericType::F64,
},
),
);
continue; // Skip adding the geo field itself
}
};
fields.insert(field_name.clone(), (field, field_def));
}
let schema = schema_builder.build();
let index_schema = IndexSchema {
schema: schema.clone(),
fields,
default_search_fields,
};
// Create or open index
let dir = MmapDirectory::open(&index_path)
.map_err(|e| DBError(format!("Failed to open index directory: {}", e)))?;
let mut index = Index::open_or_create(dir, schema)
.map_err(|e| DBError(format!("Failed to create index: {}", e)))?;
// Configure tokenizers
let tokenizer_manager = TokenizerManager::default();
index.set_tokenizers(tokenizer_manager);
let writer = index
.writer(15_000_000)
.map_err(|e| DBError(format!("Failed to create index writer: {}", e)))?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommitWithDelay)
.try_into()
.map_err(|e| DBError(format!("Failed to create reader: {}", e)))?;
let config = config.unwrap_or_default();
Ok(TantivySearch {
index,
writer: Arc::new(RwLock::new(writer)),
reader,
index_schema,
name,
config,
})
}
pub fn add_document_with_fields(
&self,
doc_id: &str,
fields: HashMap<String, String>,
) -> Result<(), DBError> {
let mut writer = self
.writer
.write()
.map_err(|e| DBError(format!("Failed to acquire writer lock: {}", e)))?;
// Delete existing document with same ID
if let Some((id_field, _)) = self.index_schema.fields.get("_id") {
writer.delete_term(Term::from_field_text(*id_field, doc_id));
}
// Create new document
let mut doc = tantivy::doc!();
// Add document ID
if let Some((id_field, _)) = self.index_schema.fields.get("_id") {
doc.add_text(*id_field, doc_id);
}
// Add other fields based on schema
for (field_name, field_value) in fields {
if let Some((field, field_def)) = self.index_schema.fields.get(&field_name) {
match field_def {
FieldDef::Text { .. } => {
doc.add_text(*field, &field_value);
}
FieldDef::Numeric { precision, .. } => match precision {
NumericType::I64 => {
if let Ok(v) = field_value.parse::<i64>() {
doc.add_i64(*field, v);
}
}
NumericType::U64 => {
if let Ok(v) = field_value.parse::<u64>() {
doc.add_u64(*field, v);
}
}
NumericType::F64 => {
if let Ok(v) = field_value.parse::<f64>() {
doc.add_f64(*field, v);
}
}
NumericType::Date => {
if let Ok(v) = field_value.parse::<i64>() {
doc.add_date(*field, DateTime::from_timestamp_millis(v));
}
}
},
FieldDef::Tag {
separator,
case_sensitive,
..
} => {
let tags = if !case_sensitive {
field_value.to_lowercase()
} else {
field_value.clone()
};
// Store tags as separate terms for efficient filtering
for tag in tags.split(separator.as_str()) {
doc.add_text(*field, tag.trim());
}
}
FieldDef::Geo { .. } => {
// Parse "lat,lon" format
let parts: Vec<&str> = field_value.split(',').collect();
if parts.len() == 2 {
if let (Ok(lat), Ok(lon)) =
(parts[0].parse::<f64>(), parts[1].parse::<f64>())
{
if let Some((lat_field, _)) =
self.index_schema.fields.get(&format!("{}_lat", field_name))
{
doc.add_f64(*lat_field, lat);
}
if let Some((lon_field, _)) =
self.index_schema.fields.get(&format!("{}_lon", field_name))
{
doc.add_f64(*lon_field, lon);
}
}
}
}
}
}
}
writer
.add_document(doc)
.map_err(|e| DBError(format!("Failed to add document: {}", e)))?;
writer
.commit()
.map_err(|e| DBError(format!("Failed to commit: {}", e)))?;
Ok(())
}
pub fn search_with_options(
&self,
query_str: &str,
options: SearchOptions,
) -> Result<SearchResults, DBError> {
let searcher = self.reader.searcher();
// Parse query based on search fields
let query: Box<dyn Query> = if self.index_schema.default_search_fields.is_empty() {
return Err(DBError(
"No searchable fields defined in schema".to_string(),
));
} else {
let query_parser = QueryParser::for_index(
&self.index,
self.index_schema.default_search_fields.clone(),
);
Box::new(
query_parser
.parse_query(query_str)
.map_err(|e| DBError(format!("Failed to parse query: {}", e)))?,
)
};
// Apply filters if any
let final_query = if !options.filters.is_empty() {
let mut clauses: Vec<(Occur, Box<dyn Query>)> = vec![(Occur::Must, query)];
// Add filters
for filter in options.filters {
if let Some((field, _)) = self.index_schema.fields.get(&filter.field) {
match filter.filter_type {
FilterType::Equals(value) => {
let term_query = TermQuery::new(
Term::from_field_text(*field, &value),
tantivy::schema::IndexRecordOption::Basic,
);
clauses.push((Occur::Must, Box::new(term_query)));
}
FilterType::Range { min: _, max: _ } => {
// Would need numeric field handling here
// Simplified for now
}
FilterType::InSet(values) => {
let mut sub_clauses: Vec<(Occur, Box<dyn Query>)> = vec![];
for value in values {
let term_query = TermQuery::new(
Term::from_field_text(*field, &value),
tantivy::schema::IndexRecordOption::Basic,
);
sub_clauses.push((Occur::Should, Box::new(term_query)));
}
clauses.push((Occur::Must, Box::new(BooleanQuery::new(sub_clauses))));
}
}
}
}
Box::new(BooleanQuery::new(clauses))
} else {
query
};
// Execute search
let top_docs = searcher
.search(
&*final_query,
&TopDocs::with_limit(options.limit + options.offset),
)
.map_err(|e| DBError(format!("Search failed: {}", e)))?;
let total_hits = top_docs.len();
let mut documents = Vec::new();
for (score, doc_address) in top_docs.iter().skip(options.offset).take(options.limit) {
let retrieved_doc: TantivyDocument = searcher
.doc(*doc_address)
.map_err(|e| DBError(format!("Failed to retrieve doc: {}", e)))?;
let mut doc_fields = HashMap::new();
// Extract all stored fields
for (field_name, (field, field_def)) in &self.index_schema.fields {
match field_def {
FieldDef::Text { stored, .. } | FieldDef::Tag { stored, .. } => {
if *stored {
if let Some(value) = retrieved_doc.get_first(*field) {
if let Some(text) = value.as_str() {
doc_fields.insert(field_name.clone(), text.to_string());
}
}
}
}
FieldDef::Numeric {
stored, precision, ..
} => {
if *stored {
let value_str = match precision {
NumericType::I64 => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_i64())
.map(|v| v.to_string()),
NumericType::U64 => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_u64())
.map(|v| v.to_string()),
NumericType::F64 => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_f64())
.map(|v| v.to_string()),
NumericType::Date => retrieved_doc
.get_first(*field)
.and_then(|v| v.as_datetime())
.map(|v| v.into_timestamp_millis().to_string()),
};
if let Some(v) = value_str {
doc_fields.insert(field_name.clone(), v);
}
}
}
FieldDef::Geo { stored } => {
if *stored {
let lat_field = self
.index_schema
.fields
.get(&format!("{}_lat", field_name))
.unwrap()
.0;
let lon_field = self
.index_schema
.fields
.get(&format!("{}_lon", field_name))
.unwrap()
.0;
let lat = retrieved_doc.get_first(lat_field).and_then(|v| v.as_f64());
let lon = retrieved_doc.get_first(lon_field).and_then(|v| v.as_f64());
if let (Some(lat), Some(lon)) = (lat, lon) {
doc_fields.insert(field_name.clone(), format!("{},{}", lat, lon));
}
}
}
}
}
documents.push(SearchDocument {
fields: doc_fields,
score: *score,
});
}
Ok(SearchResults {
total: total_hits,
documents,
})
}
pub fn get_info(&self) -> Result<IndexInfo, DBError> {
let searcher = self.reader.searcher();
let num_docs = searcher.num_docs();
let fields_info: Vec<FieldInfo> = self
.index_schema
.fields
.iter()
.map(|(name, (_, def))| FieldInfo {
name: name.clone(),
field_type: format!("{:?}", def),
})
.collect();
Ok(IndexInfo {
name: self.name.clone(),
num_docs,
fields: fields_info,
config: self.config.clone(),
})
}
}
#[derive(Debug, Clone)]
pub struct SearchOptions {
pub limit: usize,
pub offset: usize,
pub filters: Vec<Filter>,
pub sort_by: Option<String>,
pub return_fields: Option<Vec<String>>,
pub highlight: bool,
}
impl Default for SearchOptions {
fn default() -> Self {
SearchOptions {
limit: 10,
offset: 0,
filters: vec![],
sort_by: None,
return_fields: None,
highlight: false,
}
}
}
#[derive(Debug, Clone)]
pub struct Filter {
pub field: String,
pub filter_type: FilterType,
}
#[derive(Debug, Clone)]
pub enum FilterType {
Equals(String),
Range { min: String, max: String },
InSet(Vec<String>),
}
#[derive(Debug)]
pub struct SearchResults {
pub total: usize,
pub documents: Vec<SearchDocument>,
}
#[derive(Debug)]
pub struct SearchDocument {
pub fields: HashMap<String, String>,
pub score: f32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct IndexInfo {
pub name: String,
pub num_docs: u64,
pub fields: Vec<FieldInfo>,
pub config: IndexConfig,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FieldInfo {
pub name: String,
pub field_type: String,
}

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server}; use herodb::{server::Server, options::DBOption};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -7,7 +7,7 @@ use tokio::time::sleep;
// Helper function to send command and get response // Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String { async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -19,7 +19,7 @@ async fn debug_hset_simple() {
let test_dir = "/tmp/herodb_debug_hset"; let test_dir = "/tmp/herodb_debug_hset";
let _ = std::fs::remove_dir_all(test_dir); let _ = std::fs::remove_dir_all(test_dir);
std::fs::create_dir_all(test_dir).unwrap(); std::fs::create_dir_all(test_dir).unwrap();
let port = 16500; let port = 16500;
let option = DBOption { let option = DBOption {
dir: test_dir.to_string(), dir: test_dir.to_string(),
@@ -29,49 +29,35 @@ async fn debug_hset_simple() {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let mut server = Server::new(option).await; let mut server = Server::new(option).await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
.await
.unwrap();
// Test simple HSET // Test simple HSET
println!("Testing HSET..."); println!("Testing HSET...");
let response = send_command( let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
println!("HSET response: {}", response); println!("HSET response: {}", response);
assert!(response.contains("1"), "Expected '1' but got: {}", response); assert!(response.contains("1"), "Expected '1' but got: {}", response);
// Test HGET // Test HGET
println!("Testing HGET..."); println!("Testing HGET...");
let response = send_command( let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HGET response: {}", response); println!("HGET response: {}", response);
assert!( assert!(response.contains("value1"), "Expected 'value1' but got: {}", response);
response.contains("value1"), }
"Expected 'value1' but got: {}",
response
);
}

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server}; use herodb::{server::Server, options::DBOption};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -7,11 +7,11 @@ use tokio::time::sleep;
#[tokio::test] #[tokio::test]
async fn debug_hset_return_value() { async fn debug_hset_return_value() {
let test_dir = "/tmp/herodb_debug_hset_return"; let test_dir = "/tmp/herodb_debug_hset_return";
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir.to_string(), dir: test_dir.to_string(),
port: 16390, port: 16390,
@@ -20,42 +20,38 @@ async fn debug_hset_return_value() {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let mut server = Server::new(option).await; let mut server = Server::new(option).await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind("127.0.0.1:16390") let listener = tokio::net::TcpListener::bind("127.0.0.1:16390")
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
// Connect and test HSET // Connect and test HSET
let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap(); let mut stream = TcpStream::connect("127.0.0.1:16390").await.unwrap();
// Send HSET command // Send HSET command
let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n"; let cmd = "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n";
stream.write_all(cmd.as_bytes()).await.unwrap(); stream.write_all(cmd.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
println!("HSET response: {}", response); println!("HSET response: {}", response);
println!("Response bytes: {:?}", &buffer[..n]); println!("Response bytes: {:?}", &buffer[..n]);
// Check if response contains "1" // Check if response contains "1"
assert!( assert!(response.contains("1"), "Expected response to contain '1', got: {}", response);
response.contains("1"), }
"Expected response to contain '1', got: {}",
response
);
}

View File

@@ -1,15 +1,12 @@
use herodb::cmd::Cmd;
use herodb::protocol::Protocol; use herodb::protocol::Protocol;
use herodb::cmd::Cmd;
#[test] #[test]
fn test_protocol_parsing() { fn test_protocol_parsing() {
// Test TYPE command parsing // Test TYPE command parsing
let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n"; let type_cmd = "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n";
println!( println!("Parsing TYPE command: {}", type_cmd.replace("\r\n", "\\r\\n"));
"Parsing TYPE command: {}",
type_cmd.replace("\r\n", "\\r\\n")
);
match Protocol::from(type_cmd) { match Protocol::from(type_cmd) {
Ok((protocol, _)) => { Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol); println!("Protocol parsed successfully: {:?}", protocol);
@@ -20,14 +17,11 @@ fn test_protocol_parsing() {
} }
Err(e) => println!("Protocol parsing failed: {:?}", e), Err(e) => println!("Protocol parsing failed: {:?}", e),
} }
// Test HEXISTS command parsing // Test HEXISTS command parsing
let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n"; let hexists_cmd = "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n";
println!( println!("\nParsing HEXISTS command: {}", hexists_cmd.replace("\r\n", "\\r\\n"));
"\nParsing HEXISTS command: {}",
hexists_cmd.replace("\r\n", "\\r\\n")
);
match Protocol::from(hexists_cmd) { match Protocol::from(hexists_cmd) {
Ok((protocol, _)) => { Ok((protocol, _)) => {
println!("Protocol parsed successfully: {:?}", protocol); println!("Protocol parsed successfully: {:?}", protocol);
@@ -38,4 +32,4 @@ fn test_protocol_parsing() {
} }
Err(e) => println!("Protocol parsing failed: {:?}", e), Err(e) => println!("Protocol parsing failed: {:?}", e),
} }
} }

View File

@@ -81,13 +81,13 @@ fn setup_server() -> (ServerProcessGuard, u16) {
]) ])
.spawn() .spawn()
.expect("Failed to start server process"); .expect("Failed to start server process");
// Create a new guard that also owns the test directory path // Create a new guard that also owns the test directory path
let guard = ServerProcessGuard { let guard = ServerProcessGuard {
process: child, process: child,
test_dir, test_dir,
}; };
// Give the server time to build and start (cargo run may compile first) // Give the server time to build and start (cargo run may compile first)
std::thread::sleep(Duration::from_millis(2500)); std::thread::sleep(Duration::from_millis(2500));
@@ -206,9 +206,7 @@ async fn test_expiration(conn: &mut Connection) {
async fn test_scan_operations(conn: &mut Connection) { async fn test_scan_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
for i in 0..5 { for i in 0..5 {
let _: () = conn let _: () = conn.set(format!("key{}", i), format!("value{}", i)).unwrap();
.set(format!("key{}", i), format!("value{}", i))
.unwrap();
} }
let result: (u64, Vec<String>) = redis::cmd("SCAN") let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(0) .arg(0)
@@ -255,9 +253,7 @@ async fn test_scan_with_count(conn: &mut Connection) {
async fn test_hscan_operations(conn: &mut Connection) { async fn test_hscan_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
for i in 0..3 { for i in 0..3 {
let _: () = conn let _: () = conn.hset("testhash", format!("field{}", i), format!("value{}", i)).unwrap();
.hset("testhash", format!("field{}", i), format!("value{}", i))
.unwrap();
} }
let result: (u64, Vec<String>) = redis::cmd("HSCAN") let result: (u64, Vec<String>) = redis::cmd("HSCAN")
.arg("testhash") .arg("testhash")
@@ -277,16 +273,8 @@ async fn test_hscan_operations(conn: &mut Connection) {
async fn test_transaction_operations(conn: &mut Connection) { async fn test_transaction_operations(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap(); let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET") let _: () = redis::cmd("SET").arg("key1").arg("value1").query(conn).unwrap();
.arg("key1") let _: () = redis::cmd("SET").arg("key2").arg("value2").query(conn).unwrap();
.arg("value1")
.query(conn)
.unwrap();
let _: () = redis::cmd("SET")
.arg("key2")
.arg("value2")
.query(conn)
.unwrap();
let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap(); let _: Vec<String> = redis::cmd("EXEC").query(conn).unwrap();
let result: String = conn.get("key1").unwrap(); let result: String = conn.get("key1").unwrap();
assert_eq!(result, "value1"); assert_eq!(result, "value1");
@@ -298,11 +286,7 @@ async fn test_transaction_operations(conn: &mut Connection) {
async fn test_discard_transaction(conn: &mut Connection) { async fn test_discard_transaction(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let _: () = redis::cmd("MULTI").query(conn).unwrap(); let _: () = redis::cmd("MULTI").query(conn).unwrap();
let _: () = redis::cmd("SET") let _: () = redis::cmd("SET").arg("discard").arg("value").query(conn).unwrap();
.arg("discard")
.arg("value")
.query(conn)
.unwrap();
let _: () = redis::cmd("DISCARD").query(conn).unwrap(); let _: () = redis::cmd("DISCARD").query(conn).unwrap();
let result: Option<String> = conn.get("discard").unwrap(); let result: Option<String> = conn.get("discard").unwrap();
assert_eq!(result, None); assert_eq!(result, None);
@@ -322,6 +306,7 @@ async fn test_type_command(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
} }
async fn test_info_command(conn: &mut Connection) { async fn test_info_command(conn: &mut Connection) {
cleanup_keys(conn).await; cleanup_keys(conn).await;
let result: String = redis::cmd("INFO").query(conn).unwrap(); let result: String = redis::cmd("INFO").query(conn).unwrap();

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server}; use herodb::{server::Server, options::DBOption};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -8,14 +8,14 @@ use tokio::time::sleep;
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379); static PORT_COUNTER: AtomicU16 = AtomicU16::new(16379);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name); let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up and create test directory // Clean up and create test directory
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -24,7 +24,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let server = Server::new(option).await; let server = Server::new(option).await;
(server, port) (server, port)
} }
@@ -47,7 +47,7 @@ async fn connect_to_server(port: u16) -> TcpStream {
// Helper function to send command and get response // Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String { async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -56,22 +56,22 @@ async fn send_command(stream: &mut TcpStream, command: &str) -> String {
#[tokio::test] #[tokio::test]
async fn test_basic_ping() { async fn test_basic_ping() {
let (mut server, port) = start_test_server("ping").await; let (mut server, port) = start_test_server("ping").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG")); assert!(response.contains("PONG"));
@@ -80,44 +80,40 @@ async fn test_basic_ping() {
#[tokio::test] #[tokio::test]
async fn test_string_operations() { async fn test_string_operations() {
let (mut server, port) = start_test_server("string").await; let (mut server, port) = start_test_server("string").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test SET // Test SET
let response = send_command( let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
&mut stream,
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test GET // Test GET
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value")); assert!(response.contains("value"));
// Test GET non-existent key // Test GET non-existent key
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("$-1")); // NULL response assert!(response.contains("$-1")); // NULL response
// Test DEL // Test DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nDEL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test GET after DEL // Test GET after DEL
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("$-1")); // NULL response assert!(response.contains("$-1")); // NULL response
@@ -126,37 +122,33 @@ async fn test_string_operations() {
#[tokio::test] #[tokio::test]
async fn test_incr_operations() { async fn test_incr_operations() {
let (mut server, port) = start_test_server("incr").await; let (mut server, port) = start_test_server("incr").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test INCR on non-existent key // Test INCR on non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test INCR on existing key // Test INCR on existing key
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$7\r\ncounter\r\n").await;
assert!(response.contains("2")); assert!(response.contains("2"));
// Test INCR on string value (should fail) // Test INCR on string value (should fail)
send_command( send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n").await;
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nhello\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINCR\r\n$6\r\nstring\r\n").await;
assert!(response.contains("ERR")); assert!(response.contains("ERR"));
} }
@@ -164,83 +156,63 @@ async fn test_incr_operations() {
#[tokio::test] #[tokio::test]
async fn test_hash_operations() { async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash").await; let (mut server, port) = start_test_server("hash").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test HSET // Test HSET
let response = send_command( let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
assert!(response.contains("1")); // 1 new field assert!(response.contains("1")); // 1 new field
// Test HGET // Test HGET
let response = send_command( let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("value1")); assert!(response.contains("value1"));
// Test HSET multiple fields // Test HSET multiple fields
let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await; let response = send_command(&mut stream, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n$6\r\nfield3\r\n$6\r\nvalue3\r\n").await;
assert!(response.contains("2")); // 2 new fields assert!(response.contains("2")); // 2 new fields
// Test HGETALL // Test HGETALL
let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1")); assert!(response.contains("field1"));
assert!(response.contains("value1")); assert!(response.contains("value1"));
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Test HLEN // Test HLEN
let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("3")); assert!(response.contains("3"));
// Test HEXISTS // Test HEXISTS
let response = send_command( let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
let response = send_command( let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
assert!(response.contains("0")); assert!(response.contains("0"));
// Test HDEL // Test HDEL
let response = send_command( let response = send_command(&mut stream, "*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
&mut stream,
"*3\r\n$4\r\nHDEL\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HKEYS // Test HKEYS
let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$5\r\nHKEYS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("field3")); assert!(response.contains("field3"));
assert!(!response.contains("field1")); // Should be deleted assert!(!response.contains("field1")); // Should be deleted
// Test HVALS // Test HVALS
let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$5\r\nHVALS\r\n$4\r\nhash\r\n").await;
assert!(response.contains("value2")); assert!(response.contains("value2"));
@@ -250,50 +222,46 @@ async fn test_hash_operations() {
#[tokio::test] #[tokio::test]
async fn test_expiration() { async fn test_expiration() {
let (mut server, port) = start_test_server("expiration").await; let (mut server, port) = start_test_server("expiration").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test SETEX (expire in 1 second) // Test SETEX (expire in 1 second)
let response = send_command( let response = send_command(&mut stream, "*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n").await;
&mut stream,
"*5\r\n$3\r\nSET\r\n$6\r\nexpkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$1\r\n1\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test TTL // Test TTL
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds assert!(response.contains("1") || response.contains("0")); // Should be 1 or 0 seconds
// Test EXISTS // Test EXISTS
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Wait for expiration // Wait for expiration
sleep(Duration::from_millis(1100)).await; sleep(Duration::from_millis(1100)).await;
// Test GET after expiration // Test GET after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("$-1")); // Should be NULL assert!(response.contains("$-1")); // Should be NULL
// Test TTL after expiration // Test TTL after expiration
let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nTTL\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("-2")); // Key doesn't exist assert!(response.contains("-2")); // Key doesn't exist
// Test EXISTS after expiration // Test EXISTS after expiration
let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await; let response = send_command(&mut stream, "*2\r\n$6\r\nEXISTS\r\n$6\r\nexpkey\r\n").await;
assert!(response.contains("0")); assert!(response.contains("0"));
@@ -302,37 +270,33 @@ async fn test_expiration() {
#[tokio::test] #[tokio::test]
async fn test_scan_operations() { async fn test_scan_operations() {
let (mut server, port) = start_test_server("scan").await; let (mut server, port) = start_test_server("scan").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Set up test data // Set up test data
for i in 0..5 { for i in 0..5 {
let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i); let cmd = format!("*3\r\n$3\r\nSET\r\n$4\r\nkey{}\r\n$6\r\nvalue{}\r\n", i, i);
send_command(&mut stream, &cmd).await; send_command(&mut stream, &cmd).await;
} }
// Test SCAN // Test SCAN
let response = send_command( let response = send_command(&mut stream, "*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
&mut stream,
"*6\r\n$4\r\nSCAN\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n",
)
.await;
assert!(response.contains("key")); assert!(response.contains("key"));
// Test KEYS // Test KEYS
let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nKEYS\r\n$1\r\n*\r\n").await;
assert!(response.contains("key0")); assert!(response.contains("key0"));
@@ -342,32 +306,29 @@ async fn test_scan_operations() {
#[tokio::test] #[tokio::test]
async fn test_hscan_operations() { async fn test_hscan_operations() {
let (mut server, port) = start_test_server("hscan").await; let (mut server, port) = start_test_server("hscan").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Set up hash data // Set up hash data
for i in 0..3 { for i in 0..3 {
let cmd = format!( let cmd = format!("*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n", i, i);
"*4\r\n$4\r\nHSET\r\n$8\r\ntesthash\r\n$6\r\nfield{}\r\n$6\r\nvalue{}\r\n",
i, i
);
send_command(&mut stream, &cmd).await; send_command(&mut stream, &cmd).await;
} }
// Test HSCAN // Test HSCAN
let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; let response = send_command(&mut stream, "*7\r\n$5\r\nHSCAN\r\n$8\r\ntesthash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field")); assert!(response.contains("field"));
@@ -377,50 +338,42 @@ async fn test_hscan_operations() {
#[tokio::test] #[tokio::test]
async fn test_transaction_operations() { async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transaction").await; let (mut server, port) = start_test_server("transaction").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test MULTI // Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await; let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued commands // Test queued commands
let response = send_command( let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n").await;
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
let response = send_command( let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n").await;
&mut stream,
"*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test EXEC // Test EXEC
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("OK")); // Should contain results of executed commands assert!(response.contains("OK")); // Should contain results of executed commands
// Verify commands were executed // Verify commands were executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n").await;
assert!(response.contains("value1")); assert!(response.contains("value1"));
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n").await;
assert!(response.contains("value2")); assert!(response.contains("value2"));
} }
@@ -428,39 +381,35 @@ async fn test_transaction_operations() {
#[tokio::test] #[tokio::test]
async fn test_discard_transaction() { async fn test_discard_transaction() {
let (mut server, port) = start_test_server("discard").await; let (mut server, port) = start_test_server("discard").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test MULTI // Test MULTI
let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await; let response = send_command(&mut stream, "*1\r\n$5\r\nMULTI\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued command // Test queued command
let response = send_command( let response = send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n").await;
&mut stream,
"*3\r\n$3\r\nSET\r\n$7\r\ndiscard\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test DISCARD // Test DISCARD
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await; let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Verify command was not executed // Verify command was not executed
let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await; let response = send_command(&mut stream, "*2\r\n$3\r\nGET\r\n$7\r\ndiscard\r\n").await;
assert!(response.contains("$-1")); // Should be NULL assert!(response.contains("$-1")); // Should be NULL
@@ -469,41 +418,33 @@ async fn test_discard_transaction() {
#[tokio::test] #[tokio::test]
async fn test_type_command() { async fn test_type_command() {
let (mut server, port) = start_test_server("type").await; let (mut server, port) = start_test_server("type").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test string type // Test string type
send_command( send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
assert!(response.contains("string")); assert!(response.contains("string"));
// Test hash type // Test hash type
send_command( send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
assert!(response.contains("hash")); assert!(response.contains("hash"));
// Test non-existent key // Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
assert!(response.contains("none")); assert!(response.contains("none"));
@@ -512,38 +453,30 @@ async fn test_type_command() {
#[tokio::test] #[tokio::test]
async fn test_config_commands() { async fn test_config_commands() {
let (mut server, port) = start_test_server("config").await; let (mut server, port) = start_test_server("config").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test CONFIG GET databases // Test CONFIG GET databases
let response = send_command( let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n").await;
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$9\r\ndatabases\r\n",
)
.await;
assert!(response.contains("databases")); assert!(response.contains("databases"));
assert!(response.contains("16")); assert!(response.contains("16"));
// Test CONFIG GET dir // Test CONFIG GET dir
let response = send_command( let response = send_command(&mut stream, "*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n").await;
&mut stream,
"*3\r\n$6\r\nCONFIG\r\n$3\r\nGET\r\n$3\r\ndir\r\n",
)
.await;
assert!(response.contains("dir")); assert!(response.contains("dir"));
assert!(response.contains("/tmp/herodb_test_config")); assert!(response.contains("/tmp/herodb_test_config"));
} }
@@ -551,27 +484,27 @@ async fn test_config_commands() {
#[tokio::test] #[tokio::test]
async fn test_info_command() { async fn test_info_command() {
let (mut server, port) = start_test_server("info").await; let (mut server, port) = start_test_server("info").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test INFO // Test INFO
let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nINFO\r\n").await;
assert!(response.contains("redis_version")); assert!(response.contains("redis_version"));
// Test INFO replication // Test INFO replication
let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nINFO\r\n$11\r\nreplication\r\n").await;
assert!(response.contains("role:master")); assert!(response.contains("role:master"));
@@ -580,44 +513,36 @@ async fn test_info_command() {
#[tokio::test] #[tokio::test]
async fn test_error_handling() { async fn test_error_handling() {
let (mut server, port) = start_test_server("error").await; let (mut server, port) = start_test_server("error").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test WRONGTYPE error - try to use hash command on string // Test WRONGTYPE error - try to use hash command on string
send_command( send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
&mut stream, let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n").await;
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(
&mut stream,
"*3\r\n$4\r\nHGET\r\n$6\r\nstring\r\n$5\r\nfield\r\n",
)
.await;
assert!(response.contains("WRONGTYPE")); assert!(response.contains("WRONGTYPE"));
// Test unknown command // Test unknown command
let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await; let response = send_command(&mut stream, "*1\r\n$7\r\nUNKNOWN\r\n").await;
assert!(response.contains("unknown cmd") || response.contains("ERR")); assert!(response.contains("unknown cmd") || response.contains("ERR"));
// Test EXEC without MULTI // Test EXEC without MULTI
let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nEXEC\r\n").await;
assert!(response.contains("ERR")); assert!(response.contains("ERR"));
// Test DISCARD without MULTI // Test DISCARD without MULTI
let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await; let response = send_command(&mut stream, "*1\r\n$7\r\nDISCARD\r\n").await;
assert!(response.contains("ERR")); assert!(response.contains("ERR"));
@@ -626,37 +551,29 @@ async fn test_error_handling() {
#[tokio::test] #[tokio::test]
async fn test_list_operations() { async fn test_list_operations() {
let (mut server, port) = start_test_server("list").await; let (mut server, port) = start_test_server("list").await;
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test LPUSH // Test LPUSH
let response = send_command( let response = send_command(&mut stream, "*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n").await;
&mut stream,
"*4\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n$1\r\nb\r\n",
)
.await;
assert!(response.contains("2")); // 2 elements assert!(response.contains("2")); // 2 elements
// Test RPUSH // Test RPUSH
let response = send_command( let response = send_command(&mut stream, "*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n").await;
&mut stream,
"*4\r\n$5\r\nRPUSH\r\n$4\r\nlist\r\n$1\r\nc\r\n$1\r\nd\r\n",
)
.await;
assert!(response.contains("4")); // 4 elements assert!(response.contains("4")); // 4 elements
// Test LLEN // Test LLEN
@@ -664,52 +581,29 @@ async fn test_list_operations() {
assert!(response.contains("4")); assert!(response.contains("4"));
// Test LRANGE // Test LRANGE
let response = send_command( let response = send_command(&mut stream, "*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n").await;
&mut stream, assert_eq!(response, "*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n");
"*4\r\n$6\r\nLRANGE\r\n$4\r\nlist\r\n$1\r\n0\r\n$2\r\n-1\r\n",
)
.await;
assert_eq!(
response,
"*4\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n$1\r\nd\r\n"
);
// Test LINDEX // Test LINDEX
let response = send_command( let response = send_command(&mut stream, "*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n").await;
&mut stream,
"*3\r\n$6\r\nLINDEX\r\n$4\r\nlist\r\n$1\r\n0\r\n",
)
.await;
assert_eq!(response, "$1\r\nb\r\n"); assert_eq!(response, "$1\r\nb\r\n");
// Test LPOP // Test LPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nLPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nb\r\n"); assert_eq!(response, "$1\r\nb\r\n");
// Test RPOP // Test RPOP
let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nRPOP\r\n$4\r\nlist\r\n").await;
assert_eq!(response, "$1\r\nd\r\n"); assert_eq!(response, "$1\r\nd\r\n");
// Test LREM // Test LREM
send_command( send_command(&mut stream, "*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n").await; // list is now a, c, a
&mut stream, let response = send_command(&mut stream, "*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n").await;
"*3\r\n$5\r\nLPUSH\r\n$4\r\nlist\r\n$1\r\na\r\n",
)
.await; // list is now a, c, a
let response = send_command(
&mut stream,
"*4\r\n$4\r\nLREM\r\n$4\r\nlist\r\n$1\r\n1\r\n$1\r\na\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test LTRIM // Test LTRIM
let response = send_command( let response = send_command(&mut stream, "*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n").await;
&mut stream,
"*4\r\n$5\r\nLTRIM\r\n$4\r\nlist\r\n$1\r\n0\r\n$1\r\n0\r\n",
)
.await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nLLEN\r\n$4\r\nlist\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
} }

View File

@@ -1,23 +1,23 @@
use herodb::{options::DBOption, server::Server}; use herodb::{server::Server, options::DBOption};
use std::time::Duration; use std::time::Duration;
use tokio::time::sleep;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep;
// Helper function to start a test server with clean data directory // Helper function to start a test server with clean data directory
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000); static PORT_COUNTER: AtomicU16 = AtomicU16::new(17000);
// Get a unique port for this test // Get a unique port for this test
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_test_{}", test_name); let test_dir = format!("/tmp/herodb_test_{}", test_name);
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -26,18 +26,16 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let server = Server::new(option).await; let server = Server::new(option).await;
(server, port) (server, port)
} }
// Helper function to send Redis command and get response // Helper function to send Redis command and get response
async fn send_redis_command(port: u16, command: &str) -> String { async fn send_redis_command(port: u16, command: &str) -> String {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
.await
.unwrap();
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -46,13 +44,13 @@ async fn send_redis_command(port: u16, command: &str) -> String {
#[tokio::test] #[tokio::test]
async fn test_basic_redis_functionality() { async fn test_basic_redis_functionality() {
let (mut server, port) = start_test_server("basic").await; let (mut server, port) = start_test_server("basic").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..10 { for _ in 0..10 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
@@ -60,79 +58,68 @@ async fn test_basic_redis_functionality() {
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Test PING // Test PING
let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await; let response = send_redis_command(port, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG")); assert!(response.contains("PONG"));
// Test SET // Test SET
let response = let response = send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
send_redis_command(port, "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await;
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test GET // Test GET
let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$3\r\nGET\r\n$3\r\nkey\r\n").await;
assert!(response.contains("value")); assert!(response.contains("value"));
// Test HSET // Test HSET
let response = send_redis_command( let response = send_redis_command(port, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
port,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HGET // Test HGET
let response = let response = send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
send_redis_command(port, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$5\r\nfield\r\n").await;
assert!(response.contains("value")); assert!(response.contains("value"));
// Test EXISTS // Test EXISTS
let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$6\r\nEXISTS\r\n$3\r\nkey\r\n").await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test TTL // Test TTL
let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$3\r\nTTL\r\n$3\r\nkey\r\n").await;
assert!(response.contains("-1")); // No expiration assert!(response.contains("-1")); // No expiration
// Test TYPE // Test TYPE
let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await; let response = send_redis_command(port, "*2\r\n$4\r\nTYPE\r\n$3\r\nkey\r\n").await;
assert!(response.contains("string")); assert!(response.contains("string"));
// Test QUIT to close connection gracefully // Test QUIT to close connection gracefully
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
.await stream.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes()).await.unwrap();
.unwrap();
stream
.write_all("*1\r\n$4\r\nQUIT\r\n".as_bytes())
.await
.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Ensure the stream is closed // Ensure the stream is closed
stream.shutdown().await.unwrap(); stream.shutdown().await.unwrap();
// Stop the server // Stop the server
server_handle.abort(); server_handle.abort();
println!("✅ All basic Redis functionality tests passed!"); println!("✅ All basic Redis functionality tests passed!");
} }
#[tokio::test] #[tokio::test]
async fn test_hash_operations() { async fn test_hash_operations() {
let (mut server, port) = start_test_server("hash_ops").await; let (mut server, port) = start_test_server("hash_ops").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..5 { for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
@@ -140,57 +127,53 @@ async fn test_hash_operations() {
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Test HSET multiple fields // Test HSET multiple fields
let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await; let response = send_redis_command(port, "*6\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n$6\r\nfield2\r\n$6\r\nvalue2\r\n").await;
assert!(response.contains("2")); // 2 new fields assert!(response.contains("2")); // 2 new fields
// Test HGETALL // Test HGETALL
let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await; let response = send_redis_command(port, "*2\r\n$7\r\nHGETALL\r\n$4\r\nhash\r\n").await;
assert!(response.contains("field1")); assert!(response.contains("field1"));
assert!(response.contains("value1")); assert!(response.contains("value1"));
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Test HEXISTS // Test HEXISTS
let response = send_redis_command( let response = send_redis_command(port, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
port,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HLEN // Test HLEN
let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await; let response = send_redis_command(port, "*2\r\n$4\r\nHLEN\r\n$4\r\nhash\r\n").await;
assert!(response.contains("2")); assert!(response.contains("2"));
// Test HSCAN // Test HSCAN
let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await; let response = send_redis_command(port, "*7\r\n$5\r\nHSCAN\r\n$4\r\nhash\r\n$1\r\n0\r\n$5\r\nMATCH\r\n$1\r\n*\r\n$5\r\nCOUNT\r\n$2\r\n10\r\n").await;
assert!(response.contains("field1")); assert!(response.contains("field1"));
assert!(response.contains("value1")); assert!(response.contains("value1"));
assert!(response.contains("field2")); assert!(response.contains("field2"));
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Stop the server // Stop the server
// For hash operations, we don't have a persistent stream, so we'll just abort the server. // For hash operations, we don't have a persistent stream, so we'll just abort the server.
// The server should handle closing its connections. // The server should handle closing its connections.
server_handle.abort(); server_handle.abort();
println!("✅ All hash operations tests passed!"); println!("✅ All hash operations tests passed!");
} }
#[tokio::test] #[tokio::test]
async fn test_transaction_operations() { async fn test_transaction_operations() {
let (mut server, port) = start_test_server("transactions").await; let (mut server, port) = start_test_server("transactions").await;
// Start server in background with timeout // Start server in background with timeout
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
// Accept only a few connections for testing // Accept only a few connections for testing
for _ in 0..5 { for _ in 0..5 {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
@@ -198,69 +181,49 @@ async fn test_transaction_operations() {
} }
} }
}); });
sleep(Duration::from_millis(100)).await; sleep(Duration::from_millis(100)).await;
// Use a single connection for the transaction // Use a single connection for the transaction
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).await.unwrap();
.await
.unwrap();
// Test MULTI // Test MULTI
stream stream.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes()).await.unwrap();
.write_all("*1\r\n$5\r\nMULTI\r\n".as_bytes())
.await
.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); assert!(response.contains("OK"));
// Test queued commands // Test queued commands
stream stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes()).await.unwrap();
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey1\r\n$6\r\nvalue1\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
stream stream.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes()).await.unwrap();
.write_all("*3\r\n$3\r\nSET\r\n$4\r\nkey2\r\n$6\r\nvalue2\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("QUEUED")); assert!(response.contains("QUEUED"));
// Test EXEC // Test EXEC
stream stream.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes()).await.unwrap();
.write_all("*1\r\n$4\r\nEXEC\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("OK")); // Should contain array of OK responses assert!(response.contains("OK")); // Should contain array of OK responses
// Verify commands were executed // Verify commands were executed
stream stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes()).await.unwrap();
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey1\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value1")); assert!(response.contains("value1"));
stream stream.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes()).await.unwrap();
.write_all("*2\r\n$3\r\nGET\r\n$4\r\nkey2\r\n".as_bytes())
.await
.unwrap();
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
let response = String::from_utf8_lossy(&buffer[..n]); let response = String::from_utf8_lossy(&buffer[..n]);
assert!(response.contains("value2")); assert!(response.contains("value2"));
// Stop the server // Stop the server
server_handle.abort(); server_handle.abort();
println!("✅ All transaction operations tests passed!"); println!("✅ All transaction operations tests passed!");
} }

View File

@@ -1,4 +1,4 @@
use herodb::{options::DBOption, server::Server}; use herodb::{server::Server, options::DBOption};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -8,14 +8,14 @@ use tokio::time::sleep;
async fn start_test_server(test_name: &str) -> (Server, u16) { async fn start_test_server(test_name: &str) -> (Server, u16) {
use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::atomic::{AtomicU16, Ordering};
static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500); static PORT_COUNTER: AtomicU16 = AtomicU16::new(16500);
let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst); let port = PORT_COUNTER.fetch_add(1, Ordering::SeqCst);
let test_dir = format!("/tmp/herodb_simple_test_{}", test_name); let test_dir = format!("/tmp/herodb_simple_test_{}", test_name);
// Clean up any existing test data // Clean up any existing test data
let _ = std::fs::remove_dir_all(&test_dir); let _ = std::fs::remove_dir_all(&test_dir);
std::fs::create_dir_all(&test_dir).unwrap(); std::fs::create_dir_all(&test_dir).unwrap();
let option = DBOption { let option = DBOption {
dir: test_dir, dir: test_dir,
port, port,
@@ -24,7 +24,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
encryption_key: None, encryption_key: None,
backend: herodb::options::BackendType::Redb, backend: herodb::options::BackendType::Redb,
}; };
let server = Server::new(option).await; let server = Server::new(option).await;
(server, port) (server, port)
} }
@@ -32,7 +32,7 @@ async fn start_test_server(test_name: &str) -> (Server, u16) {
// Helper function to send command and get response // Helper function to send command and get response
async fn send_command(stream: &mut TcpStream, command: &str) -> String { async fn send_command(stream: &mut TcpStream, command: &str) -> String {
stream.write_all(command.as_bytes()).await.unwrap(); stream.write_all(command.as_bytes()).await.unwrap();
let mut buffer = [0; 1024]; let mut buffer = [0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); let n = stream.read(&mut buffer).await.unwrap();
String::from_utf8_lossy(&buffer[..n]).to_string() String::from_utf8_lossy(&buffer[..n]).to_string()
@@ -56,22 +56,22 @@ async fn connect_to_server(port: u16) -> TcpStream {
#[tokio::test] #[tokio::test]
async fn test_basic_ping_simple() { async fn test_basic_ping_simple() {
let (mut server, port) = start_test_server("ping").await; let (mut server, port) = start_test_server("ping").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await; let response = send_command(&mut stream, "*1\r\n$4\r\nPING\r\n").await;
assert!(response.contains("PONG")); assert!(response.contains("PONG"));
@@ -80,43 +80,31 @@ async fn test_basic_ping_simple() {
#[tokio::test] #[tokio::test]
async fn test_hset_clean_db() { async fn test_hset_clean_db() {
let (mut server, port) = start_test_server("hset_clean").await; let (mut server, port) = start_test_server("hset_clean").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test HSET - should return 1 for new field // Test HSET - should return 1 for new field
let response = send_command( let response = send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
println!("HSET response: {}", response); println!("HSET response: {}", response);
assert!( assert!(response.contains("1"), "Expected HSET to return 1, got: {}", response);
response.contains("1"),
"Expected HSET to return 1, got: {}",
response
);
// Test HGET // Test HGET
let response = send_command( let response = send_command(&mut stream, "*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
&mut stream,
"*3\r\n$4\r\nHGET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HGET response: {}", response); println!("HGET response: {}", response);
assert!(response.contains("value1")); assert!(response.contains("value1"));
} }
@@ -124,101 +112,73 @@ async fn test_hset_clean_db() {
#[tokio::test] #[tokio::test]
async fn test_type_command_simple() { async fn test_type_command_simple() {
let (mut server, port) = start_test_server("type").await; let (mut server, port) = start_test_server("type").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Test string type // Test string type
send_command( send_command(&mut stream, "*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n").await;
&mut stream,
"*3\r\n$3\r\nSET\r\n$6\r\nstring\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$6\r\nstring\r\n").await;
println!("TYPE string response: {}", response); println!("TYPE string response: {}", response);
assert!(response.contains("string")); assert!(response.contains("string"));
// Test hash type // Test hash type
send_command( send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n").await;
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$5\r\nfield\r\n$5\r\nvalue\r\n",
)
.await;
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$4\r\nhash\r\n").await;
println!("TYPE hash response: {}", response); println!("TYPE hash response: {}", response);
assert!(response.contains("hash")); assert!(response.contains("hash"));
// Test non-existent key // Test non-existent key
let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await; let response = send_command(&mut stream, "*2\r\n$4\r\nTYPE\r\n$7\r\nnoexist\r\n").await;
println!("TYPE noexist response: {}", response); println!("TYPE noexist response: {}", response);
assert!( assert!(response.contains("none"), "Expected 'none' for non-existent key, got: {}", response);
response.contains("none"),
"Expected 'none' for non-existent key, got: {}",
response
);
} }
#[tokio::test] #[tokio::test]
async fn test_hexists_simple() { async fn test_hexists_simple() {
let (mut server, port) = start_test_server("hexists").await; let (mut server, port) = start_test_server("hexists").await;
// Start server in background // Start server in background
tokio::spawn(async move { tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
loop { loop {
if let Ok((stream, _)) = listener.accept().await { if let Ok((stream, _)) = listener.accept().await {
let _ = server.handle(stream).await; let _ = server.handle(stream).await;
} }
} }
}); });
sleep(Duration::from_millis(200)).await; sleep(Duration::from_millis(200)).await;
let mut stream = connect_to_server(port).await; let mut stream = connect_to_server(port).await;
// Set up hash // Set up hash
send_command( send_command(&mut stream, "*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n").await;
&mut stream,
"*4\r\n$4\r\nHSET\r\n$4\r\nhash\r\n$6\r\nfield1\r\n$6\r\nvalue1\r\n",
)
.await;
// Test HEXISTS for existing field // Test HEXISTS for existing field
let response = send_command( let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n").await;
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$6\r\nfield1\r\n",
)
.await;
println!("HEXISTS existing field response: {}", response); println!("HEXISTS existing field response: {}", response);
assert!(response.contains("1")); assert!(response.contains("1"));
// Test HEXISTS for non-existent field // Test HEXISTS for non-existent field
let response = send_command( let response = send_command(&mut stream, "*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n").await;
&mut stream,
"*3\r\n$7\r\nHEXISTS\r\n$4\r\nhash\r\n$7\r\nnoexist\r\n",
)
.await;
println!("HEXISTS non-existent field response: {}", response); println!("HEXISTS non-existent field response: {}", response);
assert!( assert!(response.contains("0"), "Expected HEXISTS to return 0 for non-existent field, got: {}", response);
response.contains("0"), }
"Expected HEXISTS to return 0 for non-existent field, got: {}",
response
);
}

View File

@@ -325,11 +325,7 @@ async fn test_03_scan_and_keys() {
let mut s = connect(port).await; let mut s = connect(port).await;
for i in 0..5 { for i in 0..5 {
let _ = send_cmd( let _ = send_cmd(&mut s, &["SET", &format!("key{}", i), &format!("value{}", i)]).await;
&mut s,
&["SET", &format!("key{}", i), &format!("value{}", i)],
)
.await;
} }
let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await; let scan = send_cmd(&mut s, &["SCAN", "0", "MATCH", "key*", "COUNT", "10"]).await;
@@ -362,11 +358,7 @@ async fn test_04_hashes_suite() {
assert_contains(&h2, "2", "HSET added 2 new fields"); assert_contains(&h2, "2", "HSET added 2 new fields");
// HMGET // HMGET
let hmg = send_cmd( let hmg = send_cmd(&mut s, &["HMGET", "profile:1", "name", "age", "city", "nope"]).await;
&mut s,
&["HMGET", "profile:1", "name", "age", "city", "nope"],
)
.await;
assert_contains(&hmg, "alice", "HMGET name"); assert_contains(&hmg, "alice", "HMGET name");
assert_contains(&hmg, "30", "HMGET age"); assert_contains(&hmg, "30", "HMGET age");
assert_contains(&hmg, "paris", "HMGET city"); assert_contains(&hmg, "paris", "HMGET city");
@@ -400,11 +392,7 @@ async fn test_04_hashes_suite() {
assert_contains(&hnx1, "1", "HSETNX new field -> 1"); assert_contains(&hnx1, "1", "HSETNX new field -> 1");
// HSCAN // HSCAN
let hscan = send_cmd( let hscan = send_cmd(&mut s, &["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"]).await;
&mut s,
&["HSCAN", "profile:1", "0", "MATCH", "n*", "COUNT", "10"],
)
.await;
assert_contains(&hscan, "name", "HSCAN matches fields starting with n"); assert_contains(&hscan, "name", "HSCAN matches fields starting with n");
assert_contains(&hscan, "nickname", "HSCAN nickname present"); assert_contains(&hscan, "nickname", "HSCAN nickname present");
@@ -436,21 +424,13 @@ async fn test_05_lists_suite_including_blpop() {
assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b"); assert_eq_resp(&lidx, "$1\r\nb\r\n", "LINDEX q:jobs 0 should be b");
let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; let lr = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp( assert_eq_resp(&lr, "*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n", "LRANGE q:jobs 0 -1 should be [b,a,c]");
&lr,
"*3\r\n$1\r\nb\r\n$1\r\na\r\n$1\r\nc\r\n",
"LRANGE q:jobs 0 -1 should be [b,a,c]",
);
// LTRIM // LTRIM
let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await; let ltrim = send_cmd(&mut a, &["LTRIM", "q:jobs", "0", "1"]).await;
assert_contains(&ltrim, "OK", "LTRIM OK"); assert_contains(&ltrim, "OK", "LTRIM OK");
let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await; let lr_post = send_cmd(&mut a, &["LRANGE", "q:jobs", "0", "-1"]).await;
assert_eq_resp( assert_eq_resp(&lr_post, "*2\r\n$1\r\nb\r\n$1\r\na\r\n", "After LTRIM, list [b,a]");
&lr_post,
"*2\r\n$1\r\nb\r\n$1\r\na\r\n",
"After LTRIM, list [b,a]",
);
// LREM remove first occurrence of b // LREM remove first occurrence of b
let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await; let lrem = send_cmd(&mut a, &["LREM", "q:jobs", "1", "b"]).await;
@@ -464,11 +444,7 @@ async fn test_05_lists_suite_including_blpop() {
// LPOP with count on empty -> [] // LPOP with count on empty -> []
let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await; let lpop0 = send_cmd(&mut a, &["LPOP", "q:jobs", "2"]).await;
assert_eq_resp( assert_eq_resp(&lpop0, "*0\r\n", "LPOP with count on empty returns empty array");
&lpop0,
"*0\r\n",
"LPOP with count on empty returns empty array",
);
// BLPOP: block on one client, push from another // BLPOP: block on one client, push from another
let c1 = connect(port).await; let c1 = connect(port).await;
@@ -537,7 +513,7 @@ async fn test_07_age_stateless_suite() {
// naive parse for tests // naive parse for tests
let mut lines = resp.lines(); let mut lines = resp.lines();
let _ = lines.next(); // *2 let _ = lines.next(); // *2
// $len // $len
let _ = lines.next(); let _ = lines.next();
let recip = lines.next().unwrap_or("").to_string(); let recip = lines.next().unwrap_or("").to_string();
let _ = lines.next(); let _ = lines.next();
@@ -572,16 +548,8 @@ async fn test_07_age_stateless_suite() {
let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await; let v_ok = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "msg", &sig_b64]).await;
assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature"); assert_contains(&v_ok, "1", "VERIFY should be 1 for valid signature");
let v_bad = send_cmd( let v_bad = send_cmd(&mut s, &["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64]).await;
&mut s, assert_contains(&v_bad, "0", "VERIFY should be 0 for invalid message/signature");
&["AGE", "VERIFY", &verify_pub, "tampered", &sig_b64],
)
.await;
assert_contains(
&v_bad,
"0",
"VERIFY should be 0 for invalid message/signature",
);
} }
#[tokio::test] #[tokio::test]
@@ -613,7 +581,7 @@ async fn test_08_age_persistent_named_suite() {
skg skg
); );
let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"]).await; let sig = send_cmd(&mut s, &["AGE", "SIGNNAME", "app1", "m"] ).await;
let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME"); let sig_b64 = extract_bulk_payload(&sig).expect("Failed to parse bulk payload from SIGNNAME");
let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await; let v1 = send_cmd(&mut s, &["AGE", "VERIFYNAME", "app1", "m", &sig_b64]).await;
assert_contains(&v1, "1", "VERIFYNAME valid => 1"); assert_contains(&v1, "1", "VERIFYNAME valid => 1");
@@ -629,69 +597,60 @@ async fn test_08_age_persistent_named_suite() {
#[tokio::test] #[tokio::test]
async fn test_10_expire_pexpire_persist() { async fn test_10_expire_pexpire_persist() {
let (server, port) = start_test_server("expire_suite").await; let (server, port) = start_test_server("expire_suite").await;
spawn_listener(server, port).await; spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await; sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await; let mut s = connect(port).await;
// EXPIRE: seconds // EXPIRE: seconds
let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:s", "v"]).await;
let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await; let ex = send_cmd(&mut s, &["EXPIRE", "exp:s", "1"]).await;
assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)"); assert_contains(&ex, "1", "EXPIRE exp:s 1 -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await; let ttl1 = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert!( assert!(
ttl1.contains("1") || ttl1.contains("0"), ttl1.contains("1") || ttl1.contains("0"),
"TTL exp:s should be 1 or 0, got: {}", "TTL exp:s should be 1 or 0, got: {}",
ttl1 ttl1
); );
sleep(Duration::from_millis(1100)).await; sleep(Duration::from_millis(1100)).await;
let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await; let get_after = send_cmd(&mut s, &["GET", "exp:s"]).await;
assert_contains(&get_after, "$-1", "GET after expiry should be Null"); assert_contains(&get_after, "$-1", "GET after expiry should be Null");
let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await; let ttl_after = send_cmd(&mut s, &["TTL", "exp:s"]).await;
assert_contains(&ttl_after, "-2", "TTL after expiry -> -2"); assert_contains(&ttl_after, "-2", "TTL after expiry -> -2");
let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await; let exists_after = send_cmd(&mut s, &["EXISTS", "exp:s"]).await;
assert_contains(&exists_after, "0", "EXISTS after expiry -> 0"); assert_contains(&exists_after, "0", "EXISTS after expiry -> 0");
// PEXPIRE: milliseconds // PEXPIRE: milliseconds
let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:ms", "v"]).await;
let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await; let pex = send_cmd(&mut s, &["PEXPIRE", "exp:ms", "1500"]).await;
assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)"); assert_contains(&pex, "1", "PEXPIRE exp:ms 1500 -> 1 (applied)");
let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await; let ttl_ms1 = send_cmd(&mut s, &["TTL", "exp:ms"]).await;
assert!( assert!(
ttl_ms1.contains("1") || ttl_ms1.contains("0"), ttl_ms1.contains("1") || ttl_ms1.contains("0"),
"TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}", "TTL exp:ms should be 1 or 0 soon after PEXPIRE, got: {}",
ttl_ms1 ttl_ms1
); );
sleep(Duration::from_millis(1600)).await; sleep(Duration::from_millis(1600)).await;
let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await; let exists_ms_after = send_cmd(&mut s, &["EXISTS", "exp:ms"]).await;
assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0"); assert_contains(&exists_ms_after, "0", "EXISTS exp:ms after ms expiry -> 0");
// PERSIST: remove expiration // PERSIST: remove expiration
let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:persist", "v"]).await;
let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await; let _ = send_cmd(&mut s, &["EXPIRE", "exp:persist", "5"]).await;
let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await; let ttl_pre = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert!( assert!(
ttl_pre.contains("5") ttl_pre.contains("5") || ttl_pre.contains("4") || ttl_pre.contains("3") || ttl_pre.contains("2") || ttl_pre.contains("1") || ttl_pre.contains("0"),
|| ttl_pre.contains("4") "TTL exp:persist should be >=0 before persist, got: {}",
|| ttl_pre.contains("3") ttl_pre
|| ttl_pre.contains("2") );
|| ttl_pre.contains("1") let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
|| ttl_pre.contains("0"), assert_contains(&persist1, "1", "PERSIST should remove expiration");
"TTL exp:persist should be >=0 before persist, got: {}", let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
ttl_pre assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
); // Second persist should return 0 (nothing to remove)
let persist1 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await; let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(&persist1, "1", "PERSIST should remove expiration"); assert_contains(&persist2, "0", "PERSIST again -> 0 (no expiration to remove)");
let ttl_post = send_cmd(&mut s, &["TTL", "exp:persist"]).await;
assert_contains(&ttl_post, "-1", "TTL after PERSIST -> -1 (no expiration)");
// Second persist should return 0 (nothing to remove)
let persist2 = send_cmd(&mut s, &["PERSIST", "exp:persist"]).await;
assert_contains(
&persist2,
"0",
"PERSIST again -> 0 (no expiration to remove)",
);
} }
#[tokio::test] #[tokio::test]
@@ -704,11 +663,7 @@ async fn test_11_set_with_options() {
// SET with GET on non-existing key -> returns Null, sets value // SET with GET on non-existing key -> returns Null, sets value
let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await; let set_get1 = send_cmd(&mut s, &["SET", "s1", "v1", "GET"]).await;
assert_contains( assert_contains(&set_get1, "$-1", "SET s1 v1 GET returns Null when key didn't exist");
&set_get1,
"$-1",
"SET s1 v1 GET returns Null when key didn't exist",
);
let g1 = send_cmd(&mut s, &["GET", "s1"]).await; let g1 = send_cmd(&mut s, &["GET", "s1"]).await;
assert_contains(&g1, "v1", "GET s1 after first SET"); assert_contains(&g1, "v1", "GET s1 after first SET");
@@ -752,42 +707,42 @@ async fn test_11_set_with_options() {
#[tokio::test] #[tokio::test]
async fn test_09_mget_mset_and_variadic_exists_del() { async fn test_09_mget_mset_and_variadic_exists_del() {
let (server, port) = start_test_server("mget_mset_variadic").await; let (server, port) = start_test_server("mget_mset_variadic").await;
spawn_listener(server, port).await; spawn_listener(server, port).await;
sleep(Duration::from_millis(150)).await; sleep(Duration::from_millis(150)).await;
let mut s = connect(port).await; let mut s = connect(port).await;
// MSET multiple keys // MSET multiple keys
let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await; let mset = send_cmd(&mut s, &["MSET", "k1", "v1", "k2", "v2", "k3", "v3"]).await;
assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK"); assert_contains(&mset, "OK", "MSET k1 v1 k2 v2 k3 v3 -> OK");
// MGET should return values and Null for missing // MGET should return values and Null for missing
let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await; let mget = send_cmd(&mut s, &["MGET", "k1", "k2", "nope", "k3"]).await;
// Expect an array with 4 entries; verify payloads // Expect an array with 4 entries; verify payloads
assert_contains(&mget, "v1", "MGET k1"); assert_contains(&mget, "v1", "MGET k1");
assert_contains(&mget, "v2", "MGET k2"); assert_contains(&mget, "v2", "MGET k2");
assert_contains(&mget, "v3", "MGET k3"); assert_contains(&mget, "v3", "MGET k3");
assert_contains(&mget, "$-1", "MGET missing returns Null"); assert_contains(&mget, "$-1", "MGET missing returns Null");
// EXISTS variadic: count how many exist // EXISTS variadic: count how many exist
let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await; let exists_multi = send_cmd(&mut s, &["EXISTS", "k1", "nope", "k3"]).await;
// Server returns SimpleString numeric, e.g. +2 // Server returns SimpleString numeric, e.g. +2
assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2"); assert_contains(&exists_multi, "2", "EXISTS k1 nope k3 -> 2");
// DEL variadic: delete multiple keys, return count deleted // DEL variadic: delete multiple keys, return count deleted
let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await; let del_multi = send_cmd(&mut s, &["DEL", "k1", "k3", "nope"]).await;
assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2"); assert_contains(&del_multi, "2", "DEL k1 k3 nope -> 2");
// Verify deletion // Verify deletion
let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await; let exists_after = send_cmd(&mut s, &["EXISTS", "k1", "k3"]).await;
assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0"); assert_contains(&exists_after, "0", "EXISTS k1 k3 after DEL -> 0");
// MGET after deletion should include Nulls for deleted keys // MGET after deletion should include Nulls for deleted keys
let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await; let mget_after = send_cmd(&mut s, &["MGET", "k1", "k2", "k3"]).await;
assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null"); assert_contains(&mget_after, "$-1", "MGET k1 after DEL -> Null");
assert_contains(&mget_after, "v2", "MGET k2 remains"); assert_contains(&mget_after, "v2", "MGET k2 remains");
assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null"); assert_contains(&mget_after, "$-1", "MGET k3 after DEL -> Null");
} }
#[tokio::test] #[tokio::test]
async fn test_12_hash_incr() { async fn test_12_hash_incr() {
@@ -907,16 +862,9 @@ async fn test_14_expireat_pexpireat() {
let mut s = connect(port).await; let mut s = connect(port).await;
// EXPIREAT: seconds since epoch // EXPIREAT: seconds since epoch
let now_secs = SystemTime::now() let now_secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:at:s", "v"]).await;
let exat = send_cmd( let exat = send_cmd(&mut s, &["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)]).await;
&mut s,
&["EXPIREAT", "exp:at:s", &format!("{}", now_secs + 1)],
)
.await;
assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)"); assert_contains(&exat, "1", "EXPIREAT exp:at:s now+1s -> 1 (applied)");
let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await; let ttl1 = send_cmd(&mut s, &["TTL", "exp:at:s"]).await;
assert!( assert!(
@@ -926,23 +874,12 @@ async fn test_14_expireat_pexpireat() {
); );
sleep(Duration::from_millis(1200)).await; sleep(Duration::from_millis(1200)).await;
let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await; let exists_after_exat = send_cmd(&mut s, &["EXISTS", "exp:at:s"]).await;
assert_contains( assert_contains(&exists_after_exat, "0", "EXISTS exp:at:s after EXPIREAT expiry -> 0");
&exists_after_exat,
"0",
"EXISTS exp:at:s after EXPIREAT expiry -> 0",
);
// PEXPIREAT: milliseconds since epoch // PEXPIREAT: milliseconds since epoch
let now_ms = SystemTime::now() let now_ms = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as i64;
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await; let _ = send_cmd(&mut s, &["SET", "exp:at:ms", "v"]).await;
let pexat = send_cmd( let pexat = send_cmd(&mut s, &["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)]).await;
&mut s,
&["PEXPIREAT", "exp:at:ms", &format!("{}", now_ms + 450)],
)
.await;
assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)"); assert_contains(&pexat, "1", "PEXPIREAT exp:at:ms now+450ms -> 1 (applied)");
let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await; let ttl2 = send_cmd(&mut s, &["TTL", "exp:at:ms"]).await;
assert!( assert!(
@@ -952,9 +889,5 @@ async fn test_14_expireat_pexpireat() {
); );
sleep(Duration::from_millis(600)).await; sleep(Duration::from_millis(600)).await;
let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await; let exists_after_pexat = send_cmd(&mut s, &["EXISTS", "exp:at:ms"]).await;
assert_contains( assert_contains(&exists_after_pexat, "0", "EXISTS exp:at:ms after PEXPIREAT expiry -> 0");
&exists_after_pexat, }
"0",
"EXISTS exp:at:ms after PEXPIREAT expiry -> 0",
);
}