package tasks

import (
	"context"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"log"
	"net/url"
	"strings"
	"sync"
	"time"

	"github.com/reposurvey/pipeline/client"
	"github.com/reposurvey/pipeline/config"
	"github.com/reposurvey/pipeline/models"
	"github.com/reposurvey/pipeline/parquet"
)

// enrichPRJob represents a batch of PRs to enrich
type enrichPRJob struct {
	prs []models.PRMetadata
}

// enrichThroughputMonitor tracks and reports enrichment pipeline throughput
type enrichThroughputMonitor struct {
	mu           sync.Mutex
	prsProcessed int64
	prsEnriched  int64
	startTime    time.Time
	client       *client.GithubClient
	lastAPICount uint64
	writer       *parquet.ParallelBatchWriter[models.EnrichedPRData]
}

func newEnrichThroughputMonitor(client *client.GithubClient, writer *parquet.ParallelBatchWriter[models.EnrichedPRData]) *enrichThroughputMonitor {
	return &enrichThroughputMonitor{
		startTime:    time.Now(),
		client:       client,
		lastAPICount: 0,
		writer:       writer,
	}
}

func (m *enrichThroughputMonitor) addPRsProcessed(count int) {
	m.mu.Lock()
	defer m.mu.Unlock()
	m.prsProcessed += int64(count)
}

func (m *enrichThroughputMonitor) addPRsEnriched(count int) {
	m.mu.Lock()
	defer m.mu.Unlock()
	m.prsEnriched += int64(count)
}

func (m *enrichThroughputMonitor) report() {
	m.mu.Lock()
	defer m.mu.Unlock()

	elapsed := time.Since(m.startTime).Seconds()
	processRate := float64(m.prsProcessed) / elapsed
	enrichRate := float64(m.prsEnriched) / elapsed

	// Get current API request count
	currentAPICount := m.client.GetRequestCount()
	apiRequestsSinceLastReport := currentAPICount - m.lastAPICount
	apiRate := float64(currentAPICount) / elapsed

	// Get data written size
	totalBytes := m.writer.GetTotalSize()
	totalGB := float64(totalBytes) / (1024 * 1024 * 1024)
	mbps := float64(totalBytes) / (1024 * 1024) / elapsed

	fmt.Printf("[THROUGHPUT] PRs Processed: %d (%.1f/s) | PRs Enriched: %d (%.1f/s) | API Requests: %d (%.1f/s, +%d) | Data Written: %.2f GB (%.2f MB/s) | Elapsed: %.1fs\n",
		m.prsProcessed, processRate, m.prsEnriched, enrichRate,
		currentAPICount, apiRate, apiRequestsSinceLastReport, totalGB, mbps, elapsed)

	// Update last API count for next report
	m.lastAPICount = currentAPICount
}

// Task3EnrichPR implements PR enrichment
type Task3EnrichPR struct {
	cfg    *config.Config
	client *client.GithubClient
	writer *parquet.ParallelBatchWriter[models.EnrichedPRData]

	// Channel for piping jobs from reader to workers
	jobChan chan enrichPRJob

	// Throughput monitoring
	monitor *enrichThroughputMonitor

	// In-memory map of repo IDs to descriptions
	repoDescMap map[int64]string
}

// NewTask3EnrichPR creates a new PR enrichment task
func NewTask3EnrichPR(cfg *config.Config, client *client.GithubClient) (*Task3EnrichPR, error) {
	// Create enriched PR writer
	// Use 32 workers for parallel writing as per benchmark results
	writer, err := parquet.NewParallelBatchWriter[models.EnrichedPRData](
		cfg.EnrichedPRsDir,
		cfg.PRBatchSize,
		cfg.MaxFileSize,
		cfg.FlushInterval,
		32, // 32 concurrent writers
	)
	if err != nil {
		return nil, fmt.Errorf("failed to create enriched PR writer: %w", err)
	}

	// Load repository descriptions from parquet file
	fmt.Println("[INFO] Loading repository descriptions from data/filtered_repos/part-0001.parquet")
	repoDescMap, err := loadRepoDescriptions(cfg.FilteredReposDir)
	if err != nil {
		return nil, fmt.Errorf("failed to load repo descriptions: %w", err)
	}
	fmt.Printf("[INFO] Loaded %d repository descriptions\n", len(repoDescMap))

	return &Task3EnrichPR{
		cfg:         cfg,
		client:      client,
		writer:      writer,
		jobChan:     make(chan enrichPRJob, cfg.MaxConcurrency),
		monitor:     newEnrichThroughputMonitor(client, writer),
		repoDescMap: repoDescMap,
	}, nil
}

