import os
import json
import logging
import shutil
from typing import List, Dict, Any, Tuple, Optional
import re
from datetime import datetime


class SimpleConsoleHandler:
    """Simple console handler that prints key progress messages only."""
    
    @staticmethod
    def info(problem_id: str, stage: str, message: str):
        """Print info to console."""
        print(f"[{problem_id}] {stage}: {message}")
    
    @staticmethod
    def error(problem_id: str, stage: str, message: str):
        """Print error to console."""
        print(f"[{problem_id}] {stage}: ❌ {message}")
    
    @staticmethod
    def success(problem_id: str, stage: str, message: str):
        """Print success message to console."""
        print(f"[{problem_id}] {stage}: ✅ {message}")


class PerformanceLogger:
    """Structured performance statistics logger."""
    
    def __init__(self, problem_id: str, log_dir: str):
        self.problem_id = problem_id
        self.log_file = os.path.join(log_dir, 'performance.jsonl')
        self.start_time = datetime.now()
        self.stages = {}
        
    def start_stage(self, stage: str):
        """Start a stage."""
        self.stages[stage] = {
            'start_time': datetime.now(),
            'status': 'running',
            'error_count': 0,
            'retry_count': 0,
            'applied_fixes': False,
            'passed_verification': False,
            'details': []
        }
    
    def stage_error(self, stage: str, error_msg: str):
        """Record a stage error."""
        if stage in self.stages:
            self.stages[stage]['error_count'] += 1
            self.stages[stage]['details'].append(f"ERROR: {error_msg}")
    
    def stage_retry(self, stage: str):
        """Record a stage retry."""
        if stage in self.stages:
            self.stages[stage]['retry_count'] += 1
    
    def stage_fix_applied(self, stage: str, fix_description: str):
        """Record an applied fix."""
        if stage in self.stages:
            self.stages[stage]['applied_fixes'] = True
            self.stages[stage]['details'].append(f"FIX: {fix_description}")
    
    def stage_verification_passed(self, stage: str):
        """Record verification success."""
        if stage in self.stages:
            self.stages[stage]['passed_verification'] = True
    
    def finish_stage(self, stage: str, status: str = 'success'):
        """Finish a stage."""
        if stage in self.stages:
            self.stages[stage]['end_time'] = datetime.now()
            self.stages[stage]['status'] = status
            self.stages[stage]['duration_seconds'] = (
                self.stages[stage]['end_time'] - self.stages[stage]['start_time']
            ).total_seconds()
    
    def finalize(self):
        """Finalize processing and write log entry."""
        end_time = datetime.now()
        total_duration = (end_time - self.start_time).total_seconds()
        
        log_entry = {
            'problem_id': self.problem_id,
            'timestamp': self.start_time.isoformat(),
            'total_duration_seconds': total_duration,
            'stages': {}
        }
        
        for stage, data in self.stages.items():
            stage_info = {
                'status': data['status'],
                'duration_seconds': data.get('duration_seconds', 0),
                'error_count': data['error_count'],
                'retry_count': data['retry_count'],
                'applied_fixes': data['applied_fixes'],
                'passed_verification': data['passed_verification'],
                'details_count': len(data['details'])
            }
            log_entry['stages'][stage] = stage_info
        
        # Write JSONL file.
        os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
        with open(self.log_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(log_entry, ensure_ascii=False) + '\n')


def setup_logging(log_file='outputs/pipeline.log'):
    """
    Configure a simplified logger (global information only).
    
    Args:
        log_file: Log file path
        
    Returns:
        Configured logger
    """
    # Ensure log directory exists.
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    
    # Configure logger.
    logger = logging.getLogger('pipeline')
    logger.setLevel(logging.INFO)
    
    # Clear existing handlers to avoid duplicates.
    for h in list(logger.handlers):
        logger.removeHandler(h)
    
    # File handler.
    fh = logging.FileHandler(log_file, mode='w', encoding='utf-8')
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    fh.setFormatter(formatter)
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)
    
    return logger


def save_output(filepath: str, content: str) -> None:
    """
    Save output content to a file.
    
    Args:
        filepath: File path
        content: File content
    """
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    with open(filepath, 'w', encoding='utf-8') as f:
        f.write(content)


def read_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """
    Read a JSONL file.
    
    Args:
        file_path: JSONL file path
        
    Returns:
        Parsed data list
    """
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            if 'id' not in obj:
                obj['id'] = str(idx)
            data.append(obj)
    return data


def ensure_problem_dirs(problem_id: str) -> Dict[str, str]:
    """Create directory structure for a problem."""
    base = os.path.join('outputs', 'problems', str(problem_id))
    logs = os.path.join(base, 'logs')
    artifacts = os.path.join(base, 'artifacts')
    verification = os.path.join(base, 'verification')
    os.makedirs(logs, exist_ok=True)
    os.makedirs(artifacts, exist_ok=True)
    os.makedirs(verification, exist_ok=True)
    return {"base": base, "logs": logs, "artifacts": artifacts, "verification": verification}


def create_simple_logger(name: str, log_file: str) -> logging.Logger:
    """Create a simple file-only logger (no console output)."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    logger.propagate = False

    # Clear existing handlers to avoid duplicates.
    for h in list(logger.handlers):
        logger.removeHandler(h)

    # File handler only.
    fh = logging.FileHandler(log_file, mode='w', encoding='utf-8')
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    fh.setFormatter(formatter)
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    return logger


def init_problem_loggers(problem_id: str) -> Dict[str, Any]:
    """Initialize problem-specific loggers."""
    dirs = ensure_problem_dirs(problem_id)
    
    # Create simple file loggers (no console output).
    loggers = {
        'planner_agent': create_simple_logger(
            f"problem.{problem_id}.planner_agent",
            os.path.join(dirs['logs'], 'planner_agent.log')
        ),
        'merge_agent': create_simple_logger(
            f"problem.{problem_id}.merge_agent",
            os.path.join(dirs['logs'], 'merge_agent.log')
        ),
        'rtl_agent': create_simple_logger(
            f"problem.{problem_id}.rtl_agent",
            os.path.join(dirs['logs'], 'rtl_agent.log')
        ),
        'debug_agent': create_simple_logger(
            f"problem.{problem_id}.debug_agent",
            os.path.join(dirs['logs'], 'debug_agent.log')
        ),
        'rag_agent': create_simple_logger(
            f"problem.{problem_id}.rag_agent",
            os.path.join(dirs['logs'], 'rag_agent.log')
        ),
    }
    
    # Add console handler and performance logger.
    loggers.update({
        'console': SimpleConsoleHandler(),
        'performance': PerformanceLogger(problem_id, dirs['logs']),
        'dirs': dirs
    })
    
    return loggers


def save_verification_files(problem_id: str, source_dir: str, target_dir: str) -> bool:
    """
    Copy verification files to target directory.
    
    Args:
        problem_id: Problem ID
        source_dir: Source directory
        target_dir: Target directory
        
    Returns:
        Whether the copy succeeded
    """
    try:
        if not os.path.exists(source_dir):
            return False
            
        # Copy entire directory.
        if os.path.exists(target_dir):
            shutil.rmtree(target_dir)
        shutil.copytree(source_dir, target_dir)
        return True
    except Exception:
        return False 