"""Data loading and analysis utilities for ImpossibleBench evaluation results."""

import os
import json
import re
import pandas as pd
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
import logging

try:
    from inspect_ai.log import read_eval_log
except ImportError:
    print("Warning: inspect_ai not available. Install with: pip install inspect-ai")
    read_eval_log = None

logger = logging.getLogger(__name__)

def get_dump_of_everything(s):
    from inspect_ai.model._chat_message import ChatMessageSystem, ChatMessageUser, ChatMessageAssistant, ChatMessageTool
    from inspect_ai._util.content import ContentText, ContentReasoning
    dump_of_everything = ''
    for msg in s.messages:
        content = msg.content
        tool_calls = getattr(msg, 'tool_calls', [])
        if isinstance(msg, ChatMessageSystem):
            dump_of_everything+='**SYSTEM MESSAGE:**\n'
        elif isinstance(msg, ChatMessageUser):
            dump_of_everything+='**USER MESSAGE:**\n'
        elif isinstance(msg, ChatMessageAssistant):
            dump_of_everything+='**ASSISTANT MESSAGE:**\n'
        elif isinstance(msg, ChatMessageTool):
            dump_of_everything+='**TOOL OUTPUT:**\n'
        else:
            raise ValueError(f"Unknown message type: {type(msg)}")
        if not isinstance(content, list):
            content = [content]
        for c in content:
            if isinstance(c, ContentText):
                dump_of_everything+=c.text+'\n'
            elif isinstance(c, str):
                dump_of_everything+=c+'\n'
        if tool_calls is None:
            tool_calls = []
        for tool_call in tool_calls:
            dump_of_everything+=f'[TOOL CALL: {tool_call.function}, ARGS: {tool_call.arguments}]\n'
        dump_of_everything+='\n\n'
    return dump_of_everything


@dataclass
class EvalResult:
    """Container for evaluation results with metadata."""
    # File metadata
    file_path: str
    filename: str
    
    # Task metadata
    task_id: str
    model: str
    dataset: str
    variant: str
    agent_type: str
    prompt_id: Optional[str]
    modification: Optional[str]

    # Results
    passed: bool
    first_pass: bool
    score: Optional[float]
    first_score: Optional[float]
    
    # Additional metadata
    timestamp: Optional[str]
    sample_id: Optional[str]
    sample_original_id: Optional[str]
    task_name: Optional[str]
    dump: Optional[str]
    
    def to_dict(self, dump: bool = False) -> Dict:
        """Convert to dictionary for DataFrame creation."""
        return {
            'file_path': self.file_path,
            'filename': self.filename,
            'task_id': self.task_id,
            'model': self.model,
            'dataset': self.dataset,
            'variant': self.variant,
            'agent_type': self.agent_type,
            'prompt_id': self.prompt_id,
            'modification': self.modification,
            'passed': self.passed,
            'first_pass': self.first_pass,
            'score': self.score,
            'first_score': self.first_score,
            'timestamp': self.timestamp,
            'sample_id': self.sample_id,
            'sample_original_id': self.sample_original_id,
            'task_name': self.task_name,
            'dump': self.dump if dump else None,
        }


def parse_task_display_name(task_display_name: str) -> Dict[str, Optional[str]]:
    """Parse metadata from task display name.
    
    Expected patterns:
    - lcbdev_conflicting_canmod_p0_minimal
    - swebench_conflicting_minimal
    - swebench_oneoff_tools
    - lcbhard_original_canmod_p2_tools_iv
    
    Returns:
        Dictionary with parsed metadata: dataset, variant, modification, prompt_id, agent_type, etc.
    """
    parts = task_display_name.split('_')
    
    metadata = {
        'dataset': None,
        'variant': 'original',  # default
        'prompt_id': '',   # iv is also a prompt id
        'agent_type': None,
        'modification': '',
        # 'intervention': False,  # _iv suffix
    }
    
    if not parts:
        return metadata
    
    # First part is usually dataset
    metadata['dataset'] = parts[0]
    
    # Check for _iv suffix (human intervention)
    # if 'iv' in parts:
    #     metadata['intervention'] = True
    #     parts = parts[:-1]  # Remove _iv from processing
    
    # Last part is usually agent type
    for w in ['minimal', 'tools']:
        if w in parts:
            metadata['agent_type'] = w
            # remove from parts
            parts.remove(w)
            break
    
    # Look for variant keywords
    variant_keywords = ['original', 'oneoff', 'inverted', 'conflicting']
    for i, part in enumerate(parts[1:], 1):  # Skip dataset (first part)
        if part in variant_keywords:
            metadata['variant'] = part
            break
    
    # Look for prompt ID (p0, p1, p2, etc.)
    for part in parts:
        if re.match(r'^p\d+$', part):
            metadata['prompt_id'] = part
            break
    if '_be_' in task_display_name:
        metadata['prompt_id'] = 'be'
    if '_iv_' in task_display_name:
        metadata['prompt_id'] = 'iv'
    
    # Look for modification keywords
    modification_keywords = ['reset', 'hide']
    for kw in modification_keywords:
        if kw in task_display_name:
            metadata['modification'] += kw
    
    return metadata


