package tasks

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"path/filepath"
	"strings"
	"sync"
	"time"

	"github.com/reposurvey/pipeline/client"
	"github.com/reposurvey/pipeline/config"
	"github.com/reposurvey/pipeline/models"
	"github.com/reposurvey/pipeline/parquet"
	"github.com/xitongsys/parquet-go-source/local"
	"github.com/xitongsys/parquet-go/reader"
)

// prJob represents a job to process a batch of PRs from one page of the list API
type prJob struct {
	repo models.PublicRepo
	prs  []models.PRListItem
	page int
}

// prThroughputMonitor tracks and reports PR pipeline throughput
type prThroughputMonitor struct {
	mu             sync.Mutex
	reposProcessed int64
	prsProcessed   int64
	prsFiltered    int64
	startTime      time.Time
	client         *client.GithubClient
	lastAPICount   uint64
}

func newPRThroughputMonitor(client *client.GithubClient) *prThroughputMonitor {
	return &prThroughputMonitor{
		startTime:    time.Now(),
		client:       client,
		lastAPICount: 0,
	}
}

func (m *prThroughputMonitor) addRepoProcessed() {
	m.mu.Lock()
	defer m.mu.Unlock()
	m.reposProcessed++
}

func (m *prThroughputMonitor) addPRsProcessed(count int) {
	m.mu.Lock()
	defer m.mu.Unlock()
	m.prsProcessed += int64(count)
}

func (m *prThroughputMonitor) addPRsFiltered(count int) {
	m.mu.Lock()
	defer m.mu.Unlock()
	m.prsFiltered += int64(count)
}

func (m *prThroughputMonitor) report() {
	m.mu.Lock()
	defer m.mu.Unlock()

	elapsed := time.Since(m.startTime).Seconds()
	repoRate := float64(m.reposProcessed) / elapsed
	prProcessRate := float64(m.prsProcessed) / elapsed
	prFilterRate := float64(m.prsFiltered) / elapsed

	// Get current API request count
	currentAPICount := m.client.GetRequestCount()
	apiRequestsSinceLastReport := currentAPICount - m.lastAPICount
	apiRate := float64(currentAPICount) / elapsed

	fmt.Printf("[THROUGHPUT] Repos: %d (%.1f/s) | PRs Processed: %d (%.1f/s) | PRs Filtered: %d (%.1f/s) | API Requests: %d (%.1f/s, +%d) | Elapsed: %.1fs\n",
		m.reposProcessed, repoRate, m.prsProcessed, prProcessRate,
		m.prsFiltered, prFilterRate, currentAPICount, apiRate, apiRequestsSinceLastReport, elapsed)

	// Update last API count for next report
	m.lastAPICount = currentAPICount
}

// Task2PRIngestion implements PR metadata fetching and filtering
type Task2PRIngestion struct {
	cfg      *config.Config
	client   *client.GithubClient
	prWriter *parquet.BatchWriter[models.PRMetadata]

	// Channel for piping PR jobs from producers to consumers
	prJobChan chan prJob

	// Number of producer goroutines (one per repo)
	numProducers int

	// Throughput monitoring
	monitor *prThroughputMonitor
}

// NewTask2PRIngestion creates a new PR ingestion task
func NewTask2PRIngestion(cfg *config.Config, client *client.GithubClient) (*Task2PRIngestion, error) {
	// Create PR writer
	prWriter, err := parquet.NewBatchWriter[models.PRMetadata](
		cfg.RawPRsDir,
		cfg.BatchSize,
		cfg.MaxFileSize,
		cfg.FlushInterval,
	)
	if err != nil {
		return nil, fmt.Errorf("failed to create PR writer: %w", err)
	}

	// Calculate number of producers (use 10% of max concurrency for producers, rest for consumers)
	numProducers := cfg.MaxConcurrency / 10
	if numProducers < 10 {
		numProducers = 10
	}

	return &Task2PRIngestion{
		cfg:          cfg,
		client:       client,
		prWriter:     prWriter,
		prJobChan:    make(chan prJob, cfg.MaxConcurrency*2),
		numProducers: numProducers,
		monitor:      newPRThroughputMonitor(client),
	}, nil
}

