This commit is contained in:
2025-05-23 16:10:49 +04:00
parent 3f01074e3f
commit 29d0d25a3b
133 changed files with 346 additions and 168 deletions

View File

@@ -0,0 +1,115 @@
package models
import (
"encoding/binary"
"errors"
"math"
"strings"
)
// Account represents a financial account in the billing system
type Account struct {
ID uint32 // Unique identifier, auto-incremented
UserID uint32 // Link to the unique user ID
Name string // Name of the account
Currency string // Currency code (e.g., USD, EUR), always uppercase, 3 letters
Amount float64 // Current balance, always positive
}
// Serialize converts an Account to a binary representation
func (a *Account) Serialize() []byte {
// Calculate the size of the serialized data
// 4 bytes for ID + 4 bytes for UserID + 2 bytes for name length + len(name) bytes +
// 3 bytes for currency + 8 bytes for amount
size := 4 + 4 + 2 + len(a.Name) + 3 + 8
data := make([]byte, size)
// Write ID (4 bytes)
binary.LittleEndian.PutUint32(data[0:4], a.ID)
// Write UserID (4 bytes)
binary.LittleEndian.PutUint32(data[4:8], a.UserID)
// Write name length (2 bytes) and name
nameLen := uint16(len(a.Name))
binary.LittleEndian.PutUint16(data[8:10], nameLen)
copy(data[10:10+nameLen], a.Name)
// Write currency (3 bytes)
currencyOffset := 10 + nameLen
copy(data[currencyOffset:currencyOffset+3], a.Currency)
// Write amount (8 bytes)
amountOffset := currencyOffset + 3
binary.LittleEndian.PutUint64(data[amountOffset:amountOffset+8], math.Float64bits(a.Amount))
return data
}
// DeserializeAccount converts a binary representation back to an Account
func DeserializeAccount(data []byte) (*Account, error) {
if len(data) < 17 { // Minimum size: 4 (ID) + 4 (UserID) + 2 (name length) + 3 (currency) + 8 (amount)
return nil, errors.New("data too short to deserialize Account")
}
account := &Account{}
// Read ID
account.ID = binary.LittleEndian.Uint32(data[0:4])
// Read UserID
account.UserID = binary.LittleEndian.Uint32(data[4:8])
// Read name length and name
nameLen := binary.LittleEndian.Uint16(data[8:10])
if 10+nameLen > uint16(len(data)) {
return nil, errors.New("data too short to read name")
}
account.Name = string(data[10 : 10+nameLen])
// Read currency
currencyOffset := 10 + nameLen
if int(currencyOffset)+3 > len(data) {
return nil, errors.New("data too short to read currency")
}
account.Currency = string(data[currencyOffset : currencyOffset+3])
// Read amount
amountOffset := currencyOffset + 3
if int(amountOffset)+8 > len(data) {
return nil, errors.New("data too short to read amount")
}
account.Amount = math.Float64frombits(binary.LittleEndian.Uint64(data[amountOffset : amountOffset+8]))
return account, nil
}
// Validate checks if the account data is valid
func (a *Account) Validate() error {
if a.UserID == 0 {
return errors.New("user ID cannot be zero")
}
if a.Name == "" {
return errors.New("account name cannot be empty")
}
// Check currency format (3 uppercase letters)
if len(a.Currency) != 3 {
return errors.New("currency must be 3 characters")
}
a.Currency = strings.ToUpper(a.Currency)
for _, r := range a.Currency {
if r < 'A' || r > 'Z' {
return errors.New("currency must contain only uppercase letters")
}
}
// Ensure amount is non-negative
if a.Amount < 0 {
return errors.New("amount cannot be negative")
}
return nil
}

View File