def parse_filename_metadata(filename: str) -> Dict[str, Optional[str]]:
    """Parse metadata from eval filename.
    
    Eval files have format: {timestamp}_{task_display_name}_{random_id}.eval
    """
    # Remove .eval extension
    basename = filename.replace('.eval', '')
    parts = basename.split('_')
    
    if len(parts) < 3:
        logger.warning(f"Unexpected filename format: {filename}")
        return {'timestamp': None}
    
    # Extract timestamp (first part with date format)
    timestamp_match = re.match(r'(\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}\+\d{2}-\d{2})', basename)
    timestamp = timestamp_match.group(1) if timestamp_match else None
    
    return {'timestamp': timestamp}


def parse_eval_file(file_path: str) -> List[EvalResult]:
    """Parse a single eval file to extract evaluation results for all samples."""
    if read_eval_log is None:
        logger.error("inspect_ai not available. Cannot parse eval files.")
        return []
    
    results = []
    
    try:
        filename = os.path.basename(file_path)
        file_metadata = parse_filename_metadata(filename)
        
        # Read eval log with header only first for basic info
        eval_log = read_eval_log(file_path, header_only=True)
        
        # Extract basic metadata
        task_display_name = eval_log.eval.task_display_name
        model = eval_log.eval.model
        
        # Parse task display name for detailed metadata
        task_metadata = parse_task_display_name(task_display_name)
        
        # Get score information
        if eval_log.results and eval_log.results.scores:
            score_info = eval_log.results.scores[0]
            metrics = score_info.metrics
            
            # Extract overall pass rate
            if 'mean' in metrics:
                overall_score = metrics['mean'].value
            elif 'accuracy' in metrics:
                overall_score = metrics['accuracy'].value
            else:
                overall_score = 0.0
            
            num_scored_samples = score_info.scored_samples if hasattr(score_info, 'scored_samples') else 1
        else:
            overall_score = 0.0
            num_scored_samples = 1
        
        # Now read full log to get individual sample results
        try:
            eval_log_full = read_eval_log(file_path, header_only=False)
            samples = eval_log_full.samples if eval_log_full.samples else []
        except Exception as e:
            import traceback
            traceback.print_exc()
            logger.warning(f"Could not read full log for {filename}: {e}. Using header-only data.")
            samples = []
        
        # If we have individual samples, create results for each
        if samples:
            for i, sample in enumerate(samples):
                # Get individual sample score
                sample_passed = False
                sample_score = 0.0
                sample_first_pass = False
                sample_first_score = 0.0
                sample_dump = get_dump_of_everything(sample)
                
                
                if hasattr(sample, 'scores') and sample.scores:
                    # Look for specific scorer results
                    for scorer_name, score_result in sample.scores.items():
                        score_to_float = lambda v: float({'I': 0, 'C': 1}.get(v, v))
                        if hasattr(score_result, 'value'):
                            sample_score = score_to_float(score_result.value)
                            sample_passed = sample_score > 0.0
                            break
                    scores_got = []
                    for e in sample.events:
                        try:
                            scores_got.append(score_to_float(e.score.value))
                        except:
                            continue
                    # if len(scores_got) > 1:
                    #     print(f'{scores_got=} {sample_score=} {task_metadata.get("agent_type")=}')
                else:
                    # invalid score
                    continue
                
                if sample_passed:
                    if task_metadata.get("agent_type") == "tools":
                        sample_first_pass = scores_got[0] > 0.0 if len(scores_got) > 0 else False
                        sample_first_score = scores_got[0] if len(scores_got) > 0 else 0.0
                    elif task_metadata.get("agent_type") == "minimal":
                        assert hasattr(sample, 'messages') and sample.messages
                        assistant_messages = [m for m in sample.messages if hasattr(m, 'role') and m.role == 'assistant']
                        sample_first_pass = len(assistant_messages) <= 1
                        sample_first_score = sample_score if sample_first_pass else 0.0
                    else:
                        raise NotImplementedError(f'{task_metadata.get("agent_type")=} {task_display_name=}')
                
                # Create result for this sample
                result = EvalResult(
                    file_path=file_path,
                    filename=filename,
                    task_id=f"{task_display_name}_sample_{i}",
                    model=model,
                    dataset=task_metadata.get('dataset', 'unknown'),
                    variant=task_metadata.get('variant', 'original'),
                    agent_type=task_metadata.get('agent_type', 'unknown'),
                    prompt_id=task_metadata.get('prompt_id'),
                    modification=task_metadata.get('modification'),
                    passed=sample_passed,
                    first_pass=sample_first_pass,
                    score=sample_score,
                    first_score=sample_first_score,
                    timestamp=file_metadata.get('timestamp'),
                    sample_id=str(i),
                    sample_original_id=sample.id,
                    task_name=task_display_name,
                    dump=sample_dump,
                )
                results.append(result)
        else:
            print(f'No individual samples for {filename}: {model=}')
        # No individual samples, create single aggregate result
        result = EvalResult(
            file_path=file_path,
            filename=filename,
            task_id=task_display_name,
            model=model,
            dataset=task_metadata.get('dataset', 'unknown'),
            variant=task_metadata.get('variant', 'original'),
            agent_type=task_metadata.get('agent_type', 'unknown'),
            prompt_id=task_metadata.get('prompt_id'),
            modification=task_metadata.get('modification'),
            passed=None,#overall_score > 0.0,
            first_pass=None,#overall_score > 0.0,  # No way to determine first pass from aggregate
            score=None, #overall_score,
            first_score=None, #overall_score,
            timestamp=file_metadata.get('timestamp'),
            sample_id=None,
            sample_original_id=None,
            task_name=task_display_name,
            dump=None,
        )
        results.append(result)
        
        return results
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        logger.error(f"Error parsing eval file {file_path}: {e}")
        return []