// Run starts the PR ingestion process
func (t *Task2PRIngestion) Run(ctx context.Context) error {
	fmt.Println("[INFO] Starting Task 2: PR Metadata Ingestion")

	// Read filtered repos from Task 1 output
	repos, err := t.loadFilteredRepos()
	if err != nil {
		return fmt.Errorf("failed to load filtered repos: %w", err)
	}

	fmt.Printf("[INFO] Loaded %d repositories to process\n", len(repos))

	// Start throughput monitor
	monitorCtx, monitorCancel := context.WithCancel(ctx)
	defer monitorCancel()
	go t.runThroughputMonitor(monitorCtx)

	// Start consumer workers (PR file processing)
	consumersDone := make(chan error, 1)
	go func() {
		if err := t.runConsumers(ctx); err != nil {
			consumersDone <- fmt.Errorf("consumers error: %w", err)
		}
		close(consumersDone)
	}()

	// Start producer workers (PR list fetching)
	var wg sync.WaitGroup
	producerErrors := make(chan error, t.numProducers)
	repoChan := make(chan models.PublicRepo, t.numProducers*2)

	for i := 0; i < t.numProducers; i++ {
		wg.Add(1)
		go func(producerID int) {
			defer wg.Done()
			if err := t.runProducer(ctx, producerID, repoChan); err != nil {
				producerErrors <- fmt.Errorf("producer %d error: %w", producerID, err)
			}
		}(i)
	}

	// Feed repositories to producers
	go func() {
		for _, repo := range repos {
			select {
			case repoChan <- repo:
			case <-ctx.Done():
				close(repoChan)
				return
			}
		}
		close(repoChan)
	}()

	// Wait for all producers to complete
	go func() {
		wg.Wait()
		close(t.prJobChan) // Signal consumers that no more PR jobs will be sent
		close(producerErrors)
	}()

	// Check for producer errors
	for err := range producerErrors {
		if err != nil {
			return err
		}
	}

	// Wait for consumers to complete
	if err := <-consumersDone; err != nil {
		return err
	}

	// Final report
	t.monitor.report()
	fmt.Println("[INFO] Task 2 completed successfully")
	return nil
}

// runThroughputMonitor reports throughput statistics every 30 seconds
func (t *Task2PRIngestion) runThroughputMonitor(ctx context.Context) {
	ticker := time.NewTicker(30 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-ctx.Done():
			// Final report before exit
			t.monitor.report()
			return
		case <-ticker.C:
			t.monitor.report()
		}
	}
}

// runProducer implements a single PR list producer
// Each producer consumes repos and lists their PRs, submitting individual PR jobs to consumers
func (t *Task2PRIngestion) runProducer(ctx context.Context, producerID int, repoChan <-chan models.PublicRepo) error {
	if producerID == 0 {
		fmt.Printf("[INFO] Starting %d producers for PR list fetching\n", t.numProducers)
	}

	for repo := range repoChan {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}

		// Process this repository's PRs
		if err := t.listRepositoryPRs(ctx, repo); err != nil {
			if !errors.Is(err, context.Canceled) {
				fmt.Printf("[ERROR] Producer %d failed to list PRs for repo %s: %v\n",
					producerID, repo.FullName, err)
			}
			// Continue to next repo even if this one failed
		}

		// Update monitor
		t.monitor.addRepoProcessed()
	}

	return nil
}