@@ -0,0 +1,348 @@
package main
import (
"fmt"
"log"
"math/rand"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"strings"
"time"
"git.ourworld.tf/herocode/heroagent/pkg/heroservices/billing/models"
)
const (
numUsers = 1000
numAccounts = 3000 // Average 3 accounts per user
numTransactions = 10000 // Reduced for quicker testing
currencies = "USD,EUR,GBP,JPY,CNY,AUD,CAD,CHF,HKD,SGD"
accountTypes = "Checking,Savings,Investment,Retirement,Business,Credit,Loan,Mortgage,Travel,Emergency"
companies = "Apple,Google,Microsoft,Amazon,Facebook,Tesla,Walmart,Target,Costco,HomeDepot,Lowes,BestBuy,Starbucks,McDonalds,Chipotle,Nike,Adidas,Uber,Lyft,Airbnb"
firstNames = "John,Jane,Michael,Emily,David,Sarah,Robert,Jennifer,William,Elizabeth,James,Linda,Richard,Barbara,Joseph,Susan,Thomas,Jessica,Charles,Mary"
lastNames = "Smith,Johnson,Williams,Jones,Brown,Davis,Miller,Wilson,Moore,Taylor,Anderson,Thomas,Jackson,White,Harris,Martin,Thompson,Garcia,Martinez,Robinson"
)
var (
currencyList []string
accountList []string
companyList []string
firstNameList []string
lastNameList []string
)
func init() {
// Initialize random seed
rand.Seed(time.Now().UnixNano())
// Split the constant strings into slices
currencyList = strings.Split(currencies, ",")
accountList = strings.Split(accountTypes, ",")
companyList = strings.Split(companies, ",")
firstNameList = strings.Split(firstNames, ",")
lastNameList = strings.Split(lastNames, ",")
}
// generateRandomName generates a random user name
func generateRandomName() string {
firstName := firstNameList[rand.Intn(len(firstNameList))]
lastName := lastNameList[rand.Intn(len(lastNameList))]
// Add timestamp and random number to ensure uniqueness
return fmt.Sprintf("%s_%s_%d_%d", firstName, lastName, time.Now().UnixNano()%1000000, rand.Intn(1000))
}
// generateRandomAccountName generates a random account name
func generateRandomAccountName() string {
accountType := accountList[rand.Intn(len(accountList))]
// 30% chance to add a company name
if rand.Intn(100) < 30 {
company := companyList[rand.Intn(len(companyList))]
return fmt.Sprintf("%s %s", company, accountType)
}
return accountType
}
// generateRandomCurrency returns a random currency code
func generateRandomCurrency() string {
return currencyList[rand.Intn(len(currencyList))]
}
// generateRandomAmount generates a random amount between min and max
func generateRandomAmount(min, max float64) float64 {
return min + rand.Float64()*(max-min)
}
// generateRandomComment generates a realistic transaction comment
func generateRandomComment(fromAccount, toAccount *models.Account, amount float64) string {
commentTypes := []string{
"Payment",
"Transfer",
"Deposit",
"Withdrawal",
"Refund",
"Purchase",
"Subscription",
"Salary",
"Dividend",
"Interest",
}
commentType := commentTypes[rand.Intn(len(commentTypes))]
switch commentType {
case "Payment":
return fmt.Sprintf("Payment to %s", toAccount.Name)
case "Transfer":
return fmt.Sprintf("Transfer to %s", toAccount.Name)
case "Deposit":
return fmt.Sprintf("Deposit to %s", toAccount.Name)
case "Withdrawal":
return fmt.Sprintf("Withdrawal from %s", fromAccount.Name)
case "Refund":
return fmt.Sprintf("Refund from %s", fromAccount.Name)
case "Purchase":
company := companyList[rand.Intn(len(companyList))]
return fmt.Sprintf("Purchase at %s", company)
case "Subscription":
company := companyList[rand.Intn(len(companyList))]
return fmt.Sprintf("%s subscription", company)
case "Salary":
company := companyList[rand.Intn(len(companyList))]
return fmt.Sprintf("Salary from %s", company)
case "Dividend":
company := companyList[rand.Intn(len(companyList))]
return fmt.Sprintf("Dividend from %s", company)
case "Interest":
return fmt.Sprintf("Interest payment %.2f%%", rand.Float64()*5)
default:
return fmt.Sprintf("Transfer of %.2f %s", amount, fromAccount.Currency)
}
}
func main() {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "billing-perf-test-*")
if err != nil {
log.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create subdirectories
dbPath := filepath.Join(tempDir, "db")
metaPath := filepath.Join(tempDir, "meta")
if err := os.MkdirAll(dbPath, 0755); err != nil {
log.Fatalf("Failed to create db directory: %v", err)
}
if err := os.MkdirAll(metaPath, 0755); err != nil {
log.Fatalf("Failed to create meta directory: %v", err)
}
// Set environment variables to use the temp directory
originalHome := os.Getenv("HOME")
os.Setenv("HOME", tempDir)
defer os.Setenv("HOME", originalHome)
// Start CPU profiling
cpuProfile, err := os.Create("cpu_profile.prof")
if err != nil {
log.Fatalf("Failed to create CPU profile: %v", err)
}
pprof.StartCPUProfile(cpuProfile)
defer pprof.StopCPUProfile()
// Create a new database store
fmt.Println("Creating database store...")
startTime := time.Now()
dbStore, err := models.NewDBStore()
if err != nil {
log.Fatalf("Failed to create database store: %v", err)
}
defer dbStore.Close()
fmt.Printf("Database store created in %v\n", time.Since(startTime))
// Create user, account, and transaction stores
userStore := &models.UserStore{Store: dbStore}
accountStore := models.NewAccountStore(dbStore)
transactionStore := models.NewTransactionStore(dbStore)
// Create users
fmt.Printf("Creating %d users...\n", numUsers)
startTime = time.Now()
users := make([]*models.User, numUsers)
for i := 0; i < numUsers; i++ {
user := &models.User{
Key: generateRandomName(),
Accounts: []uint32{},
}
if err := userStore.Save(user); err != nil {
log.Fatalf("Failed to save user %d: %v", i, err)
}
users[i] = user
if (i+1)%100 == 0 {
fmt.Printf("Created %d users...\n", i+1)
}
}
fmt.Printf("Created %d users in %v\n", numUsers, time.Since(startTime))
// Create accounts
fmt.Printf("Creating %d accounts...\n", numAccounts)
startTime = time.Now()
accounts := make([]*models.Account, numAccounts)
for i := 0; i < numAccounts; i++ {
// Assign to a random user
userIndex := rand.Intn(numUsers)
user := users[userIndex]
account := &models.Account{
UserID: user.ID,
Name: generateRandomAccountName(),
Currency: generateRandomCurrency(),
Amount: generateRandomAmount(1000, 10000),
}
if err := accountStore.Save(account); err != nil {
log.Fatalf("Failed to save account %d: %v", i, err)
}
accounts[i] = account
if (i+1)%300 == 0 {
fmt.Printf("Created %d accounts...\n", i+1)
}
}
fmt.Printf("Created %d accounts in %v\n", numAccounts, time.Since(startTime))
// Create transactions
fmt.Printf("Creating %d transactions...\n", numTransactions)
startTime = time.Now()
var memStats runtime.MemStats
for i := 0; i < numTransactions; i++ {
// Select random source and destination accounts
fromIndex := rand.Intn(numAccounts)
var toIndex int
for {
toIndex = rand.Intn(numAccounts)
if toIndex != fromIndex {
break
}
}
fromAccount := accounts[fromIndex]
toAccount := accounts[toIndex]
// Generate a realistic amount (usually smaller than the available balance)
maxAmount := fromAccount.Amount * 0.2
if maxAmount < 10 {
maxAmount = 10
}
amount := generateRandomAmount(1, maxAmount)
transaction := &models.Transaction{
From: fromAccount.ID,
To: toAccount.ID,
Comment: generateRandomComment(fromAccount, toAccount, amount),
Amount: amount,
}
if err := transactionStore.Save(transaction); err != nil {
log.Fatalf("Failed to save transaction %d: %v", i, err)
}
// Update our local copy of the accounts
fromAccount.Amount -= amount
toAccount.Amount += amount
if (i+1)%1000 == 0 {
elapsed := time.Since(startTime)
transPerSec := float64(i+1) / elapsed.Seconds()
// Collect memory stats
runtime.ReadMemStats(&memStats)
fmt.Printf("Created %d transactions (%.2f/sec), Memory: %.2f MB\n",
i+1,
transPerSec,
float64(memStats.Alloc)/1024/1024)
}
}
totalTime := time.Since(startTime)
transPerSec := float64(numTransactions) / totalTime.Seconds()
fmt.Printf("Created %d transactions in %v (%.2f transactions/sec)\n",
numTransactions,
totalTime,
transPerSec)
// Memory profile
memProfile, err := os.Create("mem_profile.prof")
if err != nil {
log.Fatalf("Failed to create memory profile: %v", err)
}
runtime.GC() // Run garbage collection before taking memory profile
if err := pprof.WriteHeapProfile(memProfile); err != nil {
log.Fatalf("Failed to write memory profile: %v", err)
}
memProfile.Close()
// Final memory stats
runtime.ReadMemStats(&memStats)
fmt.Printf("\nFinal Memory Stats:\n")
fmt.Printf("Alloc: %.2f MB\n", float64(memStats.Alloc)/1024/1024)
fmt.Printf("TotalAlloc: %.2f MB\n", float64(memStats.TotalAlloc)/1024/1024)
fmt.Printf("Sys: %.2f MB\n", float64(memStats.Sys)/1024/1024)
fmt.Printf("NumGC: %d\n", memStats.NumGC)
// Test retrieving random transactions
fmt.Println("\nTesting random transaction retrieval...")
startTime = time.Now()
for i := 0; i < 1000; i++ {
txID := rand.Uint32()%uint32(numTransactions) + 1
_, err := transactionStore.GetByID(txID)
if err != nil {
// Skip errors for non-existent transactions
continue
}
}
fmt.Printf("Retrieved 1000 random transactions in %v\n", time.Since(startTime))
// Test retrieving random accounts
fmt.Println("\nTesting random account retrieval...")
startTime = time.Now()
for i := 0; i < 1000; i++ {
accountID := rand.Uint32()%uint32(numAccounts) + 1
_, err := accountStore.GetByID(accountID)
if err != nil {
// Skip errors for non-existent accounts
continue
}
}
fmt.Printf("Retrieved 1000 random accounts in %v\n", time.Since(startTime))
// Test retrieving random users
fmt.Println("\nTesting random user retrieval...")
startTime = time.Now()
for i := 0; i < 1000; i++ {
userID := rand.Uint32()%uint32(numUsers) + 1
_, err := userStore.GetByID(userID)
if err != nil {
// Skip errors for non-existent users
continue
}
}
fmt.Printf("Retrieved 1000 random users in %v\n", time.Since(startTime))
fmt.Println("\nPerformance test completed successfully!")
fmt.Printf("CPU profile saved to cpu_profile.prof\n")
fmt.Printf("Memory profile saved to mem_profile.prof\n")
fmt.Printf("To analyze profiles, run: go tool pprof cpu_profile.prof\n")
}

