heroagent/pkg/servers/redisserver/cmd/redistest/main.go
2025-04-23 04:18:28 +02:00

439 lines
14 KiB
Go

package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"strconv"
"sync"
"time"
"github.com/freeflowuniverse/heroagent/pkg/servers/redisserver"
"github.com/redis/go-redis/v9"
)
// TestResult represents the result of a single test
type TestResult struct {
Name string
Passed bool
Description string
Error error
}
// TestSuite represents a collection of tests
type TestSuite struct {
Name string
Tests []TestResult
Passed int
Failed int
}
func (ts *TestSuite) AddResult(name string, passed bool, description string, err error) {
result := TestResult{
Name: name,
Passed: passed,
Description: description,
Error: err,
}
ts.Tests = append(ts.Tests, result)
if passed {
ts.Passed++
} else {
ts.Failed++
}
}
func (ts *TestSuite) PrintResults() {
fmt.Printf("\n=== Test Suite: %s ===\n", ts.Name)
for _, test := range ts.Tests {
status := "✅ PASS"
if !test.Passed {
status = "❌ FAIL"
}
fmt.Printf("%s: %s - %s\n", status, test.Name, test.Description)
if test.Error != nil {
fmt.Printf(" Error: %v\n", test.Error)
}
}
fmt.Printf("\nSummary: %d passed, %d failed\n", ts.Passed, ts.Failed)
}
func main() {
// Parse command line flags
tcpPort := flag.String("tcp-port", "7777", "Redis server TCP port")
unixSocket := flag.String("unix-socket", "/tmp/redis-test.sock", "Redis server Unix domain socket path")
debug := flag.Bool("debug", false, "Enable debug output")
dbNum := flag.Int("db", 0, "Redis database number")
flag.Parse()
// Start Redis server in a goroutine
log.Printf("Starting Redis server on TCP port %s and Unix socket %s", *tcpPort, *unixSocket)
// Create a wait group to ensure the server is started before testing
var wg sync.WaitGroup
wg.Add(1)
// Remove the Unix socket file if it already exists
if *unixSocket != "" {
if _, err := os.Stat(*unixSocket); err == nil {
log.Printf("Removing existing Unix socket file: %s", *unixSocket)
if err := os.Remove(*unixSocket); err != nil {
log.Printf("Warning: Failed to remove existing Unix socket file: %v", err)
}
}
}
// Start the Redis server in a goroutine
go func() {
// Create a new server instance
_ = redisserver.NewServer(redisserver.ServerConfig{TCPPort: *tcpPort, UnixSocketPath: *unixSocket})
// Signal that the server is ready
wg.Done()
// Keep the server running
select {}
}()
// Wait for the server to start
wg.Wait()
// Give the server a moment to initialize, especially for Unix socket
time.Sleep(1 * time.Second)
// Test TCP connection
log.Println("Testing TCP connection")
tcpAddr := fmt.Sprintf("localhost:%s", *tcpPort)
runTests(tcpAddr, *debug, *dbNum)
// Test Unix socket connection if supported
log.Println("\nTesting Unix socket connection")
runTests(*unixSocket, *debug, *dbNum)
}
func runTests(addr string, debug bool, dbNum int) {
// Connect to Redis
redisClient := redis.NewClient(&redis.Options{
Network: getNetworkType(addr),
Addr: addr,
DB: dbNum,
DialTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
})
defer redisClient.Close()
ctx := context.Background()
// Check connection
pong, err := redisClient.Ping(ctx).Result()
if err != nil {
log.Fatalf("Failed to connect to Redis: %v", err)
}
log.Printf("Connected to Redis: %s", pong)
// Run test suites
connectionTestSuite := &TestSuite{Name: "Connection Tests"}
stringTestSuite := &TestSuite{Name: "String Operations"}
hashTestSuite := &TestSuite{Name: "Hash Operations"}
listTestSuite := &TestSuite{Name: "List Operations"}
patternTestSuite := &TestSuite{Name: "Pattern Matching"}
miscTestSuite := &TestSuite{Name: "Miscellaneous Operations"}
// Run connection tests
runConnectionTests(ctx, redisClient, connectionTestSuite)
// Run string operation tests
runStringTests(ctx, redisClient, stringTestSuite)
// Run hash operation tests
runHashTests(ctx, redisClient, hashTestSuite)
// Run list operation tests
runListTests(ctx, redisClient, listTestSuite)
// Run pattern matching tests
runPatternTests(ctx, redisClient, patternTestSuite)
// Run miscellaneous tests
runMiscTests(ctx, redisClient, miscTestSuite)
// Print test results
connectionTestSuite.PrintResults()
stringTestSuite.PrintResults()
hashTestSuite.PrintResults()
listTestSuite.PrintResults()
patternTestSuite.PrintResults()
miscTestSuite.PrintResults()
// Print overall summary
totalTests := connectionTestSuite.Passed + connectionTestSuite.Failed +
stringTestSuite.Passed + stringTestSuite.Failed +
hashTestSuite.Passed + hashTestSuite.Failed +
listTestSuite.Passed + listTestSuite.Failed +
patternTestSuite.Passed + patternTestSuite.Failed +
miscTestSuite.Passed + miscTestSuite.Failed
totalPassed := connectionTestSuite.Passed + stringTestSuite.Passed +
hashTestSuite.Passed + listTestSuite.Passed +
patternTestSuite.Passed + miscTestSuite.Passed
totalFailed := connectionTestSuite.Failed + stringTestSuite.Failed +
hashTestSuite.Failed + listTestSuite.Failed +
patternTestSuite.Failed + miscTestSuite.Failed
fmt.Printf("\n=== Overall Test Summary ===\n")
fmt.Printf("Total Tests: %d\n", totalTests)
fmt.Printf("Passed: %d\n", totalPassed)
fmt.Printf("Failed: %d\n", totalFailed)
fmt.Printf("Success Rate: %.2f%%\n", float64(totalPassed)/float64(totalTests)*100)
// Debug output
if debug {
log.Println("\nListing keys in Redis using SCAN:")
var cursor uint64
var allKeys []string
for {
keys, nextCursor, err := redisClient.Scan(ctx, cursor, "*", 10).Result()
if err != nil {
log.Printf("Error scanning keys: %v", err)
break
}
allKeys = append(allKeys, keys...)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
log.Printf("Found %d total keys using SCAN", len(allKeys))
for i, k := range allKeys {
if i < 20 { // Limit output to first 20 keys
log.Printf("Key[%d]: %s", i, k)
}
}
if len(allKeys) > 20 {
log.Printf("... and %d more keys", len(allKeys)-20)
}
}
}
func runConnectionTests(ctx context.Context, client *redis.Client, suite *TestSuite) {
// Test basic connection
_, err := client.Ping(ctx).Result()
suite.AddResult("Ping", err == nil, "Basic connectivity test", err)
// Test INFO command
info, err := client.Info(ctx).Result()
suite.AddResult("Info", err == nil && len(info) > 0, "Server information retrieval", err)
}
func runStringTests(ctx context.Context, client *redis.Client, suite *TestSuite) {
// Test SET and GET
err := client.Set(ctx, "test:string:key1", "value1", 0).Err()
suite.AddResult("Set", err == nil, "Set a string value", err)
val, err := client.Get(ctx, "test:string:key1").Result()
suite.AddResult("Get", err == nil && val == "value1", "Get a string value", err)
// Test SET with expiration
err = client.Set(ctx, "test:string:expiring", "expiring-value", 2*time.Second).Err()
suite.AddResult("SetWithExpiration", err == nil, "Set a string value with expiration", err)
// Test TTL
ttl, err := client.TTL(ctx, "test:string:expiring").Result()
suite.AddResult("TTL", err == nil && ttl > 0, fmt.Sprintf("Check TTL of expiring key (TTL: %v)", ttl), err)
// Wait for expiration
time.Sleep(3 * time.Second)
// Test expired key
_, err = client.Get(ctx, "test:string:expiring").Result()
suite.AddResult("ExpiredKey", err == redis.Nil, "Verify expired key is removed", err)
// Test DEL
client.Set(ctx, "test:string:to-delete", "delete-me", 0)
affected, err := client.Del(ctx, "test:string:to-delete").Result()
suite.AddResult("Del", err == nil && affected == 1, "Delete a key", err)
// Test EXISTS
client.Set(ctx, "test:string:exists", "i-exist", 0)
exists, err := client.Exists(ctx, "test:string:exists").Result()
suite.AddResult("Exists", err == nil && exists == 1, "Check if a key exists", err)
// Test INCR
client.Del(ctx, "test:counter")
incr, err := client.Incr(ctx, "test:counter").Result()
suite.AddResult("Incr", err == nil && incr == 1, "Increment a counter", err)
incr, err = client.Incr(ctx, "test:counter").Result()
suite.AddResult("IncrAgain", err == nil && incr == 2, "Increment a counter again", err)
// Test TYPE
client.Set(ctx, "test:string:type", "string-value", 0)
keyType, err := client.Type(ctx, "test:string:type").Result()
suite.AddResult("Type", err == nil && keyType == "string", "Get type of a key", err)
}
func runHashTests(ctx context.Context, client *redis.Client, suite *TestSuite) {
// Test HSET and HGET
err := client.HSet(ctx, "test:hash:user1", "name", "John").Err()
suite.AddResult("HSet", err == nil, "Set a hash field", err)
val, err := client.HGet(ctx, "test:hash:user1", "name").Result()
suite.AddResult("HGet", err == nil && val == "John", "Get a hash field", err)
// Test HSET multiple fields
err = client.HSet(ctx, "test:hash:user1", "age", "30", "city", "New York").Err()
suite.AddResult("HSetMultiple", err == nil, "Set multiple hash fields", err)
// Test HGETALL
fields, err := client.HGetAll(ctx, "test:hash:user1").Result()
suite.AddResult("HGetAll", err == nil && len(fields) == 3,
fmt.Sprintf("Get all hash fields (found %d fields)", len(fields)), err)
// Test HKEYS
keys, err := client.HKeys(ctx, "test:hash:user1").Result()
suite.AddResult("HKeys", err == nil && len(keys) == 3,
fmt.Sprintf("Get all hash keys (found %d keys)", len(keys)), err)
// Test HLEN
length, err := client.HLen(ctx, "test:hash:user1").Result()
suite.AddResult("HLen", err == nil && length == 3,
fmt.Sprintf("Get hash length (length: %d)", length), err)
// Test HDEL
deleted, err := client.HDel(ctx, "test:hash:user1", "city").Result()
suite.AddResult("HDel", err == nil && deleted == 1, "Delete a hash field", err)
// Test TYPE for hash
keyType, err := client.Type(ctx, "test:hash:user1").Result()
suite.AddResult("HashType", err == nil && keyType == "hash", "Get type of a hash key", err)
}
func runListTests(ctx context.Context, client *redis.Client, suite *TestSuite) {
// Clear any existing list
client.Del(ctx, "test:list:items")
// Test LPUSH
count, err := client.LPush(ctx, "test:list:items", "item1", "item2").Result()
suite.AddResult("LPush", err == nil && count == 2,
fmt.Sprintf("Push items to the head of a list (count: %d)", count), err)
// Test RPUSH
count, err = client.RPush(ctx, "test:list:items", "item3", "item4").Result()
suite.AddResult("RPush", err == nil && count == 4,
fmt.Sprintf("Push items to the tail of a list (count: %d)", count), err)
// Test LLEN
length, err := client.LLen(ctx, "test:list:items").Result()
suite.AddResult("LLen", err == nil && length == 4,
fmt.Sprintf("Get list length (length: %d)", length), err)
// Test LRANGE
items, err := client.LRange(ctx, "test:list:items", 0, -1).Result()
suite.AddResult("LRange", err == nil && len(items) == 4,
fmt.Sprintf("Get range of list items (found %d items)", len(items)), err)
// Test LPOP
item, err := client.LPop(ctx, "test:list:items").Result()
suite.AddResult("LPop", err == nil && item == "item2",
fmt.Sprintf("Pop item from the head of a list (item: %s)", item), err)
// Test RPOP
item, err = client.RPop(ctx, "test:list:items").Result()
suite.AddResult("RPop", err == nil && item == "item4",
fmt.Sprintf("Pop item from the tail of a list (item: %s)", item), err)
// Test TYPE for list
keyType, err := client.Type(ctx, "test:list:items").Result()
suite.AddResult("ListType", err == nil && keyType == "list", "Get type of a list key", err)
}
func runPatternTests(ctx context.Context, client *redis.Client, suite *TestSuite) {
// Create test keys
client.FlushDB(ctx)
client.Set(ctx, "user:1:name", "Alice", 0)
client.Set(ctx, "user:1:email", "alice@example.com", 0)
client.Set(ctx, "user:2:name", "Bob", 0)
client.Set(ctx, "user:2:email", "bob@example.com", 0)
client.Set(ctx, "product:1", "Laptop", 0)
client.Set(ctx, "product:2", "Phone", 0)
// Test KEYS with pattern
keys, err := client.Keys(ctx, "user:*").Result()
suite.AddResult("KeysPattern", err == nil && len(keys) == 4,
fmt.Sprintf("Get keys matching pattern 'user:*' (found %d keys)", len(keys)), err)
keys, err = client.Keys(ctx, "user:1:*").Result()
suite.AddResult("KeysSpecificPattern", err == nil && len(keys) == 2,
fmt.Sprintf("Get keys matching pattern 'user:1:*' (found %d keys)", len(keys)), err)
// Test SCAN with pattern
var cursor uint64
var allKeys []string
for {
keys, nextCursor, err := client.Scan(ctx, cursor, "product:*", 10).Result()
if err != nil {
suite.AddResult("ScanPattern", false, "Scan keys matching pattern 'product:*'", err)
break
}
allKeys = append(allKeys, keys...)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
suite.AddResult("ScanPattern", len(allKeys) == 2,
fmt.Sprintf("Scan keys matching pattern 'product:*' (found %d keys)", len(allKeys)), nil)
// Test wildcard patterns
for i := 1; i <= 5; i++ {
client.Set(ctx, fmt.Sprintf("test:wildcard:%d", i), strconv.Itoa(i), 0)
}
keys, err = client.Keys(ctx, "test:wildcard:?").Result()
suite.AddResult("SingleCharWildcard", err == nil && len(keys) == 5,
fmt.Sprintf("Get keys matching single character wildcard (found %d keys)", len(keys)), err)
}
func runMiscTests(ctx context.Context, client *redis.Client, suite *TestSuite) {
// Test FLUSHDB
err := client.FlushDB(ctx).Err()
suite.AddResult("FlushDB", err == nil, "Flush the current database", err)
// Test key count after flush
keys, err := client.Keys(ctx, "*").Result()
suite.AddResult("KeysAfterFlush", err == nil && len(keys) == 0,
fmt.Sprintf("Verify no keys after flush (found %d keys)", len(keys)), err)
// Test multiple operations in sequence
client.Set(ctx, "test:multi:1", "value1", 0)
client.Set(ctx, "test:multi:2", "value2", 0)
val1, err1 := client.Get(ctx, "test:multi:1").Result()
val2, err2 := client.Get(ctx, "test:multi:2").Result()
suite.AddResult("MultipleOperations", err1 == nil && err2 == nil && val1 == "value1" && val2 == "value2",
"Perform multiple operations in sequence", nil)
// Test non-existent key
_, err = client.Get(ctx, "test:nonexistent").Result()
suite.AddResult("NonExistentKey", err == redis.Nil, "Get a non-existent key", nil)
}
// getNetworkType determines if the address is a TCP or Unix socket
func getNetworkType(addr string) string {
if addr[0] == '/' {
return "unix"
}
return "tcp"
}