"""Competition Pipeline

This module orchestrates the complete pipeline for processing competitions including
brainstorming, generation, and refactoring phases with concurrent execution support.
"""

import asyncio
import json
import os
import shutil
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple

from components.brainstormer import BrainstormerAgent
from components.executor import ExecutorAgent
from components.refactor import RefactorAgent
from components.verifier import (
    verify_competition_directory,
    verify_refactored_competition_directory
)
from tools.workflowtool import WorkflowTools
from utils.utils import copy_folder


# ============================================================================
# Configuration and Data Classes
# ============================================================================

@dataclass
class PipelineConfig:
    """Pipeline configuration settings."""
    workplace_dir: Path = field(default_factory=lambda: Path("./workplace").resolve())
    count: int = 3
    max_attempts: int = 3
    max_concurrent: int = 1
    model: str = "gpt-5"
    temperature: float = 1.0
    max_completion_tokens: int = 20000
    max_agent_iterations: int = 30
    
    def __post_init__(self):
        """Ensure directories exist."""
        self.workplace_dir.mkdir(exist_ok=True)
        (self.workplace_dir / "logs").mkdir(exist_ok=True)


@dataclass
class TimingInfo:
    """Timing information for operations."""
    start_time: float = field(default_factory=time.time)
    duration_seconds: float = 0.0
    download_duration_seconds: float = 0.0
    brainstorm_duration_seconds: float = 0.0
    total_duration_seconds: float = 0.0
    attempt_durations: List[Dict] = field(default_factory=list)
    
    def finalize(self):
        """Calculate final duration."""
        self.total_duration_seconds = time.time() - self.start_time
        self.duration_seconds = self.total_duration_seconds


@dataclass
class CompetitionLog:
    """Log entry for a competition."""
    competition_name: str
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    task_id: Optional[int] = None
    success: bool = False
    brainstorm_success: bool = False
    generations: List[Dict] = field(default_factory=list)
    refactors: List[Dict] = field(default_factory=list)
    timing: TimingInfo = field(default_factory=TimingInfo)
    error: Optional[str] = None


# ============================================================================
# Helper Functions
# ============================================================================

def load_competitions(
    final_list_path: str = "./mle-list.txt"
) -> List[str]:
    """Load competition list from file."""
    path = Path(final_list_path)
    if not path.exists():
        print(f"Warning: Competition list not found at {final_list_path}")
        return []
    
    competitions = [line.strip() for line in path.read_text().splitlines() if line.strip()]
    print(f"Loaded {len(competitions)} competitions from {final_list_path}")
    return competitions


def setup_directories(competition_name: str, workplace_dir: Path) -> Dict[str, Path]:
    """Setup and return all directories for a competition."""
    short_name = competition_name.split("/")[-1]
    
    dirs = {
        'original': workplace_dir / "original" / short_name,
        'brainstorm': workplace_dir / "brainstorm" / short_name,
        'generation': workplace_dir / "generation" / short_name,
        'refact': workplace_dir / "refact" / short_name,
    }
    
    for dir_path in dirs.values():
        dir_path.mkdir(parents=True, exist_ok=True)
    
    return {k: v.resolve() for k, v in dirs.items()}


def reset_directory(target_dir: Path, source_dir: Path, additional_files: Dict[str, Path] = None):
    """Reset a directory with fresh content."""
    if target_dir.exists():
        shutil.rmtree(target_dir)
    target_dir.mkdir(parents=True, exist_ok=True)
    
    copy_folder(source_dir, target_dir)
    
    if additional_files:
        for dest_name, source_path in additional_files.items():
            if source_path.exists():
                shutil.copy2(source_path, target_dir / dest_name)


# ============================================================================
# Agent Execution
# ============================================================================