// listRepositoryPRs fetches all PRs for a repository and submits them as batch jobs
func (t *Task2PRIngestion) listRepositoryPRs(ctx context.Context, repo models.PublicRepo) error {
	page := 1
	perPage := 100

	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}

		// Fetch PR list sorted by creation time (oldest first)
		// This ensures cached responses from reverse proxy are mostly correct
		// since creation time is immutable and provides stable ordering
		endpoint := fmt.Sprintf("/repos/%s/pulls?state=closed&sort=created&direction=asc&per_page=%d&page=%d",
			repo.FullName, perPage, page)

		data, err := t.client.Get(ctx, endpoint)
		if err != nil {
			// Silently skip 404 errors (deleted or private repositories)
			if errors.Is(err, client.ErrNotFound) {
				return nil
			}
			return fmt.Errorf("failed to fetch PRs: %w", err)
		}

		var prs []models.PRListItem
		if err := json.Unmarshal(data, &prs); err != nil {
			return fmt.Errorf("failed to parse PRs: %w", err)
		}

		if len(prs) == 0 {
			break // No more PRs
		}

		// Filter to only merged PRs
		var mergedPRs []models.PRListItem
		for _, pr := range prs {
			if pr.MergedAt != "" {
				mergedPRs = append(mergedPRs, pr)
			}
		}

		if len(mergedPRs) > 0 {
			// Submit entire page as one job to consumers
			job := prJob{
				repo: repo,
				prs:  mergedPRs,
				page: page,
			}

			select {
			case t.prJobChan <- job:
			case <-ctx.Done():
				return ctx.Err()
			}

			// Update monitor
			t.monitor.addPRsProcessed(len(mergedPRs))
		}

		page++
	}

	return nil
}

// runConsumers implements the PR file processing workers
func (t *Task2PRIngestion) runConsumers(ctx context.Context) error {
	// Create worker pool (use remaining concurrency for consumers)
	workerCount := t.cfg.MaxConcurrency - t.numProducers
	if workerCount < 100 {
		workerCount = 100
	}

	fmt.Printf("[INFO] Starting %d consumers for PR file processing\n", workerCount)

	errChan := make(chan error, workerCount)

	for i := 0; i < workerCount; i++ {
		go func(workerID int) {
			for job := range t.prJobChan {
				if err := t.processPRJob(ctx, job); err != nil {
					// Don't log context canceled errors during shutdown
					if !errors.Is(err, context.Canceled) {
						fmt.Printf("[ERROR] Consumer %d failed to process PR batch for repo %s (page %d): %v\n",
							workerID, job.repo.FullName, job.page, err)
					}
				}
			}
			errChan <- nil
		}(i)
	}

	// Wait for all workers to complete
	for i := 0; i < workerCount; i++ {
		if err := <-errChan; err != nil {
			return err
		}
	}

	return nil
}

