// @cmd/decontam_humaneval/main.go
package main

import (
	"bufio"
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"io"
	"math"
	"net/http"
	"os"
	"os/exec"
	"path/filepath"
	"reflect"
	"runtime"
	"sort"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/reposurvey/pipeline/models"
	"github.com/reposurvey/pipeline/parquet"
	"github.com/reposurvey/pipeline/tokenizer"
	"github.com/xitongsys/parquet-go-source/local"
	"github.com/xitongsys/parquet-go/reader"
)

const (
	ngramSize = 13

	// base64-ish odd constant; the hash works mod 2^64 via uint64 overflow.
	rollingBase uint64 = 1315423911
)

type humanEvalRecord struct {
	TaskID            string `json:"task_id"`
	Prompt            string `json:"prompt"`
	CanonicalSolution string `json:"canonical_solution"`
	Test              string `json:"test"`
	EntryPoint        string `json:"entry_point"`
}

// humanEvalParquetRow matches the OpenAI HumanEval parquet schema.
// Arrow writes these string columns as OPTIONAL (max_definition_level=1),
// so we use *string to match the physical definition levels.
// Note: this is used only for reading.
type humanEvalParquetRow struct {
	TaskID            *string `parquet:"name=task_id, type=BYTE_ARRAY, convertedtype=UTF8, repetitiontype=OPTIONAL"`
	Prompt            *string `parquet:"name=prompt, type=BYTE_ARRAY, convertedtype=UTF8, repetitiontype=OPTIONAL"`
	CanonicalSolution *string `parquet:"name=canonical_solution, type=BYTE_ARRAY, convertedtype=UTF8, repetitiontype=OPTIONAL"`
	Test              *string `parquet:"name=test, type=BYTE_ARRAY, convertedtype=UTF8, repetitiontype=OPTIONAL"`
	EntryPoint        *string `parquet:"name=entry_point, type=BYTE_ARRAY, convertedtype=UTF8, repetitiontype=OPTIONAL"`
}

// windowHash computes a rolling polynomial hash for exactly 13 tokens.
func windowHash(tokens *[ngramSize]int32) uint64 {
	var h uint64
	for i := 0; i < ngramSize; i++ {
		// token ids are non-negative in practice; uint32 cast is stable even if negative.
		x := uint64(uint32(tokens[i])) + 1
		h = h*rollingBase + x
	}
	return h
}

func computePowBase() uint64 {
	var pow uint64 = 1
	for i := 0; i < ngramSize-1; i++ {
		pow *= rollingBase
	}
	return pow
}

func decodeTokens(decodeBin, model string, toks []int32) (string, error) {
	// tokenizer_decode expects --token-ids as a single string.
	var b strings.Builder
	for i, t := range toks {
		if i > 0 {
			b.WriteByte(' ')
		}
		b.WriteString(fmt.Sprintf("%d", t))
	}
	args := []string{"--model", model, "--token-ids", b.String()}
	cmd := exec.Command(decodeBin, args...)
	cmd.Stderr = os.Stderr
	out, err := cmd.Output()
	if err != nil {
		return "", err
	}
	return string(out), nil
}

func readAllFromURL(ctx context.Context, url string) ([]byte, error) {
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
	if err != nil {
		return nil, fmt.Errorf("create request: %w", err)
	}
	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return nil, fmt.Errorf("http get: %w", err)
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		b, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
		return nil, fmt.Errorf("http %d: %s", resp.StatusCode, strings.TrimSpace(string(b)))
	}
	return io.ReadAll(resp.Body)
}

func readHumanEval(ctx context.Context, path string, url string) ([]humanEvalRecord, error) {
	if url != "" {
		b, err := readAllFromURL(ctx, url)
		if err != nil {
			return nil, err
		}
		return parseHumanEvalJSONL(bytes.NewReader(b))
	}

	if path == "" {
		return nil, fmt.Errorf("humaneval path is empty")
	}

	// If it's a parquet file (HuggingFace dataset), read that file.
	if strings.HasSuffix(strings.ToLower(path), ".parquet") {
		return readHumanEvalParquetFile(path)
	}

	// Otherwise treat as jsonl.
	f, err := os.Open(path)
	if err != nil {
		return nil, fmt.Errorf("open humaneval: %w", err)
	}
	defer f.Close()
	return parseHumanEvalJSONL(f)
}