// loadRepoDescriptions loads repo IDs and descriptions from parquet files
func loadRepoDescriptions(filteredReposDir string) (map[int64]string, error) {
	// Read all repos from the filtered repos directory
	repos, err := parquet.ReadAll[models.PublicRepo](filteredReposDir)
	if err != nil {
		return nil, fmt.Errorf("failed to read repos: %w", err)
	}

	// Build map of repo ID -> description
	repoDescMap := make(map[int64]string, len(repos))
	for _, repo := range repos {
		repoDescMap[repo.RepoID] = repo.Description
	}

	return repoDescMap, nil
}

// Run starts the PR enrichment process
func (t *Task3EnrichPR) Run(ctx context.Context) error {
	fmt.Println("[INFO] Starting Task 3: PR Enrichment")

	// Create batch reader for PR metadata
	// Use 32 workers for parallel reading, 2 parallel routines per reader
	reader, err := parquet.NewParallelBatchReader[models.PRMetadata](t.cfg.RawPRsDir, t.cfg.PRBatchSize, 32, 2)
	if err != nil {
		return fmt.Errorf("failed to create PR reader: %w", err)
	}

	fmt.Printf("[INFO] Found %d PR parquet files to process\n", reader.GetFileCount())

	// Start throughput monitor
	monitorCtx, monitorCancel := context.WithCancel(ctx)
	defer monitorCancel()
	go t.runThroughputMonitor(monitorCtx)

	// Start worker pool
	workersDone := make(chan error, 1)
	go func() {
		if err := t.runWorkers(ctx); err != nil {
			workersDone <- fmt.Errorf("workers error: %w", err)
		}
		close(workersDone)
	}()

	// Read and dispatch PR batches
	for {
		select {
		case <-ctx.Done():
			close(t.jobChan)
			return ctx.Err()
		default:
		}

		batch, hasMore, err := reader.ReadBatch()
		if err != nil {
			close(t.jobChan)
			return fmt.Errorf("failed to read PR batch: %w", err)
		}

		if len(batch) > 0 {
			job := enrichPRJob{
				prs: batch,
			}

			select {
			case t.jobChan <- job:
			case <-ctx.Done():
				close(t.jobChan)
				return ctx.Err()
			}

			t.monitor.addPRsProcessed(len(batch))
		}

		if !hasMore {
			break
		}

	}

	// Signal workers that no more jobs will be sent
	close(t.jobChan)

	// Wait for workers to complete
	if err := <-workersDone; err != nil {
		return err
	}

	// Final report
	t.monitor.report()
	fmt.Println("[INFO] Task 3 completed successfully")
	return nil
}

// runThroughputMonitor reports throughput statistics every 30 seconds
func (t *Task3EnrichPR) runThroughputMonitor(ctx context.Context) {
	ticker := time.NewTicker(30 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-ctx.Done():
			t.monitor.report()
			return
		case <-ticker.C:
			t.monitor.report()
		}
	}
}

// runWorkers implements the worker pool for enriching PRs
func (t *Task3EnrichPR) runWorkers(ctx context.Context) error {
	workerCount := t.cfg.MaxConcurrency
	fmt.Printf("[INFO] Starting %d workers for PR enrichment\n", workerCount)

	errChan := make(chan error, workerCount)

	for i := 0; i < workerCount; i++ {
		go func(workerID int) {
			for job := range t.jobChan {
				if err := t.processJobBatch(ctx, job); err != nil {
					if !errors.Is(err, context.Canceled) {
						fmt.Printf("[ERROR] Worker %d failed to process batch: %v\n", workerID, err)
					}
				}
			}
			errChan <- nil
		}(i)
	}

	// Wait for all workers to complete
	for i := 0; i < workerCount; i++ {
		if err := <-errChan; err != nil {
			return err
		}
	}

	return nil
}