class DataLoader:
    """Data loader for ImpossibleBench evaluation results."""
    
    def __init__(self, n_workers: int = 4):
        """Initialize DataLoader.
        
        Args:
            n_workers: Number of worker processes for parallel loading
        """
        self.n_workers = n_workers
        self.results: List[EvalResult] = []
    
    def load_folder(self, folder_path: str, pattern: str = "*.eval", show_progress: bool = True, limit: Optional[int] = None) -> 'DataLoader':
        """Load evaluation results from a folder.
        
        Args:
            folder_path: Path to folder containing eval files
            pattern: Glob pattern for files to load
            
        Returns:
            Self for method chaining
        """
        folder_path = Path(folder_path)
        if not folder_path.exists():
            raise FileNotFoundError(f"Folder does not exist: {folder_path}")
        
        # Find all matching files
        eval_files = list(folder_path.glob(pattern))
        if limit:
            eval_files = eval_files[:limit]
        
        if not eval_files:
            logger.warning(f"No files found matching pattern {pattern} in {folder_path}")
            return self
        
        logger.info(f"Loading {len(eval_files)} files from {folder_path}")
        # Process files in parallel
        with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
            futures = {executor.submit(parse_eval_file, str(f)): f for f in eval_files}
            
            if show_progress:
                from tqdm import tqdm
                futures_iter = tqdm(as_completed(futures), total=len(futures), desc="Loading files")
            else:
                futures_iter = as_completed(futures)
                
            for future in futures_iter:
                file_results = future.result()  # This returns a list of EvalResult
                if file_results:  # file_results is a list
                    self.results.extend(file_results)
        
        logger.info(f"Successfully loaded {len(self.results)} evaluation results")
        return self
    
    def load_folders(self, folder_paths: List[str], pattern: str = "*.eval") -> 'DataLoader':
        """Load evaluation results from multiple folders.
        
        Args:
            folder_paths: List of folder paths
            pattern: Glob pattern for files to load
            
        Returns:
            Self for method chaining
        """
        for folder_path in folder_paths:
            self.load_folder(folder_path, pattern)
        return self
    
    def to_df(self) -> pd.DataFrame:
        """Convert loaded results to pandas DataFrame.
        
        Returns:
            DataFrame with evaluation results and metadata
        """
        if not self.results:
            logger.warning("No results loaded. DataFrame will be empty.")
            return pd.DataFrame()
        
        # Convert results to dictionaries
        data = [result.to_dict() for result in self.results]
        df = pd.DataFrame(data)
        
        # Add computed columns only for aggregated rows (sample_id=None)
        df['pass_rate'] = df.groupby(['model', 'dataset', 'variant', 'agent_type', 'prompt_id', 'modification'])['passed'].transform(lambda x: x.mean(skipna=True))
        df['first_pass_rate'] = df.groupby(['model', 'dataset', 'variant', 'agent_type', 'prompt_id', 'modification'])['first_pass'].transform(lambda x: x.mean(skipna=True))
        df['count'] = df.groupby(['model', 'dataset', 'variant', 'agent_type', 'prompt_id', 'modification'])['passed'].transform(lambda x: x.count())
        print('A')

        # Calculate standard errors for pass rates
        def calculate_standard_error(scores):
            """Calculate standard error for a group of binary scores."""
            # remove nans
            scores = scores.dropna()
            if len(scores) == 0:
                return float('nan')
            if len(scores) <= 1:
                return 0.0
            # For binary data, use sample standard deviation
            sample_std = scores.std(ddof=1)
            return sample_std / (len(scores) ** 0.5)
        
        # Group by the same keys and calculate standard errors
        pass_rate_err = df.groupby(['model', 'dataset', 'variant', 'agent_type', 'prompt_id', 'modification'])['passed'].transform(calculate_standard_error)
        first_pass_rate_err = df.groupby(['model', 'dataset', 'variant', 'agent_type', 'prompt_id', 'modification'])['first_pass'].transform(calculate_standard_error)
        print('B')
        df['pass_rate_err'] = pass_rate_err
        df['first_pass_rate_err'] = first_pass_rate_err
        print('C')
        # Make the pass rate and first pass rate nan for non-aggregated rows
        aggregated_mask = df['sample_id'].isna()
        df.loc[~aggregated_mask, 'pass_rate'] = float('nan')
        df.loc[~aggregated_mask, 'first_pass_rate'] = float('nan')
        df.loc[~aggregated_mask, 'pass_rate_err'] = float('nan')
        df.loc[~aggregated_mask, 'first_pass_rate_err'] = float('nan')
        return df
    
    def get_summary(self) -> Dict[str, float]:
        """Get summary statistics.
        
        Returns:
            Dictionary with summary statistics
        """
        if not self.results:
            return {}
        
        df = self.to_df()
        
        return {
            'total_evaluations': len(df),
            'overall_pass_rate': df['passed'].mean(),
            'overall_first_pass_rate': df['first_pass'].mean(),
            'unique_models': df['model'].nunique(),
            'unique_datasets': df['dataset'].nunique(),
            'unique_variants': df['variant'].nunique(),
        }
    
    def filter_by(self, **kwargs) -> 'DataLoader':
        """Filter results by metadata fields.
        
        Args:
            **kwargs: Field-value pairs to filter by
            
        Returns:
            New DataLoader with filtered results
        """
        filtered_results = []
        
        for result in self.results:
            match = True
            for field, value in kwargs.items():
                result_value = getattr(result, field, None)
                if result_value != value:
                    match = False
                    break
            
            if match:
                filtered_results.append(result)
        
        # Create new DataLoader with filtered results
        new_loader = DataLoader(self.n_workers)
        new_loader.results = filtered_results
        return new_loader