func parseHumanEvalJSONL(r io.Reader) ([]humanEvalRecord, error) {
	scanner := bufio.NewScanner(r)
	// HumanEval lines can be large (tests). Raise scanner buffer.
	buf := make([]byte, 1024*1024)
	scanner.Buffer(buf, 64*1024*1024)

	var out []humanEvalRecord
	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line == "" {
			continue
		}
		var rec humanEvalRecord
		if err := json.Unmarshal([]byte(line), &rec); err != nil {
			return nil, fmt.Errorf("parse humaneval jsonl line: %w", err)
		}
		if rec.TaskID == "" {
			return nil, fmt.Errorf("humaneval record missing task_id")
		}
		out = append(out, rec)
	}
	if err := scanner.Err(); err != nil {
		return nil, fmt.Errorf("scan humaneval: %w", err)
	}
	return out, nil
}

func readHumanEvalParquetFile(path string) ([]humanEvalRecord, error) {
	fr, err := local.NewLocalFileReader(path)
	if err != nil {
		return nil, fmt.Errorf("open parquet: %w", err)
	}
	defer fr.Close()

	// Use a single read thread to avoid parquet-go concurrency bugs on some Arrow-produced files.
	pr, err := reader.NewParquetReader(fr, new(humanEvalParquetRow), 1)
	if err != nil {
		return nil, fmt.Errorf("create parquet reader: %w", err)
	}
	defer pr.ReadStop()

	n := int(pr.GetNumRows())
	rows := make([]humanEvalParquetRow, n)
	if err := pr.Read(&rows); err != nil {
		return nil, fmt.Errorf("read parquet rows: %w", err)
	}

	get := func(p *string) string {
		if p == nil {
			return ""
		}
		return *p
	}

	out := make([]humanEvalRecord, 0, len(rows))
	for _, r := range rows {
		out = append(out, humanEvalRecord{
			TaskID:            get(r.TaskID),
			Prompt:            get(r.Prompt),
			CanonicalSolution: get(r.CanonicalSolution),
			Test:              get(r.Test),
			EntryPoint:        get(r.EntryPoint),
		})
	}
	return out, nil
}

func buildHumanEvalText(rec humanEvalRecord, fields string) (string, error) {
	fields = strings.ToLower(strings.TrimSpace(fields))
	switch fields {
	case "prompt":
		return rec.Prompt, nil
	case "prompt+canonical_solution", "prompt+solution":
		return rec.Prompt + "\n" + rec.CanonicalSolution, nil
	case "all", "prompt+canonical_solution+test", "prompt+solution+test":
		return strings.Join([]string{rec.Prompt, rec.CanonicalSolution, rec.Test}, "\n"), nil
	default:
		return "", fmt.Errorf("unknown fields mode %q (use prompt|prompt+solution|all)", fields)
	}
}

func tokenizeHumanEval(ctx context.Context, worker *tokenizer.RustTokenizerClient, he []humanEvalRecord, fields string, requestBatchSize int) ([][]int32, []string, error) {
	if requestBatchSize <= 0 {
		requestBatchSize = 64
	}

	heTexts := make([]string, len(he))
	texts := make([]tokenizer.PRText, 0, len(he))
	for i, rec := range he {
		text, err := buildHumanEvalText(rec, fields)
		if err != nil {
			return nil, nil, err
		}
		heTexts[i] = text
		texts = append(texts, tokenizer.PRText{
			RepoID:   0,
			RepoName: "humaneval",
			PRID:     int64(i),
			Text:     text,
		})
	}

	// Send in batches to avoid max IPC message size.
	maxTokens := int32(math.MaxInt32)
	for i := 0; i < len(texts); i += requestBatchSize {
		end := i + requestBatchSize
		if end > len(texts) {
			end = len(texts)
		}
		if err := worker.SendRequest(texts[i:end], maxTokens); err != nil {
			return nil, nil, fmt.Errorf("send tokenize request: %w", err)
		}
	}

	if err := worker.CloseStdin(); err != nil {
		return nil, nil, fmt.Errorf("close tokenizer stdin: %w", err)
	}

	out := make([][]int32, len(he))
	pending := int64(len(he))
	for pending > 0 {
		resp, err := worker.ReceiveResponse()
		if err != nil {
			return nil, nil, fmt.Errorf("receive tokenize response: %w", err)
		}
		if resp.Status != "success" {
			msg := "unknown error"
			if resp.Error != nil {
				msg = *resp.Error
			}
			return nil, nil, fmt.Errorf("tokenize error: %s", msg)
		}
		for _, r := range resp.Results {
			idx := int(r.PRID)
			if idx < 0 || idx >= len(out) {
				return nil, nil, fmt.Errorf("tokenizer returned out-of-range PRID %d", r.PRID)
			}
			if r.Discarded || r.TokenIDs == nil {
				out[idx] = nil
			} else {
				out[idx] = r.TokenIDs
			}
			pending--
		}
	}

	if err := worker.Close(); err != nil {
		return nil, nil, fmt.Errorf("close tokenizer worker: %w", err)
	}

	return out, heTexts, nil
}