// processJobBatch enriches a batch of PRs
func (t *Task3EnrichPR) processJobBatch(ctx context.Context, job enrichPRJob) error {
	for _, pr := range job.prs {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}

		// Enrich this PR
		enrichedData, err := t.enrichPR(ctx, pr)
		if err != nil {
			if !errors.Is(err, context.Canceled) {
				fmt.Printf("[WARN] Failed to enrich PR %s#%d: %v\n", pr.RepoName, pr.PRNumber, err)
			}
			continue
		}

		// Write enriched data
		if err := t.writer.Write(*enrichedData); err != nil {
			fmt.Printf("[ERROR] Failed to write enriched data for PR %s#%d: %v\n",
				pr.RepoName, pr.PRNumber, err)
			continue
		}

		t.monitor.addPRsEnriched(1)
	}

	return nil
}

// enrichPR enriches a single PR with all required data
func (t *Task3EnrichPR) enrichPR(ctx context.Context, pr models.PRMetadata) (*models.EnrichedPRData, error) {
	// Parse owner and repo from full name
	parts := strings.Split(pr.RepoName, "/")
	if len(parts) != 2 {
		return nil, fmt.Errorf("invalid repo name format: %s", pr.RepoName)
	}
	owner, repo := parts[0], parts[1]

	// Fetch PR details to get base SHA
	prDetails, err := t.fetchPRDetails(ctx, owner, repo, pr.PRNumber)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch PR details: %w", err)
	}
	baseCommitSHA := prDetails.Base.SHA

	// Fetch changed Python files from PR
	changedPyFiles, pyBaseFiles, err := t.fetchChangedPyFiles(ctx, owner, repo, pr.PRNumber)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch changed files: %w", err)
	}

	// Fetch PR details to get base SHA
	commits, firstCommitParentSHA, err := t.fetchCommitsWithDiffs(ctx, owner, repo, pr.PRNumber, changedPyFiles)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch commits: %w", err)
	}
	if firstCommitParentSHA == "" {
		firstCommitParentSHA = baseCommitSHA
	}

	// Fetch file contents at base SHA (only for files that existed before)
	relevantFiles, err := t.fetchFileContents(ctx, owner, repo, firstCommitParentSHA, pyBaseFiles)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch file contents: %w", err)
	}

	// Fetch file tree (Python and documentation files) at baseCommitSHA. First commit parent SHA may be from other repos
	fileTree, err := t.fetchFileTree(ctx, owner, repo, baseCommitSHA)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch file tree: %w", err)
	}

	// Fetch related issue (first closing issue only) and total count
	var relatedIssue *models.RelatedIssue
	var relatedIssueCount int32
	issue, count, err := t.fetchRelatedIssue(ctx, owner, repo, pr.PRNumber)
	if err == nil && issue != nil {
		relatedIssue = issue
		relatedIssueCount = count
	} else if err != nil {
		return nil, fmt.Errorf("failed to fetch related issue: %w", err)
	}

	// Fetch repo description from in-memory map
	repoDesc := t.repoDescMap[pr.RepoID]

	return &models.EnrichedPRData{
		PRID:                 pr.PRNumber,
		RepoID:               pr.RepoID,
		RepoName:             pr.RepoName,
		RepoDesc:             repoDesc,
		Title:                pr.Title,
		Body:                 pr.Body,
		FirstCommitParentSHA: firstCommitParentSHA,
		Issue:                relatedIssue,
		RelatedIssueCount:    relatedIssueCount,
		ChangedPyFiles:       changedPyFiles,
		RelevantFiles:        relevantFiles,
		Commits:              commits,
		FileTree:             fileTree,
	}, nil
}

// fetchPRDetails fetches PR details including base SHA
func (t *Task3EnrichPR) fetchPRDetails(ctx context.Context, owner, repo string, prNumber int64) (*models.PRDetailResponse, error) {
	endpoint := fmt.Sprintf("/repos/%s/%s/pulls/%d", owner, repo, prNumber)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return nil, err
	}

	var prDetail models.PRDetailResponse
	if err := json.Unmarshal(data, &prDetail); err != nil {
		return nil, fmt.Errorf("failed to parse PR details: %w", err)
	}

	return &prDetail, nil
}

