package main

import (
	"flag"
	"fmt"
	"math/rand"
	"os"
	"sync"
	"sync/atomic"
	"time"

	"github.com/reposurvey/pipeline/models"
	"github.com/reposurvey/pipeline/parquet"
)

var (
	mode        = flag.String("mode", "write", "Mode: write or read")
	dir         = flag.String("dir", "./bench_data", "Directory for data")
	inputDir    = flag.String("input_dir", "", "Input directory for read mode (defaults to dir)")
	workers     = flag.Int("workers", 1, "Number of concurrent workers")
	np          = flag.Int64("np", 4, "Number of parallel routines per reader")
	totalSizeGB = flag.Float64("size_gb", 1.0, "Total size to write in GB (approx)")
	fileSizeMB  = flag.Int64("file_size_mb", 100, "Size per file in MB")
	recordSize  = flag.Int("record_size", 2048, "Size of text payload in bytes")
)

func main() {
	flag.Parse()

	fmt.Printf("Benchmark Configuration:\n")
	fmt.Printf("  Mode: %s\n", *mode)
	fmt.Printf("  Directory: %s\n", *dir)
	fmt.Printf("  Workers: %d\n", *workers)
	if *mode == "read" {
		fmt.Printf("  NP: %d\n", *np)
	}
	if *mode == "write" {
		fmt.Printf("  Total Size: %.2f GB\n", *totalSizeGB)
		fmt.Printf("  File Size: %d MB\n", *fileSizeMB)
		fmt.Printf("  Record Size: %d bytes\n", *recordSize)
	}

	if *mode == "write" {
		runWriteBenchmark()
	} else if *mode == "read" {
		runReadBenchmark()
	} else {
		fmt.Println("Invalid mode. Use 'write' or 'read'")
		os.Exit(1)
	}
}

func runWriteBenchmark() {
	// Calculate total records
	totalBytes := int64(*totalSizeGB * 1024 * 1024 * 1024)
	totalRecords := totalBytes / int64(*recordSize)
	recordsPerWorker := totalRecords / int64(*workers)

	fmt.Printf("Target: Writing %d records total (~%d per worker)\n", totalRecords, recordsPerWorker)

	// Pre-generate payload to avoid CPU bottleneck
	payload := randomString(*recordSize)

	var wg sync.WaitGroup
	var totalWritten int64
	startTime := time.Now()

	// Use ParallelBatchWriter
	// Batch size 5000, flush interval 30s
	bw, err := parquet.NewParallelBatchWriter[models.RenderedPRText](
		*dir,
		5000,
		*fileSizeMB*1024*1024,
		30,
		*workers,
	)
	if err != nil {
		fmt.Printf("Failed to create writer: %v\n", err)
		return
	}
	defer bw.Close()

	for i := 0; i < *workers; i++ {
		wg.Add(1)
		go func(id int) {
			defer wg.Done()

			var written int64
			for j := int64(0); j < recordsPerWorker; j++ {
				rec := models.RenderedPRText{
					PRID:     int64(j),
					RepoID:   int64(id),
					RepoName: fmt.Sprintf("user/repo-%d", id),
					Text:     payload,
				}
				if err := bw.Write(rec); err != nil {
					fmt.Printf("Worker %d write error: %v\n", id, err)
					return
				}
				written++
			}
			atomic.AddInt64(&totalWritten, written)
		}(i)
	}

	wg.Wait()
	duration := time.Since(startTime)

	printStats(totalWritten, int64(*recordSize), duration)
}

func runReadBenchmark() {
	readDir := *dir
	if *inputDir != "" {
		readDir = *inputDir
	}

	// Use ParallelBatchReader
	pbr, err := parquet.NewParallelBatchReader[models.RenderedPRText](readDir, 5000, *workers, *np)
	if err != nil {
		fmt.Printf("Failed to create reader: %v\n", err)
		os.Exit(1)
	}
	defer pbr.Close()

	fmt.Printf("Found %d parquet files to read\n", pbr.GetFileCount())

	var totalRead int64
	startTime := time.Now()
	var wg sync.WaitGroup

	// Start multiple consumers to simulate real workload
	// We use same number of consumers as workers
	for i := 0; i < *workers; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for {
				batch, hasMore, err := pbr.ReadBatch()
				if err != nil {
					fmt.Printf("Read error: %v\n", err)
					break
				}

				if len(batch) > 0 {
					atomic.AddInt64(&totalRead, int64(len(batch)))
				}

				if !hasMore {
					break
				}
			}
		}()
	}

	wg.Wait()
	duration := time.Since(startTime)

	// Estimate record size from first file if possible, or use flag
	// For simplicity using flag recordSize for throughput calc
	printStats(totalRead, int64(*recordSize), duration)
}

func printStats(count int64, recSize int64, duration time.Duration) {
	seconds := duration.Seconds()
	totalBytes := count * recSize
	mb := float64(totalBytes) / (1024 * 1024)

	fmt.Printf("\nBenchmark Results:\n")
	fmt.Printf("  Time: %.2fs\n", seconds)
	fmt.Printf("  Records: %d\n", count)
	fmt.Printf("  Total Data: %.2f MB\n", mb)
	fmt.Printf("  Throughput: %.2f MB/s\n", mb/seconds)
	fmt.Printf("  Records/s: %.2f\n", float64(count)/seconds)
}

const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "

func randomString(n int) string {
	b := make([]byte, n)
	for i := range b {
		b[i] = charset[rand.Intn(len(charset))]
	}
	return string(b)
}