// buildEvalGramIndex builds an inverted index: gram_hash -> eval indices containing it.
// It uses UNIQUE grams per eval instance (set semantics).
func buildEvalGramIndex(tokensByTask [][]int32) (map[uint64][]int, []int) {
	index := make(map[uint64][]int)
	evalGramCounts := make([]int, len(tokensByTask))

	for taskIdx, toks := range tokensByTask {
		if len(toks) < ngramSize {
			evalGramCounts[taskIdx] = 0
			continue
		}

		seen := make(map[uint64]struct{}, len(toks)-ngramSize+1)

		// First window.
		var win [ngramSize]int32
		copy(win[:], toks[:ngramSize])
		h := windowHash(&win)
		seen[h] = struct{}{}

		// Rolling.
		pow := computePowBase()
		for start := 1; start+ngramSize <= len(toks); start++ {
			xOld := uint64(uint32(toks[start-1])) + 1
			xNew := uint64(uint32(toks[start+ngramSize-1])) + 1
			h = (h-xOld*pow)*rollingBase + xNew
			seen[h] = struct{}{}
		}

		evalGramCounts[taskIdx] = len(seen)
		for gram := range seen {
			index[gram] = append(index[gram], taskIdx)
		}
	}

	return index, evalGramCounts
}

// uniqueSampleGramHashes returns the unique ngram hashes present in a sample row.
func uniqueSampleGramHashes(toks []int32, pow uint64) []uint64 {
	if len(toks) < ngramSize {
		return nil
	}

	seen := make(map[uint64]struct{}, len(toks)-ngramSize+1)

	var win [ngramSize]int32
	copy(win[:], toks[:ngramSize])
	h := windowHash(&win)
	seen[h] = struct{}{}

	for start := 1; start+ngramSize <= len(toks); start++ {
		xOld := uint64(uint32(toks[start-1])) + 1
		xNew := uint64(uint32(toks[start+ngramSize-1])) + 1
		h = (h-xOld*pow)*rollingBase + xNew
		seen[h] = struct{}{}
	}

	out := make([]uint64, 0, len(seen))
	for gram := range seen {
		out = append(out, gram)
	}
	return out
}

type bestMatch struct {
	RepoID     int64   `json:"repo_id"`
	RepoName   string  `json:"repo_name"`
	PRID       int64   `json:"pr_id"`
	ByteSize   int64   `json:"byte_size"`
	Ratio      float64 `json:"ratio"`
	Overlap    int     `json:"overlap"`
	EvalGrams  int     `json:"eval_grams"`
	SampleText string  `json:"sample_text,omitempty"`

	// Fallback if SampleText is not available.
	SampleTokenCount int `json:"sample_token_count"`

	// Keep token ids only in-memory so we can decode the best match later.
	SampleTokenIDs []int32 `json:"-"`
}

type evalState struct {
	mu      sync.Mutex
	max     float64
	best    bestMatch
	flagged atomic.Bool
}

func extractOptionalText(row models.TokenizedPRData) (string, bool) {
	// Best-effort: if the struct has a "Text" or "DecodedText" string field, use it.
	v := reflect.ValueOf(row)
	if v.Kind() == reflect.Struct {
		if f := v.FieldByName("Text"); f.IsValid() && f.Kind() == reflect.String {
			s := f.String()
			if s != "" {
				return s, true
			}
		}
		if f := v.FieldByName("DecodedText"); f.IsValid() && f.Kind() == reflect.String {
			s := f.String()
			if s != "" {
				return s, true
			}
		}
	}
	return "", false
}