// fetchRelatedIssue fetches the first closing issue for a PR and returns the total count
func (t *Task3EnrichPR) fetchRelatedIssue(ctx context.Context, owner, repo string, prNumber int64) (*models.RelatedIssue, int32, error) {
	// Fetch closing issues via GraphQL endpoint
	endpoint := fmt.Sprintf("/gql/pull_closing_issues/%s/%s/%d", owner, repo, prNumber)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return nil, 0, err
	}

	var gqlResp models.GraphQLClosingIssuesResponse
	if err := json.Unmarshal(data, &gqlResp); err != nil {
		return nil, 0, fmt.Errorf("failed to parse GraphQL response: %w", err)
	}

	// Check for GraphQL errors
	if len(gqlResp.Errors) > 0 {
		return nil, 0, fmt.Errorf("GraphQL error: %s", gqlResp.Errors[0].Message)
	}

	// Get total count and first closing issue
	nodes := gqlResp.Data.Repository.PullRequest.ClosingIssuesReferences.Nodes
	totalCount := int32(len(nodes))

	if totalCount == 0 {
		return nil, 0, nil // No closing issues
	}

	issueRef := nodes[0]
	issueOwner := issueRef.Repository.Owner.Login
	issueRepo := issueRef.Repository.Name
	issueNumber := issueRef.Number
	issueRepoID := issueRef.Repository.DatabaseId

	// Fetch issue details
	issueEndpoint := fmt.Sprintf("/repos/%s/%s/issues/%d", issueOwner, issueRepo, issueNumber)
	issueData, err := t.client.Get(ctx, issueEndpoint)
	if err != nil {
		// Silently handle 410 (issues disabled) - return nil to indicate no issue
		if errors.Is(err, client.ErrGone) {
			return nil, 0, nil
		}
		return nil, 0, fmt.Errorf("failed to fetch issue details: %w", err)
	}

	var issueResp models.IssueResponse
	if err := json.Unmarshal(issueData, &issueResp); err != nil {
		return nil, 0, fmt.Errorf("failed to parse issue: %w", err)
	}

	// Parse created_at timestamp
	createdAt, err := time.Parse(time.RFC3339, issueResp.CreatedAt)
	if err != nil {
		return nil, 0, fmt.Errorf("failed to parse issue created_at: %w", err)
	}

	// Fetch comments (limit to first 100)
	comments, err := t.fetchIssueComments(ctx, issueOwner, issueRepo, issueNumber)
	if err != nil {
		return nil, 0, fmt.Errorf("failed to fetch issue comments: %w", err)
	}

	return &models.RelatedIssue{
		IssueID:   issueNumber,
		RepoID:    issueRepoID,
		Title:     issueResp.Title,
		Body:      issueResp.Body,
		Author:    issueResp.User.Login,
		CreatedAt: createdAt.UnixMilli(),
		Comments:  comments,
	}, totalCount, nil
}

// fetchIssueComments fetches up to 100 comments for an issue
func (t *Task3EnrichPR) fetchIssueComments(ctx context.Context, owner, repo string, issueNumber int64) ([]models.IssueComment, error) {
	endpoint := fmt.Sprintf("/repos/%s/%s/issues/%d/comments?per_page=100", owner, repo, issueNumber)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return nil, err
	}

	var commentsResp []models.CommentResponse
	if err := json.Unmarshal(data, &commentsResp); err != nil {
		return nil, fmt.Errorf("failed to parse comments: %w", err)
	}

	var comments []models.IssueComment
	for _, c := range commentsResp {
		createdAt, err := time.Parse(time.RFC3339, c.CreatedAt)
		if err != nil {
			continue // Skip comments with invalid timestamps
		}

		comments = append(comments, models.IssueComment{
			Author:    c.User.Login,
			CreatedAt: createdAt.UnixMilli(),
			Body:      c.Body,
		})
	}

	return comments, nil
}

// fetchChangedPyFiles fetches the list of changed Python files for a PR along with their status
func (t *Task3EnrichPR) fetchChangedPyFiles(ctx context.Context, owner, repo string, prNumber int64) ([]string, []string, error) {
	endpoint := fmt.Sprintf("/repos/%s/%s/pulls/%d/files?per_page=100", owner, repo, prNumber)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return nil, nil, err
	}

	var files []models.PRFile
	if err := json.Unmarshal(data, &files); err != nil {
		return nil, nil, fmt.Errorf("failed to parse files: %w", err)
	}

	var pyFiles []string
	var pyBaseFiles []string
	for _, file := range files {
		if isPythonFile(file.Filename) {
			pyFiles = append(pyFiles, file.Filename)
			status := file.Status
			switch status {
			case "removed", "modified":
				pyBaseFiles = append(pyBaseFiles, file.Filename)
			case "renamed":
				pyBaseFiles = append(pyBaseFiles, file.PreviousFilename)
			case "added":
				// Do not add to base files
			default:
				log.Printf("[WARN] Unknown file status %s for file %s in PR %s#%d\n", status, file.Filename, owner+"/"+repo, prNumber)
			}
		}
	}

	return pyFiles, pyBaseFiles, nil
}