class AgentExecutor:
    """Handles agent execution with consistent error handling and logging."""
    
    def __init__(self, config: PipelineConfig):
        self.config = config
        self.workflow_tools = WorkflowTools()
    
    async def run_agent(
        self, 
        agent, 
        max_iterations: int, 
        task_label: str
    ) -> Dict[str, Any]:
        """Execute an agent task with logging."""
        print(f"{task_label} starting...")
        start_time = time.time()
        
        try:
            response = await agent.execute_task(max_iterations=max_iterations)
            duration = time.time() - start_time
            
            if isinstance(response, dict):
                cost = response.get("total_cost", 0)
                print(f"{task_label} completed in {duration:.2f}s (cost: ${cost:.6f})")
            else:
                print(f"{task_label} completed in {duration:.2f}s")
            
            return response
        except Exception as e:
            print(f"{task_label} failed: {e}")
            raise
    
    def create_agent(self, agent_class, **kwargs):
        """Create an agent with default configuration."""
        chat_config = {
            "temperature": self.config.temperature,
            "max_completion_tokens": self.config.max_completion_tokens
        }
        return agent_class(
            model=self.config.model,
            chat_config=chat_config,
            **kwargs
        )


# ============================================================================
# Pipeline Phases
# ============================================================================

class PipelinePhases:
    """Manages individual pipeline phases."""
    
    def __init__(self, executor: AgentExecutor, config: PipelineConfig):
        self.executor = executor
        self.config = config
    
    async def brainstorm_phase(
        self, 
        dirs: Dict[str, Path], 
        competition_name: str,
        task_id: int
    ) -> Tuple[bool, int, float]:
        """Execute brainstorming phase."""
        print(f"[Task {task_id}] Running brainstormer...")
        start_time = time.time()
        
        # Copy original to brainstorm directory
        copy_folder(dirs['original'], dirs['brainstorm'])
        
        # Run brainstormer
        agent = self.executor.create_agent(
            BrainstormerAgent,
            competition_name=competition_name.split("/")[-1],
            working_directory=str(dirs['brainstorm']),
            count=self.config.count
        )
        
        await self.executor.run_agent(
            agent, 
            self.config.max_agent_iterations,
            f"[Task {task_id}] Brainstormer"
        )
        
        # Count generated files
        actual_count = sum(
            1 for i in range(1, self.config.count + 1)
            if (dirs['brainstorm'] / f"brainstorming_{i}.md").exists()
        )
        
        duration = time.time() - start_time
        print(f"[Task {task_id}] Generated {actual_count}/{self.config.count} brainstorm files in {duration:.2f}s")
        
        return actual_count > 0, actual_count, duration
    
    async def generation_phase(
        self,
        generation_id: int,
        dirs: Dict[str, Path],
        competition_name: str,
        task_id: int
    ) -> Dict[str, Any]:
        """Execute generation phase with retry logic."""
        generation_dir = dirs['generation'] / f"generation{generation_id}"
        log = {
            "generation_id": generation_id,
            "attempts": 0,
            "success": False,
            "successful_attempt": None,
            "iterations_used": None,
            "total_cost": 0.0,
            "timing": TimingInfo()
        }
        
        for attempt in range(1, self.config.max_attempts + 1):
            attempt_start = time.time()
            log["attempts"] = attempt
            
            # Prepare directory
            if attempt > 1:
                reset_directory(generation_dir, dirs['original'])
            else:
                generation_dir.mkdir(parents=True, exist_ok=True)
                copy_folder(dirs['original'], generation_dir)
            
            # Copy brainstorming file
            brainstorm_file = dirs['brainstorm'] / f"brainstorming_{generation_id}.md"
            if brainstorm_file.exists():
                shutil.copy2(brainstorm_file, generation_dir / "brainstorming.md")
            
            # Run executor
            agent = self.executor.create_agent(
                ExecutorAgent,
                competition_name=competition_name.split("/")[-1],
                working_directory=str(generation_dir)
            )
            
            response = await self.executor.run_agent(
                agent,
                self.config.max_agent_iterations,
                f"[Task {task_id}] Generation{generation_id}"
            )
            
            # Verify results
            verification = verify_competition_directory(str(generation_dir))
            success = "Verification failed" not in verification
            
            # Record attempt
            log["timing"].attempt_durations.append({
                "attempt": attempt,
                "duration_seconds": time.time() - attempt_start,
                "success": success
            })
            
            if success:
                log.update({
                    "success": True,
                    "successful_attempt": attempt,
                    "iterations_used": response.get("iterations_used"),
                    "total_cost": response.get("total_cost", 0.0)
                })
                break
        
        log["timing"].finalize()
        return log
    
    async def refactor_phase(
        self,
        refactor_id: int,
        dirs: Dict[str, Path],
        generation_dir: Path,
        task_id: int
    ) -> Dict[str, Any]:
        """Execute refactor phase with retry logic."""
        refact_dir = dirs['refact'] / f"refact{refactor_id}"
        log = {
            "refactor_id": refactor_id,
            "attempts": 0,
            "success": False,
            "successful_attempt": None,
            "iterations_used": None,
            "total_cost": 0.0,
            "timing": TimingInfo()
        }
        
        for attempt in range(1, self.config.max_attempts + 1):
            attempt_start = time.time()
            log["attempts"] = attempt
            
            # Prepare directory
            self._prepare_refactor_directory(refact_dir, dirs['original'], generation_dir, attempt)
            
            # Run refactor
            agent = self.executor.create_agent(
                RefactorAgent,
                working_directory=str(refact_dir)
            )
            
            response = await self.executor.run_agent(
                agent,
                self.config.max_agent_iterations,
                f"[Task {task_id}] Refact{refactor_id}"
            )
            
            # Verify results
            verification = verify_refactored_competition_directory(str(refact_dir))
            success = "Verification failed" not in verification
            
            # Record attempt
            log["timing"].attempt_durations.append({
                "attempt": attempt,
                "duration_seconds": time.time() - attempt_start,
                "success": success
            })
            
            if success:
                log.update({
                    "success": True,
                    "successful_attempt": attempt,
                    "iterations_used": response.get("iterations_used"),
                    "total_cost": response.get("total_cost", 0.0)
                })
                break
        
        log["timing"].finalize()
        return log
    
    def _prepare_refactor_directory(
        self, 
        refact_dir: Path, 
        original_dir: Path, 
        generation_dir: Path,
        attempt: int
    ):
        """Prepare refactor directory for processing."""
        if attempt > 1 and refact_dir.exists():
            shutil.rmtree(refact_dir)
        
        refact_dir.mkdir(parents=True, exist_ok=True)
        
        # Copy required files
        for file_name in ["description.txt", "metric.py", "prepare.py", "test.py"]:
            source = generation_dir / file_name
            if source.exists():
                shutil.copy2(source, refact_dir / file_name)
        
        # Setup raw data
        raw_dir = refact_dir / "raw"
        raw_dir.mkdir(parents=True, exist_ok=True)
        copy_folder(original_dir, raw_dir)
        
        # Copy samples if exists
        samples_source = Path("./samples").resolve()
        if samples_source.exists():
            samples_dest = refact_dir / "samples"
            if samples_dest.exists():
                shutil.rmtree(samples_dest)
            shutil.copytree(samples_source, samples_dest)