func scanTokenizedDataset(
	ctx context.Context,
	datasetDir string,
	heTaskIDs []string,
	index map[uint64][]int,
	evalGramCounts []int,
	states []evalState,
	leakageThreshold float64,
	scanWorkers int,
	progressInterval time.Duration,
) (int64, int64, error) {
	reader, err := parquet.NewParallelBatchReader[models.TokenizedPRData](datasetDir, 256, scanWorkers, 2)
	if err != nil {
		return 0, 0, err
	}
	defer reader.Close()

	jobs := make(chan []models.TokenizedPRData, scanWorkers*2)
	var wg sync.WaitGroup

	var rowsScanned atomic.Int64
	var gramsScanned atomic.Int64
	var bytesScanned atomic.Int64

	if progressInterval <= 0 {
		progressInterval = 5 * time.Second
	}
	progressStop := make(chan struct{})
	progressDone := make(chan struct{})
	start := time.Now()

	// Progress reporter.
	go func() {
		defer close(progressDone)
		ticker := time.NewTicker(progressInterval)
		defer ticker.Stop()

		var lastRows int64
		var lastGrams int64
		var lastBytes int64
		lastAt := start

		for {
			select {
			case <-progressStop:
				fmt.Fprintln(os.Stderr)
				return
			case <-ctx.Done():
				fmt.Fprintln(os.Stderr)
				return
			case <-ticker.C:
				now := time.Now()
				elapsed := now.Sub(start)
				dt := now.Sub(lastAt).Seconds()
				if dt <= 0 {
					dt = 1
				}

				rows := rowsScanned.Load()
				grams := gramsScanned.Load()
				b := bytesScanned.Load()

				rowsPerSec := float64(rows-lastRows) / dt
				gramsPerSec := float64(grams-lastGrams) / dt
				mbPerSec := float64(b-lastBytes) / (1024 * 1024) / dt

				flagged := 0
				for i := range states {
					if states[i].flagged.Load() {
						flagged++
					}
				}

				fmt.Fprintf(os.Stderr,
					"[PROGRESS] rows=%d (%.0f/s) uniq_grams=%d (%.0f/s) data=%.2fGB (%.2fMB/s) flagged=%d/%d elapsed=%s\r",
					rows,
					rowsPerSec,
					grams,
					gramsPerSec,
					float64(b)/(1024*1024*1024),
					mbPerSec,
					flagged,
					len(states),
					elapsed.Truncate(time.Second),
				)

				lastRows = rows
				lastGrams = grams
				lastBytes = b
				lastAt = now
			}
		}
	}()

	if scanWorkers <= 0 {
		scanWorkers = runtime.NumCPU()
	}

	pow := computePowBase()

	workerFn := func() {
		defer wg.Done()

		// Reuse maps per worker to reduce allocations.
		overlaps := make(map[int]int, 256)

		for batch := range jobs {
			for _, row := range batch {
				select {
				case <-ctx.Done():
					return
				default:
				}

				toks := row.TokenIDs
				rowsScanned.Add(1)
				bytesScanned.Add(int64(row.ByteSize))

				// Unique grams in sample.
				grams := uniqueSampleGramHashes(toks, pow)
				gramsScanned.Add(int64(len(grams)))
				if len(grams) == 0 {
					continue
				}

				// overlaps[evalIdx] = number of UNIQUE eval grams hit by this sample.
				for k := range overlaps {
					delete(overlaps, k)
				}
				for _, g := range grams {
					evals, ok := index[g]
					if !ok {
						continue
					}
					for _, evalIdx := range evals {
						overlaps[evalIdx]++
					}
				}
				if len(overlaps) == 0 {
					continue
				}

				sampleText, _ := extractOptionalText(row)

				// Update per-eval maxima.
				for evalIdx, overlap := range overlaps {
					den := evalGramCounts[evalIdx]
					if den <= 0 {
						continue
					}
					ratio := float64(overlap) / float64(den)

					st := &states[evalIdx]
					updated := false

					st.mu.Lock()
					if ratio > st.max {
						st.max = ratio
						st.best = bestMatch{
							RepoID:           row.RepoID,
							RepoName:         row.RepoName,
							PRID:             row.PRID,
							ByteSize:         int64(row.ByteSize),
							Ratio:            ratio,
							Overlap:          overlap,
							EvalGrams:        den,
							SampleText:       sampleText,
							SampleTokenCount: len(toks),
							SampleTokenIDs:   append([]int32(nil), toks...), // copy for later decoding
						}
						updated = true
					}
					st.mu.Unlock()

					// First time it crosses the threshold -> print.
					// (This flags "contaminated" when leakage_ratio >= threshold.)
					if updated && ratio >= leakageThreshold {
						if st.flagged.CompareAndSwap(false, true) {
							id := "<unknown>"
							if evalIdx >= 0 && evalIdx < len(heTaskIDs) {
								id = heTaskIDs[evalIdx]
							}
							fmt.Fprintf(os.Stderr,
								"\n[FLAG] %s leakage_ratio=%.4f overlap=%d/%d matched in %s#%d (repo_id=%d)\n",
								id,
								ratio,
								overlap,
								den,
								row.RepoName,
								row.PRID,
								row.RepoID,
							)
						}
					}
				}
			}
		}
	}

	for i := 0; i < scanWorkers; i++ {
		wg.Add(1)
		go workerFn()
	}

	for {
		select {
		case <-ctx.Done():
			close(jobs)
			wg.Wait()
			close(progressStop)
			<-progressDone
			return rowsScanned.Load(), gramsScanned.Load(), ctx.Err()
		default:
		}

		batch, hasMore, err := reader.ReadBatch()
		if err != nil {
			close(jobs)
			wg.Wait()
			close(progressStop)
			<-progressDone
			return rowsScanned.Load(), gramsScanned.Load(), err
		}
		if len(batch) > 0 {
			select {
			case jobs <- batch:
			case <-ctx.Done():
				close(jobs)
				wg.Wait()
				close(progressStop)
				<-progressDone
				return rowsScanned.Load(), gramsScanned.Load(), ctx.Err()
			}
		}
		if !hasMore {
			break
		}
	}

	close(jobs)
	wg.Wait()
	close(progressStop)
	<-progressDone
	return rowsScanned.Load(), gramsScanned.Load(), nil
}

