package parquet

import (
	"fmt"
	"path/filepath"
	"sync"
	"sync/atomic"

	"github.com/xitongsys/parquet-go-source/local"
	"github.com/xitongsys/parquet-go/reader"
)

// ParallelBatchReader provides concurrent reading of parquet files
type ParallelBatchReader[T any] struct {
	files      []string
	batchSize  int
	numWorkers int
	np         int64

	// Shared state
	fileIndex int64

	// Output channel
	batchChan chan []T
	errChan   chan error

	// Worker management
	wg       sync.WaitGroup
	stopChan chan struct{}
}

// NewParallelBatchReader creates a new parallel batch reader
func NewParallelBatchReader[T any](inputDir string, batchSize int, numWorkers int, np int64) (*ParallelBatchReader[T], error) {
	// Find all parquet files in the input directory
	pattern := filepath.Join(inputDir, "*.parquet")
	files, err := filepath.Glob(pattern)
	if err != nil {
		return nil, fmt.Errorf("failed to glob parquet files: %w", err)
	}

	if len(files) == 0 {
		return nil, fmt.Errorf("no parquet files found in %s", inputDir)
	}

	pbr := &ParallelBatchReader[T]{
		files:      files,
		batchSize:  batchSize,
		numWorkers: numWorkers,
		np:         np,
		fileIndex:  -1, // Start at -1 so first increment gives 0
		batchChan:  make(chan []T, numWorkers*2),
		errChan:    make(chan error, numWorkers),
		stopChan:   make(chan struct{}),
	}

	// Start workers
	for i := 0; i < numWorkers; i++ {
		pbr.wg.Add(1)
		go pbr.runWorker(i)
	}

	// Start closer goroutine
	go func() {
		pbr.wg.Wait()
		close(pbr.batchChan)
		close(pbr.errChan)
	}()

	return pbr, nil
}

// ReadBatch returns the next batch of records
func (pbr *ParallelBatchReader[T]) ReadBatch() ([]T, bool, error) {
	select {
	case err := <-pbr.errChan:
		if err != nil {
			return nil, false, err
		}
		return nil, false, nil // Should not happen if closed correctly

	case batch, ok := <-pbr.batchChan:
		if !ok {
			return nil, false, nil // Channel closed, no more data
		}
		return batch, true, nil
	}
}

// runWorker processes files from the shared list
func (pbr *ParallelBatchReader[T]) runWorker(id int) {
	defer pbr.wg.Done()

	for {
		// Get next file index
		idx := atomic.AddInt64(&pbr.fileIndex, 1)
		if idx >= int64(len(pbr.files)) {
			return // No more files
		}

		filename := pbr.files[idx]

		// Read file
		if err := pbr.processFile(filename); err != nil {
			select {
			case pbr.errChan <- fmt.Errorf("worker %d failed to process file %s: %w", id, filename, err):
			case <-pbr.stopChan:
			}
			return
		}
	}
}

// processFile reads a single file and sends batches
func (pbr *ParallelBatchReader[T]) processFile(filename string) error {
	fr, err := local.NewLocalFileReader(filename)
	if err != nil {
		return fmt.Errorf("failed to open file: %w", err)
	}
	defer fr.Close()

	pr, err := reader.NewParquetReader(fr, new(T), pbr.np)
	if err != nil {
		return fmt.Errorf("failed to create parquet reader: %w", err)
	}
	defer pr.ReadStop()

	numRows := int(pr.GetNumRows())
	cursor := 0

	for cursor < numRows {
		// Check for stop signal
		select {
		case <-pbr.stopChan:
			return nil
		default:
		}

		// Calculate batch size
		currentBatchSize := pbr.batchSize
		if cursor+currentBatchSize > numRows {
			currentBatchSize = numRows - cursor
		}

		batch := make([]T, currentBatchSize)
		if err := pr.Read(&batch); err != nil {
			return fmt.Errorf("failed to read rows: %w", err)
		}

		// Send batch
		select {
		case pbr.batchChan <- batch:
		case <-pbr.stopChan:
			return nil
		}

		cursor += currentBatchSize
	}

	return nil
}

// GetFileCount returns the total number of parquet files
func (pbr *ParallelBatchReader[T]) GetFileCount() int {
	return len(pbr.files)
}

// Close stops all workers
func (pbr *ParallelBatchReader[T]) Close() {
	close(pbr.stopChan)
}