View File

@@ -0,0 +1,546 @@
package models
import (
"encoding/binary"
"errors"
"fmt"
"os"
"path/filepath"
"git.ourworld.tf/herocode/heroagent/pkg/data/ourdb"
"git.ourworld.tf/herocode/heroagent/pkg/data/radixtree"
"git.ourworld.tf/herocode/heroagent/pkg/tools"
)
// DBStore represents the central database store for all models
type DBStore struct {
DB *ourdb.OurDB
Meta *radixtree.RadixTree
}
// NewDBStore creates a new database store
func NewDBStore() (*DBStore, error) {
// Get home directory
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
// Create paths
dbPath := filepath.Join(homeDir, "hero", "var", "server", "users", "db")
metaPath := filepath.Join(homeDir, "hero", "var", "server", "users", "meta")
// Create directories if they don't exist
if err := os.MkdirAll(dbPath, 0755); err != nil {
return nil, fmt.Errorf("failed to create db directory: %w", err)
}
if err := os.MkdirAll(metaPath, 0755); err != nil {
return nil, fmt.Errorf("failed to create meta directory: %w", err)
}
// Create ourdb
dbConfig := ourdb.DefaultConfig()
dbConfig.Path = dbPath
dbConfig.IncrementalMode = true
db, err := ourdb.New(dbConfig)
if err != nil {
return nil, fmt.Errorf("failed to create ourdb: %w", err)
}
// Create radixtree
meta, err := radixtree.New(radixtree.NewArgs{
Path: metaPath,
})
if err != nil {
return nil, fmt.Errorf("failed to create radixtree: %w", err)
}
return &DBStore{
DB: db,
Meta: meta,
}, nil
}
// Close closes the database connections
func (s *DBStore) Close() error {
err1 := s.DB.Close()
err2 := s.Meta.Close()
if err1 != nil {
return err1
}
return err2
}
// UserStore handles the storage and retrieval of User objects
type UserStore struct {
Store *DBStore
}
// NewUserStore creates a new UserStore
func NewUserStore() (*UserStore, *DBStore, error) {
store, err := NewDBStore()
if err != nil {
return nil, nil, err
}
return &UserStore{
Store: store,
}, store, nil
}
// Save saves a User to the database
func (s *UserStore) Save(user *User) error {
// If ID is 0, this is a new user
if user.ID == 0 {
// Fix the key
fixedKey := tools.NameFix(user.Key)
if fixedKey == "" {
return errors.New("key cannot be empty")
}
user.Key = fixedKey
// Check if key already exists
existingID, err := s.GetIDByKey(user.Key)
if err == nil && existingID != 0 {
return fmt.Errorf("user with key %s already exists", user.Key)
}
// Get next ID
nextID, err := s.Store.DB.GetNextID()
if err != nil {
return fmt.Errorf("failed to get next ID: %w", err)
}
user.ID = nextID
// Save to ourdb
data := user.Serialize()
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
Data: data,
})
if err != nil {
return fmt.Errorf("failed to save user to ourdb: %w", err)
}
// Save key to radixtree
idBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(idBytes, user.ID)
err = s.Store.Meta.Set(user.Key, idBytes)
if err != nil {
return fmt.Errorf("failed to save user key to radixtree: %w", err)
}
} else {
// Update existing user
// Get existing user to verify key
existingUser, err := s.GetByID(user.ID)
if err != nil {
return fmt.Errorf("failed to get existing user: %w", err)
}
// If key changed, update radixtree
if existingUser.Key != user.Key {
// Fix the new key
fixedKey := tools.NameFix(user.Key)
if fixedKey == "" {
return errors.New("key cannot be empty")
}
user.Key = fixedKey
// Check if new key already exists
existingID, err := s.GetIDByKey(user.Key)
if err == nil && existingID != 0 && existingID != user.ID {
return fmt.Errorf("user with key %s already exists", user.Key)
}
// Delete old key from radixtree
err = s.Store.Meta.Delete(existingUser.Key)
if err != nil {
return fmt.Errorf("failed to delete old user key from radixtree: %w", err)
}
// Save new key to radixtree
idBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(idBytes, user.ID)
err = s.Store.Meta.Set(user.Key, idBytes)
if err != nil {
return fmt.Errorf("failed to save new user key to radixtree: %w", err)
}
}
// Save to ourdb
data := user.Serialize()
id := user.ID
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
ID: &id,
Data: data,
})
if err != nil {
return fmt.Errorf("failed to update user in ourdb: %w", err)
}
}
return nil
}
// GetByID retrieves a User by ID
func (s *UserStore) GetByID(id uint32) (*User, error) {
data, err := s.Store.DB.Get(id)
if err != nil {
return nil, fmt.Errorf("failed to get user from ourdb: %w", err)
}
user, err := DeserializeUser(data)
if err != nil {
return nil, fmt.Errorf("failed to deserialize user: %w", err)
}
return user, nil
}
// GetByKey retrieves a User by key
func (s *UserStore) GetByKey(key string) (*User, error) {
// Fix the key
fixedKey := tools.NameFix(key)
if fixedKey == "" {
return nil, errors.New("key cannot be empty")
}
// Get ID from radixtree
id, err := s.GetIDByKey(fixedKey)
if err != nil {
return nil, fmt.Errorf("failed to get user ID from radixtree: %w", err)
}
// Get user from ourdb
return s.GetByID(id)
}
// GetIDByKey retrieves a user ID by key
func (s *UserStore) GetIDByKey(key string) (uint32, error) {
// Get ID from radixtree
idBytes, err := s.Store.Meta.Get(key)
if err != nil {
return 0, fmt.Errorf("failed to get user ID from radixtree: %w", err)
}
if len(idBytes) != 4 {
return 0, fmt.Errorf("invalid ID bytes length: %d", len(idBytes))
}
return binary.LittleEndian.Uint32(idBytes), nil
}
// Delete deletes a User by ID
func (s *UserStore) Delete(id uint32) error {
// Get user to get key
user, err := s.GetByID(id)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Delete from radixtree
err = s.Store.Meta.Delete(user.Key)
if err != nil {
return fmt.Errorf("failed to delete user key from radixtree: %w", err)
}
// Delete from ourdb
err = s.Store.DB.Delete(id)
if err != nil {
return fmt.Errorf("failed to delete user from ourdb: %w", err)
}
return nil
}
// AccountStore handles the storage and retrieval of Account objects
type AccountStore struct {
Store *DBStore
}
// NewAccountStore creates a new AccountStore using the provided DBStore
func NewAccountStore(store *DBStore) *AccountStore {
return &AccountStore{
Store: store,
}
}
// Save saves an Account to the database
func (s *AccountStore) Save(account *Account) error {
// Validate account data
if err := account.Validate(); err != nil {
return err
}
// If ID is 0, this is a new account
if account.ID == 0 {
// Get next ID
nextID, err := s.Store.DB.GetNextID()
if err != nil {
return fmt.Errorf("failed to get next ID: %w", err)
}
account.ID = nextID
// Save to ourdb
data := account.Serialize()
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
Data: data,
})
if err != nil {
return fmt.Errorf("failed to save account to ourdb: %w", err)
}
// Update user's accounts list
userStore := &UserStore{Store: s.Store}
user, err := userStore.GetByID(account.UserID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Add account ID to user's accounts
user.Accounts = append(user.Accounts, account.ID)
// Save updated user
err = userStore.Save(user)
if err != nil {
return fmt.Errorf("failed to update user's accounts: %w", err)
}
} else {
// Update existing account
// Get existing account to verify it exists
_, err := s.GetByID(account.ID)
if err != nil {
return fmt.Errorf("failed to get existing account: %w", err)
}
// Save to ourdb
data := account.Serialize()
id := account.ID
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
ID: &id,
Data: data,
})
if err != nil {
return fmt.Errorf("failed to update account in ourdb: %w", err)
}
}
return nil
}
// GetByID retrieves an Account by ID
func (s *AccountStore) GetByID(id uint32) (*Account, error) {
data, err := s.Store.DB.Get(id)
if err != nil {
return nil, fmt.Errorf("failed to get account from ourdb: %w", err)
}
account, err := DeserializeAccount(data)
if err != nil {
return nil, fmt.Errorf("failed to deserialize account: %w", err)
}
return account, nil
}
// GetByUserID retrieves all Accounts for a user
func (s *AccountStore) GetByUserID(userID uint32) ([]*Account, error) {
// Get user to get account IDs
userStore := &UserStore{Store: s.Store}
user, err := userStore.GetByID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
// Get accounts
accounts := make([]*Account, 0, len(user.Accounts))
for _, accountID := range user.Accounts {
account, err := s.GetByID(accountID)
if err != nil {
return nil, fmt.Errorf("failed to get account %d: %w", accountID, err)
}
accounts = append(accounts, account)
}
return accounts, nil
}
// Delete deletes an Account by ID
func (s *AccountStore) Delete(id uint32) error {
// Get account to get user ID
account, err := s.GetByID(id)
if err != nil {
return fmt.Errorf("failed to get account: %w", err)
}
// Get user to remove account from accounts list
userStore := &UserStore{Store: s.Store}
user, err := userStore.GetByID(account.UserID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Remove account ID from user's accounts
for i, accountID := range user.Accounts {
if accountID == id {
user.Accounts = append(user.Accounts[:i], user.Accounts[i+1:]...)
break
}
}
// Save updated user
err = userStore.Save(user)
if err != nil {
return fmt.Errorf("failed to update user's accounts: %w", err)
}
// Delete from ourdb
err = s.Store.DB.Delete(id)
if err != nil {
return fmt.Errorf("failed to delete account from ourdb: %w", err)
}
return nil
}
// TransactionStore handles the storage and retrieval of Transaction objects
type TransactionStore struct {
Store *DBStore
}
// NewTransactionStore creates a new TransactionStore using the provided DBStore
func NewTransactionStore(store *DBStore) *TransactionStore {
return &TransactionStore{
Store: store,
}
}
// Save saves a Transaction to the database and updates account balances
func (s *TransactionStore) Save(transaction *Transaction) error {
// Validate transaction data
if err := transaction.Validate(); err != nil {
return err
}
// Get account store
accountStore := &AccountStore{Store: s.Store}
// Get source account
fromAccount, err := accountStore.GetByID(transaction.From)
if err != nil {
return fmt.Errorf("failed to get source account: %w", err)
}
// Get destination account
toAccount, err := accountStore.GetByID(transaction.To)
if err != nil {
return fmt.Errorf("failed to get destination account: %w", err)
}
// Check if source account has enough balance
if fromAccount.Amount < transaction.Amount {
return errors.New("insufficient balance in source account")
}
// If TxID is 0, this is a new transaction
if transaction.TxID == 0 {
// Get next ID
nextID, err := s.Store.DB.GetNextID()
if err != nil {
return fmt.Errorf("failed to get next ID: %w", err)
}
transaction.TxID = nextID
// Update account balances
fromAccount.Amount -= transaction.Amount
toAccount.Amount += transaction.Amount
// Save updated accounts
if err := accountStore.Save(fromAccount); err != nil {
return fmt.Errorf("failed to update source account: %w", err)
}
if err := accountStore.Save(toAccount); err != nil {
// Rollback source account change if destination update fails
fromAccount.Amount += transaction.Amount
if rollbackErr := accountStore.Save(fromAccount); rollbackErr != nil {
return fmt.Errorf("failed to update destination account and rollback failed: %w, rollback error: %v", err, rollbackErr)
}
return fmt.Errorf("failed to update destination account: %w", err)
}
// Save transaction to ourdb
data := transaction.Serialize()
_, err = s.Store.DB.Set(ourdb.OurDBSetArgs{
Data: data,
})
if err != nil {
// Rollback account changes if transaction save fails
fromAccount.Amount += transaction.Amount
toAccount.Amount -= transaction.Amount
if rollbackErr1 := accountStore.Save(fromAccount); rollbackErr1 != nil {
return fmt.Errorf("failed to save transaction and rollback failed: %w, rollback error 1: %v", err, rollbackErr1)
}
if rollbackErr2 := accountStore.Save(toAccount); rollbackErr2 != nil {
return fmt.Errorf("failed to save transaction and rollback failed: %w, rollback error 2: %v", err, rollbackErr2)
}
return fmt.Errorf("failed to save transaction to ourdb: %w", err)
}
} else {
// Updating existing transactions is not allowed as it would affect account balances
return errors.New("updating existing transactions is not allowed")
}
return nil
}
// GetByID retrieves a Transaction by ID
func (s *TransactionStore) GetByID(id uint32) (*Transaction, error) {
data, err := s.Store.DB.Get(id)
if err != nil {
return nil, fmt.Errorf("failed to get transaction from ourdb: %w", err)
}
transaction, err := DeserializeTransaction(data)
if err != nil {
return nil, fmt.Errorf("failed to deserialize transaction: %w", err)
}
return transaction, nil
}
// GetByAccount retrieves all Transactions for an account (either as source or destination)
func (s *TransactionStore) GetByAccount(accountID uint32) ([]*Transaction, error) {
// This is a simplified implementation that would be inefficient in a real system
// In a real system, we would maintain indexes for account transactions
// Get the next ID to determine the range of transaction IDs
nextID, err := s.Store.DB.GetNextID()
if err != nil {
return nil, fmt.Errorf("failed to get next ID: %w", err)
}
transactions := make([]*Transaction, 0)
// Iterate through all transactions (this is inefficient but works for a simple implementation)
for i := uint32(1); i < nextID; i++ {
transaction, err := s.GetByID(i)
if err != nil {
// Skip if transaction doesn't exist or can't be deserialized
continue
}
// Include transaction if it involves the specified account
if transaction.From == accountID || transaction.To == accountID {
transactions = append(transactions, transaction)
}
}
return transactions, nil
}
// Delete is not implemented for transactions as it would affect account balances
// In a real system, you might want to implement a "void" or "reverse" transaction instead
func (s *TransactionStore) Delete(id uint32) error {
return errors.New("deleting transactions is not allowed")
}

View File

@@ -0,0 +1,489 @@
package models
import (
"os"
"path/filepath"
"testing"
)
// setupTestEnvironment creates a temporary test environment
func setupTestEnvironment(t *testing.T) (string, func()) {
// Create a temporary directory for the test
tempDir, err := os.MkdirTemp("", "billing-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
// Create subdirectories
dbPath := filepath.Join(tempDir, "db")
metaPath := filepath.Join(tempDir, "meta")
if err := os.MkdirAll(dbPath, 0755); err != nil {
t.Fatalf("Failed to create db directory: %v", err)
}
if err := os.MkdirAll(metaPath, 0755); err != nil {
t.Fatalf("Failed to create meta directory: %v", err)
}
// Set environment variables to use the temp directory
originalHome := os.Getenv("HOME")
os.Setenv("HOME", tempDir)
// Return cleanup function
cleanup := func() {
os.Setenv("HOME", originalHome)
os.RemoveAll(tempDir)
}
return tempDir, cleanup
}
// TestUserCRUD tests the CRUD operations for User
func TestUserCRUD(t *testing.T) {
_, cleanup := setupTestEnvironment(t)
defer cleanup()
// Create a new database store
dbStore, err := NewDBStore()
if err != nil {
t.Fatalf("Failed to create database store: %v", err)
}
defer dbStore.Close()
// Create a new user store
userStore := &UserStore{Store: dbStore}
// Create a new user
user := &User{
Key: "john_doe",
Accounts: []uint32{},
}
// Save the user
err = userStore.Save(user)
if err != nil {
t.Fatalf("Failed to save user: %v", err)
}
// Verify the user ID was set
if user.ID == 0 {
t.Fatalf("User ID was not set")
}
// Get the user by ID
retrievedUser, err := userStore.GetByID(user.ID)
if err != nil {
t.Fatalf("Failed to get user by ID: %v", err)
}
// Verify the user data
if retrievedUser.ID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, retrievedUser.ID)
}
if retrievedUser.Key != user.Key {
t.Errorf("Expected user key %s, got %s", user.Key, retrievedUser.Key)
}
// Get the user by key
retrievedUser, err = userStore.GetByKey("john_doe")
if err != nil {
t.Fatalf("Failed to get user by key: %v", err)
}
// Verify the user data
if retrievedUser.ID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, retrievedUser.ID)
}
// Update the user
user.Key = "jane_doe"
err = userStore.Save(user)
if err != nil {
t.Fatalf("Failed to update user: %v", err)
}
// Verify the update
retrievedUser, err = userStore.GetByKey("jane_doe")
if err != nil {
t.Fatalf("Failed to get updated user by key: %v", err)
}
if retrievedUser.Key != "jane_doe" {
t.Errorf("Expected updated user key 'jane_doe', got %s", retrievedUser.Key)
}
// Delete the user
err = userStore.Delete(user.ID)
if err != nil {
t.Fatalf("Failed to delete user: %v", err)
}
// Verify the user was deleted
_, err = userStore.GetByID(user.ID)
if err == nil {
t.Errorf("Expected error when getting deleted user, got nil")
}
}
// TestAccountCRUD tests the CRUD operations for Account
func TestAccountCRUD(t *testing.T) {
_, cleanup := setupTestEnvironment(t)
defer cleanup()
// Create a new database store
dbStore, err := NewDBStore()
if err != nil {
t.Fatalf("Failed to create database store: %v", err)
}
defer dbStore.Close()
// Create a new user store
userStore := &UserStore{Store: dbStore}
// Create a new user
user := &User{
Key: "john_doe",
Accounts: []uint32{},
}
// Save the user
err = userStore.Save(user)
if err != nil {
t.Fatalf("Failed to save user: %v", err)
}
// Create a new account store
accountStore := NewAccountStore(dbStore)
// Create a new account
account := &Account{
UserID: user.ID,
Name: "Savings",
Currency: "USD",
Amount: 1000.0,
}
// Save the account
err = accountStore.Save(account)
if err != nil {
t.Fatalf("Failed to save account: %v", err)
}
// Verify the account ID was set
if account.ID == 0 {
t.Fatalf("Account ID was not set")
}
// Get the account by ID
retrievedAccount, err := accountStore.GetByID(account.ID)
if err != nil {
t.Fatalf("Failed to get account by ID: %v", err)
}
// Verify the account data
if retrievedAccount.ID != account.ID {
t.Errorf("Expected account ID %d, got %d", account.ID, retrievedAccount.ID)
}
if retrievedAccount.UserID != user.ID {
t.Errorf("Expected account user ID %d, got %d", user.ID, retrievedAccount.UserID)
}
if retrievedAccount.Name != "Savings" {
t.Errorf("Expected account name 'Savings', got %s", retrievedAccount.Name)
}
if retrievedAccount.Currency != "USD" {
t.Errorf("Expected account currency 'USD', got %s", retrievedAccount.Currency)
}
if retrievedAccount.Amount != 1000.0 {
t.Errorf("Expected account amount 1000.0, got %f", retrievedAccount.Amount)
}
// Verify the user's accounts list was updated
retrievedUser, err := userStore.GetByID(user.ID)
if err != nil {
t.Fatalf("Failed to get user: %v", err)
}
if len(retrievedUser.Accounts) != 1 {
t.Errorf("Expected user to have 1 account, got %d", len(retrievedUser.Accounts))
}
if retrievedUser.Accounts[0] != account.ID {
t.Errorf("Expected user account ID %d, got %d", account.ID, retrievedUser.Accounts[0])
}
// Update the account
account.Name = "Checking"
account.Amount = 2000.0
err = accountStore.Save(account)
if err != nil {
t.Fatalf("Failed to update account: %v", err)
}
// Verify the update
retrievedAccount, err = accountStore.GetByID(account.ID)
if err != nil {
t.Fatalf("Failed to get updated account: %v", err)
}
if retrievedAccount.Name != "Checking" {
t.Errorf("Expected updated account name 'Checking', got %s", retrievedAccount.Name)
}
if retrievedAccount.Amount != 2000.0 {
t.Errorf("Expected updated account amount 2000.0, got %f", retrievedAccount.Amount)
}
// Delete the account
err = accountStore.Delete(account.ID)
if err != nil {
t.Fatalf("Failed to delete account: %v", err)
}
// Verify the account was deleted
_, err = accountStore.GetByID(account.ID)
if err == nil {
t.Errorf("Expected error when getting deleted account, got nil")
}
// Verify the user's accounts list was updated
retrievedUser, err = userStore.GetByID(user.ID)
if err != nil {
t.Fatalf("Failed to get user after account deletion: %v", err)
}
if len(retrievedUser.Accounts) != 0 {
t.Errorf("Expected user to have 0 accounts after deletion, got %d", len(retrievedUser.Accounts))
}
}
// TestTransactionCRUD tests the CRUD operations for Transaction
func TestTransactionCRUD(t *testing.T) {
_, cleanup := setupTestEnvironment(t)
defer cleanup()
// Create a new database store
dbStore, err := NewDBStore()
if err != nil {
t.Fatalf("Failed to create database store: %v", err)
}
defer dbStore.Close()
// Create a new user store
userStore := &UserStore{Store: dbStore}
// Create a new user
user := &User{
Key: "john_doe",
Accounts: []uint32{},
}
// Save the user
err = userStore.Save(user)
if err != nil {
t.Fatalf("Failed to save user: %v", err)
}
// Create a new account store
accountStore := NewAccountStore(dbStore)
// Create source account
sourceAccount := &Account{
UserID: user.ID,
Name: "Checking",
Currency: "USD",
Amount: 1000.0,
}
// Save the source account
err = accountStore.Save(sourceAccount)
if err != nil {
t.Fatalf("Failed to save source account: %v", err)
}
// Create destination account
destAccount := &Account{
UserID: user.ID,
Name: "Savings",
Currency: "USD",
Amount: 500.0,
}
// Save the destination account
err = accountStore.Save(destAccount)
if err != nil {
t.Fatalf("Failed to save destination account: %v", err)
}
// Create a new transaction store
transactionStore := NewTransactionStore(dbStore)
// Create a new transaction
transaction := &Transaction{
From: sourceAccount.ID,
To: destAccount.ID,
Comment: "Transfer to savings",
Amount: 200.0,
}
// Save the transaction
err = transactionStore.Save(transaction)
if err != nil {
t.Fatalf("Failed to save transaction: %v", err)
}
// Verify the transaction ID was set
if transaction.TxID == 0 {
t.Fatalf("Transaction ID was not set")
}
// Get the transaction by ID
retrievedTransaction, err := transactionStore.GetByID(transaction.TxID)
if err != nil {
t.Fatalf("Failed to get transaction by ID: %v", err)
}
// Verify the transaction data
if retrievedTransaction.TxID != transaction.TxID {
t.Errorf("Expected transaction ID %d, got %d", transaction.TxID, retrievedTransaction.TxID)
}
if retrievedTransaction.From != sourceAccount.ID {
t.Errorf("Expected transaction from %d, got %d", sourceAccount.ID, retrievedTransaction.From)
}
if retrievedTransaction.To != destAccount.ID {
t.Errorf("Expected transaction to %d, got %d", destAccount.ID, retrievedTransaction.To)
}
if retrievedTransaction.Comment != "Transfer to savings" {
t.Errorf("Expected transaction comment 'Transfer to savings', got %s", retrievedTransaction.Comment)
}
if retrievedTransaction.Amount != 200.0 {
t.Errorf("Expected transaction amount 200.0, got %f", retrievedTransaction.Amount)
}
// Verify account balances were updated
updatedSourceAccount, err := accountStore.GetByID(sourceAccount.ID)
if err != nil {
t.Fatalf("Failed to get updated source account: %v", err)
}
if updatedSourceAccount.Amount != 800.0 {
t.Errorf("Expected source account amount 800.0, got %f", updatedSourceAccount.Amount)
}
updatedDestAccount, err := accountStore.GetByID(destAccount.ID)
if err != nil {
t.Fatalf("Failed to get updated destination account: %v", err)
}
if updatedDestAccount.Amount != 700.0 {
t.Errorf("Expected destination account amount 700.0, got %f", updatedDestAccount.Amount)
}
// Verify getting transactions by account
sourceTransactions, err := transactionStore.GetByAccount(sourceAccount.ID)
if err != nil {
t.Fatalf("Failed to get transactions by source account: %v", err)
}
if len(sourceTransactions) != 1 {
t.Errorf("Expected 1 transaction for source account, got %d", len(sourceTransactions))
}
destTransactions, err := transactionStore.GetByAccount(destAccount.ID)
if err != nil {
t.Fatalf("Failed to get transactions by destination account: %v", err)
}
if len(destTransactions) != 1 {
t.Errorf("Expected 1 transaction for destination account, got %d", len(destTransactions))
}
// Verify that deleting transactions is not allowed
err = transactionStore.Delete(transaction.TxID)
if err == nil {
t.Errorf("Expected error when deleting transaction, got nil")
}
}
// TestSerializationDeserialization tests the serialization and deserialization of models
func TestSerializationDeserialization(t *testing.T) {
// Test User serialization/deserialization
user := &User{
ID: 42,
Key: "test_user",
Accounts: []uint32{1, 2, 3},
}
userData := user.Serialize()
deserializedUser, err := DeserializeUser(userData)
if err != nil {
t.Fatalf("Failed to deserialize user: %v", err)
}
if deserializedUser.ID != user.ID {
t.Errorf("Expected user ID %d, got %d", user.ID, deserializedUser.ID)
}
if deserializedUser.Key != user.Key {
t.Errorf("Expected user key %s, got %s", user.Key, deserializedUser.Key)
}
if len(deserializedUser.Accounts) != len(user.Accounts) {
t.Errorf("Expected %d accounts, got %d", len(user.Accounts), len(deserializedUser.Accounts))
}
for i, accountID := range user.Accounts {
if deserializedUser.Accounts[i] != accountID {
t.Errorf("Expected account ID %d at index %d, got %d", accountID, i, deserializedUser.Accounts[i])
}
}
// Test Account serialization/deserialization
account := &Account{
ID: 24,
UserID: 42,
Name: "Test Account",
Currency: "USD",
Amount: 123.45,
}
accountData := account.Serialize()
deserializedAccount, err := DeserializeAccount(accountData)
if err != nil {
t.Fatalf("Failed to deserialize account: %v", err)
}
if deserializedAccount.ID != account.ID {
t.Errorf("Expected account ID %d, got %d", account.ID, deserializedAccount.ID)
}
if deserializedAccount.UserID != account.UserID {
t.Errorf("Expected account user ID %d, got %d", account.UserID, deserializedAccount.UserID)
}
if deserializedAccount.Name != account.Name {
t.Errorf("Expected account name %s, got %s", account.Name, deserializedAccount.Name)
}
if deserializedAccount.Currency != account.Currency {
t.Errorf("Expected account currency %s, got %s", account.Currency, deserializedAccount.Currency)
}
if deserializedAccount.Amount != account.Amount {
t.Errorf("Expected account amount %f, got %f", account.Amount, deserializedAccount.Amount)
}
// Test Transaction serialization/deserialization
transaction := &Transaction{
TxID: 123,
From: 42,
To: 24,
Comment: "Test Transaction",
Amount: 67.89,
}
transactionData := transaction.Serialize()
deserializedTransaction, err := DeserializeTransaction(transactionData)
if err != nil {
t.Fatalf("Failed to deserialize transaction: %v", err)
}
if deserializedTransaction.TxID != transaction.TxID {
t.Errorf("Expected transaction ID %d, got %d", transaction.TxID, deserializedTransaction.TxID)
}
if deserializedTransaction.From != transaction.From {
t.Errorf("Expected transaction from %d, got %d", transaction.From, deserializedTransaction.From)
}
if deserializedTransaction.To != transaction.To {
t.Errorf("Expected transaction to %d, got %d", transaction.To, deserializedTransaction.To)
}
if deserializedTransaction.Comment != transaction.Comment {
t.Errorf("Expected transaction comment %s, got %s", transaction.Comment, deserializedTransaction.Comment)
}
if deserializedTransaction.Amount != transaction.Amount {
t.Errorf("Expected transaction amount %f, got %f", transaction.Amount, deserializedTransaction.Amount)
}
}

