...
This commit is contained in:
115
pkg2_dont_use/heroservices/billing/models/account.go
Normal file
115
pkg2_dont_use/heroservices/billing/models/account.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Account represents a financial account in the billing system
|
||||
type Account struct {
|
||||
ID uint32 // Unique identifier, auto-incremented
|
||||
UserID uint32 // Link to the unique user ID
|
||||
Name string // Name of the account
|
||||
Currency string // Currency code (e.g., USD, EUR), always uppercase, 3 letters
|
||||
Amount float64 // Current balance, always positive
|
||||
}
|
||||
|
||||
// Serialize converts an Account to a binary representation
|
||||
func (a *Account) Serialize() []byte {
|
||||
// Calculate the size of the serialized data
|
||||
// 4 bytes for ID + 4 bytes for UserID + 2 bytes for name length + len(name) bytes +
|
||||
// 3 bytes for currency + 8 bytes for amount
|
||||
size := 4 + 4 + 2 + len(a.Name) + 3 + 8
|
||||
data := make([]byte, size)
|
||||
|
||||
// Write ID (4 bytes)
|
||||
binary.LittleEndian.PutUint32(data[0:4], a.ID)
|
||||
|
||||
// Write UserID (4 bytes)
|
||||
binary.LittleEndian.PutUint32(data[4:8], a.UserID)
|
||||
|
||||
// Write name length (2 bytes) and name
|
||||
nameLen := uint16(len(a.Name))
|
||||
binary.LittleEndian.PutUint16(data[8:10], nameLen)
|
||||
copy(data[10:10+nameLen], a.Name)
|
||||
|
||||
// Write currency (3 bytes)
|
||||
currencyOffset := 10 + nameLen
|
||||
copy(data[currencyOffset:currencyOffset+3], a.Currency)
|
||||
|
||||
// Write amount (8 bytes)
|
||||
amountOffset := currencyOffset + 3
|
||||
binary.LittleEndian.PutUint64(data[amountOffset:amountOffset+8], math.Float64bits(a.Amount))
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// DeserializeAccount converts a binary representation back to an Account
|
||||
func DeserializeAccount(data []byte) (*Account, error) {
|
||||
if len(data) < 17 { // Minimum size: 4 (ID) + 4 (UserID) + 2 (name length) + 3 (currency) + 8 (amount)
|
||||
return nil, errors.New("data too short to deserialize Account")
|
||||
}
|
||||
|
||||
account := &Account{}
|
||||
|
||||
// Read ID
|
||||
account.ID = binary.LittleEndian.Uint32(data[0:4])
|
||||
|
||||
// Read UserID
|
||||
account.UserID = binary.LittleEndian.Uint32(data[4:8])
|
||||
|
||||
// Read name length and name
|
||||
nameLen := binary.LittleEndian.Uint16(data[8:10])
|
||||
if 10+nameLen > uint16(len(data)) {
|
||||
return nil, errors.New("data too short to read name")
|
||||
}
|
||||
account.Name = string(data[10 : 10+nameLen])
|
||||
|
||||
// Read currency
|
||||
currencyOffset := 10 + nameLen
|
||||
if int(currencyOffset)+3 > len(data) {
|
||||
return nil, errors.New("data too short to read currency")
|
||||
}
|
||||
account.Currency = string(data[currencyOffset : currencyOffset+3])
|
||||
|
||||
// Read amount
|
||||
amountOffset := currencyOffset + 3
|
||||
if int(amountOffset)+8 > len(data) {
|
||||
return nil, errors.New("data too short to read amount")
|
||||
}
|
||||
account.Amount = math.Float64frombits(binary.LittleEndian.Uint64(data[amountOffset : amountOffset+8]))
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// Validate checks if the account data is valid
|
||||
func (a *Account) Validate() error {
|
||||
if a.UserID == 0 {
|
||||
return errors.New("user ID cannot be zero")
|
||||
}
|
||||
|
||||
if a.Name == "" {
|
||||
return errors.New("account name cannot be empty")
|
||||
}
|
||||
|
||||
// Check currency format (3 uppercase letters)
|
||||
if len(a.Currency) != 3 {
|
||||
return errors.New("currency must be 3 characters")
|
||||
}
|
||||
|
||||
a.Currency = strings.ToUpper(a.Currency)
|
||||
for _, r := range a.Currency {
|
||||
if r < 'A' || r > 'Z' {
|
||||
return errors.New("currency must contain only uppercase letters")
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure amount is non-negative
|
||||
if a.Amount < 0 {
|
||||
return errors.New("amount cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
348
pkg2_dont_use/heroservices/billing/models/cmd/main.go
Normal file
348
pkg2_dont_use/heroservices/billing/models/cmd/main.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.ourworld.tf/herocode/heroagent/pkg/heroservices/billing/models"
|
||||
)
|
||||
|
||||
const (
|
||||
numUsers = 1000
|
||||
numAccounts = 3000 // Average 3 accounts per user
|
||||
numTransactions = 10000 // Reduced for quicker testing
|
||||
currencies = "USD,EUR,GBP,JPY,CNY,AUD,CAD,CHF,HKD,SGD"
|
||||
accountTypes = "Checking,Savings,Investment,Retirement,Business,Credit,Loan,Mortgage,Travel,Emergency"
|
||||
companies = "Apple,Google,Microsoft,Amazon,Facebook,Tesla,Walmart,Target,Costco,HomeDepot,Lowes,BestBuy,Starbucks,McDonalds,Chipotle,Nike,Adidas,Uber,Lyft,Airbnb"
|
||||
firstNames = "John,Jane,Michael,Emily,David,Sarah,Robert,Jennifer,William,Elizabeth,James,Linda,Richard,Barbara,Joseph,Susan,Thomas,Jessica,Charles,Mary"
|
||||
lastNames = "Smith,Johnson,Williams,Jones,Brown,Davis,Miller,Wilson,Moore,Taylor,Anderson,Thomas,Jackson,White,Harris,Martin,Thompson,Garcia,Martinez,Robinson"
|
||||
)
|
||||
|
||||
var (
|
||||
currencyList []string
|
||||
accountList []string
|
||||
companyList []string
|
||||
firstNameList []string
|
||||
lastNameList []string
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize random seed
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Split the constant strings into slices
|
||||
currencyList = strings.Split(currencies, ",")
|
||||
accountList = strings.Split(accountTypes, ",")
|
||||
companyList = strings.Split(companies, ",")
|
||||
firstNameList = strings.Split(firstNames, ",")
|
||||
lastNameList = strings.Split(lastNames, ",")
|
||||
}
|
||||
|
||||
// generateRandomName generates a random user name
|
||||
func generateRandomName() string {
|
||||
firstName := firstNameList[rand.Intn(len(firstNameList))]
|
||||
lastName := lastNameList[rand.Intn(len(lastNameList))]
|
||||
// Add timestamp and random number to ensure uniqueness
|
||||
return fmt.Sprintf("%s_%s_%d_%d", firstName, lastName, time.Now().UnixNano()%1000000, rand.Intn(1000))
|
||||
}
|
||||
|
||||
// generateRandomAccountName generates a random account name
|
||||
func generateRandomAccountName() string {
|
||||
accountType := accountList[rand.Intn(len(accountList))]
|
||||
|
||||
// 30% chance to add a company name
|
||||
if rand.Intn(100) < 30 {
|
||||
company := companyList[rand.Intn(len(companyList))]
|
||||
return fmt.Sprintf("%s %s", company, accountType)
|
||||
}
|
||||
|
||||
return accountType
|
||||
}
|
||||
|
||||
// generateRandomCurrency returns a random currency code
|
||||
func generateRandomCurrency() string {
|
||||
return currencyList[rand.Intn(len(currencyList))]
|
||||
}
|
||||
|
||||
// generateRandomAmount generates a random amount between min and max
|
||||
func generateRandomAmount(min, max float64) float64 {
|
||||
return min + rand.Float64()*(max-min)
|
||||
}
|
||||
|
||||
// generateRandomComment generates a realistic transaction comment
|
||||
func generateRandomComment(fromAccount, toAccount *models.Account, amount float64) string {
|
||||
commentTypes := []string{
|
||||
"Payment",
|
||||
"Transfer",
|
||||
"Deposit",
|
||||
"Withdrawal",
|
||||
"Refund",
|
||||
"Purchase",
|
||||
"Subscription",
|
||||
"Salary",
|
||||
"Dividend",
|
||||
"Interest",
|
||||
}
|
||||
|
||||
commentType := commentTypes[rand.Intn(len(commentTypes))]
|
||||
|
||||
switch commentType {
|
||||
case "Payment":
|
||||
return fmt.Sprintf("Payment to %s", toAccount.Name)
|
||||
case "Transfer":
|
||||
return fmt.Sprintf("Transfer to %s", toAccount.Name)
|
||||
case "Deposit":
|
||||
return fmt.Sprintf("Deposit to %s", toAccount.Name)
|
||||
case "Withdrawal":
|
||||
return fmt.Sprintf("Withdrawal from %s", fromAccount.Name)
|
||||
case "Refund":
|
||||
return fmt.Sprintf("Refund from %s", fromAccount.Name)
|
||||
case "Purchase":
|
||||
company := companyList[rand.Intn(len(companyList))]
|
||||
return fmt.Sprintf("Purchase at %s", company)
|
||||
case "Subscription":
|
||||
company := companyList[rand.Intn(len(companyList))]
|
||||
return fmt.Sprintf("%s subscription", company)
|
||||
case "Salary":
|
||||
company := companyList[rand.Intn(len(companyList))]
|
||||
return fmt.Sprintf("Salary from %s", company)
|
||||
case "Dividend":
|
||||
company := companyList[rand.Intn(len(companyList))]
|
||||
return fmt.Sprintf("Dividend from %s", company)
|
||||
case "Interest":
|
||||
return fmt.Sprintf("Interest payment %.2f%%", rand.Float64()*5)
|
||||
default:
|
||||
return fmt.Sprintf("Transfer of %.2f %s", amount, fromAccount.Currency)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Create a temporary directory for the test
|
||||
tempDir, err := os.MkdirTemp("", "billing-perf-test-*")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create subdirectories
|
||||
dbPath := filepath.Join(tempDir, "db")
|
||||
metaPath := filepath.Join(tempDir, "meta")
|
||||
|
||||
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
||||
log.Fatalf("Failed to create db directory: %v", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(metaPath, 0755); err != nil {
|
||||
log.Fatalf("Failed to create meta directory: %v", err)
|
||||
}
|
||||
|
||||
// Set environment variables to use the temp directory
|
||||
originalHome := os.Getenv("HOME")
|
||||
os.Setenv("HOME", tempDir)
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
|
||||
// Start CPU profiling
|
||||
cpuProfile, err := os.Create("cpu_profile.prof")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create CPU profile: %v", err)
|
||||
}
|
||||
pprof.StartCPUProfile(cpuProfile)
|
||||
defer pprof.StopCPUProfile()
|
||||
|
||||
// Create a new database store
|
||||
fmt.Println("Creating database store...")
|
||||
startTime := time.Now()
|
||||
dbStore, err := models.NewDBStore()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create database store: %v", err)
|
||||
}
|
||||
defer dbStore.Close()
|
||||
fmt.Printf("Database store created in %v\n", time.Since(startTime))
|
||||
|
||||
// Create user, account, and transaction stores
|
||||
userStore := &models.UserStore{Store: dbStore}
|
||||
accountStore := models.NewAccountStore(dbStore)
|
||||
transactionStore := models.NewTransactionStore(dbStore)
|
||||
|
||||
// Create users
|
||||
fmt.Printf("Creating %d users...\n", numUsers)
|
||||
startTime = time.Now()
|
||||
users := make([]*models.User, numUsers)
|
||||
for i := 0; i < numUsers; i++ {
|
||||
user := &models.User{
|
||||
Key: generateRandomName(),
|
||||
Accounts: []uint32{},
|
||||
}
|
||||
|
||||
if err := userStore.Save(user); err != nil {
|
||||
log.Fatalf("Failed to save user %d: %v", i, err)
|
||||
}
|
||||
|
||||
users[i] = user
|
||||
|
||||
if (i+1)%100 == 0 {
|
||||
fmt.Printf("Created %d users...\n", i+1)
|
||||
}
|
||||
}
|
||||
fmt.Printf("Created %d users in %v\n", numUsers, time.Since(startTime))
|
||||
|
||||
// Create accounts
|
||||
fmt.Printf("Creating %d accounts...\n", numAccounts)
|
||||
startTime = time.Now()
|
||||
accounts := make([]*models.Account, numAccounts)
|
||||
for i := 0; i < numAccounts; i++ {
|
||||
// Assign to a random user
|
||||
userIndex := rand.Intn(numUsers)
|
||||
user := users[userIndex]
|
||||
|
||||
account := &models.Account{
|
||||
UserID: user.ID,
|
||||
Name: generateRandomAccountName(),
|
||||
Currency: generateRandomCurrency(),
|
||||
Amount: generateRandomAmount(1000, 10000),
|
||||
}
|
||||
|
||||
if err := accountStore.Save(account); err != nil {
|
||||
log.Fatalf("Failed to save account %d: %v", i, err)
|
||||
}
|
||||
|
||||
accounts[i] = account
|
||||
|
||||
if (i+1)%300 == 0 {
|
||||
fmt.Printf("Created %d accounts...\n", i+1)
|
||||
}
|
||||
}
|
||||
fmt.Printf("Created %d accounts in %v\n", numAccounts, time.Since(startTime))
|
||||
|
||||
// Create transactions
|
||||
fmt.Printf("Creating %d transactions...\n", numTransactions)
|
||||
startTime = time.Now()
|
||||
var memStats runtime.MemStats
|
||||
|
||||
for i := 0; i < numTransactions; i++ {
|
||||
// Select random source and destination accounts
|
||||
fromIndex := rand.Intn(numAccounts)
|
||||
var toIndex int
|
||||
for {
|
||||
toIndex = rand.Intn(numAccounts)
|
||||
if toIndex != fromIndex {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fromAccount := accounts[fromIndex]
|
||||
toAccount := accounts[toIndex]
|
||||
|
||||
// Generate a realistic amount (usually smaller than the available balance)
|
||||
maxAmount := fromAccount.Amount * 0.2
|
||||
if maxAmount < 10 {
|
||||
maxAmount = 10
|
||||
}
|
||||
amount := generateRandomAmount(1, maxAmount)
|
||||
|
||||
transaction := &models.Transaction{
|
||||
From: fromAccount.ID,
|
||||
To: toAccount.ID,
|
||||
Comment: generateRandomComment(fromAccount, toAccount, amount),
|
||||
Amount: amount,
|
||||
}
|
||||
|
||||
if err := transactionStore.Save(transaction); err != nil {
|
||||
log.Fatalf("Failed to save transaction %d: %v", i, err)
|
||||
}
|
||||
|
||||
// Update our local copy of the accounts
|
||||
fromAccount.Amount -= amount
|
||||
toAccount.Amount += amount
|
||||
|
||||
if (i+1)%1000 == 0 {
|
||||
elapsed := time.Since(startTime)
|
||||
transPerSec := float64(i+1) / elapsed.Seconds()
|
||||
|
||||
// Collect memory stats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
fmt.Printf("Created %d transactions (%.2f/sec), Memory: %.2f MB\n",
|
||||
i+1,
|
||||
transPerSec,
|
||||
float64(memStats.Alloc)/1024/1024)
|
||||
}
|
||||
}
|
||||
|
||||
totalTime := time.Since(startTime)
|
||||
transPerSec := float64(numTransactions) / totalTime.Seconds()
|
||||
fmt.Printf("Created %d transactions in %v (%.2f transactions/sec)\n",
|
||||
numTransactions,
|
||||
totalTime,
|
||||
transPerSec)
|
||||
|
||||
// Memory profile
|
||||
memProfile, err := os.Create("mem_profile.prof")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create memory profile: %v", err)
|
||||
}
|
||||
runtime.GC() // Run garbage collection before taking memory profile
|
||||
if err := pprof.WriteHeapProfile(memProfile); err != nil {
|
||||
log.Fatalf("Failed to write memory profile: %v", err)
|
||||
}
|
||||
memProfile.Close()
|
||||
|
||||
// Final memory stats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
fmt.Printf("\nFinal Memory Stats:\n")
|
||||
fmt.Printf("Alloc: %.2f MB\n", float64(memStats.Alloc)/1024/1024)
|
||||
fmt.Printf("TotalAlloc: %.2f MB\n", float64(memStats.TotalAlloc)/1024/1024)
|
||||
fmt.Printf("Sys: %.2f MB\n", float64(memStats.Sys)/1024/1024)
|
||||
fmt.Printf("NumGC: %d\n", memStats.NumGC)
|
||||
|
||||
// Test retrieving random transactions
|
||||
fmt.Println("\nTesting random transaction retrieval...")
|
||||
startTime = time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
txID := rand.Uint32()%uint32(numTransactions) + 1
|
||||
_, err := transactionStore.GetByID(txID)
|
||||
if err != nil {
|
||||
// Skip errors for non-existent transactions
|
||||
continue
|
||||
}
|
||||
}
|
||||
fmt.Printf("Retrieved 1000 random transactions in %v\n", time.Since(startTime))
|
||||
|
||||
// Test retrieving random accounts
|
||||
fmt.Println("\nTesting random account retrieval...")
|
||||
startTime = time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
accountID := rand.Uint32()%uint32(numAccounts) + 1
|
||||
_, err := accountStore.GetByID(accountID)
|
||||
if err != nil {
|
||||
// Skip errors for non-existent accounts
|
||||
continue
|
||||
}
|
||||
}
|
||||
fmt.Printf("Retrieved 1000 random accounts in %v\n", time.Since(startTime))
|
||||
|
||||
// Test retrieving random users
|
||||
fmt.Println("\nTesting random user retrieval...")
|
||||
startTime = time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
userID := rand.Uint32()%uint32(numUsers) + 1
|
||||
_, err := userStore.GetByID(userID)
|
||||
if err != nil {
|
||||
// Skip errors for non-existent users
|
||||
continue
|
||||
}
|
||||
}
|
||||
fmt.Printf("Retrieved 1000 random users in %v\n", time.Since(startTime))
|
||||
|
||||
fmt.Println("\nPerformance test completed successfully!")
|
||||
fmt.Printf("CPU profile saved to cpu_profile.prof\n")
|
||||
fmt.Printf("Memory profile saved to mem_profile.prof\n")
|
||||
fmt.Printf("To analyze profiles, run: go tool pprof cpu_profile.prof\n")
|
||||
}
|
546
pkg2_dont_use/heroservices/billing/models/db.go
Normal file
546
pkg2_dont_use/heroservices/billing/models/db.go
Normal file
@@ -0,0 +1,546 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"git.ourworld.tf/herocode/heroagent/pkg/data/ourdb"
|
||||
"git.ourworld.tf/herocode/heroagent/pkg/data/radixtree"
|
||||
"git.ourworld.tf/herocode/heroagent/pkg/tools"
|
||||
)
|
||||
|
||||
// DBStore represents the central database store for all models
|
||||
type DBStore struct {
|
||||
DB *ourdb.OurDB
|
||||
Meta *radixtree.RadixTree
|
||||
}
|
||||
|
||||
// NewDBStore creates a new database store
|
||||
func NewDBStore() (*DBStore, error) {
|
||||
// Get home directory
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
// Create paths
|
||||
dbPath := filepath.Join(homeDir, "hero", "var", "server", "users", "db")
|
||||
metaPath := filepath.Join(homeDir, "hero", "var", "server", "users", "meta")
|
||||
|
||||
// Create directories if they don't exist
|
||||
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create db directory: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(metaPath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create meta directory: %w", err)
|
||||
}
|
||||
|
||||
// Create ourdb
|
||||
dbConfig := ourdb.DefaultConfig()
|
||||
dbConfig.Path = dbPath
|
||||
dbConfig.IncrementalMode = true
|
||||
db, err := ourdb.New(dbConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ourdb: %w", err)
|
||||
}
|
||||
|
||||
// Create radixtree
|
||||
meta, err := radixtree.New(radixtree.NewArgs{
|
||||
Path: metaPath,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create radixtree: %w", err)
|
||||
}
|
||||
|
||||
return &DBStore{
|
||||
DB: db,
|
||||
Meta: meta,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the database connections
|
||||
func (s *DBStore) Close() error {
|
||||
err1 := s.DB.Close()
|
||||
err2 := s.Meta.Close()
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
// UserStore handles the storage and retrieval of User objects
|
||||
type UserStore struct {
|
||||
Store *DBStore
|
||||
}
|
||||
|
||||
// NewUserStore creates a new UserStore
|
||||
func NewUserStore() (*UserStore, *DBStore, error) {
|
||||
store, err := NewDBStore()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &UserStore{
|
||||
Store: store,
|
||||
}, store, nil
|
||||
}
|
||||
|
||||
// Save saves a User to the database
|
||||
func (s *UserStore) Save(user *User) error {
|
||||
// If ID is 0, this is a new user
|
||||
if user.ID == 0 {
|
||||
// Fix the key
|
||||
fixedKey := tools.NameFix(user.Key)
|
||||
if fixedKey == "" {
|
||||
return errors.New("key cannot be empty")
|
||||
}
|
||||
user.Key = fixedKey
|
||||
|
||||
// Check if key already exists
|
||||
existingID, err := s.GetIDByKey(user.Key)
|
||||
if err == nil && existingID != 0 {
|
||||
return fmt.Errorf("user with key %s already exists", user.Key)
|
||||
}
|
||||
|
||||
// Get next ID
|
||||
nextID, err := s.Store.DB.GetNextID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get next ID: %w", err)
|
||||
}
|
||||
user.ID = nextID
|
||||
|
||||
// Save to ourdb
|
||||
data := user.Serialize()
|
||||
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save user to ourdb: %w", err)
|
||||
}
|
||||
|
||||
// Save key to radixtree
|
||||
idBytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(idBytes, user.ID)
|
||||
err = s.Store.Meta.Set(user.Key, idBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save user key to radixtree: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Update existing user
|
||||
// Get existing user to verify key
|
||||
existingUser, err := s.GetByID(user.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing user: %w", err)
|
||||
}
|
||||
|
||||
// If key changed, update radixtree
|
||||
if existingUser.Key != user.Key {
|
||||
// Fix the new key
|
||||
fixedKey := tools.NameFix(user.Key)
|
||||
if fixedKey == "" {
|
||||
return errors.New("key cannot be empty")
|
||||
}
|
||||
user.Key = fixedKey
|
||||
|
||||
// Check if new key already exists
|
||||
existingID, err := s.GetIDByKey(user.Key)
|
||||
if err == nil && existingID != 0 && existingID != user.ID {
|
||||
return fmt.Errorf("user with key %s already exists", user.Key)
|
||||
}
|
||||
|
||||
// Delete old key from radixtree
|
||||
err = s.Store.Meta.Delete(existingUser.Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old user key from radixtree: %w", err)
|
||||
}
|
||||
|
||||
// Save new key to radixtree
|
||||
idBytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(idBytes, user.ID)
|
||||
err = s.Store.Meta.Set(user.Key, idBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save new user key to radixtree: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save to ourdb
|
||||
data := user.Serialize()
|
||||
id := user.ID
|
||||
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
|
||||
ID: &id,
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user in ourdb: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a User by ID
|
||||
func (s *UserStore) GetByID(id uint32) (*User, error) {
|
||||
data, err := s.Store.DB.Get(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user from ourdb: %w", err)
|
||||
}
|
||||
|
||||
user, err := DeserializeUser(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to deserialize user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetByKey retrieves a User by key
|
||||
func (s *UserStore) GetByKey(key string) (*User, error) {
|
||||
// Fix the key
|
||||
fixedKey := tools.NameFix(key)
|
||||
if fixedKey == "" {
|
||||
return nil, errors.New("key cannot be empty")
|
||||
}
|
||||
|
||||
// Get ID from radixtree
|
||||
id, err := s.GetIDByKey(fixedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user ID from radixtree: %w", err)
|
||||
}
|
||||
|
||||
// Get user from ourdb
|
||||
return s.GetByID(id)
|
||||
}
|
||||
|
||||
// GetIDByKey retrieves a user ID by key
|
||||
func (s *UserStore) GetIDByKey(key string) (uint32, error) {
|
||||
// Get ID from radixtree
|
||||
idBytes, err := s.Store.Meta.Get(key)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get user ID from radixtree: %w", err)
|
||||
}
|
||||
|
||||
if len(idBytes) != 4 {
|
||||
return 0, fmt.Errorf("invalid ID bytes length: %d", len(idBytes))
|
||||
}
|
||||
|
||||
return binary.LittleEndian.Uint32(idBytes), nil
|
||||
}
|
||||
|
||||
// Delete deletes a User by ID
|
||||
func (s *UserStore) Delete(id uint32) error {
|
||||
// Get user to get key
|
||||
user, err := s.GetByID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Delete from radixtree
|
||||
err = s.Store.Meta.Delete(user.Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user key from radixtree: %w", err)
|
||||
}
|
||||
|
||||
// Delete from ourdb
|
||||
err = s.Store.DB.Delete(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user from ourdb: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountStore handles the storage and retrieval of Account objects
|
||||
type AccountStore struct {
|
||||
Store *DBStore
|
||||
}
|
||||
|
||||
// NewAccountStore creates a new AccountStore using the provided DBStore
|
||||
func NewAccountStore(store *DBStore) *AccountStore {
|
||||
return &AccountStore{
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Save saves an Account to the database
|
||||
func (s *AccountStore) Save(account *Account) error {
|
||||
// Validate account data
|
||||
if err := account.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If ID is 0, this is a new account
|
||||
if account.ID == 0 {
|
||||
// Get next ID
|
||||
nextID, err := s.Store.DB.GetNextID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get next ID: %w", err)
|
||||
}
|
||||
account.ID = nextID
|
||||
|
||||
// Save to ourdb
|
||||
data := account.Serialize()
|
||||
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save account to ourdb: %w", err)
|
||||
}
|
||||
|
||||
// Update user's accounts list
|
||||
userStore := &UserStore{Store: s.Store}
|
||||
user, err := userStore.GetByID(account.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Add account ID to user's accounts
|
||||
user.Accounts = append(user.Accounts, account.ID)
|
||||
|
||||
// Save updated user
|
||||
err = userStore.Save(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user's accounts: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Update existing account
|
||||
// Get existing account to verify it exists
|
||||
_, err := s.GetByID(account.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing account: %w", err)
|
||||
}
|
||||
|
||||
// Save to ourdb
|
||||
data := account.Serialize()
|
||||
id := account.ID
|
||||
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
|
||||
ID: &id,
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update account in ourdb: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves an Account by ID
|
||||
func (s *AccountStore) GetByID(id uint32) (*Account, error) {
|
||||
data, err := s.Store.DB.Get(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account from ourdb: %w", err)
|
||||
}
|
||||
|
||||
account, err := DeserializeAccount(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to deserialize account: %w", err)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// GetByUserID retrieves all Accounts for a user
|
||||
func (s *AccountStore) GetByUserID(userID uint32) ([]*Account, error) {
|
||||
// Get user to get account IDs
|
||||
userStore := &UserStore{Store: s.Store}
|
||||
user, err := userStore.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Get accounts
|
||||
accounts := make([]*Account, 0, len(user.Accounts))
|
||||
for _, accountID := range user.Accounts {
|
||||
account, err := s.GetByID(accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account %d: %w", accountID, err)
|
||||
}
|
||||
accounts = append(accounts, account)
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// Delete deletes an Account by ID
|
||||
func (s *AccountStore) Delete(id uint32) error {
|
||||
// Get account to get user ID
|
||||
account, err := s.GetByID(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %w", err)
|
||||
}
|
||||
|
||||
// Get user to remove account from accounts list
|
||||
userStore := &UserStore{Store: s.Store}
|
||||
user, err := userStore.GetByID(account.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Remove account ID from user's accounts
|
||||
for i, accountID := range user.Accounts {
|
||||
if accountID == id {
|
||||
user.Accounts = append(user.Accounts[:i], user.Accounts[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Save updated user
|
||||
err = userStore.Save(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user's accounts: %w", err)
|
||||
}
|
||||
|
||||
// Delete from ourdb
|
||||
err = s.Store.DB.Delete(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete account from ourdb: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TransactionStore handles the storage and retrieval of Transaction objects
|
||||
type TransactionStore struct {
|
||||
Store *DBStore
|
||||
}
|
||||
|
||||
// NewTransactionStore creates a new TransactionStore using the provided DBStore
|
||||
func NewTransactionStore(store *DBStore) *TransactionStore {
|
||||
return &TransactionStore{
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Save saves a Transaction to the database and updates account balances
|
||||
func (s *TransactionStore) Save(transaction *Transaction) error {
|
||||
// Validate transaction data
|
||||
if err := transaction.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get account store
|
||||
accountStore := &AccountStore{Store: s.Store}
|
||||
|
||||
// Get source account
|
||||
fromAccount, err := accountStore.GetByID(transaction.From)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source account: %w", err)
|
||||
}
|
||||
|
||||
// Get destination account
|
||||
toAccount, err := accountStore.GetByID(transaction.To)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get destination account: %w", err)
|
||||
}
|
||||
|
||||
// Check if source account has enough balance
|
||||
if fromAccount.Amount < transaction.Amount {
|
||||
return errors.New("insufficient balance in source account")
|
||||
}
|
||||
|
||||
// If TxID is 0, this is a new transaction
|
||||
if transaction.TxID == 0 {
|
||||
// Get next ID
|
||||
nextID, err := s.Store.DB.GetNextID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get next ID: %w", err)
|
||||
}
|
||||
transaction.TxID = nextID
|
||||
|
||||
// Update account balances
|
||||
fromAccount.Amount -= transaction.Amount
|
||||
toAccount.Amount += transaction.Amount
|
||||
|
||||
// Save updated accounts
|
||||
if err := accountStore.Save(fromAccount); err != nil {
|
||||
return fmt.Errorf("failed to update source account: %w", err)
|
||||
}
|
||||
|
||||
if err := accountStore.Save(toAccount); err != nil {
|
||||
// Rollback source account change if destination update fails
|
||||
fromAccount.Amount += transaction.Amount
|
||||
if rollbackErr := accountStore.Save(fromAccount); rollbackErr != nil {
|
||||
return fmt.Errorf("failed to update destination account and rollback failed: %w, rollback error: %v", err, rollbackErr)
|
||||
}
|
||||
return fmt.Errorf("failed to update destination account: %w", err)
|
||||
}
|
||||
|
||||
// Save transaction to ourdb
|
||||
data := transaction.Serialize()
|
||||
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
// Rollback account changes if transaction save fails
|
||||
fromAccount.Amount += transaction.Amount
|
||||
toAccount.Amount -= transaction.Amount
|
||||
if rollbackErr1 := accountStore.Save(fromAccount); rollbackErr1 != nil {
|
||||
return fmt.Errorf("failed to save transaction and rollback failed: %w, rollback error 1: %v", err, rollbackErr1)
|
||||
}
|
||||
if rollbackErr2 := accountStore.Save(toAccount); rollbackErr2 != nil {
|
||||
return fmt.Errorf("failed to save transaction and rollback failed: %w, rollback error 2: %v", err, rollbackErr2)
|
||||
}
|
||||
return fmt.Errorf("failed to save transaction to ourdb: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Updating existing transactions is not allowed as it would affect account balances
|
||||
return errors.New("updating existing transactions is not allowed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a Transaction by ID
|
||||
func (s *TransactionStore) GetByID(id uint32) (*Transaction, error) {
|
||||
data, err := s.Store.DB.Get(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get transaction from ourdb: %w", err)
|
||||
}
|
||||
|
||||
transaction, err := DeserializeTransaction(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to deserialize transaction: %w", err)
|
||||
}
|
||||
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
// GetByAccount retrieves all Transactions for an account (either as source or destination)
|
||||
func (s *TransactionStore) GetByAccount(accountID uint32) ([]*Transaction, error) {
|
||||
// This is a simplified implementation that would be inefficient in a real system
|
||||
// In a real system, we would maintain indexes for account transactions
|
||||
|
||||
// Get the next ID to determine the range of transaction IDs
|
||||
nextID, err := s.Store.DB.GetNextID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get next ID: %w", err)
|
||||
}
|
||||
|
||||
transactions := make([]*Transaction, 0)
|
||||
|
||||
// Iterate through all transactions (this is inefficient but works for a simple implementation)
|
||||
for i := uint32(1); i < nextID; i++ {
|
||||
transaction, err := s.GetByID(i)
|
||||
if err != nil {
|
||||
// Skip if transaction doesn't exist or can't be deserialized
|
||||
continue
|
||||
}
|
||||
|
||||
// Include transaction if it involves the specified account
|
||||
if transaction.From == accountID || transaction.To == accountID {
|
||||
transactions = append(transactions, transaction)
|
||||
}
|
||||
}
|
||||
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// Delete is not implemented for transactions as it would affect account balances
|
||||
// In a real system, you might want to implement a "void" or "reverse" transaction instead
|
||||
func (s *TransactionStore) Delete(id uint32) error {
|
||||
return errors.New("deleting transactions is not allowed")
|
||||
}
|
489
pkg2_dont_use/heroservices/billing/models/models_test.go
Normal file
489
pkg2_dont_use/heroservices/billing/models/models_test.go
Normal file
@@ -0,0 +1,489 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setupTestEnvironment creates a temporary test environment
|
||||
func setupTestEnvironment(t *testing.T) (string, func()) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir, err := os.MkdirTemp("", "billing-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
|
||||
// Create subdirectories
|
||||
dbPath := filepath.Join(tempDir, "db")
|
||||
metaPath := filepath.Join(tempDir, "meta")
|
||||
|
||||
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
||||
t.Fatalf("Failed to create db directory: %v", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(metaPath, 0755); err != nil {
|
||||
t.Fatalf("Failed to create meta directory: %v", err)
|
||||
}
|
||||
|
||||
// Set environment variables to use the temp directory
|
||||
originalHome := os.Getenv("HOME")
|
||||
os.Setenv("HOME", tempDir)
|
||||
|
||||
// Return cleanup function
|
||||
cleanup := func() {
|
||||
os.Setenv("HOME", originalHome)
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return tempDir, cleanup
|
||||
}
|
||||
|
||||
// TestUserCRUD tests the CRUD operations for User
|
||||
func TestUserCRUD(t *testing.T) {
|
||||
_, cleanup := setupTestEnvironment(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a new database store
|
||||
dbStore, err := NewDBStore()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database store: %v", err)
|
||||
}
|
||||
defer dbStore.Close()
|
||||
|
||||
// Create a new user store
|
||||
userStore := &UserStore{Store: dbStore}
|
||||
|
||||
// Create a new user
|
||||
user := &User{
|
||||
Key: "john_doe",
|
||||
Accounts: []uint32{},
|
||||
}
|
||||
|
||||
// Save the user
|
||||
err = userStore.Save(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save user: %v", err)
|
||||
}
|
||||
|
||||
// Verify the user ID was set
|
||||
if user.ID == 0 {
|
||||
t.Fatalf("User ID was not set")
|
||||
}
|
||||
|
||||
// Get the user by ID
|
||||
retrievedUser, err := userStore.GetByID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user by ID: %v", err)
|
||||
}
|
||||
|
||||
// Verify the user data
|
||||
if retrievedUser.ID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, retrievedUser.ID)
|
||||
}
|
||||
if retrievedUser.Key != user.Key {
|
||||
t.Errorf("Expected user key %s, got %s", user.Key, retrievedUser.Key)
|
||||
}
|
||||
|
||||
// Get the user by key
|
||||
retrievedUser, err = userStore.GetByKey("john_doe")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user by key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the user data
|
||||
if retrievedUser.ID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, retrievedUser.ID)
|
||||
}
|
||||
|
||||
// Update the user
|
||||
user.Key = "jane_doe"
|
||||
err = userStore.Save(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update user: %v", err)
|
||||
}
|
||||
|
||||
// Verify the update
|
||||
retrievedUser, err = userStore.GetByKey("jane_doe")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated user by key: %v", err)
|
||||
}
|
||||
if retrievedUser.Key != "jane_doe" {
|
||||
t.Errorf("Expected updated user key 'jane_doe', got %s", retrievedUser.Key)
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
err = userStore.Delete(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
// Verify the user was deleted
|
||||
_, err = userStore.GetByID(user.ID)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when getting deleted user, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccountCRUD tests the CRUD operations for Account
|
||||
func TestAccountCRUD(t *testing.T) {
|
||||
_, cleanup := setupTestEnvironment(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a new database store
|
||||
dbStore, err := NewDBStore()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database store: %v", err)
|
||||
}
|
||||
defer dbStore.Close()
|
||||
|
||||
// Create a new user store
|
||||
userStore := &UserStore{Store: dbStore}
|
||||
|
||||
// Create a new user
|
||||
user := &User{
|
||||
Key: "john_doe",
|
||||
Accounts: []uint32{},
|
||||
}
|
||||
|
||||
// Save the user
|
||||
err = userStore.Save(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save user: %v", err)
|
||||
}
|
||||
|
||||
// Create a new account store
|
||||
accountStore := NewAccountStore(dbStore)
|
||||
|
||||
// Create a new account
|
||||
account := &Account{
|
||||
UserID: user.ID,
|
||||
Name: "Savings",
|
||||
Currency: "USD",
|
||||
Amount: 1000.0,
|
||||
}
|
||||
|
||||
// Save the account
|
||||
err = accountStore.Save(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save account: %v", err)
|
||||
}
|
||||
|
||||
// Verify the account ID was set
|
||||
if account.ID == 0 {
|
||||
t.Fatalf("Account ID was not set")
|
||||
}
|
||||
|
||||
// Get the account by ID
|
||||
retrievedAccount, err := accountStore.GetByID(account.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account by ID: %v", err)
|
||||
}
|
||||
|
||||
// Verify the account data
|
||||
if retrievedAccount.ID != account.ID {
|
||||
t.Errorf("Expected account ID %d, got %d", account.ID, retrievedAccount.ID)
|
||||
}
|
||||
if retrievedAccount.UserID != user.ID {
|
||||
t.Errorf("Expected account user ID %d, got %d", user.ID, retrievedAccount.UserID)
|
||||
}
|
||||
if retrievedAccount.Name != "Savings" {
|
||||
t.Errorf("Expected account name 'Savings', got %s", retrievedAccount.Name)
|
||||
}
|
||||
if retrievedAccount.Currency != "USD" {
|
||||
t.Errorf("Expected account currency 'USD', got %s", retrievedAccount.Currency)
|
||||
}
|
||||
if retrievedAccount.Amount != 1000.0 {
|
||||
t.Errorf("Expected account amount 1000.0, got %f", retrievedAccount.Amount)
|
||||
}
|
||||
|
||||
// Verify the user's accounts list was updated
|
||||
retrievedUser, err := userStore.GetByID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user: %v", err)
|
||||
}
|
||||
if len(retrievedUser.Accounts) != 1 {
|
||||
t.Errorf("Expected user to have 1 account, got %d", len(retrievedUser.Accounts))
|
||||
}
|
||||
if retrievedUser.Accounts[0] != account.ID {
|
||||
t.Errorf("Expected user account ID %d, got %d", account.ID, retrievedUser.Accounts[0])
|
||||
}
|
||||
|
||||
// Update the account
|
||||
account.Name = "Checking"
|
||||
account.Amount = 2000.0
|
||||
err = accountStore.Save(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update account: %v", err)
|
||||
}
|
||||
|
||||
// Verify the update
|
||||
retrievedAccount, err = accountStore.GetByID(account.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated account: %v", err)
|
||||
}
|
||||
if retrievedAccount.Name != "Checking" {
|
||||
t.Errorf("Expected updated account name 'Checking', got %s", retrievedAccount.Name)
|
||||
}
|
||||
if retrievedAccount.Amount != 2000.0 {
|
||||
t.Errorf("Expected updated account amount 2000.0, got %f", retrievedAccount.Amount)
|
||||
}
|
||||
|
||||
// Delete the account
|
||||
err = accountStore.Delete(account.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete account: %v", err)
|
||||
}
|
||||
|
||||
// Verify the account was deleted
|
||||
_, err = accountStore.GetByID(account.ID)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when getting deleted account, got nil")
|
||||
}
|
||||
|
||||
// Verify the user's accounts list was updated
|
||||
retrievedUser, err = userStore.GetByID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user after account deletion: %v", err)
|
||||
}
|
||||
if len(retrievedUser.Accounts) != 0 {
|
||||
t.Errorf("Expected user to have 0 accounts after deletion, got %d", len(retrievedUser.Accounts))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransactionCRUD tests the CRUD operations for Transaction
|
||||
func TestTransactionCRUD(t *testing.T) {
|
||||
_, cleanup := setupTestEnvironment(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a new database store
|
||||
dbStore, err := NewDBStore()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database store: %v", err)
|
||||
}
|
||||
defer dbStore.Close()
|
||||
|
||||
// Create a new user store
|
||||
userStore := &UserStore{Store: dbStore}
|
||||
|
||||
// Create a new user
|
||||
user := &User{
|
||||
Key: "john_doe",
|
||||
Accounts: []uint32{},
|
||||
}
|
||||
|
||||
// Save the user
|
||||
err = userStore.Save(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save user: %v", err)
|
||||
}
|
||||
|
||||
// Create a new account store
|
||||
accountStore := NewAccountStore(dbStore)
|
||||
|
||||
// Create source account
|
||||
sourceAccount := &Account{
|
||||
UserID: user.ID,
|
||||
Name: "Checking",
|
||||
Currency: "USD",
|
||||
Amount: 1000.0,
|
||||
}
|
||||
|
||||
// Save the source account
|
||||
err = accountStore.Save(sourceAccount)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save source account: %v", err)
|
||||
}
|
||||
|
||||
// Create destination account
|
||||
destAccount := &Account{
|
||||
UserID: user.ID,
|
||||
Name: "Savings",
|
||||
Currency: "USD",
|
||||
Amount: 500.0,
|
||||
}
|
||||
|
||||
// Save the destination account
|
||||
err = accountStore.Save(destAccount)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save destination account: %v", err)
|
||||
}
|
||||
|
||||
// Create a new transaction store
|
||||
transactionStore := NewTransactionStore(dbStore)
|
||||
|
||||
// Create a new transaction
|
||||
transaction := &Transaction{
|
||||
From: sourceAccount.ID,
|
||||
To: destAccount.ID,
|
||||
Comment: "Transfer to savings",
|
||||
Amount: 200.0,
|
||||
}
|
||||
|
||||
// Save the transaction
|
||||
err = transactionStore.Save(transaction)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save transaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify the transaction ID was set
|
||||
if transaction.TxID == 0 {
|
||||
t.Fatalf("Transaction ID was not set")
|
||||
}
|
||||
|
||||
// Get the transaction by ID
|
||||
retrievedTransaction, err := transactionStore.GetByID(transaction.TxID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get transaction by ID: %v", err)
|
||||
}
|
||||
|
||||
// Verify the transaction data
|
||||
if retrievedTransaction.TxID != transaction.TxID {
|
||||
t.Errorf("Expected transaction ID %d, got %d", transaction.TxID, retrievedTransaction.TxID)
|
||||
}
|
||||
if retrievedTransaction.From != sourceAccount.ID {
|
||||
t.Errorf("Expected transaction from %d, got %d", sourceAccount.ID, retrievedTransaction.From)
|
||||
}
|
||||
if retrievedTransaction.To != destAccount.ID {
|
||||
t.Errorf("Expected transaction to %d, got %d", destAccount.ID, retrievedTransaction.To)
|
||||
}
|
||||
if retrievedTransaction.Comment != "Transfer to savings" {
|
||||
t.Errorf("Expected transaction comment 'Transfer to savings', got %s", retrievedTransaction.Comment)
|
||||
}
|
||||
if retrievedTransaction.Amount != 200.0 {
|
||||
t.Errorf("Expected transaction amount 200.0, got %f", retrievedTransaction.Amount)
|
||||
}
|
||||
|
||||
// Verify account balances were updated
|
||||
updatedSourceAccount, err := accountStore.GetByID(sourceAccount.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated source account: %v", err)
|
||||
}
|
||||
if updatedSourceAccount.Amount != 800.0 {
|
||||
t.Errorf("Expected source account amount 800.0, got %f", updatedSourceAccount.Amount)
|
||||
}
|
||||
|
||||
updatedDestAccount, err := accountStore.GetByID(destAccount.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated destination account: %v", err)
|
||||
}
|
||||
if updatedDestAccount.Amount != 700.0 {
|
||||
t.Errorf("Expected destination account amount 700.0, got %f", updatedDestAccount.Amount)
|
||||
}
|
||||
|
||||
// Verify getting transactions by account
|
||||
sourceTransactions, err := transactionStore.GetByAccount(sourceAccount.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get transactions by source account: %v", err)
|
||||
}
|
||||
if len(sourceTransactions) != 1 {
|
||||
t.Errorf("Expected 1 transaction for source account, got %d", len(sourceTransactions))
|
||||
}
|
||||
|
||||
destTransactions, err := transactionStore.GetByAccount(destAccount.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get transactions by destination account: %v", err)
|
||||
}
|
||||
if len(destTransactions) != 1 {
|
||||
t.Errorf("Expected 1 transaction for destination account, got %d", len(destTransactions))
|
||||
}
|
||||
|
||||
// Verify that deleting transactions is not allowed
|
||||
err = transactionStore.Delete(transaction.TxID)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when deleting transaction, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializationDeserialization tests the serialization and deserialization of models
|
||||
func TestSerializationDeserialization(t *testing.T) {
|
||||
// Test User serialization/deserialization
|
||||
user := &User{
|
||||
ID: 42,
|
||||
Key: "test_user",
|
||||
Accounts: []uint32{1, 2, 3},
|
||||
}
|
||||
|
||||
userData := user.Serialize()
|
||||
deserializedUser, err := DeserializeUser(userData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize user: %v", err)
|
||||
}
|
||||
|
||||
if deserializedUser.ID != user.ID {
|
||||
t.Errorf("Expected user ID %d, got %d", user.ID, deserializedUser.ID)
|
||||
}
|
||||
if deserializedUser.Key != user.Key {
|
||||
t.Errorf("Expected user key %s, got %s", user.Key, deserializedUser.Key)
|
||||
}
|
||||
if len(deserializedUser.Accounts) != len(user.Accounts) {
|
||||
t.Errorf("Expected %d accounts, got %d", len(user.Accounts), len(deserializedUser.Accounts))
|
||||
}
|
||||
for i, accountID := range user.Accounts {
|
||||
if deserializedUser.Accounts[i] != accountID {
|
||||
t.Errorf("Expected account ID %d at index %d, got %d", accountID, i, deserializedUser.Accounts[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Test Account serialization/deserialization
|
||||
account := &Account{
|
||||
ID: 24,
|
||||
UserID: 42,
|
||||
Name: "Test Account",
|
||||
Currency: "USD",
|
||||
Amount: 123.45,
|
||||
}
|
||||
|
||||
accountData := account.Serialize()
|
||||
deserializedAccount, err := DeserializeAccount(accountData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize account: %v", err)
|
||||
}
|
||||
|
||||
if deserializedAccount.ID != account.ID {
|
||||
t.Errorf("Expected account ID %d, got %d", account.ID, deserializedAccount.ID)
|
||||
}
|
||||
if deserializedAccount.UserID != account.UserID {
|
||||
t.Errorf("Expected account user ID %d, got %d", account.UserID, deserializedAccount.UserID)
|
||||
}
|
||||
if deserializedAccount.Name != account.Name {
|
||||
t.Errorf("Expected account name %s, got %s", account.Name, deserializedAccount.Name)
|
||||
}
|
||||
if deserializedAccount.Currency != account.Currency {
|
||||
t.Errorf("Expected account currency %s, got %s", account.Currency, deserializedAccount.Currency)
|
||||
}
|
||||
if deserializedAccount.Amount != account.Amount {
|
||||
t.Errorf("Expected account amount %f, got %f", account.Amount, deserializedAccount.Amount)
|
||||
}
|
||||
|
||||
// Test Transaction serialization/deserialization
|
||||
transaction := &Transaction{
|
||||
TxID: 123,
|
||||
From: 42,
|
||||
To: 24,
|
||||
Comment: "Test Transaction",
|
||||
Amount: 67.89,
|
||||
}
|
||||
|
||||
transactionData := transaction.Serialize()
|
||||
deserializedTransaction, err := DeserializeTransaction(transactionData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize transaction: %v", err)
|
||||
}
|
||||
|
||||
if deserializedTransaction.TxID != transaction.TxID {
|
||||
t.Errorf("Expected transaction ID %d, got %d", transaction.TxID, deserializedTransaction.TxID)
|
||||
}
|
||||
if deserializedTransaction.From != transaction.From {
|
||||
t.Errorf("Expected transaction from %d, got %d", transaction.From, deserializedTransaction.From)
|
||||
}
|
||||
if deserializedTransaction.To != transaction.To {
|
||||
t.Errorf("Expected transaction to %d, got %d", transaction.To, deserializedTransaction.To)
|
||||
}
|
||||
if deserializedTransaction.Comment != transaction.Comment {
|
||||
t.Errorf("Expected transaction comment %s, got %s", transaction.Comment, deserializedTransaction.Comment)
|
||||
}
|
||||
if deserializedTransaction.Amount != transaction.Amount {
|
||||
t.Errorf("Expected transaction amount %f, got %f", transaction.Amount, deserializedTransaction.Amount)
|
||||
}
|
||||
}
|
100
pkg2_dont_use/heroservices/billing/models/transaction.go
Normal file
100
pkg2_dont_use/heroservices/billing/models/transaction.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math"
|
||||
)
|
||||
|
||||
// Transaction represents a financial transaction in the billing system
|
||||
type Transaction struct {
|
||||
TxID uint32 // Transaction ID, auto-incremented
|
||||
From uint32 // Source account ID
|
||||
To uint32 // Destination account ID
|
||||
Comment string // Transaction comment
|
||||
Amount float64 // Transaction amount
|
||||
}
|
||||
|
||||
// Serialize converts a Transaction to a binary representation
|
||||
func (t *Transaction) Serialize() []byte {
|
||||
// Calculate the size of the serialized data
|
||||
// 4 bytes for TxID + 4 bytes for From + 4 bytes for To +
|
||||
// 2 bytes for comment length + len(comment) bytes + 8 bytes for amount
|
||||
size := 4 + 4 + 4 + 2 + len(t.Comment) + 8
|
||||
data := make([]byte, size)
|
||||
|
||||
// Write TxID (4 bytes)
|
||||
binary.LittleEndian.PutUint32(data[0:4], t.TxID)
|
||||
|
||||
// Write From (4 bytes)
|
||||
binary.LittleEndian.PutUint32(data[4:8], t.From)
|
||||
|
||||
// Write To (4 bytes)
|
||||
binary.LittleEndian.PutUint32(data[8:12], t.To)
|
||||
|
||||
// Write comment length (2 bytes) and comment
|
||||
commentLen := uint16(len(t.Comment))
|
||||
binary.LittleEndian.PutUint16(data[12:14], commentLen)
|
||||
copy(data[14:14+commentLen], t.Comment)
|
||||
|
||||
// Write amount (8 bytes)
|
||||
amountOffset := 14 + int(commentLen)
|
||||
binary.LittleEndian.PutUint64(data[amountOffset:amountOffset+8], math.Float64bits(t.Amount))
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// DeserializeTransaction converts a binary representation back to a Transaction
|
||||
func DeserializeTransaction(data []byte) (*Transaction, error) {
|
||||
if len(data) < 22 { // Minimum size: 4 (TxID) + 4 (From) + 4 (To) + 2 (comment length) + 8 (amount)
|
||||
return nil, errors.New("data too short to deserialize Transaction")
|
||||
}
|
||||
|
||||
transaction := &Transaction{}
|
||||
|
||||
// Read TxID
|
||||
transaction.TxID = binary.LittleEndian.Uint32(data[0:4])
|
||||
|
||||
// Read From
|
||||
transaction.From = binary.LittleEndian.Uint32(data[4:8])
|
||||
|
||||
// Read To
|
||||
transaction.To = binary.LittleEndian.Uint32(data[8:12])
|
||||
|
||||
// Read comment length and comment
|
||||
commentLen := binary.LittleEndian.Uint16(data[12:14])
|
||||
if 14+commentLen > uint16(len(data)) {
|
||||
return nil, errors.New("data too short to read comment")
|
||||
}
|
||||
transaction.Comment = string(data[14 : 14+commentLen])
|
||||
|
||||
// Read amount
|
||||
amountOffset := 14 + commentLen
|
||||
if int(amountOffset)+8 > len(data) {
|
||||
return nil, errors.New("data too short to read amount")
|
||||
}
|
||||
transaction.Amount = math.Float64frombits(binary.LittleEndian.Uint64(data[amountOffset : amountOffset+8]))
|
||||
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
// Validate checks if the transaction data is valid
|
||||
func (t *Transaction) Validate() error {
|
||||
if t.From == 0 {
|
||||
return errors.New("source account ID cannot be zero")
|
||||
}
|
||||
|
||||
if t.To == 0 {
|
||||
return errors.New("destination account ID cannot be zero")
|
||||
}
|
||||
|
||||
if t.From == t.To {
|
||||
return errors.New("source and destination accounts cannot be the same")
|
||||
}
|
||||
|
||||
if t.Amount <= 0 {
|
||||
return errors.New("amount must be positive")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
78
pkg2_dont_use/heroservices/billing/models/user.go
Normal file
78
pkg2_dont_use/heroservices/billing/models/user.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// User represents a user in the billing system
|
||||
type User struct {
|
||||
ID uint32 // Unique identifier, auto-incremented
|
||||
Key string // Unique key for the user
|
||||
Accounts []uint32 // Links to the accounts a user has
|
||||
}
|
||||
|
||||
// Serialize converts a User to a binary representation
|
||||
func (u *User) Serialize() []byte {
|
||||
// Calculate the size of the serialized data
|
||||
// 4 bytes for ID + 2 bytes for key length + len(key) bytes + 2 bytes for accounts length + 4 bytes per account
|
||||
size := 4 + 2 + len(u.Key) + 2 + (4 * len(u.Accounts))
|
||||
data := make([]byte, size)
|
||||
|
||||
// Write ID (4 bytes)
|
||||
binary.LittleEndian.PutUint32(data[0:4], u.ID)
|
||||
|
||||
// Write key length (2 bytes) and key
|
||||
keyLen := uint16(len(u.Key))
|
||||
binary.LittleEndian.PutUint16(data[4:6], keyLen)
|
||||
copy(data[6:6+keyLen], u.Key)
|
||||
|
||||
// Write accounts length (2 bytes) and accounts
|
||||
accountsOffset := 6 + keyLen
|
||||
accountsLen := uint16(len(u.Accounts))
|
||||
binary.LittleEndian.PutUint16(data[accountsOffset:accountsOffset+2], accountsLen)
|
||||
|
||||
for i, accountID := range u.Accounts {
|
||||
offset := int(accountsOffset) + 2 + (4 * i)
|
||||
binary.LittleEndian.PutUint32(data[offset:offset+4], accountID)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// DeserializeUser converts a binary representation back to a User
|
||||
func DeserializeUser(data []byte) (*User, error) {
|
||||
if len(data) < 8 { // Minimum size: 4 (ID) + 2 (key length) + 2 (accounts length)
|
||||
return nil, errors.New("data too short to deserialize User")
|
||||
}
|
||||
|
||||
user := &User{}
|
||||
|
||||
// Read ID
|
||||
user.ID = binary.LittleEndian.Uint32(data[0:4])
|
||||
|
||||
// Read key length and key
|
||||
keyLen := binary.LittleEndian.Uint16(data[4:6])
|
||||
if 6+keyLen > uint16(len(data)) {
|
||||
return nil, errors.New("data too short to read key")
|
||||
}
|
||||
user.Key = string(data[6 : 6+keyLen])
|
||||
|
||||
// Read accounts length and accounts
|
||||
accountsOffset := 6 + keyLen
|
||||
if accountsOffset+2 > uint16(len(data)) {
|
||||
return nil, errors.New("data too short to read accounts length")
|
||||
}
|
||||
accountsLen := binary.LittleEndian.Uint16(data[accountsOffset : accountsOffset+2])
|
||||
|
||||
user.Accounts = make([]uint32, accountsLen)
|
||||
for i := uint16(0); i < accountsLen; i++ {
|
||||
offset := accountsOffset + 2 + (4 * i)
|
||||
if offset+4 > uint16(len(data)) {
|
||||
return nil, errors.New("data too short to read account ID")
|
||||
}
|
||||
user.Accounts[i] = binary.LittleEndian.Uint32(data[offset : offset+4])
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
98
pkg2_dont_use/heroservices/openaiproxy/cmd/main.go
Normal file
98
pkg2_dont_use/heroservices/openaiproxy/cmd/main.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
openaiproxy "git.ourworld.tf/herocode/heroagent/pkg/heroservices/openaiproxy"
|
||||
"github.com/openai/openai-go/option"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Start the server in a goroutine
|
||||
go runServerMode()
|
||||
|
||||
// Wait a moment for the server to start
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Test the proxy with a client
|
||||
testProxyWithClient()
|
||||
|
||||
// Keep the main function running
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("Shutting down...")
|
||||
}
|
||||
|
||||
// testProxyWithClient tests the proxy using the OpenAI Go client
|
||||
func testProxyWithClient() {
|
||||
log.Println("Testing proxy with OpenAI Go client...")
|
||||
|
||||
// Create a client that points to our proxy
|
||||
// Note: The server is using "/ai" as the prefix for all routes
|
||||
client := openaiproxy.NewClient(
|
||||
option.WithAPIKey("test-key"), // This is our test key, not a real OpenAI key
|
||||
option.WithBaseURL("http://localhost:8080/ai"), // Use the /ai prefix to match the server routes
|
||||
)
|
||||
|
||||
// Create a completion request
|
||||
chatCompletion, err := client.Chat.Completions.New(context.Background(), openaiproxy.ChatCompletionNewParams{
|
||||
Messages: []openaiproxy.ChatCompletionMessageParamUnion{
|
||||
openaiproxy.UserMessage("Say this is a test"),
|
||||
},
|
||||
Model: "gpt-3.5-turbo", // Use a model that our proxy supports
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating completion: %v", err)
|
||||
}
|
||||
|
||||
// Print the response
|
||||
log.Printf("Completion response: %s", chatCompletion.Choices[0].Message.Content)
|
||||
log.Println("Proxy test completed successfully!")
|
||||
}
|
||||
|
||||
// runServerMode starts the proxy server with example configurations
|
||||
func runServerMode() {
|
||||
// Get the OpenAI API key from environment variable
|
||||
openaiKey := os.Getenv("OPENAIKEY")
|
||||
if openaiKey == "" {
|
||||
log.Println("ERROR: OPENAIKEY environment variable is not set")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create a proxy configuration
|
||||
config := proxy.ProxyConfig{
|
||||
Port: 8080, // Use a non-privileged port for testing
|
||||
OpenAIBaseURL: "https://api.openaiproxy.com", // Default OpenAI API URL
|
||||
DefaultOpenAIKey: openaiKey, // Fallback API key if user doesn't have one
|
||||
}
|
||||
|
||||
// Create a new factory with the configuration
|
||||
factory := proxy.NewFactory(config)
|
||||
|
||||
// Add some example user configurations with the test key
|
||||
factory.AddUserConfig("test-key", proxy.UserConfig{
|
||||
Budget: 10000, // 10,000 tokens
|
||||
ModelGroups: []string{"all"}, // Allow access to all models
|
||||
OpenAIKey: "", // Empty means use the default key
|
||||
})
|
||||
|
||||
// Print debug info
|
||||
log.Printf("Added user config for 'test-key'")
|
||||
|
||||
// Create a new server with the factory
|
||||
server := proxy.NewServer(factory)
|
||||
|
||||
// Start the server
|
||||
fmt.Printf("OpenAI Proxy Server listening on port %d\n", config.Port)
|
||||
if err := server.Start(); err != nil {
|
||||
log.Printf("Error starting server: %v", err)
|
||||
}
|
||||
}
|
134
pkg2_dont_use/heroservices/openaiproxy/factory.go
Normal file
134
pkg2_dont_use/heroservices/openaiproxy/factory.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Factory manages the proxy server and user configurations
|
||||
type Factory struct {
|
||||
// Config is the proxy server configuration
|
||||
Config ProxyConfig
|
||||
|
||||
// userConfigs is a map of API keys to user configurations
|
||||
userConfigs map[string]UserConfig
|
||||
|
||||
// Lock for concurrent access to userConfigs
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFactory creates a new proxy factory with the given configuration
|
||||
func NewFactory(config ProxyConfig) *Factory {
|
||||
// Check for OPENAIKEY environment variable and use it if available
|
||||
if envKey := os.Getenv("OPENAIKEY"); envKey != "" {
|
||||
config.DefaultOpenAIKey = envKey
|
||||
}
|
||||
|
||||
return &Factory{
|
||||
Config: config,
|
||||
userConfigs: make(map[string]UserConfig),
|
||||
}
|
||||
}
|
||||
|
||||
// AddUserConfig adds or updates a user configuration with the associated API key
|
||||
func (f *Factory) AddUserConfig(apiKey string, config UserConfig) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.userConfigs[apiKey] = config
|
||||
}
|
||||
|
||||
// GetUserConfig retrieves a user configuration by API key
|
||||
func (f *Factory) GetUserConfig(apiKey string) (UserConfig, error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
config, exists := f.userConfigs[apiKey]
|
||||
if !exists {
|
||||
return UserConfig{}, errors.New("invalid API key")
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// RemoveUserConfig removes a user configuration by API key
|
||||
func (f *Factory) RemoveUserConfig(apiKey string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
delete(f.userConfigs, apiKey)
|
||||
}
|
||||
|
||||
// GetOpenAIKey returns the OpenAI API key to use for a given proxy API key
|
||||
// Always returns the default OpenAI key from environment variable
|
||||
func (f *Factory) GetOpenAIKey(proxyAPIKey string) string {
|
||||
// Always use the default OpenAI key from environment variable
|
||||
// This ensures that all requests to OpenAI use our key, not the user's key
|
||||
return f.Config.DefaultOpenAIKey
|
||||
}
|
||||
|
||||
// DecreaseBudget decreases a user's budget by the specified amount
|
||||
// Returns error if the user doesn't have enough budget
|
||||
func (f *Factory) DecreaseBudget(apiKey string, amount uint32) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
config, exists := f.userConfigs[apiKey]
|
||||
if !exists {
|
||||
return errors.New("invalid API key")
|
||||
}
|
||||
|
||||
if config.Budget < amount {
|
||||
return errors.New("insufficient budget")
|
||||
}
|
||||
|
||||
config.Budget -= amount
|
||||
f.userConfigs[apiKey] = config
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncreaseBudget increases a user's budget by the specified amount
|
||||
func (f *Factory) IncreaseBudget(apiKey string, amount uint32) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
config, exists := f.userConfigs[apiKey]
|
||||
if !exists {
|
||||
return errors.New("invalid API key")
|
||||
}
|
||||
|
||||
config.Budget += amount
|
||||
f.userConfigs[apiKey] = config
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanAccessModel checks if a user can access a specific model
|
||||
func (f *Factory) CanAccessModel(apiKey string, model string) bool {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
config, exists := f.userConfigs[apiKey]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// If no model groups are specified, allow access to all models
|
||||
if len(config.ModelGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if the model is in any of the allowed model groups
|
||||
// This is a placeholder - the actual implementation would depend on
|
||||
// how model groups are defined and mapped to specific models
|
||||
for _, group := range config.ModelGroups {
|
||||
if group == "all" {
|
||||
return true
|
||||
}
|
||||
// Add logic to check if model is in group
|
||||
// For now we'll just check if the model contains the group name
|
||||
if group == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
26
pkg2_dont_use/heroservices/openaiproxy/model.go
Normal file
26
pkg2_dont_use/heroservices/openaiproxy/model.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package proxy
|
||||
|
||||
// UserConfig represents the configuration for a user
|
||||
// It contains information about the user's budget and allowed model groups
|
||||
type UserConfig struct {
|
||||
// Budget represents the virtual money the user has available
|
||||
Budget uint32 `json:"budget"`
|
||||
|
||||
// ModelGroups is a list of model groups the user has access to
|
||||
ModelGroups []string `json:"model_groups"`
|
||||
|
||||
// APIKey is the OpenAI API key to use for this user's requests
|
||||
OpenAIKey string `json:"openai_key"`
|
||||
}
|
||||
|
||||
// ProxyConfig represents the configuration for the AI proxy server
|
||||
type ProxyConfig struct {
|
||||
// Port is the port to listen on
|
||||
Port int `json:"port"`
|
||||
|
||||
// OpenAIBaseURL is the base URL for the OpenAI API
|
||||
OpenAIBaseURL string `json:"openai_base_url"`
|
||||
|
||||
// DefaultOpenAIKey is the default OpenAI API key to use if not specified in UserConfig
|
||||
DefaultOpenAIKey string `json:"default_openai_key"`
|
||||
}
|
35805
pkg2_dont_use/heroservices/openaiproxy/openapi.yaml
Normal file
35805
pkg2_dont_use/heroservices/openaiproxy/openapi.yaml
Normal file
File diff suppressed because it is too large
Load Diff
1125
pkg2_dont_use/heroservices/openaiproxy/server.go
Normal file
1125
pkg2_dont_use/heroservices/openaiproxy/server.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user