# utils/evaluation.py
"""
Evaluation utilities.

Helpers for result analysis, statistics, and report generation.
"""

import json
import re
from typing import List, Dict, Any, Optional
from collections import defaultdict, Counter

from .common import parse_multi_choice_response
from config.settings import DATASET_CONFIG


LETTER_SET = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ")


def _ground_truth_letter(gt_index: int) -> Optional[str]:
    """Convert an index to a letter."""
    try:
        i = int(gt_index)
        if 0 <= i < 26:
            return chr(ord("A") + i)  # 0->A, 1->B, ...
    except Exception:
        pass
    return None


def _parse_prediction(pred: Any) -> tuple[Optional[str], Any]:
    """
    Parse a prediction and extract the option letter.

    Returns:
        (option_letter, original_prediction)
    """
    if pred is None:
        return None, pred
    
    # String input
    if isinstance(pred, str):
        # Try multi-choice parsing
        parsed = parse_multi_choice_response(pred)
        if parsed is not None:
            return parsed, pred
        
        # Try extracting a single letter
        clean = pred.strip().upper()
        if len(clean) == 1 and clean in LETTER_SET:
            return clean, pred
        
        # Regex fallback
        matches = re.findall(r'\b([A-Z])\b', pred.upper())
        if matches:
            return matches[0], pred
    
    # Numeric input (index -> letter)
    elif isinstance(pred, (int, float)):
        letter = _ground_truth_letter(int(pred))
        return letter, pred
    
    # Dict input (may contain final_answer)
    elif isinstance(pred, dict):
        final = pred.get('final_answer') or pred.get('final') or pred.get('answer')
        if final is not None:
            return _parse_prediction(final)[0], pred
    
    return None, pred


def _is_correct(pred: Any, gt_index: int) -> bool:
    """
    Check whether a prediction matches the ground truth.

    Args:
        pred: Prediction (any type).
        gt_index: Ground-truth answer index.

    Returns:
        bool: Whether it is correct.
    """
    if gt_index is None:
        return False
    
    pred_letter, _ = _parse_prediction(pred)
    gt_letter = _ground_truth_letter(gt_index)
    
    return pred_letter is not None and gt_letter is not None and pred_letter == gt_letter


def evaluate_detailed_results(detailed_results: List[Dict]) -> Dict[str, Any]:
    """
    Evaluate detailed results and compute summary statistics.

    Args:
        detailed_results: A list of per-sample result dicts.

    Returns:
        Dict: A dict containing summary metrics.
    """
    correct_count = 0
    total_iterations = 0
    iter_hist = defaultdict(int)
    extra_hist = defaultdict(int)

    sum_sampling_time = 0.0
    sum_agent_time = 0.0
    sum_e2e_time = 0.0

    # Token accounting (simplified)
    prompt_tokens = output_tokens = total_tokens = 0

    valid_items = 0
    for item in detailed_results:
        if not isinstance(item, dict):
            continue
        valid_items += 1

        # Correctness
        gt = item.get('ground_truth')
        pred = item.get('prediction')
        if _is_correct(pred, gt):
            correct_count += 1

        # Iteration stats
        iters = int(item.get('iterations', 1) or 1)
        total_iterations += iters
        iter_hist[iters] += 1
        extra_hist[max(0, iters - 1)] += 1

        # Timing stats
        times = item.get('times', {}) or {}
        sum_sampling_time += float(times.get('sampling_time', 0) or 0)
        sum_agent_time += float(times.get('agent_inference_time', 0) or 0)
        sum_e2e_time += float(times.get('end_to_end_time', 0) or 0)

        # Token stats (normalized)
        if 'tokens' in item:
            tk = item['tokens'] or {}
            prompt_tokens += int(tk.get('prompt', 0) or 0)
            output_tokens += int(tk.get('output', 0) or 0)
            total_tokens += int(tk.get('total', 0) or 0)

    total_tasks = max(1, valid_items)
    accuracy = (correct_count / total_tasks) * 100.0
    avg_iterations = total_iterations / total_tasks
    iterating_task_count = sum(c for k, c in extra_hist.items() if k >= 1)
    iterating_task_ratio = (iterating_task_count / total_tasks) * 100.0

    evaluation_result = {
        'accuracy_summary': {
            'correct_count': correct_count,
            'total_tasks': total_tasks,
            'accuracy_percent': accuracy
        },
        'timing_summary': {
            'total_sampling_time': sum_sampling_time,
            'total_agent_inference_time': sum_agent_time,
            'total_end_to_end_time': sum_e2e_time,
            'average_sampling_time': sum_sampling_time / total_tasks,
            'average_agent_inference_time': sum_agent_time / total_tasks,
            'average_end_to_end_time': sum_e2e_time / total_tasks
        },
        'iteration_summary': {
            'total_iterations': total_iterations,
            'average_iterations_per_task': avg_iterations,
            'iterating_task_count': iterating_task_count,
            'iterating_task_ratio_percent': iterating_task_ratio,
            'iterations_histogram': {f'iter_{k}': v for k, v in sorted(iter_hist.items())},
            'extra_iteration_histogram': {f'extra_{k}': v for k, v in sorted(extra_hist.items())},
        }
    }

    # Add token stats if present
    if any(isinstance(r, dict) and 'tokens' in r for r in detailed_results):
        evaluation_result['token_usage'] = {
            'prompt_tokens': prompt_tokens,
            'output_tokens': output_tokens,
            'total_tokens': total_tokens
        }

    return evaluation_result


def load_dataset(dataset_path: str, dataset_name: str) -> List[Dict]:
    """
    Load a dataset.

    Args:
        dataset_path: Dataset root path.
        dataset_name: Dataset name key in config.

    Returns:
        List[Dict]: A list of task dicts.
    """
    import os
    
    if dataset_name not in DATASET_CONFIG:
        raise ValueError(
            f"Unknown dataset: {dataset_name}. Configure it in config.settings.DATASET_CONFIG."
        )

    cfg = DATASET_CONFIG[dataset_name]
    label_file = cfg.get("label_file")
    if not label_file:
        raise ValueError(f"Dataset {dataset_name} does not define label_file")
    label_path = os.path.join(dataset_path, label_file)
    
    if not os.path.exists(label_path):
        raise FileNotFoundError(f"Dataset label file not found: {label_path}")
    
    with open(label_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def generate_summary_report(experiment_config: Dict, performance: Dict, timing: Dict, 
                           token_usage: Optional[Dict] = None) -> Dict[str, Any]:
    """
    Generate a summary report payload.

    Args:
        experiment_config: Experiment configuration.
        performance: Performance metrics.
        timing: Timing metrics.
        token_usage: Optional token usage statistics.

    Returns:
        Dict: Summary report dict.
    """
    report = {
        'experiment_config': experiment_config,
        'performance_metrics': performance,
        'timing_statistics': timing
    }
    
    if token_usage:
        report['resource_usage'] = {'token_usage': token_usage}
    
    return report
