This commit is contained in:
despiegk 2025-08-16 07:23:20 +02:00
parent de2be4a785
commit d3e28cafe4
3 changed files with 167 additions and 0 deletions

View File

@ -28,6 +28,7 @@ pub enum Cmd {
HLen(String), HLen(String),
HMGet(String, Vec<String>), HMGet(String, Vec<String>),
HSetNx(String, String, String), HSetNx(String, String, String),
Scan(u64, Option<String>, Option<u64>), // cursor, pattern, count
Unknow, Unknow,
} }
@ -176,6 +177,43 @@ impl Cmd {
} }
Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone()) Cmd::HSetNx(cmd[1].clone(), cmd[2].clone(), cmd[3].clone())
} }
"scan" => {
if cmd.len() < 2 {
return Err(DBError(format!("wrong number of arguments for SCAN command")));
}
let cursor = cmd[1].parse::<u64>().map_err(|_|
DBError("ERR invalid cursor".to_string()))?;
let mut pattern = None;
let mut count = None;
let mut i = 2;
while i < cmd.len() {
match cmd[i].to_lowercase().as_str() {
"match" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
pattern = Some(cmd[i + 1].clone());
i += 2;
}
"count" => {
if i + 1 >= cmd.len() {
return Err(DBError("ERR syntax error".to_string()));
}
count = Some(cmd[i + 1].parse::<u64>().map_err(|_|
DBError("ERR value is not an integer or out of range".to_string()))?);
i += 2;
}
_ => {
return Err(DBError(format!("ERR syntax error")));
}
}
}
Cmd::Scan(cursor, pattern, count)
}
_ => Cmd::Unknow, _ => Cmd::Unknow,
}, },
protocol.0, protocol.0,
@ -244,6 +282,7 @@ impl Cmd {
Cmd::HLen(key) => hlen_cmd(server, key).await, Cmd::HLen(key) => hlen_cmd(server, key).await,
Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await, Cmd::HMGet(key, fields) => hmget_cmd(server, key, fields).await,
Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await, Cmd::HSetNx(key, field, value) => hsetnx_cmd(server, key, field, value).await,
Cmd::Scan(cursor, pattern, count) => scan_cmd(server, cursor, pattern.as_deref(), count).await,
Cmd::Unknow => Ok(Protocol::err("unknown cmd")), Cmd::Unknow => Ok(Protocol::err("unknown cmd")),
} }
} }
@ -444,3 +483,17 @@ async fn hsetnx_cmd(server: &Server, key: &str, field: &str, value: &str) -> Res
Err(e) => Ok(Protocol::err(&e.0)), Err(e) => Ok(Protocol::err(&e.0)),
} }
} }
async fn scan_cmd(server: &Server, cursor: &u64, pattern: Option<&str>, count: &Option<u64>) -> Result<Protocol, DBError> {
match server.storage.scan(*cursor, pattern, *count) {
Ok((next_cursor, keys)) => {
let mut result = Vec::new();
result.push(Protocol::BulkString(next_cursor.to_string()));
result.push(Protocol::Array(
keys.into_iter().map(Protocol::BulkString).collect(),
));
Ok(Protocol::Array(result))
}
Err(e) => Ok(Protocol::err(&e.0)),
}
}

View File