// fetchFileContents fetches the content of files at a specific SHA
func (t *Task3EnrichPR) fetchFileContents(ctx context.Context, owner, repo, sha string, filePaths []string) ([]models.FileContent, error) {
	var contents []models.FileContent

	for _, path := range filePaths {
		select {
		case <-ctx.Done():
			return nil, ctx.Err()
		default:
		}

		content, err := t.fetchFileContent(ctx, owner, repo, path, sha)
		if err != nil {
			// Skip files that don't exist at base SHA
			// When a file is added in PR and is then modified by another commit, the PR file API will list it as "modified"
			if errors.Is(err, client.ErrNotFound) {
				continue
			}
			return nil, fmt.Errorf("failed to fetch content for file %s: %w", path, err)
		}

		contents = append(contents, models.FileContent{
			Path:    path,
			Content: content,
		})
	}

	return contents, nil
}

// fetchFileContent fetches a single file's content at a specific ref
func (t *Task3EnrichPR) fetchFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
	// URL-encode the file path to handle special characters like #, spaces, etc.
	encodedPath := url.PathEscape(path)
	endpoint := fmt.Sprintf("/repos/%s/%s/contents/%s?ref=%s", owner, repo, encodedPath, ref)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return "", err
	}

	var fileResp models.FileContentResponse
	if err := json.Unmarshal(data, &fileResp); err != nil {
		return "", fmt.Errorf("failed to parse file content response: %w", err)
	}

	// Decode base64 content
	if fileResp.Encoding == "base64" {
		decoded, err := base64.StdEncoding.DecodeString(fileResp.Content)
		if err != nil {
			return "", fmt.Errorf("failed to decode base64 content: %w", err)
		}
		return string(decoded), nil
	}

	return fileResp.Content, nil
}

// fetchCommitsWithDiffs fetches commits for a PR with filtered diffs
func (t *Task3EnrichPR) fetchCommitsWithDiffs(ctx context.Context, owner, repo string, prNumber int64, pyFiles []string) ([]models.CommitInfo, string, error) {
	// Fetch commit list
	endpoint := fmt.Sprintf("/repos/%s/%s/pulls/%d/commits", owner, repo, prNumber)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return nil, "", err
	}

	var commitsResp []models.CommitResponse
	if err := json.Unmarshal(data, &commitsResp); err != nil {
		return nil, "", fmt.Errorf("failed to parse commits: %w", err)
	}

	var commits []models.CommitInfo
	isFirst := true
	firstCommitParentSHA := ""

	// Fetch details for each commit
	for _, commitResp := range commitsResp {
		select {
		case <-ctx.Done():
			return nil, "", ctx.Err()
		default:
		}

		commitInfo, err := t.fetchCommitDetails(ctx, owner, repo, commitResp.SHA, pyFiles)
		if err != nil {
			return nil, "", fmt.Errorf("failed to fetch commit details for %s: %w", commitResp.SHA, err)
		}

		if isFirst {
			isFirst = false
			if len(commitInfo.Parents) == 1 {
				firstCommitParentSHA = commitInfo.Parents[0]
			}
		}

		// Only include commits that have relevant diffs
		if len(commitInfo.Diffs) > 0 {
			commits = append(commits, *commitInfo)
		}
	}

	return commits, firstCommitParentSHA, nil
}