View File

@@ -0,0 +1,100 @@
package models
import (
"encoding/binary"
"errors"
"math"
)
// Transaction represents a financial transaction in the billing system
type Transaction struct {
TxID uint32 // Transaction ID, auto-incremented
From uint32 // Source account ID
To uint32 // Destination account ID
Comment string // Transaction comment
Amount float64 // Transaction amount
}
// Serialize converts a Transaction to a binary representation
func (t *Transaction) Serialize() []byte {
// Calculate the size of the serialized data
// 4 bytes for TxID + 4 bytes for From + 4 bytes for To +
// 2 bytes for comment length + len(comment) bytes + 8 bytes for amount
size := 4 + 4 + 4 + 2 + len(t.Comment) + 8
data := make([]byte, size)
// Write TxID (4 bytes)
binary.LittleEndian.PutUint32(data[0:4], t.TxID)
// Write From (4 bytes)
binary.LittleEndian.PutUint32(data[4:8], t.From)
// Write To (4 bytes)
binary.LittleEndian.PutUint32(data[8:12], t.To)
// Write comment length (2 bytes) and comment
commentLen := uint16(len(t.Comment))
binary.LittleEndian.PutUint16(data[12:14], commentLen)
copy(data[14:14+commentLen], t.Comment)
// Write amount (8 bytes)
amountOffset := 14 + int(commentLen)
binary.LittleEndian.PutUint64(data[amountOffset:amountOffset+8], math.Float64bits(t.Amount))
return data
}
// DeserializeTransaction converts a binary representation back to a Transaction
func DeserializeTransaction(data []byte) (*Transaction, error) {
if len(data) < 22 { // Minimum size: 4 (TxID) + 4 (From) + 4 (To) + 2 (comment length) + 8 (amount)
return nil, errors.New("data too short to deserialize Transaction")
}
transaction := &Transaction{}
// Read TxID
transaction.TxID = binary.LittleEndian.Uint32(data[0:4])
// Read From
transaction.From = binary.LittleEndian.Uint32(data[4:8])
// Read To
transaction.To = binary.LittleEndian.Uint32(data[8:12])
// Read comment length and comment
commentLen := binary.LittleEndian.Uint16(data[12:14])
if 14+commentLen > uint16(len(data)) {
return nil, errors.New("data too short to read comment")
}
transaction.Comment = string(data[14 : 14+commentLen])
// Read amount
amountOffset := 14 + commentLen
if int(amountOffset)+8 > len(data) {
return nil, errors.New("data too short to read amount")
}
transaction.Amount = math.Float64frombits(binary.LittleEndian.Uint64(data[amountOffset : amountOffset+8]))
return transaction, nil
}
// Validate checks if the transaction data is valid
func (t *Transaction) Validate() error {
if t.From == 0 {
return errors.New("source account ID cannot be zero")
}
if t.To == 0 {
return errors.New("destination account ID cannot be zero")
}
if t.From == t.To {
return errors.New("source and destination accounts cannot be the same")
}
if t.Amount <= 0 {
return errors.New("amount must be positive")
}
return nil
}