# ============================================================================
# Main Pipeline
# ============================================================================

class CompetitionPipeline:
    """Main pipeline orchestrator."""
    
    def __init__(self, config: PipelineConfig = None):
        self.config = config or PipelineConfig()
        self.executor = AgentExecutor(self.config)
        self.phases = PipelinePhases(self.executor, self.config)
    
    async def process_competition(
        self, 
        competition_name: str, 
        task_id: int = None
    ) -> CompetitionLog:
        """Process a single competition through all phases."""
        short_name = competition_name.split("/")[-1]
        log = CompetitionLog(competition_name=short_name, task_id=task_id)
        
        print(f"\n{'='*60}")
        print(f"[Task {task_id}] Processing: {competition_name}")
        print(f"{'='*60}")
        
        try:
            # Setup directories
            dirs = setup_directories(competition_name, self.config.workplace_dir)
            
            # Download competition data
            await self._download_competition(competition_name, dirs['original'], log)
            
            # Brainstorm phase
            success, count, duration = await self.phases.brainstorm_phase(
                dirs, competition_name, task_id
            )
            log.brainstorm_success = success
            log.timing.brainstorm_duration_seconds = duration
            
            if not success:
                print(f"[Task {task_id}] No brainstorm files generated, skipping...")
                return log
            
            # Generation and refactor phases (parallel)
            await self._process_generation_refactor(
                dirs, competition_name, count, task_id, log
            )
            
            log.success = log.brainstorm_success
            
        except Exception as e:
            print(f"[Task {task_id}] Error processing {competition_name}: {e}")
            log.error = str(e)
        
        finally:
            log.timing.finalize()
            self._save_log(log, short_name)
            self._print_summary(log, task_id)
        
        return log
    
    async def _download_competition(
        self, 
        competition_name: str, 
        target_dir: Path,
        log: CompetitionLog
    ):
        """Download competition data."""
        print(f"[Task {log.task_id}] Downloading competition data...")
        start_time = time.time()
        
        self.executor.workflow_tools.download_competition_data(
            competition_name, 
            str(target_dir)
        )
        
        log.timing.download_duration_seconds = time.time() - start_time
        print(f"[Task {log.task_id}] Download completed in {log.timing.download_duration_seconds:.2f}s")
    
    async def _process_generation_refactor(
        self, 
        dirs: Dict[str, Path],
        competition_name: str,
        count: int,
        task_id: int,
        log: CompetitionLog
    ):
        """Process all generation-refactor pipelines in parallel."""
        tasks = []
        
        for i in range(1, count + 1):
            brainstorm_file = dirs['brainstorm'] / f"brainstorming_{i}.md"
            if brainstorm_file.exists():
                task = self._run_single_pipeline(i, dirs, competition_name, task_id)
                tasks.append(task)
        
        results = await asyncio.gather(*tasks)
        
        for gen_log, ref_log in results:
            log.generations.append(gen_log)
            if ref_log:
                log.refactors.append(ref_log)
    
    async def _run_single_pipeline(
        self,
        index: int,
        dirs: Dict[str, Path],
        competition_name: str,
        task_id: int
    ) -> Tuple[Dict, Optional[Dict]]:
        """Run a single generation-refactor pipeline."""
        # Generation phase
        gen_log = await self.phases.generation_phase(
            index, dirs, competition_name, task_id
        )
        
        if not gen_log["success"]:
            print(f"[Task {task_id}] Generation {index} failed, skipping refactor")
            return gen_log, None
        
        # Refactor phase
        generation_dir = dirs['generation'] / f"generation{index}"
        ref_log = await self.phases.refactor_phase(
            index, dirs, generation_dir, task_id
        )
        
        return gen_log, ref_log
    
    def _save_log(self, log: CompetitionLog, short_name: str):
        """Save competition log to file."""
        log_file = self.config.workplace_dir / "logs" / f"{short_name}.json"
        
        # Convert to dict, handling dataclass fields
        log_dict = asdict(log)
        log_dict['timing'] = asdict(log.timing)
        
        with open(log_file, 'w', encoding='utf-8') as f:
            json.dump(log_dict, f, indent=2, ensure_ascii=False)
        
        print(f"\nLog saved to: {log_file}")
    
    def _print_summary(self, log: CompetitionLog, task_id: int):
        """Print competition processing summary."""
        print(f"\n[Task {task_id}] Summary for {log.competition_name}:")
        print(f"- Success: {log.success}")
        print(f"- Brainstorm: {log.brainstorm_success}")
        print(f"- Generations: {len(log.generations)}")
        print(f"- Refactors: {len(log.refactors)}")
        print(f"- Total time: {log.timing.total_duration_seconds:.2f}s")