@ -445,4 +445,65 @@ impl Storage {
write_txn.commit()?; write_txn.commit()?;
Ok(result) Ok(result)
} }
pub fn scan(&self, cursor: u64, pattern: Option<&str>, count: Option<u64>) -> Result<(u64, Vec<String>), DBError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(TYPES_TABLE)?;
let count = count.unwrap_or(10); // Default count is 10
let mut keys = Vec::new();
let mut current_cursor = 0u64;
let mut returned_keys = 0u64;
let mut iter = table.iter()?;
while let Some(entry) = iter.next() {
let key = entry?.0.value().to_string();
// Skip keys until we reach the cursor position
if current_cursor < cursor {
current_cursor += 1;
continue;
}
// Check if key matches pattern
let matches = match pattern {
Some(pat) => {
if pat == "*" {
true
} else if pat.contains('*') {
// Simple glob pattern matching
let pattern_parts: Vec<&str> = pat.split('*').collect();
if pattern_parts.len() == 2 {
let prefix = pattern_parts[0];
let suffix = pattern_parts[1];
key.starts_with(prefix) && key.ends_with(suffix)
} else {
key.contains(&pat.replace('*', ""))
}
} else {
key.contains(pat)
}
}
None => true,
};
if matches {
keys.push(key);
returned_keys += 1;
// Stop if we've returned enough keys
if returned_keys >= count {
current_cursor += 1;
break;
}
}
current_cursor += 1;
}
// If we've reached the end of iteration, return cursor 0 to indicate completion
let next_cursor = if returned_keys < count { 0 } else { current_cursor };
Ok((next_cursor, keys))
}
} }

View File

@ -230,6 +230,58 @@ test_expiration() {
redis_cmd "GET expire_ex_key" "" # Should be expired redis_cmd "GET expire_ex_key" "" # Should be expired
} }
# Function to test SCAN operations
test_scan_operations() {
print_status "=== Testing SCAN Operations ==="
# Set up test data for scanning
redis_cmd "SET scan_test1 value1" "OK"
redis_cmd "SET scan_test2 value2" "OK"
redis_cmd "SET scan_test3 value3" "OK"
redis_cmd "SET other_key other_value" "OK"
redis_cmd "HSET scan_hash field1 value1" "1"
# Test basic SCAN
print_status "Testing basic SCAN with cursor 0"
redis_cmd "SCAN 0" ""
# Test SCAN with MATCH pattern
print_status "Testing SCAN with MATCH pattern"
redis_cmd "SCAN 0 MATCH scan_test*" ""
# Test SCAN with COUNT
print_status "Testing SCAN with COUNT 2"
redis_cmd "SCAN 0 COUNT 2" ""
# Test SCAN with both MATCH and COUNT
print_status "Testing SCAN with MATCH and COUNT"
redis_cmd "SCAN 0 MATCH scan_* COUNT 1" ""
# Test SCAN continuation with more keys
print_status "Setting up more keys for continuation test"
redis_cmd "SET scan_key1 val1" "OK"
redis_cmd "SET scan_key2 val2" "OK"
redis_cmd "SET scan_key3 val3" "OK"
redis_cmd "SET scan_key4 val4" "OK"
redis_cmd "SET scan_key5 val5" "OK"
print_status "Testing SCAN with small COUNT for pagination"
redis_cmd "SCAN 0 COUNT 3" ""
# Clean up SCAN test data
print_status "Cleaning up SCAN test data"
redis_cmd "DEL scan_test1" "1"
redis_cmd "DEL scan_test2" "1"
redis_cmd "DEL scan_test3" "1"
redis_cmd "DEL other_key" "1"
redis_cmd "DEL scan_hash" "1"
redis_cmd "DEL scan_key1" "1"
redis_cmd "DEL scan_key2" "1"
redis_cmd "DEL scan_key3" "1"
redis_cmd "DEL scan_key4" "1"
redis_cmd "DEL scan_key5" "1"
}
# Main execution # Main execution
main() { main() {
print_status "Starting HeroDB comprehensive test suite..." print_status "Starting HeroDB comprehensive test suite..."
@ -265,6 +317,7 @@ main() {
test_keys_operations || failed_tests=$((failed_tests + 1)) test_keys_operations || failed_tests=$((failed_tests + 1))
test_info_operations || failed_tests=$((failed_tests + 1)) test_info_operations || failed_tests=$((failed_tests + 1))
test_expiration || failed_tests=$((failed_tests + 1)) test_expiration || failed_tests=$((failed_tests + 1))
test_scan_operations || failed_tests=$((failed_tests + 1))
# Summary # Summary
echo echo