package parquet

import (
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"sync"
	"sync/atomic"
	"time"

	"github.com/xitongsys/parquet-go-source/local"
	"github.com/xitongsys/parquet-go/parquet"
	"github.com/xitongsys/parquet-go/writer"
)

// ErrWriterClosed is returned when attempting to write to a closed writer
var ErrWriterClosed = errors.New("writer is closed")

// ParallelBatchWriter handles concurrent writing to Parquet files with automatic rotation
// It maintains multiple internal writers but exposes a single Write method
type ParallelBatchWriter[T any] struct {
	outputDir     string
	batchSize     int
	maxFileSize   int64
	flushInterval time.Duration
	numWorkers    int

	// Shared file counter for consecutive numbering
	fileCounter int64

	// Channels for distributing work
	writeChan chan T

	// Worker management
	wg       sync.WaitGroup
	stopChan chan struct{}

	// Error handling
	errChan chan error
	closed  atomic.Bool

	// Metrics
	totalSize atomic.Int64
}

// GetTotalSize returns the total size of data written in bytes
func (pbw *ParallelBatchWriter[T]) GetTotalSize() int64 {
	return pbw.totalSize.Load()
}

// NewParallelBatchWriter creates a new parallel batch writer
func NewParallelBatchWriter[T any](outputDir string, batchSize int, maxFileSize int64, flushInterval int, numWorkers int) (*ParallelBatchWriter[T], error) {
	// Create output directory if it doesn't exist
	if err := os.MkdirAll(outputDir, 0755); err != nil {
		return nil, fmt.Errorf("failed to create output directory: %w", err)
	}

	pbw := &ParallelBatchWriter[T]{
		outputDir:     outputDir,
		batchSize:     batchSize,
		maxFileSize:   maxFileSize,
		flushInterval: time.Duration(flushInterval) * time.Second,
		numWorkers:    numWorkers,
		writeChan:     make(chan T, batchSize*numWorkers),
		stopChan:      make(chan struct{}),
		errChan:       make(chan error, numWorkers),
	}

	// Start workers
	for i := 0; i < numWorkers; i++ {
		pbw.wg.Add(1)
		go pbw.runWorker(i)
	}

	return pbw, nil
}

// Write adds a record to the write queue
func (pbw *ParallelBatchWriter[T]) Write(record T) error {
	if pbw.closed.Load() {
		return ErrWriterClosed
	}

	// Check for any worker errors
	select {
	case err := <-pbw.errChan:
		return err
	default:
	}

	pbw.writeChan <- record
	return nil
}

// runWorker handles writing for a single worker
func (pbw *ParallelBatchWriter[T]) runWorker(id int) {
	defer pbw.wg.Done()

	// Each worker maintains its own buffer and writer state
	buffer := make([]T, 0, pbw.batchSize)
	var currentWriter *writer.ParquetWriter
	var currentFilePath string
	var currentFileSize int64
	lastFlush := time.Now()

	flushTicker := time.NewTicker(pbw.flushInterval)
	defer flushTicker.Stop()

	// Helper to close current writer
	closeWriter := func() error {
		if currentWriter != nil {
			if err := currentWriter.WriteStop(); err != nil {
				return fmt.Errorf("worker %d failed to close writer: %w", id, err)
			}

			// Update total size with final file size
			if currentFilePath != "" {
				if info, err := os.Stat(currentFilePath); err == nil {
					newSize := info.Size()
					delta := newSize - currentFileSize
					if delta > 0 {
						pbw.totalSize.Add(delta)
					}
				}
			}

			currentWriter = nil
			currentFilePath = ""
			currentFileSize = 0
		}
		return nil
	}

	// Helper to rotate file
	rotateFile := func() error {
		if err := closeWriter(); err != nil {
			return err
		}

		// Get next file number atomically
		fileNum := atomic.AddInt64(&pbw.fileCounter, 1)
		filename := filepath.Join(pbw.outputDir, fmt.Sprintf("part-%04d.parquet", fileNum))

		fw, err := local.NewLocalFileWriter(filename)
		if err != nil {
			return fmt.Errorf("worker %d failed to create file %s: %w", id, filename, err)
		}

		pw, err := writer.NewParquetWriter(fw, new(T), 4)
		if err != nil {
			fw.Close()
			return fmt.Errorf("worker %d failed to create parquet writer: %w", id, err)
		}

		// Use ZSTD compression as requested
		pw.CompressionType = parquet.CompressionCodec_ZSTD

		currentWriter = pw
		currentFilePath = filename
		currentFileSize = 0

		return nil
	}

	// Helper to flush buffer
	flush := func() error {
		if len(buffer) == 0 {
			return nil
		}

		// Check rotation
		if currentWriter == nil {
			if err := rotateFile(); err != nil {
				return err
			}
		} else if currentFilePath != "" {
			if fileInfo, err := os.Stat(currentFilePath); err == nil {
				newSize := fileInfo.Size()
				delta := newSize - currentFileSize
				if delta > 0 {
					pbw.totalSize.Add(delta)
				}
				currentFileSize = newSize

				if currentFileSize >= pbw.maxFileSize {
					if err := rotateFile(); err != nil {
						return err
					}
				}
			}
		}

		// Write records
		for _, record := range buffer {
			if err := currentWriter.Write(record); err != nil {
				return fmt.Errorf("worker %d failed to write record: %w", id, err)
			}
		}

		buffer = buffer[:0]
		lastFlush = time.Now()
		return nil
	}

	for {
		select {
		case record, ok := <-pbw.writeChan:
			if !ok {
				// Channel closed, flush remaining and exit
				if err := flush(); err != nil {
					pbw.errChan <- err
				}
				if err := closeWriter(); err != nil {
					pbw.errChan <- err
				}
				return
			}

			buffer = append(buffer, record)
			if len(buffer) >= pbw.batchSize {
				if err := flush(); err != nil {
					pbw.errChan <- err
					return
				}
			}

		case <-flushTicker.C:
			if len(buffer) > 0 && time.Since(lastFlush) >= pbw.flushInterval {
				if err := flush(); err != nil {
					pbw.errChan <- err
					return
				}
			}
		}
	}
}

// Close closes the writer and waits for all workers to finish
func (pbw *ParallelBatchWriter[T]) Close() error {
	if !pbw.closed.CompareAndSwap(false, true) {
		return nil
	}

	close(pbw.writeChan)
	pbw.wg.Wait()

	// Check for any errors
	select {
	case err := <-pbw.errChan:
		return err
	default:
		return nil
	}
}