class ConcurrentProcessor:
    """Handles concurrent processing of multiple competitions."""
    
    def __init__(self, pipeline: CompetitionPipeline):
        self.pipeline = pipeline
    
    async def process_all(
        self, 
        competitions: List[str]
    ) -> List[CompetitionLog]:
        """Process all competitions with worker pool."""
        queue = asyncio.Queue()
        logs = []
        logs_lock = asyncio.Lock()
        
        # Fill queue
        for comp in competitions:
            await queue.put(comp)
        
        # Create workers
        workers = [
            self._worker(i + 1, queue, logs, logs_lock)
            for i in range(self.pipeline.config.max_concurrent)
        ]
        
        # Run workers
        await asyncio.gather(*workers)
        
        return logs
    
    async def _worker(
        self, 
        worker_id: int, 
        queue: asyncio.Queue,
        logs: List,
        logs_lock: asyncio.Lock
    ):
        """Worker coroutine for processing competitions."""
        while True:
            try:
                competition = await asyncio.wait_for(queue.get(), timeout=1.0)
            except asyncio.TimeoutError:
                print(f"[Worker {worker_id}] No more tasks, exiting...")
                break
            
            print(f"\n[Worker {worker_id}] Processing: {competition}")
            print(f"[Worker {worker_id}] Queue size: {queue.qsize()}")
            
            try:
                log = await self.pipeline.process_competition(competition, worker_id)
                
                async with logs_lock:
                    logs.append(log)
                    print(f"[Worker {worker_id}] Completed: {len(logs)} total")
                    
            except Exception as e:
                print(f"[Worker {worker_id}] Failed: {e}")
                
                # Create error log
                error_log = CompetitionLog(
                    competition_name=competition.split("/")[-1],
                    task_id=worker_id,
                    error=str(e)
                )
                
                async with logs_lock:
                    logs.append(error_log)
            
            finally:
                queue.task_done()