type leakageRankRow struct {
	TaskID        string    `json:"task_id"`
	LeakageRatio  float64   `json:"leakage_ratio"`
	EvalGramCount int       `json:"eval_gram_count"`
	BestMatch     bestMatch `json:"best_match"`
}

func main() {
	var (
		datasetDir        string
		humanEvalPath     string
		humanEvalURL      string
		outPath           string
		tokenizerModel    string
		rustWorkerPath    string
		decodeBin         string
		rustWorkers       int
		rustThreadsPerTok int
		fields            string
		scanWorkers       int
		requestBatchSize  int
		progressSec       int
		leakageThreshold  float64
	)

	flag.StringVar(&datasetDir, "dataset_dir", "data/tokenized_dataset", "Directory containing tokenized parquet dataset (models.TokenizedPRData)")
	flag.StringVar(&humanEvalPath, "humaneval_path", "", "Path to HumanEval .jsonl or .parquet")
	flag.StringVar(&humanEvalURL, "humaneval_url", "", "URL to HumanEval .jsonl (e.g., http://localhost:1081/humaneval.jsonl)")
	flag.StringVar(&outPath, "out", "data/humaneval_decontam_report.json", "Output JSON report path")
	flag.StringVar(&tokenizerModel, "tokenizer_model", os.Getenv("TOKENIZER_MODEL"), "Tokenizer model name or path to tokenizer.json")
	flag.StringVar(&rustWorkerPath, "rust_worker_path", "./tokenizer/rust_worker/target/release/tokenizer_worker", "Path to Rust tokenizer worker binary")
	flag.StringVar(&decodeBin, "decode_bin", "./tokenizer/rust_worker/target/release/tokenizer_decode", "Decoder binary path")
	flag.IntVar(&rustWorkers, "rust_workers", 128, "Number of Rust worker threads")
	flag.IntVar(&rustThreadsPerTok, "rust_threads_per_tokenizer", 16, "Threads per tokenizer instance")
	flag.StringVar(&fields, "humaneval_fields", "all", "Which HumanEval fields to tokenize: prompt|prompt+solution|all")
	flag.IntVar(&scanWorkers, "scan_workers", runtime.NumCPU(), "Concurrency for scanning parquet dataset")
	flag.IntVar(&requestBatchSize, "tokenize_batch", 64, "How many HumanEval items per tokenizer request")
	flag.IntVar(&progressSec, "progress_sec", 30, "Progress refresh interval in seconds while scanning dataset")
	flag.Float64Var(&leakageThreshold, "leakage_threshold", 0.10, "Flag contaminated if leakage_ratio >= threshold")
	flag.Parse()

	if tokenizerModel == "" {
		tokenizerModel = "Qwen/Qwen2.5-Coder-32B-Instruct"
	}
	if humanEvalPath == "" && humanEvalURL == "" {
		fmt.Fprintln(os.Stderr, "must provide -humaneval_path or -humaneval_url")
		os.Exit(2)
	}

	ctx := context.Background()
	start := time.Now()

	fmt.Printf("[INFO] Reading HumanEval...\n")
	he, err := readHumanEval(ctx, humanEvalPath, humanEvalURL)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] read humaneval: %v\n", err)
		os.Exit(1)
	}
	fmt.Printf("[INFO] HumanEval loaded: %d items\n", len(he))

	fmt.Printf("[INFO] Starting Rust tokenizer worker...\n")
	worker, err := tokenizer.NewRustTokenizerClient(tokenizer.RustWorkerConfig{
		WorkerPath:          rustWorkerPath,
		Model:               tokenizerModel,
		Workers:             rustWorkers,
		ThreadsPerTokenizer: rustThreadsPerTok,
		MaxMessageSize:      200 * 1024 * 1024,
	})
	if err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] start tokenizer worker: %v\n", err)
		os.Exit(1)
	}

	fmt.Printf("[INFO] Tokenizing HumanEval (%s)...\n", fields)
	heTokens, heTexts, err := tokenizeHumanEval(ctx, worker, he, fields, requestBatchSize)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] tokenize humaneval: %v\n", err)
		os.Exit(1)
	}

	fmt.Printf("[INFO] Building %d-gram inverted index (set grams per eval)...\n", ngramSize)
	index, evalGramCounts := buildEvalGramIndex(heTokens)
	fmt.Printf("[INFO] Index built: %d distinct gram hash keys\n", len(index))

	heTaskIDs := make([]string, len(he))
	for i := range he {
		heTaskIDs[i] = he[i].TaskID
	}

	// Per-eval state.
	states := make([]evalState, len(he))
	for i := range states {
		states[i].max = 0
		states[i].best = bestMatch{Ratio: 0, EvalGrams: evalGramCounts[i]}
	}

	fmt.Printf("[INFO] Scanning tokenized dataset: %s\n", datasetDir)
	rows, uniqGrams, err := scanTokenizedDataset(
		ctx,
		datasetDir,
		heTaskIDs,
		index,
		evalGramCounts,
		states,
		leakageThreshold,
		scanWorkers,
		time.Duration(progressSec)*time.Second,
	)
	if err != nil {
		if !errors.Is(err, context.Canceled) {
			fmt.Fprintf(os.Stderr, "[ERROR] scan dataset: %v\n", err)
		}
		os.Exit(1)
	}

	// Build ranked leakage list and contaminated/decontaminated partitions.
	ranked := make([]leakageRankRow, 0, len(he))
	var contaminatedIDs []string
	var decontaminatedIDs []string

	for i := range he {
		st := &states[i]
		st.mu.Lock()
		r := st.max
		bm := st.best
		st.mu.Unlock()

		ranked = append(ranked, leakageRankRow{
			TaskID:        he[i].TaskID,
			LeakageRatio:  r,
			EvalGramCount: evalGramCounts[i],
			BestMatch:     bm,
		})

		if r >= leakageThreshold {
			contaminatedIDs = append(contaminatedIDs, he[i].TaskID)
		} else {
			decontaminatedIDs = append(decontaminatedIDs, he[i].TaskID)
		}
	}

	sort.Slice(ranked, func(i, j int) bool {
		if ranked[i].LeakageRatio == ranked[j].LeakageRatio {
			return ranked[i].TaskID < ranked[j].TaskID
		}
		return ranked[i].LeakageRatio > ranked[j].LeakageRatio
	})

	report := map[string]any{
		"dataset_dir":           datasetDir,
		"humaneval_path":        humanEvalPath,
		"humaneval_url":         humanEvalURL,
		"humaneval_fields":      fields,
		"ngram_size":            ngramSize,
		"humaneval_items":       len(he),
		"index_hash_keys":       len(index),
		"rows_scanned":          rows,
		"unique_grams_scanned":  uniqGrams,
		"leakage_threshold":     leakageThreshold,
		"contaminated_count":    len(contaminatedIDs),
		"decontaminated_count":  len(decontaminatedIDs),
		"contaminated":          contaminatedIDs,
		"decontaminated":        decontaminatedIDs,
		"leakage_ratios_ranked": ranked,
		"elapsed_sec":           time.Since(start).Seconds(),
	}

	// Write report JSON.
	if err := os.MkdirAll(filepath.Dir(outPath), 0o755); err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] mkdir out dir: %v\n", err)
		os.Exit(1)
	}
	f, err := os.Create(outPath)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] create out: %v\n", err)
		os.Exit(1)
	}
	enc := json.NewEncoder(f)
	enc.SetIndent("", "  ")
	if err := enc.Encode(report); err != nil {
		_ = f.Close()
		fmt.Fprintf(os.Stderr, "[ERROR] write out: %v\n", err)
		os.Exit(1)
	}
	if err := f.Close(); err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] close out: %v\n", err)
		os.Exit(1)
	}

	// Write {out}.matches.json: decoded content of the highest match of each instance.
	type matchOut struct {
		TaskID       string    `json:"task_id"`
		Fields       string    `json:"humaneval_fields"`
		EvalText     string    `json:"eval_text"`
		LeakageRatio float64   `json:"leakage_ratio"`
		EvalGrams    int       `json:"eval_gram_count"`
		BestMatch    bestMatch `json:"best_match"`
		DecodeError  string    `json:"decode_error,omitempty"`
	}

	fmt.Printf("[INFO] Decoding best matches via %s ...\n", decodeBin)
	matches := make([]matchOut, 0, len(he))
	for i := range he {
		st := &states[i]
		st.mu.Lock()
		r := st.max
		bm := st.best
		st.mu.Unlock()

		var decErr string
		if bm.SampleText == "" && len(bm.SampleTokenIDs) > 0 {
			txt, err := decodeTokens(decodeBin, tokenizerModel, bm.SampleTokenIDs)
			if err != nil {
				decErr = err.Error()
			} else {
				bm.SampleText = txt
			}
		}

		// Ensure we don't leak huge token ids into JSON (json:"-" already, but be explicit).
		bm.SampleTokenIDs = nil

		matches = append(matches, matchOut{
			TaskID:       he[i].TaskID,
			Fields:       fields,
			EvalText:     heTexts[i],
			LeakageRatio: r,
			EvalGrams:    evalGramCounts[i],
			BestMatch:    bm,
			DecodeError:  decErr,
		})
	}

	matchesPath := outPath + ".matches.json"
	mf, err := os.Create(matchesPath)
	if err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] create matches out: %v\n", err)
		os.Exit(1)
	}
	menc := json.NewEncoder(mf)
	menc.SetIndent("", "  ")
	if err := menc.Encode(matches); err != nil {
		_ = mf.Close()
		fmt.Fprintf(os.Stderr, "[ERROR] write matches out: %v\n", err)
		os.Exit(1)
	}
	if err := mf.Close(); err != nil {
		fmt.Fprintf(os.Stderr, "[ERROR] close matches out: %v\n", err)
		os.Exit(1)
	}

	fmt.Printf("[INFO] Done. contaminated=%d decontaminated=%d rows=%d uniq_grams=%d elapsed=%s\n",
		len(contaminatedIDs),
		len(decontaminatedIDs),
		rows,
		uniqGrams,
		time.Since(start).Truncate(time.Millisecond),
	)
	fmt.Printf("[INFO] Report written to %s\n", outPath)
	fmt.Printf("[INFO] Matches written to %s\n", matchesPath)
}
