mirror of
https://github.com/chenasraf/cospend-cli.git
synced 2026-05-18 01:39:03 +00:00
feat: add browser login method
This commit is contained in:
15
README.md
15
README.md
@@ -20,6 +20,7 @@ add and list expenses directly from your terminal without opening the web interf
|
||||
- **Currency code support** (e.g., `usd`, `eur`, `gbp`) with automatic symbol resolution
|
||||
- **Local caching** of project data with 1-hour TTL for faster subsequent calls
|
||||
- **Global project flag** - set `-p` before the command for easy shell aliases
|
||||
- **Secure browser login** - OAuth-style authentication with 2FA support
|
||||
- Cross-platform support: **macOS**, **Linux**, and **Windows**
|
||||
|
||||
---
|
||||
@@ -57,7 +58,19 @@ Run the interactive setup wizard:
|
||||
cospend init
|
||||
```
|
||||
|
||||
This will prompt for your Nextcloud credentials and save them to a config file.
|
||||
This will prompt for your Nextcloud domain and let you choose an authentication method using an
|
||||
interactive selector (use arrow keys or j/k to navigate, Enter to select):
|
||||
|
||||
```
|
||||
Choose login method:
|
||||
> Browser login (recommended) - Opens browser for secure authentication
|
||||
Password/App token - Enter credentials manually
|
||||
```
|
||||
|
||||
- **Browser login (recommended)** - Opens your browser for secure OAuth-style authentication.
|
||||
Handles 2FA automatically and generates an app-specific password.
|
||||
|
||||
- **Password/App token** - Enter your credentials manually (useful for headless servers).
|
||||
|
||||
You can specify the config format with `--format`:
|
||||
|
||||
|
||||
335
cmd/init.go
335
cmd/init.go
@@ -2,9 +2,15 @@ package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/chenasraf/cospend-cli/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -48,6 +54,7 @@ func runInit(cmd *cobra.Command, _ []string) error {
|
||||
cmd.SilenceUsage = true
|
||||
|
||||
// Check if config already exists
|
||||
var overwritePath string
|
||||
if existingPath := config.GetConfigPath(); existingPath != "" {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config file already exists: %s\n", existingPath)
|
||||
overwrite, err := promptYesNo(cmd, "Overwrite?")
|
||||
@@ -58,41 +65,60 @@ func runInit(cmd *cobra.Command, _ []string) error {
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Aborted.")
|
||||
return nil
|
||||
}
|
||||
// Remove existing config
|
||||
if err := os.Remove(existingPath); err != nil {
|
||||
return fmt.Errorf("removing existing config: %w", err)
|
||||
}
|
||||
overwritePath = existingPath
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Setting up Cospend CLI configuration...")
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout())
|
||||
|
||||
// Prompt for domain
|
||||
domain, err := promptString(cmd, "Nextcloud domain (e.g., https://cloud.example.com)")
|
||||
domain, err := promptString(cmd, "Nextcloud domain (e.g., cloud.example.com)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
domain = strings.TrimRight(domain, "/")
|
||||
|
||||
// Prompt for username
|
||||
user, err := promptString(cmd, "Username")
|
||||
// Auto-prepend https:// if no scheme provided
|
||||
domainLower := strings.ToLower(domain)
|
||||
if !strings.HasPrefix(domainLower, "http://") && !strings.HasPrefix(domainLower, "https://") {
|
||||
domain = "https://" + domain
|
||||
}
|
||||
|
||||
// Choose login method
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout())
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Choose login method:")
|
||||
|
||||
options := []selectOption{
|
||||
{label: "Browser login (recommended)", description: "Opens browser for secure authentication"},
|
||||
{label: "Password/App token", description: "Enter credentials manually"},
|
||||
}
|
||||
|
||||
selected, err := promptSelect(cmd, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prompt for password (hidden input)
|
||||
password, err := promptPassword(cmd, "Password (or app token)")
|
||||
if err != nil {
|
||||
return err
|
||||
var cfg *config.Config
|
||||
|
||||
switch selected {
|
||||
case 0:
|
||||
cfg, err = loginFlowAuth(cmd, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case 1:
|
||||
cfg, err = passwordAuth(cmd, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Domain: domain,
|
||||
User: user,
|
||||
Password: password,
|
||||
var path string
|
||||
if overwritePath != "" {
|
||||
path, err = config.SaveToPath(cfg, overwritePath)
|
||||
} else {
|
||||
path, err = config.Save(cfg, configFormat)
|
||||
}
|
||||
|
||||
path, err := config.Save(cfg, configFormat)
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving config: %w", err)
|
||||
}
|
||||
@@ -147,3 +173,278 @@ func promptYesNo(cmd *cobra.Command, prompt string) (bool, error) {
|
||||
input = strings.TrimSpace(strings.ToLower(input))
|
||||
return input == "y" || input == "yes", nil
|
||||
}
|
||||
|
||||
// selectOption represents an option in a select prompt
|
||||
type selectOption struct {
|
||||
label string
|
||||
description string
|
||||
}
|
||||
|
||||
// promptSelect displays an interactive select menu and returns the selected index
|
||||
func promptSelect(cmd *cobra.Command, options []selectOption) (int, error) {
|
||||
// Check if we're in a terminal
|
||||
f, ok := cmd.InOrStdin().(*os.File)
|
||||
if !ok || !term.IsTerminal(int(f.Fd())) {
|
||||
// Fallback to simple numbered input for non-terminal
|
||||
return promptSelectFallback(cmd, options)
|
||||
}
|
||||
|
||||
selected := 0
|
||||
out := cmd.OutOrStdout()
|
||||
|
||||
// Save terminal state and set raw mode
|
||||
oldState, err := term.MakeRaw(int(f.Fd()))
|
||||
if err != nil {
|
||||
return promptSelectFallback(cmd, options)
|
||||
}
|
||||
defer func() { _ = term.Restore(int(f.Fd()), oldState) }()
|
||||
|
||||
// Hide cursor
|
||||
_, _ = fmt.Fprint(out, "\033[?25l")
|
||||
defer func() { _, _ = fmt.Fprint(out, "\033[?25h") }() // Show cursor on exit
|
||||
|
||||
renderOptions := func() {
|
||||
for i, opt := range options {
|
||||
if i == selected {
|
||||
_, _ = fmt.Fprintf(out, "\r\033[K \033[36m>\033[0m \033[1m%s\033[0m - %s\n", opt.label, opt.description)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(out, "\r\033[K %s - %s\n", opt.label, opt.description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Move cursor up helper
|
||||
moveUp := func(n int) {
|
||||
if n > 0 {
|
||||
_, _ = fmt.Fprintf(out, "\033[%dA", n)
|
||||
}
|
||||
}
|
||||
|
||||
renderOptions()
|
||||
|
||||
buf := make([]byte, 3)
|
||||
for {
|
||||
moveUp(len(options))
|
||||
renderOptions()
|
||||
|
||||
n, err := f.Read(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Handle input
|
||||
if n == 1 {
|
||||
switch buf[0] {
|
||||
case 13, 10: // Enter
|
||||
_, _ = fmt.Fprintln(out)
|
||||
return selected, nil
|
||||
case 3: // Ctrl+C
|
||||
_, _ = fmt.Fprintln(out)
|
||||
return 0, fmt.Errorf("cancelled")
|
||||
case 'j', 'J': // vim down
|
||||
selected = (selected + 1) % len(options)
|
||||
case 'k', 'K': // vim up
|
||||
selected = (selected - 1 + len(options)) % len(options)
|
||||
}
|
||||
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
|
||||
// Arrow keys: ESC [ A/B
|
||||
switch buf[2] {
|
||||
case 65: // Up
|
||||
selected = (selected - 1 + len(options)) % len(options)
|
||||
case 66: // Down
|
||||
selected = (selected + 1) % len(options)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// promptSelectFallback is a simple numbered fallback for non-terminals
|
||||
func promptSelectFallback(cmd *cobra.Command, options []selectOption) (int, error) {
|
||||
out := cmd.OutOrStdout()
|
||||
for i, opt := range options {
|
||||
_, _ = fmt.Fprintf(out, " %d. %s - %s\n", i+1, opt.label, opt.description)
|
||||
}
|
||||
_, _ = fmt.Fprintln(out)
|
||||
|
||||
choice, err := promptString(cmd, "Enter choice [1]")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if choice == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
idx := 0
|
||||
if _, err := fmt.Sscanf(choice, "%d", &idx); err != nil || idx < 1 || idx > len(options) {
|
||||
return 0, fmt.Errorf("invalid choice: %s", choice)
|
||||
}
|
||||
return idx - 1, nil
|
||||
}
|
||||
|
||||
// passwordAuth handles traditional password/app token authentication
|
||||
func passwordAuth(cmd *cobra.Command, domain string) (*config.Config, error) {
|
||||
// Prompt for username
|
||||
user, err := promptString(cmd, "Username")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Prompt for password (hidden input)
|
||||
password, err := promptPassword(cmd, "Password (or app token)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config.Config{
|
||||
Domain: domain,
|
||||
User: user,
|
||||
Password: password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loginFlowResponse represents the initial login flow response
|
||||
type loginFlowResponse struct {
|
||||
Poll struct {
|
||||
Token string `json:"token"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
} `json:"poll"`
|
||||
Login string `json:"login"`
|
||||
}
|
||||
|
||||
// loginFlowResult represents the successful poll response
|
||||
type loginFlowResult struct {
|
||||
Server string `json:"server"`
|
||||
LoginName string `json:"loginName"`
|
||||
AppPassword string `json:"appPassword"`
|
||||
}
|
||||
|
||||
const userAgent = "Cospend CLI"
|
||||
|
||||
// loginFlowAuth handles Nextcloud Login Flow v2 authentication
|
||||
func loginFlowAuth(cmd *cobra.Command, domain string) (*config.Config, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
// Step 1: Initiate login flow
|
||||
loginURL := domain + "/index.php/login/v2"
|
||||
req, err := http.NewRequest("POST", loginURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initiating login flow: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("login flow initiation failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var flowResp loginFlowResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&flowResp); err != nil {
|
||||
return nil, fmt.Errorf("parsing login flow response: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Open browser for user to authenticate
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout())
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Opening browser for authentication...")
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "If the browser doesn't open, visit this URL manually:")
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), flowResp.Login)
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout())
|
||||
|
||||
if err := openBrowser(flowResp.Login); err != nil {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: couldn't open browser: %v\n", err)
|
||||
}
|
||||
|
||||
// Step 3: Poll for authentication result
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Waiting for authentication...")
|
||||
|
||||
result, err := pollForLogin(flowResp.Poll.Endpoint, flowResp.Poll.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Authentication successful!")
|
||||
|
||||
// Use the server from the response (in case of redirects) or fall back to original domain
|
||||
serverDomain := result.Server
|
||||
if serverDomain == "" {
|
||||
serverDomain = domain
|
||||
}
|
||||
|
||||
return &config.Config{
|
||||
Domain: serverDomain,
|
||||
User: result.LoginName,
|
||||
Password: result.AppPassword,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// pollForLogin polls the login endpoint until authentication completes or times out
|
||||
func pollForLogin(endpoint, token string) (*loginFlowResult, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
deadline := time.Now().Add(20 * time.Minute) // Token valid for 20 minutes
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
data := url.Values{}
|
||||
data.Set("token", token)
|
||||
|
||||
req, err := http.NewRequest("POST", endpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var result loginFlowResult
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return nil, fmt.Errorf("parsing login result: %w", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// 404 means still waiting for user to authenticate
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected status during polling: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("authentication timed out (20 minutes)")
|
||||
}
|
||||
|
||||
// openBrowser is a function variable to allow mocking in tests
|
||||
var openBrowser = openBrowserDefault
|
||||
|
||||
// openBrowserDefault opens the given URL in the default browser
|
||||
func openBrowserDefault(url string) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "linux":
|
||||
cmd = exec.Command("xdg-open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("cmd", "/c", "start", url)
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
return cmd.Start()
|
||||
}
|
||||
|
||||
483
cmd/init_test.go
Normal file
483
cmd/init_test.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/chenasraf/cospend-cli/internal/config"
|
||||
)
|
||||
|
||||
func resetInitFlags() {
|
||||
configFormat = "json"
|
||||
}
|
||||
|
||||
// mockOpenBrowser replaces openBrowser for testing and returns a restore function
|
||||
func mockOpenBrowser() (openedURL *string, restore func()) {
|
||||
original := openBrowser
|
||||
var url string
|
||||
openBrowser = func(u string) error {
|
||||
url = u
|
||||
return nil
|
||||
}
|
||||
return &url, func() { openBrowser = original }
|
||||
}
|
||||
|
||||
func TestNewInitCommand(t *testing.T) {
|
||||
resetInitFlags()
|
||||
defer resetInitFlags()
|
||||
|
||||
cmd := NewInitCommand()
|
||||
|
||||
if cmd.Use != "init" {
|
||||
t.Errorf("Wrong Use: %s", cmd.Use)
|
||||
}
|
||||
|
||||
// Check format flag exists
|
||||
if cmd.Flags().Lookup("format") == nil {
|
||||
t.Error("Missing flag: format")
|
||||
}
|
||||
|
||||
// Check short flag
|
||||
flag := cmd.Flags().ShorthandLookup("f")
|
||||
if flag == nil {
|
||||
t.Error("Missing short flag: -f")
|
||||
} else if flag.Name != "format" {
|
||||
t.Errorf("Short flag -f maps to %s, want format", flag.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitCommandInvalidFormat(t *testing.T) {
|
||||
resetInitFlags()
|
||||
defer resetInitFlags()
|
||||
|
||||
cmd := NewInitCommand()
|
||||
cmd.SetArgs([]string{"--format", "xml"})
|
||||
|
||||
var stderr bytes.Buffer
|
||||
cmd.SetErr(&stderr)
|
||||
|
||||
// Provide empty input to avoid blocking
|
||||
cmd.SetIn(strings.NewReader("\n"))
|
||||
|
||||
err := cmd.Execute()
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid format")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported format") {
|
||||
t.Errorf("Error should mention unsupported format: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginFlowSuccess(t *testing.T) {
|
||||
// Mock openBrowser to prevent actual browser opening
|
||||
openedURL, restore := mockOpenBrowser()
|
||||
defer restore()
|
||||
|
||||
// Track the polling attempts
|
||||
pollCount := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Initial login flow request
|
||||
if r.URL.Path == "/index.php/login/v2" && r.Method == "POST" {
|
||||
// Check User-Agent
|
||||
if ua := r.Header.Get("User-Agent"); ua != "Cospend CLI" {
|
||||
t.Errorf("User-Agent = %s, want Cospend CLI", ua)
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"poll": map[string]string{
|
||||
"token": "test-token-123",
|
||||
"endpoint": "http://" + r.Host + "/login/v2/poll",
|
||||
},
|
||||
"login": "http://" + r.Host + "/login/v2/flow/abc123",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Poll endpoint
|
||||
if r.URL.Path == "/login/v2/poll" && r.Method == "POST" {
|
||||
pollCount++
|
||||
// Check User-Agent
|
||||
if ua := r.Header.Get("User-Agent"); ua != "Cospend CLI" {
|
||||
t.Errorf("Poll User-Agent = %s, want Cospend CLI", ua)
|
||||
}
|
||||
|
||||
// Check token
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Errorf("ParseForm error: %v", err)
|
||||
}
|
||||
if token := r.FormValue("token"); token != "test-token-123" {
|
||||
t.Errorf("Token = %s, want test-token-123", token)
|
||||
}
|
||||
|
||||
// Return success on first poll
|
||||
resp := map[string]string{
|
||||
"server": "https://cloud.example.com",
|
||||
"loginName": "testuser",
|
||||
"appPassword": "app-password-xyz",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
return
|
||||
}
|
||||
|
||||
t.Errorf("Unexpected request: %s %s", r.Method, r.URL.Path)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a mock command to test loginFlowAuth
|
||||
cmd := NewInitCommand()
|
||||
var stdout bytes.Buffer
|
||||
cmd.SetOut(&stdout)
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
|
||||
cfg, err := loginFlowAuth(cmd, server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("loginFlowAuth error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Domain != "https://cloud.example.com" {
|
||||
t.Errorf("Domain = %s, want https://cloud.example.com", cfg.Domain)
|
||||
}
|
||||
if cfg.User != "testuser" {
|
||||
t.Errorf("User = %s, want testuser", cfg.User)
|
||||
}
|
||||
if cfg.Password != "app-password-xyz" {
|
||||
t.Errorf("Password = %s, want app-password-xyz", cfg.Password)
|
||||
}
|
||||
|
||||
if pollCount != 1 {
|
||||
t.Errorf("Poll count = %d, want 1", pollCount)
|
||||
}
|
||||
|
||||
// Verify the correct URL was passed to openBrowser
|
||||
if !strings.Contains(*openedURL, "/login/v2/flow/abc123") {
|
||||
t.Errorf("openBrowser URL = %s, want to contain /login/v2/flow/abc123", *openedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginFlowInitError(t *testing.T) {
|
||||
// Mock openBrowser to prevent actual browser opening
|
||||
_, restore := mockOpenBrowser()
|
||||
defer restore()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cmd := NewInitCommand()
|
||||
cmd.SetOut(&bytes.Buffer{})
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
|
||||
_, err := loginFlowAuth(cmd, server.URL)
|
||||
if err == nil {
|
||||
t.Error("Expected error for failed login flow initiation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptPassword(t *testing.T) {
|
||||
// Test password prompt in non-terminal mode (fallback to regular input)
|
||||
cmd := NewInitCommand()
|
||||
var stdout bytes.Buffer
|
||||
cmd.SetOut(&stdout)
|
||||
cmd.SetIn(strings.NewReader("secretpass\n"))
|
||||
|
||||
result, err := promptPassword(cmd, "Enter password")
|
||||
if err != nil {
|
||||
t.Fatalf("promptPassword error: %v", err)
|
||||
}
|
||||
|
||||
if result != "secretpass" {
|
||||
t.Errorf("Result = %s, want secretpass", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptSelectFallback(t *testing.T) {
|
||||
cmd := NewInitCommand()
|
||||
var stdout bytes.Buffer
|
||||
cmd.SetOut(&stdout)
|
||||
|
||||
options := []selectOption{
|
||||
{label: "Option A", description: "First option"},
|
||||
{label: "Option B", description: "Second option"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int
|
||||
wantErr bool
|
||||
}{
|
||||
{"default selection", "\n", 0, false},
|
||||
{"select first", "1\n", 0, false},
|
||||
{"select second", "2\n", 1, false},
|
||||
{"invalid choice", "5\n", 0, true},
|
||||
{"invalid input", "abc\n", 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
stdout.Reset()
|
||||
cmd.SetIn(strings.NewReader(tt.input))
|
||||
|
||||
selected, err := promptSelectFallback(cmd, options)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if selected != tt.expected {
|
||||
t.Errorf("Selected = %d, want %d", selected, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainAutoPrependHTTPS(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"cloud.example.com", "https://cloud.example.com"},
|
||||
{"https://cloud.example.com", "https://cloud.example.com"},
|
||||
{"http://cloud.example.com", "http://cloud.example.com"},
|
||||
{"HTTPS://CLOUD.EXAMPLE.COM", "HTTPS://CLOUD.EXAMPLE.COM"},
|
||||
{"HTTP://cloud.example.com", "HTTP://cloud.example.com"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
domain := tt.input
|
||||
domain = strings.TrimRight(domain, "/")
|
||||
domainLower := strings.ToLower(domain)
|
||||
if !strings.HasPrefix(domainLower, "http://") && !strings.HasPrefix(domainLower, "https://") {
|
||||
domain = "https://" + domain
|
||||
}
|
||||
|
||||
if domain != tt.expected {
|
||||
t.Errorf("Domain = %s, want %s", domain, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveToPath(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Domain: "https://test.example.com",
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ext string
|
||||
}{
|
||||
{"JSON", ".json"},
|
||||
{"YAML", ".yaml"},
|
||||
{"TOML", ".toml"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := filepath.Join(tempDir, "config"+tt.ext)
|
||||
|
||||
savedPath, err := config.SaveToPath(cfg, path)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveToPath error: %v", err)
|
||||
}
|
||||
|
||||
if savedPath != path {
|
||||
t.Errorf("SaveToPath returned %s, want %s", savedPath, path)
|
||||
}
|
||||
|
||||
// Verify file exists and can be loaded
|
||||
loaded, err := config.LoadFromFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile error: %v", err)
|
||||
}
|
||||
|
||||
if loaded.Domain != cfg.Domain {
|
||||
t.Errorf("Domain = %s, want %s", loaded.Domain, cfg.Domain)
|
||||
}
|
||||
if loaded.User != cfg.User {
|
||||
t.Errorf("User = %s, want %s", loaded.User, cfg.User)
|
||||
}
|
||||
if loaded.Password != cfg.Password {
|
||||
t.Errorf("Password = %s, want %s", loaded.Password, cfg.Password)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveToPathCreatesDirectory(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
nestedPath := filepath.Join(tempDir, "nested", "dir", "config.json")
|
||||
|
||||
cfg := &config.Config{
|
||||
Domain: "https://test.example.com",
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
_, err := config.SaveToPath(cfg, nestedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveToPath error: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(nestedPath); os.IsNotExist(err) {
|
||||
t.Error("Config file was not created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveToPathUnsupportedFormat(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
path := filepath.Join(tempDir, "config.xml")
|
||||
|
||||
cfg := &config.Config{
|
||||
Domain: "https://test.example.com",
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
_, err := config.SaveToPath(cfg, path)
|
||||
if err == nil {
|
||||
t.Error("Expected error for unsupported format")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported") {
|
||||
t.Errorf("Error should mention unsupported format: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigOverwriteSameLocation(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", tempDir)
|
||||
t.Setenv("HOME", tempDir)
|
||||
|
||||
// Create initial config in a non-default location (simulating ~/.config fallback)
|
||||
customDir := filepath.Join(tempDir, "custom", "location")
|
||||
if err := os.MkdirAll(customDir, 0700); err != nil {
|
||||
t.Fatalf("Failed to create custom dir: %v", err)
|
||||
}
|
||||
customPath := filepath.Join(customDir, "cospend.yaml")
|
||||
|
||||
initialCfg := &config.Config{
|
||||
Domain: "https://initial.example.com",
|
||||
User: "initialuser",
|
||||
Password: "initialpass",
|
||||
}
|
||||
if _, err := config.SaveToPath(initialCfg, customPath); err != nil {
|
||||
t.Fatalf("Failed to save initial config: %v", err)
|
||||
}
|
||||
|
||||
// Now save updated config to the same path
|
||||
updatedCfg := &config.Config{
|
||||
Domain: "https://updated.example.com",
|
||||
User: "updateduser",
|
||||
Password: "updatedpass",
|
||||
}
|
||||
savedPath, err := config.SaveToPath(updatedCfg, customPath)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveToPath error: %v", err)
|
||||
}
|
||||
|
||||
// Verify it saved to the exact same path
|
||||
if savedPath != customPath {
|
||||
t.Errorf("SaveToPath returned %s, want %s", savedPath, customPath)
|
||||
}
|
||||
|
||||
// Verify contents were updated
|
||||
loaded, err := config.LoadFromFile(customPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile error: %v", err)
|
||||
}
|
||||
if loaded.Domain != "https://updated.example.com" {
|
||||
t.Errorf("Domain = %s, want https://updated.example.com", loaded.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenBrowserMock(t *testing.T) {
|
||||
// Test that the mock mechanism works correctly
|
||||
openedURL, restore := mockOpenBrowser()
|
||||
defer restore()
|
||||
|
||||
err := openBrowser("https://example.com/test")
|
||||
if err != nil {
|
||||
t.Errorf("Mock openBrowser returned error: %v", err)
|
||||
}
|
||||
|
||||
if *openedURL != "https://example.com/test" {
|
||||
t.Errorf("openBrowser URL = %s, want https://example.com/test", *openedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptString(t *testing.T) {
|
||||
cmd := NewInitCommand()
|
||||
var stdout bytes.Buffer
|
||||
cmd.SetOut(&stdout)
|
||||
cmd.SetIn(strings.NewReader("test input\n"))
|
||||
|
||||
result, err := promptString(cmd, "Enter value")
|
||||
if err != nil {
|
||||
t.Fatalf("promptString error: %v", err)
|
||||
}
|
||||
|
||||
if result != "test input" {
|
||||
t.Errorf("Result = %s, want 'test input'", result)
|
||||
}
|
||||
|
||||
if !strings.Contains(stdout.String(), "Enter value:") {
|
||||
t.Errorf("Prompt not shown: %s", stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptYesNo(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"y\n", true},
|
||||
{"Y\n", true},
|
||||
{"yes\n", true},
|
||||
{"YES\n", true},
|
||||
{"n\n", false},
|
||||
{"N\n", false},
|
||||
{"no\n", false},
|
||||
{"\n", false},
|
||||
{"anything\n", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
cmd := NewInitCommand()
|
||||
cmd.SetOut(&bytes.Buffer{})
|
||||
cmd.SetIn(strings.NewReader(tt.input))
|
||||
|
||||
result, err := promptYesNo(cmd, "Confirm?")
|
||||
if err != nil {
|
||||
t.Fatalf("promptYesNo error: %v", err)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Result = %v, want %v for input %q", result, tt.expected, tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -131,42 +131,62 @@ func Load() (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Save writes configuration to a file in the specified format
|
||||
// Save writes configuration to a file in the specified format in the default config directory
|
||||
func Save(cfg *Config, format string) (string, error) {
|
||||
configDir := GetConfigDir()
|
||||
if err := os.MkdirAll(configDir, 0700); err != nil {
|
||||
return "", fmt.Errorf("creating config directory: %w", err)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var ext string
|
||||
var err error
|
||||
|
||||
switch format {
|
||||
case "json":
|
||||
ext = ".json"
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding JSON: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
case "yaml", "yml":
|
||||
ext = ".yaml"
|
||||
data, err = yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding YAML: %w", err)
|
||||
}
|
||||
case "toml":
|
||||
ext = ".toml"
|
||||
data, err = tomlMarshal(cfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding TOML: %w", err)
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported format: %s", format)
|
||||
}
|
||||
|
||||
path := filepath.Join(configDir, appName+ext)
|
||||
return SaveToPath(cfg, path)
|
||||
}
|
||||
|
||||
// SaveToPath writes configuration to a specific file path (format determined by extension)
|
||||
func SaveToPath(cfg *Config, path string) (string, error) {
|
||||
// Ensure parent directory exists
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return "", fmt.Errorf("creating config directory: %w", err)
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
ext := filepath.Ext(path)
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding JSON: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
case ".yaml", ".yml":
|
||||
data, err = yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding YAML: %w", err)
|
||||
}
|
||||
case ".toml":
|
||||
data, err = tomlMarshal(cfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("encoding TOML: %w", err)
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported config format: %s", ext)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return "", fmt.Errorf("writing config file: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user