// fetchCommitDetails fetches detailed commit info with filtered diffs
func (t *Task3EnrichPR) fetchCommitDetails(ctx context.Context, owner, repo, sha string, pyFiles []string) (*models.CommitInfo, error) {
	endpoint := fmt.Sprintf("/repos/%s/%s/commits/%s", owner, repo, sha)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return nil, err
	}

	var commitDetail models.CommitDetailResponse
	if err := json.Unmarshal(data, &commitDetail); err != nil {
		return nil, fmt.Errorf("failed to parse commit details: %w", err)
	}

	// Parse timestamp
	timestamp, err := time.Parse(time.RFC3339, commitDetail.Commit.Author.Date)
	if err != nil {
		return nil, fmt.Errorf("failed to parse commit timestamp: %w", err)
	}

	// Filter diffs to only Python files
	var diffs []models.DiffPatch
	pyFilesSet := make(map[string]bool)
	for _, path := range pyFiles {
		pyFilesSet[path] = true
	}

	for _, file := range commitDetail.Files {
		if pyFilesSet[file.Filename] {
			diffs = append(diffs, models.DiffPatch{
				Path:  file.Filename,
				Patch: file.Patch,
			})
		}
	}

	parents := make([]string, len(commitDetail.Parents))
	for i, parent := range commitDetail.Parents {
		parents[i] = parent.SHA
	}

	return &models.CommitInfo{
		SHA:       sha,
		Parents:   parents,
		Author:    commitDetail.Commit.Author.Name,
		Timestamp: timestamp.UnixMilli(),
		Diffs:     diffs,
		Message:   commitDetail.Commit.Message,
	}, nil
}

// fetchFileTree fetches the repository file tree at base SHA and returns a list of Python file paths
func (t *Task3EnrichPR) fetchFileTree(ctx context.Context, owner, repo, sha string) ([]string, error) {
	endpoint := fmt.Sprintf("/repos/%s/%s/git/trees/%s?recursive=1", owner, repo, sha)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		// If payload too large (413), fall back to top level python files and directories
		if errors.Is(err, client.ErrPayloadTooLarge) {
			return t.fetchFileTreeTopLevel(ctx, owner, repo, sha, "")
		}
		return nil, err
	}

	var treeResp models.GitTreeResponse
	if err := json.Unmarshal(data, &treeResp); err != nil {
		return nil, fmt.Errorf("failed to parse git tree response: %w", err)
	}

	// Filter for Python files
	var relevantFiles []string
	for _, entry := range treeResp.Tree {
		// Only process blob (file) entries, not tree (directory) entries
		if entry.Type != "blob" {
			continue
		}

		// Check if file is Python and not in __pycache__
		if isPythonFile(entry.Path) {
			relevantFiles = append(relevantFiles, entry.Path)
		}
	}

	return relevantFiles, nil
}

// fetchFileTreeTopLevel fetches only top-level Python files and directory names (fallback for large repos)
func (t *Task3EnrichPR) fetchFileTreeTopLevel(ctx context.Context, owner, repo, sha, prefix string) ([]string, error) {
	endpoint := fmt.Sprintf("/repos/%s/%s/git/trees/%s", owner, repo, sha)

	data, err := t.client.Get(ctx, endpoint)
	if err != nil {
		return nil, err
	}

	var treeResp models.GitTreeResponse
	if err := json.Unmarshal(data, &treeResp); err != nil {
		return nil, fmt.Errorf("failed to parse git tree response: %w", err)
	}

	// Collect top-level Python files and directory names
	var relevantFiles []string
	for _, entry := range treeResp.Tree {
		switch entry.Type {
		case "blob":
			// Check if file is Python and not in __pycache__
			if isPythonFile(entry.Path) {
				relevantFiles = append(relevantFiles, entry.Path)
			}
		case "tree":
			// Include directory names (without trailing slash)
			if entry.Path != "__pycache__" {
				relevantFiles = append(relevantFiles, entry.Path)
			}
		}
	}

	return relevantFiles, nil
}

// isPythonFile checks if a file path is a Python file
func isPythonFile(path string) bool {
	// Python file extensions
	allowedExtensions := map[string]bool{
		".py":  true,
		".pyi": true,
		".pyx": true,
		".pyw": true,
	}

	// If the file is in pycache, skip it
	if strings.Contains(path, "__pycache__") {
		return false
	}

	if dotIdx := strings.LastIndex(path, "."); dotIdx != -1 {
		ext := path[dotIdx:]
		if allowedExtensions[ext] {
			return true
		}
	}

	return false
}

// Close closes all writers and resources
func (t *Task3EnrichPR) Close() error {
	if err := t.writer.Close(); err != nil {
		return fmt.Errorf("failed to close enriched PR writer: %w", err)
	}
	return nil
}
