package redisserver import ( "log" "sort" "strconv" "strings" "time" "github.com/tidwall/redcon" ) func (s *Server) startRedisServer(addr string, networkType string) { networkDesc := "TCP" netType := "tcp" if networkType == "unix" { networkDesc = "Unix socket" netType = "unix" } log.Printf("Starting Redis-like server on %s (%s)", addr, networkDesc) // Use ListenAndServeNetwork to support both TCP and Unix sockets err := redcon.ListenAndServeNetwork(netType, addr, func(conn redcon.Conn, cmd redcon.Command) { // Every command is expected to have at least one argument (the command name). if len(cmd.Args) == 0 { conn.WriteError("ERR empty command") return } command := strings.ToLower(string(cmd.Args[0])) switch command { case "ping": conn.WriteString("PONG") case "set": // Usage: SET key value [EX seconds] if len(cmd.Args) < 3 { conn.WriteError("ERR wrong number of arguments for 'set' command") return } key := string(cmd.Args[1]) value := string(cmd.Args[2]) duration := time.Duration(0) // Check for an expiration option (only EX is supported here). if len(cmd.Args) > 3 { if strings.ToLower(string(cmd.Args[3])) == "ex" && len(cmd.Args) > 4 { seconds, err := strconv.Atoi(string(cmd.Args[4])) if err != nil { conn.WriteError("ERR invalid expire time") return } duration = time.Duration(seconds) * time.Second } } s.set(key, value, duration) conn.WriteString("OK") case "get": if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'get' command") return } key := string(cmd.Args[1]) v, ok := s.get(key) if !ok { conn.WriteNull() return } // Only string type is returned by GET. switch val := v.(type) { case string: conn.WriteBulkString(val) default: conn.WriteError("WRONGTYPE Operation against a key holding the wrong kind of value") } case "del": if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'del' command") return } count := 0 for i := 1; i < len(cmd.Args); i++ { key := string(cmd.Args[i]) count += s.del(key) } conn.WriteInt(count) case "keys": if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'keys' command") return } pattern := string(cmd.Args[1]) keys := s.keys(pattern) conn.WriteArray(len(keys)) for _, k := range keys { conn.WriteBulkString(k) } case "hset": // Usage: HSET key field value [field value ...] if len(cmd.Args) < 4 || len(cmd.Args)%2 != 0 { conn.WriteError("ERR wrong number of arguments for 'hset' command") return } key := string(cmd.Args[1]) // Process multiple field-value pairs totalAdded := 0 for i := 2; i < len(cmd.Args); i += 2 { field := string(cmd.Args[i]) value := string(cmd.Args[i+1]) added := s.hset(key, field, value) totalAdded += added } conn.WriteInt(totalAdded) case "hget": // Usage: HGET key field if len(cmd.Args) < 3 { conn.WriteError("ERR wrong number of arguments for 'hget' command") return } key := string(cmd.Args[1]) field := string(cmd.Args[2]) v, ok := s.hget(key, field) if !ok { conn.WriteNull() return } conn.WriteBulkString(v) case "hdel": // Usage: HDEL key field [field ...] if len(cmd.Args) < 3 { conn.WriteError("ERR wrong number of arguments for 'hdel' command") return } key := string(cmd.Args[1]) fields := make([]string, 0, len(cmd.Args)-2) for i := 2; i < len(cmd.Args); i++ { fields = append(fields, string(cmd.Args[i])) } removed := s.hdel(key, fields) conn.WriteInt(removed) case "hkeys": // Usage: HKEYS key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'hkeys' command") return } key := string(cmd.Args[1]) fields := s.hkeys(key) conn.WriteArray(len(fields)) for _, field := range fields { conn.WriteBulkString(field) } case "hlen": // Usage: HLEN key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'hlen' command") return } key := string(cmd.Args[1]) length := s.hlen(key) conn.WriteInt(length) case "hgetall": // Usage: HGETALL key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'hgetall' command") return } key := string(cmd.Args[1]) hash, ok := s.getHash(key) if !ok { // Return empty array if key doesn't exist or is not a hash conn.WriteArray(0) return } // Write field-value pairs conn.WriteArray(len(hash) * 2) // Each field has a corresponding value // Sort fields for consistent output fields := make([]string, 0, len(hash)) for field := range hash { fields = append(fields, field) } sort.Strings(fields) for _, field := range fields { conn.WriteBulkString(field) conn.WriteBulkString(hash[field]) } case "flushdb": // Usage: FLUSHDB s.mu.Lock() s.data = make(map[string]*entry) s.mu.Unlock() conn.WriteString("OK") case "incr": if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'incr' command") return } key := string(cmd.Args[1]) newVal, err := s.incr(key) if err != nil { conn.WriteError("ERR " + err.Error()) return } conn.WriteInt64(newVal) case "info": // Return basic information about the server info := s.getInfo() conn.WriteBulkString(info) case "type": // Usage: TYPE key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'type' command") return } key := string(cmd.Args[1]) keyType := s.getType(key) conn.WriteBulkString(keyType) case "ttl": // Usage: TTL key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'ttl' command") return } key := string(cmd.Args[1]) ttl := s.getTTL(key) conn.WriteInt64(ttl) case "exists": // Usage: EXISTS key [key ...] if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'exists' command") return } keys := make([]string, 0, len(cmd.Args)-1) for i := 1; i < len(cmd.Args); i++ { keys = append(keys, string(cmd.Args[i])) } count := s.exists(keys) conn.WriteInt(count) case "expire": // Usage: EXPIRE key seconds if len(cmd.Args) < 3 { conn.WriteError("ERR wrong number of arguments for 'expire' command") return } key := string(cmd.Args[1]) seconds, err := strconv.ParseInt(string(cmd.Args[2]), 10, 64) if err != nil { conn.WriteError("ERR value is not an integer or out of range") return } success := s.expire(key, time.Duration(seconds)*time.Second) if success { conn.WriteInt(1) } else { conn.WriteInt(0) } case "scan": // Usage: SCAN cursor [MATCH pattern] [COUNT count] if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'scan' command") return } cursor := string(cmd.Args[1]) cursorInt, err := strconv.Atoi(cursor) if err != nil { conn.WriteError("ERR invalid cursor") return } // Default values pattern := "*" count := 10 // Parse optional arguments for i := 2; i < len(cmd.Args); i++ { arg := strings.ToLower(string(cmd.Args[i])) if arg == "match" && i+1 < len(cmd.Args) { pattern = string(cmd.Args[i+1]) i++ } else if arg == "count" && i+1 < len(cmd.Args) { count, err = strconv.Atoi(string(cmd.Args[i+1])) if err != nil { conn.WriteError("ERR value is not an integer or out of range") return } i++ } } // Get matching keys nextCursor, keys := s.scan(cursorInt, pattern, count) // Write response conn.WriteArray(2) conn.WriteBulkString(strconv.Itoa(nextCursor)) conn.WriteArray(len(keys)) for _, key := range keys { conn.WriteBulkString(key) } case "hscan": // Usage: HSCAN key cursor [MATCH pattern] [COUNT count] if len(cmd.Args) < 3 { conn.WriteError("ERR wrong number of arguments for 'hscan' command") return } key := string(cmd.Args[1]) cursor := string(cmd.Args[2]) cursorInt, err := strconv.Atoi(cursor) if err != nil { conn.WriteError("ERR invalid cursor") return } // Default values pattern := "*" count := 10 // Parse optional arguments for i := 3; i < len(cmd.Args); i++ { arg := strings.ToLower(string(cmd.Args[i])) if arg == "match" && i+1 < len(cmd.Args) { pattern = string(cmd.Args[i+1]) i++ } else if arg == "count" && i+1 < len(cmd.Args) { count, err = strconv.Atoi(string(cmd.Args[i+1])) if err != nil { conn.WriteError("ERR value is not an integer or out of range") return } i++ } } // Get matching fields and values nextCursor, fields, values := s.hscan(key, cursorInt, pattern, count) // Write response conn.WriteArray(2) conn.WriteBulkString(strconv.Itoa(nextCursor)) // Write field-value pairs conn.WriteArray(len(fields) * 2) // Each field has a corresponding value for i := 0; i < len(fields); i++ { conn.WriteBulkString(fields[i]) conn.WriteBulkString(values[i]) } case "lpush": // Usage: LPUSH key value [value ...] if len(cmd.Args) < 3 { conn.WriteError("ERR wrong number of arguments for 'lpush' command") return } key := string(cmd.Args[1]) values := make([]string, len(cmd.Args)-2) for i := 2; i < len(cmd.Args); i++ { values[i-2] = string(cmd.Args[i]) } length := s.lpush(key, values) conn.WriteInt(length) case "rpush": // Usage: RPUSH key value [value ...] if len(cmd.Args) < 3 { conn.WriteError("ERR wrong number of arguments for 'rpush' command") return } key := string(cmd.Args[1]) values := make([]string, len(cmd.Args)-2) for i := 2; i < len(cmd.Args); i++ { values[i-2] = string(cmd.Args[i]) } length := s.rpush(key, values) conn.WriteInt(length) case "lpop": // Usage: LPOP key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'lpop' command") return } key := string(cmd.Args[1]) val, ok := s.lpop(key) if !ok { conn.WriteNull() return } conn.WriteBulkString(val) case "rpop": // Usage: RPOP key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'rpop' command") return } key := string(cmd.Args[1]) val, ok := s.rpop(key) if !ok { conn.WriteNull() return } conn.WriteBulkString(val) case "llen": // Usage: LLEN key if len(cmd.Args) < 2 { conn.WriteError("ERR wrong number of arguments for 'llen' command") return } key := string(cmd.Args[1]) length := s.llen(key) conn.WriteInt(length) case "lrange": // Usage: LRANGE key start stop if len(cmd.Args) < 4 { conn.WriteError("ERR wrong number of arguments for 'lrange' command") return } key := string(cmd.Args[1]) start, err := strconv.Atoi(string(cmd.Args[2])) if err != nil { conn.WriteError("ERR value is not an integer or out of range") return } stop, err := strconv.Atoi(string(cmd.Args[3])) if err != nil { conn.WriteError("ERR value is not an integer or out of range") return } values := s.lrange(key, start, stop) conn.WriteArray(len(values)) for _, val := range values { conn.WriteBulkString(val) } default: conn.WriteError("ERR unknown command '" + command + "'") } }, // Accept connection: always allow. func(conn redcon.Conn) bool { return true }, // On connection close. func(conn redcon.Conn, err error) {}, ) if err != nil { log.Printf("Error starting Redis server: %v", err) } }