From 79c1e93c005ebeccb31b8482da9d71a446af9b5f Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Sun, 15 Mar 2026 11:47:31 -0700 Subject: [PATCH] bench: improve benchmarking tool (#14240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New features: - Warmup phase to eliminate cold-start outliers - time-to-first-token measured in each epoch - VRAM/memory tracking to identify CPU spillover - Controlled prompt length - Defaults to 6 epochs and 200 tokens max Benchstat fixes: - ns/request instead of ns/op — non-standard unit created a separate group instead of grouping with timing metrics - Token count as the N field — benchstat interprets N as iteration count for statistical weighting, not as a token count --- cmd/bench/README.md | 98 ++-- cmd/bench/bench.go | 453 ++++++++++----- cmd/bench/bench_test.go | 1229 ++++++++++++++++++++++++++++++++++----- 3 files changed, 1471 insertions(+), 309 deletions(-) diff --git a/cmd/bench/README.md b/cmd/bench/README.md index cf261dd0f..32d9572b6 100644 --- a/cmd/bench/README.md +++ b/cmd/bench/README.md @@ -1,27 +1,31 @@ Ollama Benchmark Tool --------------------- -A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats. +A Go-based command-line tool for benchmarking Ollama models with configurable parameters, warmup phases, TTFT tracking, VRAM monitoring, and benchstat/CSV output. ## Features * Benchmark multiple models in a single run * Support for both text and image prompts * Configurable generation parameters (temperature, max tokens, seed, etc.) - * Supports benchstat and CSV output formats - * Detailed performance metrics (prefill, generate, load, total durations) + * Warmup phase before timed epochs to stabilize measurements + * Time-to-first-token (TTFT) tracking per epoch + * Model metadata display (parameter size, quantization level, family) + * VRAM and CPU memory usage tracking via running process info + * Controlled prompt token length for reproducible benchmarks + * Benchstat and CSV output formats ## Building from Source ``` -go build -o ollama-bench bench.go -./ollama-bench -model gpt-oss:20b -epochs 6 -format csv +go build -o ollama-bench ./cmd/bench +./ollama-bench -model gemma3 -epochs 6 -format csv ``` Using Go Run (without building) ``` -go run bench.go -model gpt-oss:20b -epochs 3 +go run ./cmd/bench -model gemma3 -epochs 3 ``` ## Usage @@ -45,10 +49,16 @@ benchstat -col /name gemma.bench ./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image" ``` +### Controlled Prompt Length + +``` +./ollama-bench -model gemma3 -epochs 6 -prompt-tokens 512 +``` + ### Advanced Example ``` -./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv +./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -warmup 2 -format csv -output results.csv ``` ## Command Line Options @@ -56,41 +66,48 @@ benchstat -col /name gemma.bench | Option | Description | Default | |----------|-------------|---------| | -model | Comma-separated list of models to benchmark | (required) | -| -epochs | Number of iterations per model | 1 | -| -max-tokens | Maximum tokens for model response | 0 (unlimited) | +| -epochs | Number of iterations per model | 6 | +| -max-tokens | Maximum tokens for model response | 200 | | -temperature | Temperature parameter | 0.0 | | -seed | Random seed | 0 (random) | | -timeout | Timeout in seconds | 300 | -| -p | Prompt text | "Write a long story." | +| -p | Prompt text | (default story prompt) | | -image | Image file to include in prompt | | | -k | Keep-alive duration in seconds | 0 | | -format | Output format (benchstat, csv) | benchstat | | -output | Output file for results | "" (stdout) | +| -warmup | Number of warmup requests before timing | 1 | +| -prompt-tokens | Generate prompt targeting ~N tokens (0 = use -p) | 0 | | -v | Verbose mode | false | | -debug | Show debug information | false | ## Output Formats -### Markdown Format +### Benchstat Format (default) -The default markdown format is suitable for copying and pasting into a GitHub issue and will look like: -``` - Model | Step | Count | Duration | nsPerToken | tokensPerSec | -|-------|------|-------|----------|------------|--------------| -| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 | -| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 | -| gpt-oss:20b | load | 1 | 121.674208ms | - | - | -| gpt-oss:20b | total | 1 | 2.861047625s | - | - | -``` - -### Benchstat Format - -Compatible with Go's benchstat tool for statistical analysis: +Compatible with Go's benchstat tool for statistical analysis. Uses one value/unit pair per line, standard `ns/op` for timing metrics, and `ns/token` for throughput. Each epoch produces one set of lines -- benchstat aggregates across repeated runs to compute statistics. ``` -BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec -BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec -BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request +# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931 +BenchmarkModel/name=gemma3/step=prefill 1 78125.00 ns/token 12800.00 token/sec +BenchmarkModel/name=gemma3/step=generate 1 19531.25 ns/token 51200.00 token/sec +BenchmarkModel/name=gemma3/step=ttft 1 45123000 ns/op +BenchmarkModel/name=gemma3/step=load 1 1500000000 ns/op +BenchmarkModel/name=gemma3/step=total 1 2861047625 ns/op +``` + +Use with benchstat: +``` +./ollama-bench -model gemma3 -epochs 6 > gemma3.bench +benchstat -col /step gemma3.bench +``` + +Compare two runs: +``` +./ollama-bench -model gemma3 -epochs 6 > before.bench +# ... make changes ... +./ollama-bench -model gemma3 -epochs 6 > after.bench +benchstat before.bench after.bench ``` ### CSV Format @@ -99,17 +116,28 @@ Machine-readable comma-separated values: ``` NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC -gpt-oss:20b,prefill,128,78125.00,12800.00 -gpt-oss:20b,generate,512,19531.25,51200.00 -gpt-oss:20b,load,1,1500000000,0 +# Model: gemma3 | Params: 4.3B | Quant: Q4_K_M | Family: gemma3 | Size: 4080218931 | VRAM: 4080218931 +gemma3,prefill,128,78125.00,12800.00 +gemma3,generate,512,19531.25,51200.00 +gemma3,ttft,1,45123000,0 +gemma3,load,1,1500000000,0 +gemma3,total,1,2861047625,0 ``` ## Metrics Explained -The tool reports four types of metrics for each model: +The tool reports the following metrics for each epoch: - * prefill: Time spent processing the prompt - * generate: Time spent generating the response - * load: Model loading time (one-time cost) - * total: Total request duration + * **prefill**: Time spent processing the prompt (ns/token) + * **generate**: Time spent generating the response (ns/token) + * **ttft**: Time to first token -- latency from request start to first response content + * **load**: Model loading time (one-time cost) + * **total**: Total request duration +Additionally, the model info comment line (displayed once per model before epochs) includes: + + * **Params**: Model parameter count (e.g., 4.3B) + * **Quant**: Quantization level (e.g., Q4_K_M) + * **Family**: Model family (e.g., gemma3) + * **Size**: Total model memory in bytes + * **VRAM**: GPU memory used by the loaded model (when Size > VRAM, the difference is CPU spill) diff --git a/cmd/bench/bench.go b/cmd/bench/bench.go index 53721f877..d6ea0ade2 100644 --- a/cmd/bench/bench.go +++ b/cmd/bench/bench.go @@ -17,19 +17,21 @@ import ( ) type flagOptions struct { - models *string - epochs *int - maxTokens *int - temperature *float64 - seed *int - timeout *int - prompt *string - imageFile *string - keepAlive *float64 - format *string - outputFile *string - debug *bool - verbose *bool + models *string + epochs *int + maxTokens *int + temperature *float64 + seed *int + timeout *int + prompt *string + imageFile *string + keepAlive *float64 + format *string + outputFile *string + debug *bool + verbose *bool + warmup *int + promptTokens *int } type Metrics struct { @@ -39,48 +41,169 @@ type Metrics struct { Duration time.Duration } -var once sync.Once +type ModelInfo struct { + Name string + ParameterSize string + QuantizationLevel string + Family string + SizeBytes int64 + VRAMBytes int64 +} const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.` +// Word list for generating prompts targeting a specific token count. +var promptWordList = []string{ + "the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", + "a", "bright", "sunny", "day", "in", "the", "meadow", "where", + "flowers", "bloom", "and", "birds", "sing", "their", "morning", + "songs", "while", "gentle", "breeze", "carries", "sweet", "scent", + "of", "pine", "trees", "across", "rolling", "hills", "toward", + "distant", "mountains", "covered", "with", "fresh", "snow", + "beneath", "clear", "blue", "sky", "children", "play", "near", + "old", "stone", "bridge", "that", "crosses", "winding", "river", +} + +func generatePromptForTokenCount(targetTokens int, epoch int) string { + // ~1.3 tokens per word heuristic + targetWords := int(float64(targetTokens) / 1.3) + if targetWords < 1 { + targetWords = 1 + } + + // Vary the starting offset by epoch to defeat KV cache prefix matching + offset := epoch * 7 // stride by a prime to get good distribution + n := len(promptWordList) + words := make([]string, targetWords) + for i := range words { + words[i] = promptWordList[((i+offset)%n+n)%n] + } + return strings.Join(words, " ") +} + +func buildGenerateRequest(model string, fOpt flagOptions, imgData api.ImageData, epoch int) *api.GenerateRequest { + options := make(map[string]interface{}) + if *fOpt.maxTokens > 0 { + options["num_predict"] = *fOpt.maxTokens + } + options["temperature"] = *fOpt.temperature + if fOpt.seed != nil && *fOpt.seed > 0 { + options["seed"] = *fOpt.seed + } + + var keepAliveDuration *api.Duration + if *fOpt.keepAlive > 0 { + duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))} + keepAliveDuration = &duration + } + + prompt := *fOpt.prompt + if *fOpt.promptTokens > 0 { + prompt = generatePromptForTokenCount(*fOpt.promptTokens, epoch) + } else { + // Vary the prompt per epoch to defeat KV cache prefix matching + prompt = fmt.Sprintf("[%d] %s", epoch, prompt) + } + + req := &api.GenerateRequest{ + Model: model, + Prompt: prompt, + Raw: true, + Options: options, + KeepAlive: keepAliveDuration, + } + + if imgData != nil { + req.Images = []api.ImageData{imgData} + } + + return req +} + +func fetchModelInfo(ctx context.Context, client *api.Client, model string) ModelInfo { + info := ModelInfo{Name: model} + resp, err := client.Show(ctx, &api.ShowRequest{Model: model}) + if err != nil { + fmt.Fprintf(os.Stderr, "WARNING: Could not fetch model info for '%s': %v\n", model, err) + return info + } + info.ParameterSize = resp.Details.ParameterSize + info.QuantizationLevel = resp.Details.QuantizationLevel + info.Family = resp.Details.Family + return info +} + +func fetchMemoryUsage(ctx context.Context, client *api.Client, model string) (size, vram int64) { + resp, err := client.ListRunning(ctx) + if err != nil { + if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + fmt.Fprintf(os.Stderr, "WARNING: Could not fetch memory usage: %v\n", err) + } + return 0, 0 + } + for _, m := range resp.Models { + if m.Name == model || m.Model == model { + return m.Size, m.SizeVRAM + } + } + // Try prefix match (model names may include :latest or tags) + for _, m := range resp.Models { + if strings.HasPrefix(m.Name, model) || strings.HasPrefix(m.Model, model) { + return m.Size, m.SizeVRAM + } + } + return 0, 0 +} + +func outputFormatHeader(w io.Writer, format string, verbose bool) { + switch format { + case "benchstat": + if verbose { + fmt.Fprintf(w, "goos: %s\n", runtime.GOOS) + fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH) + } + case "csv": + headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"} + fmt.Fprintln(w, strings.Join(headings, ",")) + } +} + +func outputModelInfo(w io.Writer, format string, info ModelInfo) { + params := cmp.Or(info.ParameterSize, "unknown") + quant := cmp.Or(info.QuantizationLevel, "unknown") + family := cmp.Or(info.Family, "unknown") + + memStr := "" + if info.SizeBytes > 0 { + memStr = fmt.Sprintf(" | Size: %d | VRAM: %d", info.SizeBytes, info.VRAMBytes) + } + fmt.Fprintf(w, "# Model: %s | Params: %s | Quant: %s | Family: %s%s\n", + info.Name, params, quant, family, memStr) +} + func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) { switch format { case "benchstat": - if verbose { - printHeader := func() { - fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS) - fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH) - } - once.Do(printHeader) - } for _, m := range metrics { if m.Step == "generate" || m.Step == "prefill" { if m.Count > 0 { nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count) tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9 - - fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n", - m.Model, m.Step, m.Count, nsPerToken, tokensPerSec) + fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %.2f ns/token %.2f token/sec\n", + m.Model, m.Step, nsPerToken, tokensPerSec) } else { - fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n", - m.Model, m.Step, m.Count) + fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 0 ns/token 0 token/sec\n", + m.Model, m.Step) } + } else if m.Step == "ttft" { + fmt.Fprintf(w, "BenchmarkModel/name=%s/step=ttft 1 %d ns/op\n", + m.Model, m.Duration.Nanoseconds()) } else { - var suffix string - if m.Step == "load" { - suffix = "/step=load" - } - fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n", - m.Model, suffix, m.Duration.Nanoseconds()) + fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s 1 %d ns/op\n", + m.Model, m.Step, m.Duration.Nanoseconds()) } } case "csv": - printHeader := func() { - headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"} - fmt.Fprintln(w, strings.Join(headings, ",")) - } - once.Do(printHeader) - for _, m := range metrics { if m.Step == "generate" || m.Step == "prefill" { var nsPerToken float64 @@ -94,39 +217,14 @@ func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds()) } } - case "markdown": - printHeader := func() { - fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |") - fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|") - } - once.Do(printHeader) - - for _, m := range metrics { - var nsPerToken, tokensPerSec float64 - var nsPerTokenStr, tokensPerSecStr string - - if m.Step == "generate" || m.Step == "prefill" { - nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count) - tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9 - nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken) - tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec) - } else { - nsPerTokenStr = "-" - tokensPerSecStr = "-" - } - - fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n", - m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr) - } default: fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format) } } -func BenchmarkChat(fOpt flagOptions) error { +func BenchmarkModel(fOpt flagOptions) error { models := strings.Split(*fOpt.models, ",") - // todo - add multi-image support var imgData api.ImageData var err error if *fOpt.imageFile != "" { @@ -158,71 +256,124 @@ func BenchmarkChat(fOpt flagOptions) error { out = f } + outputFormatHeader(out, *fOpt.format, *fOpt.verbose) + + // Log prompt-tokens info in debug mode + if *fOpt.debug && *fOpt.promptTokens > 0 { + prompt := generatePromptForTokenCount(*fOpt.promptTokens, 0) + wordCount := len(strings.Fields(prompt)) + fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (%d words, varied per epoch)\n", *fOpt.promptTokens, wordCount) + } + for _, model := range models { - for range *fOpt.epochs { - options := make(map[string]interface{}) - if *fOpt.maxTokens > 0 { - options["num_predict"] = *fOpt.maxTokens - } - options["temperature"] = *fOpt.temperature - if fOpt.seed != nil && *fOpt.seed > 0 { - options["seed"] = *fOpt.seed - } - - var keepAliveDuration *api.Duration - if *fOpt.keepAlive > 0 { - duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))} - keepAliveDuration = &duration - } - - req := &api.ChatRequest{ - Model: model, - Messages: []api.Message{ - { - Role: "user", - Content: *fOpt.prompt, - }, - }, - Options: options, - KeepAlive: keepAliveDuration, - } - - if imgData != nil { - req.Messages[0].Images = []api.ImageData{imgData} - } - - var responseMetrics *api.Metrics + // Fetch model info + infoCtx, infoCancel := context.WithTimeout(context.Background(), 10*time.Second) + info := fetchModelInfo(infoCtx, client, model) + infoCancel() + // Warmup phase (uses negative epoch numbers to avoid colliding with timed epochs) + for i := range *fOpt.warmup { + req := buildGenerateRequest(model, fOpt, imgData, -(i + 1)) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second) - defer cancel() - err = client.Chat(ctx, req, func(resp api.ChatResponse) error { - if *fOpt.debug { - fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content)) - } - - if resp.Done { - responseMetrics = &resp.Metrics - } + err = client.Generate(ctx, req, func(resp api.GenerateResponse) error { return nil }) - - if *fOpt.debug { - fmt.Fprintln(os.Stderr) - } + cancel() if err != nil { - if ctx.Err() == context.DeadlineExceeded { - fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1) - continue + fmt.Fprintf(os.Stderr, "WARNING: Warmup %d/%d for %s failed: %v\n", i+1, *fOpt.warmup, model, err) + } else if *fOpt.debug { + fmt.Fprintf(os.Stderr, "Warmup %d/%d for %s complete\n", i+1, *fOpt.warmup, model) + } + } + + // Fetch memory usage once after warmup (model is loaded and stable) + memCtx, memCancel := context.WithTimeout(context.Background(), 5*time.Second) + info.SizeBytes, info.VRAMBytes = fetchMemoryUsage(memCtx, client, model) + memCancel() + + outputModelInfo(out, *fOpt.format, info) + + // Timed epoch loop + shortCount := 0 + for epoch := range *fOpt.epochs { + var responseMetrics *api.Metrics + var ttft time.Duration + short := false + + // Retry loop: if the model hits a stop token before max-tokens, + // retry with a different prompt (up to maxRetries times). + const maxRetries = 3 + for attempt := range maxRetries + 1 { + responseMetrics = nil + ttft = 0 + var ttftOnce sync.Once + + req := buildGenerateRequest(model, fOpt, imgData, epoch+attempt*1000) + requestStart := time.Now() + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second) + + err = client.Generate(ctx, req, func(resp api.GenerateResponse) error { + if *fOpt.debug { + fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Thinking, resp.Response)) + } + + // Capture TTFT on first content + ttftOnce.Do(func() { + if resp.Response != "" || resp.Thinking != "" { + ttft = time.Since(requestStart) + } + }) + + if resp.Done { + responseMetrics = &resp.Metrics + } + return nil + }) + cancel() + + if *fOpt.debug { + fmt.Fprintln(os.Stderr) } - fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err) + + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + fmt.Fprintf(os.Stderr, "ERROR: Request timed out with model '%s' after %vs\n", model, *fOpt.timeout) + } else { + fmt.Fprintf(os.Stderr, "ERROR: Couldn't generate with model '%s': %v\n", model, err) + } + break + } + + if responseMetrics == nil { + fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model) + break + } + + // Check if the response was shorter than requested + short = *fOpt.maxTokens > 0 && responseMetrics.EvalCount < *fOpt.maxTokens + if !short || attempt == maxRetries { + break + } + + if *fOpt.debug { + fmt.Fprintf(os.Stderr, "Short response (%d/%d tokens), retrying with different prompt (attempt %d/%d)\n", + responseMetrics.EvalCount, *fOpt.maxTokens, attempt+1, maxRetries) + } + } + + if err != nil || responseMetrics == nil { continue } - if responseMetrics == nil { - fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model) - continue + if short { + shortCount++ + if *fOpt.debug { + fmt.Fprintf(os.Stderr, "WARNING: Short response (%d/%d tokens) after %d retries for epoch %d\n", + responseMetrics.EvalCount, *fOpt.maxTokens, maxRetries, epoch+1) + } } metrics := []Metrics{ @@ -238,6 +389,12 @@ func BenchmarkChat(fOpt flagOptions) error { Count: responseMetrics.EvalCount, Duration: responseMetrics.EvalDuration, }, + { + Model: model, + Step: "ttft", + Count: 1, + Duration: ttft, + }, { Model: model, Step: "load", @@ -254,15 +411,42 @@ func BenchmarkChat(fOpt flagOptions) error { OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose) + if *fOpt.debug && *fOpt.promptTokens > 0 { + fmt.Fprintf(os.Stderr, "Generated prompt targeting ~%d tokens (actual: %d)\n", + *fOpt.promptTokens, responseMetrics.PromptEvalCount) + } + if *fOpt.keepAlive > 0 { time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond) } } + + if shortCount > 0 { + fmt.Fprintf(os.Stderr, "WARNING: %d/%d epochs for '%s' had short responses (<%d tokens). Generation metrics may be unreliable.\n", + shortCount, *fOpt.epochs, model, *fOpt.maxTokens) + } + + // Unload model before moving to the next one + unloadModel(client, model, *fOpt.timeout) } return nil } +func unloadModel(client *api.Client, model string, timeout int) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + zero := api.Duration{Duration: 0} + req := &api.GenerateRequest{ + Model: model, + KeepAlive: &zero, + } + _ = client.Generate(ctx, req, func(resp api.GenerateResponse) error { + return nil + }) +} + func readImage(filePath string) (api.ImageData, error) { file, err := os.Open(filePath) if err != nil { @@ -280,19 +464,21 @@ func readImage(filePath string) (api.ImageData, error) { func main() { fOpt := flagOptions{ - models: flag.String("model", "", "Model to benchmark"), - epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"), - maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"), - temperature: flag.Float64("temperature", 0, "Temperature parameter"), - seed: flag.Int("seed", 0, "Random seed"), - timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"), - prompt: flag.String("p", DefaultPrompt, "Prompt to use"), - imageFile: flag.String("image", "", "Filename for an image to include"), - keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"), - format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"), - outputFile: flag.String("output", "", "Output file for results (stdout if empty)"), - verbose: flag.Bool("v", false, "Show system information"), - debug: flag.Bool("debug", false, "Show debug information"), + models: flag.String("model", "", "Model to benchmark"), + epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"), + maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"), + temperature: flag.Float64("temperature", 0, "Temperature parameter"), + seed: flag.Int("seed", 0, "Random seed"), + timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"), + prompt: flag.String("p", DefaultPrompt, "Prompt to use"), + imageFile: flag.String("image", "", "Filename for an image to include"), + keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"), + format: flag.String("format", "benchstat", "Output format [benchstat|csv]"), + outputFile: flag.String("output", "", "Output file for results (stdout if empty)"), + verbose: flag.Bool("v", false, "Show system information"), + debug: flag.Bool("debug", false, "Show debug information"), + warmup: flag.Int("warmup", 1, "Number of warmup requests before timing"), + promptTokens: flag.Int("prompt-tokens", 0, "Generate prompt targeting ~N tokens (0 = use -p prompt)"), } flag.Usage = func() { @@ -302,11 +488,12 @@ func main() { fmt.Fprintf(os.Stderr, "Options:\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nExamples:\n") - fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n") + fmt.Fprintf(os.Stderr, " bench -model gemma3,llama3 -epochs 6\n") + fmt.Fprintf(os.Stderr, " bench -model gemma3 -epochs 6 -prompt-tokens 512 -format csv\n") } flag.Parse() - if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) { + if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) { fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format) os.Exit(1) } @@ -317,5 +504,5 @@ func main() { return } - BenchmarkChat(fOpt) + BenchmarkModel(fOpt) } diff --git a/cmd/bench/bench_test.go b/cmd/bench/bench_test.go index bcd282d79..3c4c67f13 100644 --- a/cmd/bench/bench_test.go +++ b/cmd/bench/bench_test.go @@ -19,29 +19,33 @@ func createTestFlagOptions() flagOptions { models := "test-model" format := "benchstat" epochs := 1 - maxTokens := 100 + maxTokens := 50 temperature := 0.7 seed := 42 timeout := 30 prompt := "test prompt" imageFile := "" - keepAlive := 5.0 + keepAlive := 0.0 verbose := false debug := false + warmup := 0 + promptTokens := 0 return flagOptions{ - models: &models, - format: &format, - epochs: &epochs, - maxTokens: &maxTokens, - temperature: &temperature, - seed: &seed, - timeout: &timeout, - prompt: &prompt, - imageFile: &imageFile, - keepAlive: &keepAlive, - verbose: &verbose, - debug: &debug, + models: &models, + format: &format, + epochs: &epochs, + maxTokens: &maxTokens, + temperature: &temperature, + seed: &seed, + timeout: &timeout, + prompt: &prompt, + imageFile: &imageFile, + keepAlive: &keepAlive, + verbose: &verbose, + debug: &debug, + warmup: &warmup, + promptTokens: &promptTokens, } } @@ -65,58 +69,85 @@ func captureOutput(f func()) string { return buf.String() } -func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server { +type mockServerOptions struct { + generateResponses []api.GenerateResponse + showResponse *api.ShowResponse + psResponse *api.ProcessResponse +} + +func createMockOllamaServer(t *testing.T, opts mockServerOptions) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/api/chat" { - t.Errorf("Expected path /api/chat, got %s", r.URL.Path) - http.Error(w, "Not found", http.StatusNotFound) - return - } - - if r.Method != "POST" { - t.Errorf("Expected POST method, got %s", r.Method) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - for _, resp := range responses { - jsonData, err := json.Marshal(resp) - if err != nil { - t.Errorf("Failed to marshal response: %v", err) + switch r.URL.Path { + case "/api/generate": + if r.Method != "POST" { + t.Errorf("Expected POST method for /api/generate, got %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } - w.Write(jsonData) - w.Write([]byte("\n")) - if f, ok := w.(http.Flusher); ok { - f.Flush() + + w.WriteHeader(http.StatusOK) + for _, resp := range opts.generateResponses { + jsonData, err := json.Marshal(resp) + if err != nil { + t.Errorf("Failed to marshal response: %v", err) + return + } + w.Write(jsonData) + w.Write([]byte("\n")) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + time.Sleep(10 * time.Millisecond) } - time.Sleep(10 * time.Millisecond) // Simulate some delay + + case "/api/show": + if opts.showResponse != nil { + json.NewEncoder(w).Encode(opts.showResponse) + } else { + json.NewEncoder(w).Encode(api.ShowResponse{ + Details: api.ModelDetails{ + ParameterSize: "4.3B", + QuantizationLevel: "Q4_K_M", + Family: "testfamily", + }, + }) + } + + case "/api/ps": + if opts.psResponse != nil { + json.NewEncoder(w).Encode(opts.psResponse) + } else { + json.NewEncoder(w).Encode(api.ProcessResponse{ + Models: []api.ProcessModelResponse{ + { + Name: "test-model", + Model: "test-model", + Size: 4080218931, // ~3.80 GB total + SizeVRAM: 4080218931, // ~3.80 GB on GPU + }, + }, + }) + } + + default: + http.Error(w, "Not found", http.StatusNotFound) } })) } -func TestBenchmarkChat_Success(t *testing.T) { - fOpt := createTestFlagOptions() - - mockResponses := []api.ChatResponse{ +func defaultGenerateResponses() []api.GenerateResponse { + return []api.GenerateResponse{ { - Model: "test-model", - Message: api.Message{ - Role: "assistant", - Content: "test response part 1", - }, - Done: false, + Model: "test-model", + Response: "test response part 1", + Done: false, }, { - Model: "test-model", - Message: api.Message{ - Role: "assistant", - Content: "test response part 2", - }, - Done: true, + Model: "test-model", + Response: "test response part 2", + Done: true, Metrics: api.Metrics{ PromptEvalCount: 10, PromptEvalDuration: 100 * time.Millisecond, @@ -127,14 +158,20 @@ func TestBenchmarkChat_Success(t *testing.T) { }, }, } +} - server := createMockOllamaServer(t, mockResponses) +func TestBenchmarkModel_Success(t *testing.T) { + fOpt := createTestFlagOptions() + + server := createMockOllamaServer(t, mockServerOptions{ + generateResponses: defaultGenerateResponses(), + }) defer server.Close() t.Setenv("OLLAMA_HOST", server.URL) output := captureOutput(func() { - err := BenchmarkChat(fOpt) + err := BenchmarkModel(fOpt) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -149,9 +186,12 @@ func TestBenchmarkChat_Success(t *testing.T) { if !strings.Contains(output, "ns/token") { t.Errorf("Expected output to contain ns/token metric, got: %s", output) } + if !strings.Contains(output, "BenchmarkModel/name=test-model/step=ttft") { + t.Errorf("Expected output to contain ttft metrics, got: %s", output) + } } -func TestBenchmarkChat_ServerError(t *testing.T) { +func TestBenchmarkModel_ServerError(t *testing.T) { fOpt := createTestFlagOptions() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -162,34 +202,36 @@ func TestBenchmarkChat_ServerError(t *testing.T) { t.Setenv("OLLAMA_HOST", server.URL) output := captureOutput(func() { - err := BenchmarkChat(fOpt) + err := BenchmarkModel(fOpt) if err != nil { t.Errorf("Expected error to be handled internally, got returned error: %v", err) } }) - if !strings.Contains(output, "ERROR: Couldn't chat with model") { - t.Errorf("Expected error message about chat failure, got: %s", output) + if !strings.Contains(output, "ERROR: Couldn't generate with model") { + t.Errorf("Expected error message about generate failure, got: %s", output) } } -func TestBenchmarkChat_Timeout(t *testing.T) { +func TestBenchmarkModel_Timeout(t *testing.T) { fOpt := createTestFlagOptions() - shortTimeout := 1 // Very short timeout + shortTimeout := 1 fOpt.timeout = &shortTimeout server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/show" || r.URL.Path == "/api/ps" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{}) + return + } // Simulate a long delay that will cause timeout time.Sleep(2 * time.Second) w.Header().Set("Content-Type", "application/json") - response := api.ChatResponse{ - Model: "test-model", - Message: api.Message{ - Role: "assistant", - Content: "test response", - }, - Done: true, + response := api.GenerateResponse{ + Model: "test-model", + Response: "test response", + Done: true, Metrics: api.Metrics{ PromptEvalCount: 10, PromptEvalDuration: 100 * time.Millisecond, @@ -207,38 +249,35 @@ func TestBenchmarkChat_Timeout(t *testing.T) { t.Setenv("OLLAMA_HOST", server.URL) output := captureOutput(func() { - err := BenchmarkChat(fOpt) + err := BenchmarkModel(fOpt) if err != nil { t.Errorf("Expected timeout to be handled internally, got returned error: %v", err) } }) - if !strings.Contains(output, "ERROR: Chat request timed out") { + if !strings.Contains(output, "ERROR: Request timed out") { t.Errorf("Expected timeout error message, got: %s", output) } } -func TestBenchmarkChat_NoMetrics(t *testing.T) { +func TestBenchmarkModel_NoMetrics(t *testing.T) { fOpt := createTestFlagOptions() - mockResponses := []api.ChatResponse{ - { - Model: "test-model", - Message: api.Message{ - Role: "assistant", - Content: "test response", + server := createMockOllamaServer(t, mockServerOptions{ + generateResponses: []api.GenerateResponse{ + { + Model: "test-model", + Response: "test response", + Done: false, // Never sends Done=true }, - Done: false, // Never sends Done=true }, - } - - server := createMockOllamaServer(t, mockResponses) + }) defer server.Close() t.Setenv("OLLAMA_HOST", server.URL) output := captureOutput(func() { - err := BenchmarkChat(fOpt) + err := BenchmarkModel(fOpt) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -249,56 +288,74 @@ func TestBenchmarkChat_NoMetrics(t *testing.T) { } } -func TestBenchmarkChat_MultipleModels(t *testing.T) { +func TestBenchmarkModel_MultipleModels(t *testing.T) { fOpt := createTestFlagOptions() models := "model1,model2" epochs := 2 fOpt.models = &models fOpt.epochs = &epochs - callCount := 0 + generateCallCount := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - w.Header().Set("Content-Type", "application/json") - var req api.ChatRequest - body, _ := io.ReadAll(r.Body) - json.Unmarshal(body, &req) + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) - response := api.ChatResponse{ - Model: req.Model, - Message: api.Message{ - Role: "assistant", - Content: "test response for " + req.Model, - }, - Done: true, - Metrics: api.Metrics{ - PromptEvalCount: 10, - PromptEvalDuration: 100 * time.Millisecond, - EvalCount: 50, - EvalDuration: 500 * time.Millisecond, - TotalDuration: 600 * time.Millisecond, - LoadDuration: 50 * time.Millisecond, - }, + // Don't count unload requests (empty prompt with KeepAlive) + if req.Prompt != "" { + generateCallCount++ + } + + response := api.GenerateResponse{ + Model: req.Model, + Response: "test response for " + req.Model, + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{ + Details: api.ModelDetails{ + ParameterSize: "7B", + QuantizationLevel: "Q4_0", + Family: "llama", + }, + }) + + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + + default: + http.Error(w, "Not found", http.StatusNotFound) } - jsonData, _ := json.Marshal(response) - w.Write(jsonData) })) defer server.Close() t.Setenv("OLLAMA_HOST", server.URL) output := captureOutput(func() { - err := BenchmarkChat(fOpt) + err := BenchmarkModel(fOpt) if err != nil { t.Errorf("Expected no error, got %v", err) } }) - // Should be called 4 times (2 models × 2 epochs) - if callCount != 4 { - t.Errorf("Expected 4 API calls, got %d", callCount) + // Should be called 4 times (2 models x 2 epochs), not counting unload requests + if generateCallCount != 4 { + t.Errorf("Expected 4 API calls, got %d", generateCallCount) } if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") { @@ -306,7 +363,7 @@ func TestBenchmarkChat_MultipleModels(t *testing.T) { } } -func TestBenchmarkChat_WithImage(t *testing.T) { +func TestBenchmarkModel_WithImage(t *testing.T) { fOpt := createTestFlagOptions() tmpfile, err := os.CreateTemp(t.TempDir(), "testimage") @@ -325,41 +382,51 @@ func TestBenchmarkChat_WithImage(t *testing.T) { fOpt.imageFile = &tmpfileName server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify the request contains image data - var req api.ChatRequest - body, _ := io.ReadAll(r.Body) - json.Unmarshal(body, &req) - - if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 { - t.Error("Expected request to contain images") - } - w.Header().Set("Content-Type", "application/json") - response := api.ChatResponse{ - Model: "test-model", - Message: api.Message{ - Role: "assistant", - Content: "test response with image", - }, - Done: true, - Metrics: api.Metrics{ - PromptEvalCount: 10, - PromptEvalDuration: 100 * time.Millisecond, - EvalCount: 50, - EvalDuration: 500 * time.Millisecond, - TotalDuration: 600 * time.Millisecond, - LoadDuration: 50 * time.Millisecond, - }, + + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + // Only check for images on real requests, not unload requests + if req.Prompt != "" && len(req.Images) == 0 { + t.Error("Expected request to contain images") + } + + response := api.GenerateResponse{ + Model: "test-model", + Response: "test response with image", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + + default: + http.Error(w, "Not found", http.StatusNotFound) } - jsonData, _ := json.Marshal(response) - w.Write(jsonData) })) defer server.Close() t.Setenv("OLLAMA_HOST", server.URL) output := captureOutput(func() { - err := BenchmarkChat(fOpt) + err := BenchmarkModel(fOpt) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -370,13 +437,13 @@ func TestBenchmarkChat_WithImage(t *testing.T) { } } -func TestBenchmarkChat_ImageError(t *testing.T) { +func TestBenchmarkModel_ImageError(t *testing.T) { randFileName := func() string { const charset = "abcdefghijklmnopqrstuvwxyz0123456789" const length = 8 result := make([]byte, length) - rand.Read(result) // Fill with random bytes + rand.Read(result) for i := range result { result[i] = charset[result[i]%byte(len(charset))] @@ -390,7 +457,7 @@ func TestBenchmarkChat_ImageError(t *testing.T) { fOpt.imageFile = &imageFile output := captureOutput(func() { - err := BenchmarkChat(fOpt) + err := BenchmarkModel(fOpt) if err == nil { t.Error("Expected error from image reading, got nil") } @@ -461,3 +528,883 @@ func TestOptionsMapCreation(t *testing.T) { t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"]) } } + +// --- Feature tests --- + +func TestBenchmarkModel_Warmup(t *testing.T) { + fOpt := createTestFlagOptions() + warmup := 2 + fOpt.warmup = &warmup + debug := true + fOpt.debug = &debug + + generateCallCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + // Don't count unload requests + if req.Prompt != "" { + generateCallCount++ + } + + response := api.GenerateResponse{ + Model: "test-model", + Response: "response", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + } + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + // 2 warmup + 1 epoch = 3 total generate calls (not counting unload) + if generateCallCount != 3 { + t.Errorf("Expected 3 generate calls (2 warmup + 1 epoch), got %d", generateCallCount) + } + + if !strings.Contains(output, "Warmup 1/2 for test-model complete") { + t.Errorf("Expected warmup debug output, got: %s", output) + } + if !strings.Contains(output, "Warmup 2/2 for test-model complete") { + t.Errorf("Expected warmup debug output for 2/2, got: %s", output) + } +} + +func TestBenchmarkModel_TTFT(t *testing.T) { + fOpt := createTestFlagOptions() + + server := createMockOllamaServer(t, mockServerOptions{ + generateResponses: defaultGenerateResponses(), + }) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if !strings.Contains(output, "step=ttft") { + t.Errorf("Expected TTFT metric in output, got: %s", output) + } +} + +func TestBenchmarkModel_ModelInfo(t *testing.T) { + fOpt := createTestFlagOptions() + + server := createMockOllamaServer(t, mockServerOptions{ + generateResponses: defaultGenerateResponses(), + showResponse: &api.ShowResponse{ + Details: api.ModelDetails{ + ParameterSize: "4.3B", + QuantizationLevel: "Q4_K_M", + Family: "gemma3", + }, + }, + }) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if !strings.Contains(output, "Params: 4.3B") { + t.Errorf("Expected model info with parameter size, got: %s", output) + } + if !strings.Contains(output, "Quant: Q4_K_M") { + t.Errorf("Expected model info with quant level, got: %s", output) + } + if !strings.Contains(output, "Family: gemma3") { + t.Errorf("Expected model info with family, got: %s", output) + } +} + +func TestBenchmarkModel_VRAM(t *testing.T) { + fOpt := createTestFlagOptions() + + server := createMockOllamaServer(t, mockServerOptions{ + generateResponses: defaultGenerateResponses(), + psResponse: &api.ProcessResponse{ + Models: []api.ProcessModelResponse{ + { + Name: "test-model", + Model: "test-model", + Size: 4080218931, + SizeVRAM: 4080218931, + }, + }, + }, + }) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + // VRAM should appear in model info header + if !strings.Contains(output, "VRAM: 4080218931") { + t.Errorf("Expected VRAM in model info header, got: %s", output) + } +} + +func TestBenchmarkModel_PromptTokens(t *testing.T) { + fOpt := createTestFlagOptions() + promptTokens := 100 + fOpt.promptTokens = &promptTokens + + var receivedPrompt string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + // Only capture prompt from real requests, not unload requests + if req.Prompt != "" { + receivedPrompt = req.Prompt + } + + response := api.GenerateResponse{ + Model: "test-model", + Response: "response", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 85, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + } + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + // With ~100 tokens / 1.3 = ~76 words + wordCount := len(strings.Fields(receivedPrompt)) + if wordCount < 50 || wordCount > 120 { + t.Errorf("Expected generated prompt with ~76 words, got %d words", wordCount) + } + + // Prompt should not be the default prompt + if receivedPrompt == DefaultPrompt { + t.Error("Expected generated prompt, but got default prompt") + } +} + +func TestBenchmarkModel_RawMode(t *testing.T) { + fOpt := createTestFlagOptions() + + receivedRaw := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + // Only check raw on real requests, not unload requests + if req.Prompt != "" { + receivedRaw = req.Raw + } + + response := api.GenerateResponse{ + Model: "test-model", + Response: "response", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + } + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if !receivedRaw { + t.Error("Expected raw mode to be enabled in generate request") + } +} + +func TestBenchmarkModel_PromptVariesPerEpoch(t *testing.T) { + fOpt := createTestFlagOptions() + epochs := 3 + fOpt.epochs = &epochs + + var receivedPrompts []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + // Only track prompts from real requests, not unload requests + if req.Prompt != "" { + receivedPrompts = append(receivedPrompts, req.Prompt) + } + + response := api.GenerateResponse{ + Model: "test-model", + Response: "response", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 50, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + } + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if len(receivedPrompts) != 3 { + t.Fatalf("Expected 3 prompts, got %d", len(receivedPrompts)) + } + + // Each epoch should have a different prompt to defeat KV cache + for i := range receivedPrompts { + for j := i + 1; j < len(receivedPrompts); j++ { + if receivedPrompts[i] == receivedPrompts[j] { + t.Errorf("Expected different prompts for epoch %d and %d, both got: %s", i, j, receivedPrompts[i]) + } + } + } +} + +func TestBenchmarkModel_ShortResponseRetry(t *testing.T) { + fOpt := createTestFlagOptions() + maxTokens := 100 + fOpt.maxTokens = &maxTokens + + generateCallCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + if req.Prompt == "" { + // Unload request + response := api.GenerateResponse{Done: true} + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + return + } + + generateCallCount++ + + // First 3 attempts return short responses, 4th returns full + evalCount := 20 + if generateCallCount == 4 { + evalCount = 100 + } + + response := api.GenerateResponse{ + Model: "test-model", + Response: "response", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: evalCount, + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + } + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + // 1 epoch: 3 short retries + 1 successful = 4 generate calls + if generateCallCount != 4 { + t.Errorf("Expected 4 generate calls (3 retries + 1 success), got %d", generateCallCount) + } +} + +func TestBenchmarkModel_ShortResponseWarning(t *testing.T) { + fOpt := createTestFlagOptions() + maxTokens := 100 + fOpt.maxTokens = &maxTokens + + // Always return short responses to trigger the warning + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/generate": + response := api.GenerateResponse{ + Model: "test-model", + Response: "response", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 20, // Always short + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + } + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + // Should still produce metrics (uses best attempt) + if !strings.Contains(output, "BenchmarkModel/name=test-model") { + t.Errorf("Expected benchmark output even with short responses, got: %s", output) + } + + // Should warn about short responses + if !strings.Contains(output, "WARNING") || !strings.Contains(output, "short responses") { + t.Errorf("Expected warning about short responses, got: %s", output) + } +} + +func TestBenchmarkModel_NoRetryWhenMaxTokensZero(t *testing.T) { + fOpt := createTestFlagOptions() + maxTokens := 0 + fOpt.maxTokens = &maxTokens + + generateCallCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/api/generate": + var req api.GenerateRequest + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + if req.Prompt != "" { + generateCallCount++ + } + + response := api.GenerateResponse{ + Model: "test-model", + Response: "response", + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 10, + PromptEvalDuration: 100 * time.Millisecond, + EvalCount: 5, // Very short, but maxTokens=0 so no retry + EvalDuration: 500 * time.Millisecond, + TotalDuration: 600 * time.Millisecond, + LoadDuration: 50 * time.Millisecond, + }, + } + jsonData, _ := json.Marshal(response) + w.Write(jsonData) + + case "/api/show": + json.NewEncoder(w).Encode(api.ShowResponse{}) + case "/api/ps": + json.NewEncoder(w).Encode(api.ProcessResponse{}) + } + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + // With maxTokens=0, no retries should happen: exactly 1 call for 1 epoch + if generateCallCount != 1 { + t.Errorf("Expected 1 generate call (no retries when maxTokens=0), got %d", generateCallCount) + } +} + +func TestBenchmarkModel_CSVFormat(t *testing.T) { + fOpt := createTestFlagOptions() + format := "csv" + fOpt.format = &format + + server := createMockOllamaServer(t, mockServerOptions{ + generateResponses: defaultGenerateResponses(), + }) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + output := captureOutput(func() { + err := BenchmarkModel(fOpt) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + + if !strings.Contains(output, "NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC") { + t.Errorf("Expected CSV header, got: %s", output) + } + if !strings.Contains(output, "test-model,prefill,") { + t.Errorf("Expected CSV prefill row, got: %s", output) + } + if !strings.Contains(output, "test-model,ttft,") { + t.Errorf("Expected CSV ttft row, got: %s", output) + } +} + +// --- Unit tests for helper functions --- + +func TestGeneratePromptForTokenCount(t *testing.T) { + prompt := generatePromptForTokenCount(100, 0) + wordCount := len(strings.Fields(prompt)) + + // 100 / 1.3 ≈ 76 words + if wordCount < 50 || wordCount > 100 { + t.Errorf("Expected ~76 words, got %d", wordCount) + } +} + +func TestGeneratePromptForTokenCount_Small(t *testing.T) { + prompt := generatePromptForTokenCount(1, 0) + wordCount := len(strings.Fields(prompt)) + if wordCount != 1 { + t.Errorf("Expected 1 word, got %d", wordCount) + } +} + +func TestGeneratePromptForTokenCount_VariesByEpoch(t *testing.T) { + p0 := generatePromptForTokenCount(100, 0) + p1 := generatePromptForTokenCount(100, 1) + p2 := generatePromptForTokenCount(100, 2) + + if p0 == p1 || p1 == p2 || p0 == p2 { + t.Error("Expected different prompts for different epochs") + } + + // All should have same word count + w0 := len(strings.Fields(p0)) + w1 := len(strings.Fields(p1)) + w2 := len(strings.Fields(p2)) + if w0 != w1 || w1 != w2 { + t.Errorf("Expected same word count across epochs, got %d, %d, %d", w0, w1, w2) + } +} + +func TestBuildGenerateRequest(t *testing.T) { + fOpt := createTestFlagOptions() + req := buildGenerateRequest("test-model", fOpt, nil, 0) + + if req.Model != "test-model" { + t.Errorf("Expected model 'test-model', got '%s'", req.Model) + } + if !req.Raw { + t.Error("Expected raw mode to be true") + } + if !strings.Contains(req.Prompt, "test prompt") { + t.Errorf("Expected prompt to contain 'test prompt', got '%s'", req.Prompt) + } +} + +func TestBuildGenerateRequest_WithPromptTokens(t *testing.T) { + fOpt := createTestFlagOptions() + promptTokens := 200 + fOpt.promptTokens = &promptTokens + + req := buildGenerateRequest("test-model", fOpt, nil, 0) + // Should not contain the original prompt + if strings.Contains(req.Prompt, "test prompt") { + t.Error("Expected generated prompt when promptTokens is set") + } + + wordCount := len(strings.Fields(req.Prompt)) + if wordCount < 100 || wordCount > 200 { + t.Errorf("Expected ~153 words for 200 tokens, got %d", wordCount) + } +} + +func TestBuildGenerateRequest_WithImage(t *testing.T) { + fOpt := createTestFlagOptions() + imgData := api.ImageData([]byte("fake image")) + + req := buildGenerateRequest("test-model", fOpt, imgData, 0) + if len(req.Images) != 1 { + t.Errorf("Expected 1 image, got %d", len(req.Images)) + } +} + +func TestBuildGenerateRequest_VariesByEpoch(t *testing.T) { + fOpt := createTestFlagOptions() + + req0 := buildGenerateRequest("test-model", fOpt, nil, 0) + req1 := buildGenerateRequest("test-model", fOpt, nil, 1) + + if req0.Prompt == req1.Prompt { + t.Error("Expected different prompts for different epochs") + } +} + +func TestOutputMetrics_Benchstat(t *testing.T) { + var buf bytes.Buffer + metrics := []Metrics{ + {Model: "m1", Step: "prefill", Count: 10, Duration: 100 * time.Millisecond}, + {Model: "m1", Step: "generate", Count: 50, Duration: 500 * time.Millisecond}, + {Model: "m1", Step: "ttft", Count: 1, Duration: 50 * time.Millisecond}, + {Model: "m1", Step: "load", Count: 1, Duration: 50 * time.Millisecond}, + {Model: "m1", Step: "total", Count: 1, Duration: 600 * time.Millisecond}, + } + + OutputMetrics(&buf, "benchstat", metrics, false) + output := buf.String() + + if !strings.Contains(output, "step=prefill") { + t.Errorf("Expected prefill metric, got: %s", output) + } + if !strings.Contains(output, "step=generate") { + t.Errorf("Expected generate metric, got: %s", output) + } + if !strings.Contains(output, "step=ttft") { + t.Errorf("Expected ttft metric, got: %s", output) + } + if !strings.Contains(output, "step=load") { + t.Errorf("Expected load metric, got: %s", output) + } + // Verify dual value/unit pairs for throughput lines (ns/token + token/sec) + if !strings.Contains(output, "token/sec") { + t.Errorf("Expected token/sec metric for throughput lines, got: %s", output) + } + for _, line := range strings.Split(strings.TrimSpace(output), "\n") { + if !strings.HasPrefix(line, "Benchmark") { + continue + } + if strings.Contains(line, "ns/token") && !strings.Contains(line, "token/sec") { + t.Errorf("Expected both ns/token and token/sec on throughput line, got: %s", line) + } + } +} + +func TestOutputMetrics_BenchstatFormat(t *testing.T) { + var buf bytes.Buffer + metrics := []Metrics{ + {Model: "m1", Step: "prefill", Count: 10, Duration: 100 * time.Millisecond}, + {Model: "m1", Step: "load", Count: 1, Duration: 50 * time.Millisecond}, + } + + OutputMetrics(&buf, "benchstat", metrics, false) + output := buf.String() + + // Load and total should use ns/op (standard Go benchmark unit) + if !strings.Contains(output, "ns/op") { + t.Errorf("Expected ns/op unit for load/total, got: %s", output) + } + // Prefill/generate should use ns/token + if !strings.Contains(output, "ns/token") { + t.Errorf("Expected ns/token unit for prefill, got: %s", output) + } +} + +func TestOutputModelInfo(t *testing.T) { + info := ModelInfo{ + Name: "gemma3", + ParameterSize: "4.3B", + QuantizationLevel: "Q4_K_M", + Family: "gemma3", + SizeBytes: 4080218931, + VRAMBytes: 4080218931, // Fully on GPU + } + + t.Run("benchstat", func(t *testing.T) { + var buf bytes.Buffer + outputModelInfo(&buf, "benchstat", info) + output := buf.String() + if !strings.Contains(output, "Size: 4080218931") { + t.Errorf("Expected benchstat comment with Size, got: %s", output) + } + if !strings.Contains(output, "VRAM: 4080218931") { + t.Errorf("Expected benchstat comment with VRAM, got: %s", output) + } + }) + + t.Run("csv", func(t *testing.T) { + var buf bytes.Buffer + outputModelInfo(&buf, "csv", info) + output := buf.String() + if !strings.Contains(output, "Size: 4080218931") { + t.Errorf("Expected csv comment with Size, got: %s", output) + } + if !strings.Contains(output, "VRAM: 4080218931") { + t.Errorf("Expected csv comment with VRAM, got: %s", output) + } + }) + + t.Run("no_memory_info", func(t *testing.T) { + infoNoMem := ModelInfo{ + Name: "gemma3", + ParameterSize: "4.3B", + QuantizationLevel: "Q4_K_M", + Family: "gemma3", + } + var buf bytes.Buffer + outputModelInfo(&buf, "benchstat", infoNoMem) + output := buf.String() + if strings.Contains(output, "VRAM") { + t.Errorf("Expected no VRAM in header when SizeBytes is 0, got: %s", output) + } + }) +} + +func TestOutputModelInfo_Unknown(t *testing.T) { + info := ModelInfo{Name: "test"} + + var buf bytes.Buffer + outputModelInfo(&buf, "benchstat", info) + output := buf.String() + + if !strings.Contains(output, "unknown") { + t.Errorf("Expected 'unknown' for missing fields, got: %s", output) + } +} + +func TestFetchMemoryUsage_PrefixMatch(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.ProcessResponse{ + Models: []api.ProcessModelResponse{ + { + Name: "gemma3:latest", + Model: "gemma3:latest", + Size: 20000000, + SizeVRAM: 12345678, + }, + }, + }) + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + client, err := api.ClientFromEnvironment() + if err != nil { + t.Fatal(err) + } + + size, vram := fetchMemoryUsage(t.Context(), client, "gemma3") + if vram != 12345678 { + t.Errorf("Expected VRAM 12345678 via prefix match, got %d", vram) + } + if size != 20000000 { + t.Errorf("Expected Size 20000000 via prefix match, got %d", size) + } +} + +func TestFetchMemoryUsage_CPUSpill(t *testing.T) { + totalSize := int64(8000000000) // 8 GB total + vramSize := int64(5000000000) // 5 GB on GPU, 3 GB spilled to CPU + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(api.ProcessResponse{ + Models: []api.ProcessModelResponse{ + { + Name: "big-model", + Model: "big-model", + Size: totalSize, + SizeVRAM: vramSize, + }, + }, + }) + })) + defer server.Close() + + t.Setenv("OLLAMA_HOST", server.URL) + + client, err := api.ClientFromEnvironment() + if err != nil { + t.Fatal(err) + } + + size, vram := fetchMemoryUsage(t.Context(), client, "big-model") + if size != totalSize { + t.Errorf("Expected total size %d, got %d", totalSize, size) + } + if vram != vramSize { + t.Errorf("Expected VRAM size %d, got %d", vramSize, vram) + } + cpuSize := size - vram + if cpuSize != 3000000000 { + t.Errorf("Expected CPU spill of 3000000000, got %d", cpuSize) + } +} + +func TestOutputFormatHeader(t *testing.T) { + t.Run("benchstat_verbose", func(t *testing.T) { + var buf bytes.Buffer + outputFormatHeader(&buf, "benchstat", true) + output := buf.String() + if !strings.Contains(output, "goos:") { + t.Errorf("Expected goos in verbose benchstat header, got: %s", output) + } + if !strings.Contains(output, "goarch:") { + t.Errorf("Expected goarch in verbose benchstat header, got: %s", output) + } + }) + + t.Run("benchstat_not_verbose", func(t *testing.T) { + var buf bytes.Buffer + outputFormatHeader(&buf, "benchstat", false) + output := buf.String() + if output != "" { + t.Errorf("Expected empty output for non-verbose benchstat, got: %s", output) + } + }) + + t.Run("csv", func(t *testing.T) { + var buf bytes.Buffer + outputFormatHeader(&buf, "csv", false) + output := buf.String() + if !strings.Contains(output, "NAME,STEP,COUNT") { + t.Errorf("Expected CSV header, got: %s", output) + } + }) +}