From e55912e25912b9899b53bc5628170df8ade2f04d Mon Sep 17 00:00:00 2001 From: Chen Asraf Date: Mon, 9 Feb 2026 00:26:59 +0200 Subject: [PATCH] fix: currency conversion --- cmd/add.go | 43 +++++++++++++++++++----------------- cmd/add_test.go | 8 +++---- internal/cache/cache.go | 24 ++++++++++---------- internal/cache/cache_test.go | 6 ++--- 4 files changed, 42 insertions(+), 39 deletions(-) diff --git a/cmd/add.go b/cmd/add.go index c8acec9..a37ac68 100644 --- a/cmd/add.go +++ b/cmd/add.go @@ -138,25 +138,6 @@ func runAdd(cmd *cobra.Command, args []string) error { bill.PaymentModeID = methodID } - // Resolve optional currency - if convertTo != "" { - currencyID, err := cache.ResolveCurrency(project, convertTo) - if err != nil { - return fmt.Errorf("resolving currency: %w", err) - } - bill.OriginalCurrencyID = currencyID - } - - // Add optional comment - if comment != "" { - bill.Comment = comment - } - - // Create the bill - if err := client.CreateBill(ProjectID, bill); err != nil { - return fmt.Errorf("creating bill: %w", err) - } - // Fetch user info for locale-aware formatting locale := "en_US" userInfo, ok := cache.LoadUserInfo() @@ -172,7 +153,29 @@ func runAdd(cmd *cobra.Command, args []string) error { locale = userInfo.Language } + // Resolve optional currency and convert amount + if convertTo != "" { + currency, err := cache.ResolveCurrency(project, convertTo) + if err != nil { + return fmt.Errorf("resolving currency: %w", err) + } + bill.OriginalCurrencyID = currency.ID + bill.Amount = amount * currency.ExchangeRate + origFormatter := format.NewAmountFormatter(locale, currency.Name) + bill.What = fmt.Sprintf("%s (%s)", expenseName, origFormatter.Format(amount)) + } + + // Add optional comment + if comment != "" { + bill.Comment = comment + } + + // Create the bill + if err := client.CreateBill(ProjectID, bill); err != nil { + return fmt.Errorf("creating bill: %w", err) + } + formatter := format.NewAmountFormatter(locale, project.CurrencyName) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Successfully added expense: %s (%s)\n", expenseName, formatter.Format(amount)) + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Successfully added expense: %s (%s)\n", expenseName, formatter.Format(bill.Amount)) return nil } diff --git a/cmd/add_test.go b/cmd/add_test.go index ecd7519..a64466a 100644 --- a/cmd/add_test.go +++ b/cmd/add_test.go @@ -233,7 +233,7 @@ func TestAddCommandWithAllFlags(t *testing.T) { {ID: 3, Name: "Credit Card"}, }, Currencies: []api.Currency{ - {ID: 2, Name: "€"}, + {ID: 2, Name: "€", ExchangeRate: 0.85}, }, } @@ -288,11 +288,11 @@ func TestAddCommandWithAllFlags(t *testing.T) { } // Verify bill data - if receivedBill["what"] != "Dinner" { + if receivedBill["what"] != "Dinner (€ 45.00)" { t.Errorf("Wrong what: %s", receivedBill["what"]) } - if receivedBill["amount"] != "45.00" { - t.Errorf("Wrong amount: %s", receivedBill["amount"]) + if receivedBill["amount"] != "38.25" { // 45.00 * 0.85 exchange rate + t.Errorf("Wrong amount: got %s, want 38.25 (45.00 * 0.85)", receivedBill["amount"]) } if receivedBill["payer"] != "2" { // Alice's ID t.Errorf("Wrong payer: %s", receivedBill["payer"]) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 9bc4857..a408386 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -346,33 +346,33 @@ func ResolvePaymentMode(project *api.Project, nameOrID string) (int, error) { return 0, fmt.Errorf("payment mode not found: %s", nameOrID) } -// ResolveCurrency finds a currency by name (case-insensitive), ID, or currency code symbol and returns the ID -func ResolveCurrency(project *api.Project, nameOrID string) (int, error) { +// ResolveCurrency finds a currency by name (case-insensitive), ID, or currency code symbol and returns the currency +func ResolveCurrency(project *api.Project, nameOrID string) (*api.Currency, error) { // Try parsing as ID first if id, err := strconv.Atoi(nameOrID); err == nil { - for _, cur := range project.Currencies { - if cur.ID == id { - return id, nil + for i := range project.Currencies { + if project.Currencies[i].ID == id { + return &project.Currencies[i], nil } } } // Try matching by name (case-insensitive) lowerName := strings.ToLower(nameOrID) - for _, cur := range project.Currencies { - if strings.ToLower(cur.Name) == lowerName { - return cur.ID, nil + for i := range project.Currencies { + if strings.ToLower(project.Currencies[i].Name) == lowerName { + return &project.Currencies[i], nil } } // Try matching by currency code symbol (e.g., "usd" -> "$") if symbol, ok := currencyCodeToSymbol[lowerName]; ok { - for _, cur := range project.Currencies { - if strings.Contains(cur.Name, symbol) { - return cur.ID, nil + for i := range project.Currencies { + if strings.Contains(project.Currencies[i].Name, symbol) { + return &project.Currencies[i], nil } } } - return 0, fmt.Errorf("currency not found: %s", nameOrID) + return nil, fmt.Errorf("currency not found: %s", nameOrID) } diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 90d4e5a..625fa7f 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -164,13 +164,13 @@ func TestResolveCurrency(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotID, err := ResolveCurrency(project, tt.nameOrID) + got, err := ResolveCurrency(project, tt.nameOrID) if (err != nil) != tt.wantErr { t.Errorf("ResolveCurrency() error = %v, wantErr %v", err, tt.wantErr) return } - if gotID != tt.wantID { - t.Errorf("ResolveCurrency() = %v, want %v", gotID, tt.wantID) + if err == nil && got.ID != tt.wantID { + t.Errorf("ResolveCurrency() ID = %v, want %v", got.ID, tt.wantID) } }) }