View File

@@ -0,0 +1,78 @@
package models
import (
"encoding/binary"
"errors"
)
// User represents a user in the billing system
type User struct {
ID uint32 // Unique identifier, auto-incremented
Key string // Unique key for the user
Accounts []uint32 // Links to the accounts a user has
}
// Serialize converts a User to a binary representation
func (u *User) Serialize() []byte {
// Calculate the size of the serialized data
// 4 bytes for ID + 2 bytes for key length + len(key) bytes + 2 bytes for accounts length + 4 bytes per account
size := 4 + 2 + len(u.Key) + 2 + (4 * len(u.Accounts))
data := make([]byte, size)
// Write ID (4 bytes)
binary.LittleEndian.PutUint32(data[0:4], u.ID)
// Write key length (2 bytes) and key
keyLen := uint16(len(u.Key))
binary.LittleEndian.PutUint16(data[4:6], keyLen)
copy(data[6:6+keyLen], u.Key)
// Write accounts length (2 bytes) and accounts
accountsOffset := 6 + keyLen
accountsLen := uint16(len(u.Accounts))
binary.LittleEndian.PutUint16(data[accountsOffset:accountsOffset+2], accountsLen)
for i, accountID := range u.Accounts {
offset := int(accountsOffset) + 2 + (4 * i)
binary.LittleEndian.PutUint32(data[offset:offset+4], accountID)
}
return data
}
// DeserializeUser converts a binary representation back to a User
func DeserializeUser(data []byte) (*User, error) {
if len(data) < 8 { // Minimum size: 4 (ID) + 2 (key length) + 2 (accounts length)
return nil, errors.New("data too short to deserialize User")
}
user := &User{}
// Read ID
user.ID = binary.LittleEndian.Uint32(data[0:4])
// Read key length and key
keyLen := binary.LittleEndian.Uint16(data[4:6])
if 6+keyLen > uint16(len(data)) {
return nil, errors.New("data too short to read key")
}
user.Key = string(data[6 : 6+keyLen])
// Read accounts length and accounts
accountsOffset := 6 + keyLen
if accountsOffset+2 > uint16(len(data)) {
return nil, errors.New("data too short to read accounts length")
}
accountsLen := binary.LittleEndian.Uint16(data[accountsOffset : accountsOffset+2])
user.Accounts = make([]uint32, accountsLen)
for i := uint16(0); i < accountsLen; i++ {
offset := accountsOffset + 2 + (4 * i)
if offset+4 > uint16(len(data)) {
return nil, errors.New("data too short to read account ID")
}
user.Accounts[i] = binary.LittleEndian.Uint32(data[offset : offset+4])
}
return user, nil
}