# ============================================================================
# Main Entry Point
# ============================================================================

async def main():
    """Main entry point for the pipeline."""
    config = PipelineConfig()
    pipeline = CompetitionPipeline(config)
    processor = ConcurrentProcessor(pipeline)
    
    # Load competitions
    competitions = load_competitions()
    if not competitions:
        print("No competitions found. Exiting...")
        return
    
    print(f"\nProcessing {len(competitions)} competitions")
    print(f"Configuration:")
    print(f"  - Workers: {config.max_concurrent}")
    print(f"  - Generations per competition: {config.count}")
    print(f"  - Max attempts: {config.max_attempts}")
    print(f"  - Workplace: {config.workplace_dir}")
    
    # Process competitions
    start_time = datetime.now()
    logs = await processor.process_all(competitions)
    end_time = datetime.now()
    
    # Generate and save summary
    summary = {
        "timestamp": datetime.now().isoformat(),
        "start_time": start_time.isoformat(),
        "end_time": end_time.isoformat(),
        "total_duration_seconds": (end_time - start_time).total_seconds(),
        "total_competitions": len(logs),
        "successful": sum(1 for log in logs if log.success),
        "config": asdict(config),
        "competitions": [asdict(log) for log in logs]
    }
    
    summary_file = config.workplace_dir / "logs" / "all_competitions_summary.json"
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)
    
    # Print final summary
    print(f"\n{'='*80}")
    print(f"Pipeline completed!")
    print(f"Summary saved to: {summary_file}")
    print(f"Total: {len(logs)} competitions in {summary['total_duration_seconds']:.1f}s")
    print(f"Success rate: {summary['successful']}/{len(logs)} ({summary['successful']/len(logs)*100:.1f}%)")
    print(f"{'='*80}")


if __name__ == "__main__":
    asyncio.run(main())