// processPRJob fetches file metadata and applies filtering for a batch of PRs
func (t *Task2PRIngestion) processPRJob(ctx context.Context, job prJob) error {
	// Define allowed extensions
	allowedExtensions := map[string]bool{
		".py":       true,
		".pyi":      true,
		".pyx":      true,
		".pyw":      true,
		".md":       true,
		".markdown": true,
		".mdown":    true,
		".mkd":      true,
		".mkdn":     true,
		".rst":      true,
		".rest":     true,
	}

	// Process each PR in the batch
	for _, pr := range job.prs {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}

		// Fetch PR files (only first page as per spec)
		endpoint := fmt.Sprintf("/repos/%s/pulls/%d/files?per_page=100", job.repo.FullName, pr.Number)

		data, err := t.client.Get(ctx, endpoint)
		if err != nil {
			// Silently skip PRs with 422 errors (repository missing relevant data)
			if errors.Is(err, client.ErrUnprocessableEntity) {
				continue
			}
			// Log other errors but continue processing remaining PRs
			if !errors.Is(err, context.Canceled) {
				fmt.Printf("[WARN] Failed to fetch files for PR %s#%d: %v\n",
					job.repo.FullName, pr.Number, err)
			}
			continue
		}

		var files []models.PRFile
		if err := json.Unmarshal(data, &files); err != nil {
			fmt.Printf("[WARN] Failed to parse files for PR %s#%d: %v\n",
				job.repo.FullName, pr.Number, err)
			continue
		}

		// Filter files by allowed extensions
		var filteredFiles []models.PRFile
		for _, file := range files {
			ext := filepath.Ext(file.Filename)
			if allowedExtensions[ext] {
				filteredFiles = append(filteredFiles, file)
			}
		}

		// Calculate metrics on filtered files
		pyFilesCount := 0
		totalFiles := len(filteredFiles)
		extensionsMap := make(map[string]bool)
		totalAdditions := 0
		totalDeletions := 0
		totalChanges := 0

		for _, file := range filteredFiles {
			// Count Python files
			if strings.HasSuffix(file.Filename, ".py") {
				pyFilesCount++
			}

			// Collect extensions
			ext := filepath.Ext(file.Filename)
			if ext != "" {
				extensionsMap[ext] = true
			}

			// Aggregate diff stats
			totalAdditions += file.Additions
			totalDeletions += file.Deletions
			totalChanges += file.Changes
		}

		// Apply filtering rules
		if pyFilesCount > t.cfg.MaxPyFiles {
			continue // Discard
		}

		if pyFilesCount < t.cfg.MinPyFiles {
			continue // Discard
		}

		if totalFiles > t.cfg.MaxTotalFiles {
			continue // Discard
		}

		// Convert extensions map to slice
		var extensions []string
		for ext := range extensionsMap {
			extensions = append(extensions, ext)
		}

		// Parse merged_at timestamp
		mergedAt, err := time.Parse(time.RFC3339, pr.MergedAt)
		if err != nil {
			fmt.Printf("[WARN] Failed to parse merged_at for PR %s#%d: %v\n",
				job.repo.FullName, pr.Number, err)
			continue
		}

		// Create PR metadata record
		prMetadata := models.PRMetadata{
			RepoID:         job.repo.RepoID,
			RepoName:       job.repo.FullName,
			PRNumber:       pr.Number,
			Title:          pr.Title,
			Body:           pr.Body,
			Author:         pr.User.Login,
			AuthorType:     pr.User.Type,
			MergedAt:       mergedAt.UnixMilli(),
			TotalFiles:     int32(totalFiles),
			PyFilesCount:   int32(pyFilesCount),
			Extensions:     extensions,
			TotalAdditions: int32(totalAdditions),
			TotalDeletions: int32(totalDeletions),
			TotalChanges:   int32(totalChanges),
		}

		// Write to output
		if err := t.prWriter.Write(prMetadata); err != nil {
			fmt.Printf("[ERROR] Failed to write PR metadata for %s#%d: %v\n",
				job.repo.FullName, pr.Number, err)
			continue
		}

		// Update filtered count
		t.monitor.addPRsFiltered(1)
	}

	return nil
}

// loadFilteredRepos reads all filtered repositories from Task 1 output
func (t *Task2PRIngestion) loadFilteredRepos() ([]models.PublicRepo, error) {
	var allRepos []models.PublicRepo

	// Find all parquet files in filtered repos directory
	pattern := filepath.Join(t.cfg.FilteredReposDir, "*.parquet")
	files, err := filepath.Glob(pattern)
	if err != nil {
		return nil, fmt.Errorf("failed to glob parquet files: %w", err)
	}

	if len(files) == 0 {
		return nil, fmt.Errorf("no parquet files found in %s", t.cfg.FilteredReposDir)
	}

	// Read each parquet file
	for _, file := range files {
		repos, err := t.readParquetFile(file)
		if err != nil {
			fmt.Printf("[WARN] Failed to read %s: %v\n", file, err)
			continue
		}
		allRepos = append(allRepos, repos...)
	}

	return allRepos, nil
}

// readParquetFile reads a single parquet file
func (t *Task2PRIngestion) readParquetFile(filename string) ([]models.PublicRepo, error) {
	fr, err := local.NewLocalFileReader(filename)
	if err != nil {
		return nil, fmt.Errorf("failed to open file: %w", err)
	}
	defer fr.Close()

	pr, err := reader.NewParquetReader(fr, new(models.PublicRepo), 4)
	if err != nil {
		return nil, fmt.Errorf("failed to create parquet reader: %w", err)
	}
	defer pr.ReadStop()

	num := int(pr.GetNumRows())
	repos := make([]models.PublicRepo, num)

	if err := pr.Read(&repos); err != nil {
		return nil, fmt.Errorf("failed to read rows: %w", err)
	}

	return repos, nil
}

// Close closes all writers and resources
func (t *Task2PRIngestion) Close() error {
	if err := t.prWriter.Close(); err != nil {
		return fmt.Errorf("failed to close PR writer: %w", err)
	}
	return nil
}
