"""
Utility to aggregate metrics from checkpoint evaluations.

This module provides functions to read summary.txt files from model checkpoints
and compute mean and variance across multiple runs.
"""

from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd


def parse_summary_file(summary_path: Path) -> Dict[str, float]:
    """
    Parse a summary.txt file and extract all metrics.

    Args:
        summary_path: Path to summary.txt file

    Returns:
        Dictionary mapping metric names to their values
    """
    metrics = {}

    with open(summary_path, 'r') as f:
        for line in f:
            line = line.strip()
            # Skip empty lines and header lines
            if (
                not line
                or 'Dataset mean metrics' in line
                or 'Total samples' in line
                or 'Converged samples' in line
            ):
                continue

            # Parse lines like "  total_energy_mae_mEh: 2.008968"
            if ':' in line:
                key, value = line.split(':', 1)
                key = key.strip()
                value = value.strip()
                try:
                    metrics[key] = float(value)
                except ValueError:
                    # Skip non-numeric values
                    pass

    return metrics


def parse_csv_file(csv_path: Path) -> Dict[str, float]:
    """
    Parse a results.csv file and compute mean metrics.

    Args:
        csv_path: Path to results.csv file

    Returns:
        Dictionary mapping metric names to their mean values
    """
    try:
        df = pd.read_csv(csv_path)

        # Compute mean for all numeric columns
        metrics = {}
        for col in df.columns:
            if df[col].dtype in ['float64', 'int64']:
                metrics[col] = float(df[col].mean())

        return metrics
    except Exception as e:
        print(f'Error reading CSV {csv_path}: {e}')
        return {}


def aggregate_checkpoint_metrics(
    checkpoint_root: str,
    model_name: str,
    checkpoint_indices: Tuple[int, ...],
    use_csv: bool = True,
) -> Tuple[Dict[str, float], Dict[str, float]]:
    """
    Aggregate metrics across multiple checkpoint runs.

    Args:
        checkpoint_root: Root directory containing evaluations (e.g., '../evaluations/1000')
        model_name: Name of the model (e.g., 'NNmGGA_scan_qm5')
        checkpoint_indices: Tuple of checkpoint indices to aggregate (e.g., (18, 20, 22))
        use_csv: If True, read from results.csv for better precision (default: True)

    Returns:
        Tuple of (means, stds) where each is a dict mapping metric names to values
    """
    checkpoint_root_path = Path(checkpoint_root)

    # Collect metrics from all checkpoints
    all_metrics = []
    for idx in checkpoint_indices:
        # Handle nested paths (e.g., "NNmGGA_scan_qm5/ood_basis_set")
        if '/' in model_name:
            # For nested paths, extract the base model name for checkpoint directories
            parts = model_name.split('/')
            base_model = parts[0]
            subpath = '/'.join(parts[1:])

            # Try different patterns for nested structures
            base_path1 = checkpoint_root_path / model_name / f'{base_model}_{idx}'
            base_path2 = checkpoint_root_path / parts[0] / subpath / f'{parts[0]}_{idx}'
        else:
            # Original patterns for non-nested paths
            base_path1 = checkpoint_root_path / model_name / f'{model_name}_{idx}'
            base_path2 = checkpoint_root_path / model_name / str(idx)

        if base_path1.exists():
            base_path = base_path1
        elif base_path2.exists():
            base_path = base_path2
        else:
            print(
                f'Warning: Could not find checkpoint directory for {model_name} checkpoint {idx}'
            )
            continue

        # Try to read from CSV first for better precision
        if use_csv:
            csv_path = base_path / 'results.csv'
            if csv_path.exists():
                metrics = parse_csv_file(csv_path)
            else:
                # Fall back to summary.txt
                summary_path = base_path / 'summary.txt'
                if summary_path.exists():
                    metrics = parse_summary_file(summary_path)
                else:
                    print(
                        f'Warning: No results.csv or summary.txt found for {model_name} checkpoint {idx}'
                    )
                    continue
        else:
            summary_path = base_path / 'summary.txt'
            if summary_path.exists():
                metrics = parse_summary_file(summary_path)
            else:
                print(f'Warning: No summary.txt found for {model_name} checkpoint {idx}')
                continue

        all_metrics.append(metrics)

    if not all_metrics:
        raise ValueError(
            f'No valid checkpoints found for {model_name} with indices {checkpoint_indices}'
        )

    # Compute mean and std for each metric
    metric_names = all_metrics[0].keys()
    means = {}
    stds = {}

    for metric_name in metric_names:
        values = [m[metric_name] for m in all_metrics if metric_name in m]
        means[metric_name] = float(np.mean(values))
        stds[metric_name] = float(np.std(values))

    return means, stds


def aggregate_model_loss_combinations(
    checkpoint_root: str,
    model_batches: List[List[Tuple[str, Any, str, Tuple[int, ...]]]],
    loss_labels: List[str] | None = None,
    use_csv: bool = True,
) -> Tuple[Dict[str, Dict[str, List[float]]], Dict[str, Dict[str, List[float]]]]:
    """
    Aggregate metrics for all model + loss combinations.

    Args:
        checkpoint_root: Root directory containing evaluations
        model_batches: List of model batches, where each batch is a list of tuples:
                       (model_name, color, label, checkpoint_indices)
        loss_labels: Optional list of loss configuration labels (e.g., ['E_tot', 'E_rho', 'E_xc', 'grad'])
                    If not provided, will use indices 0, 1, 2, 3
        use_csv: If True, read from results.csv for better precision (default: True)

    Returns:
        Tuple of (all_means, all_stds) where:
        - all_means: Dict[metric_name][model_idx] = [loss_0_value, loss_1_value, ...]
        - all_stds: Dict[metric_name][model_idx] = [loss_0_std, loss_1_std, ...]
    """
    if loss_labels is None:
        loss_labels = [f'loss_{i}' for i in range(len(model_batches[0]))]

    # Extract model names from first entry in each batch
    model_names = [
        batch[0][0].split('_')[0] for batch in model_batches
    ]  # e.g., 'NNmGGA', 'XCdiff', 'Skala'

    # Dictionary to store aggregated results
    # Structure: all_means[metric_name][model_name] = [value_for_loss_0, value_for_loss_1, ...]
    all_means = {}
    all_stds = {}

    for model_idx, batch in enumerate(model_batches):
        model_display_name = model_names[model_idx]

        for loss_idx, (model_name, color, label, checkpoint_indices) in enumerate(batch):
            print(
                f'Processing {model_name} with loss config {loss_idx} ({label}), checkpoints {checkpoint_indices}'
            )

            means, stds = aggregate_checkpoint_metrics(
                checkpoint_root, model_name, checkpoint_indices, use_csv=use_csv
            )

            # Add to aggregated results
            for metric_name, mean_value in means.items():
                if metric_name not in all_means:
                    all_means[metric_name] = {mn: [] for mn in model_names}
                    all_stds[metric_name] = {mn: [] for mn in model_names}

                all_means[metric_name][model_display_name].append(mean_value)
                all_stds[metric_name][model_display_name].append(stds[metric_name])

    return all_means, all_stds


def format_for_visualization(
    all_means: Dict[str, Dict[str, List[float]]],
    all_stds: Dict[str, Dict[str, List[float]]],
    metrics_of_interest: List[str],
) -> Tuple[Dict[str, Dict[str, List[float]]], Dict[str, Dict[str, List[float]]]]:
    """
    Format aggregated data for the visualization notebook format.

    Args:
        all_means: Output from aggregate_model_loss_combinations
        all_stds: Output from aggregate_model_loss_combinations
        metrics_of_interest: List of metric names to include (e.g., ['total_energy_mae_mEh', ...])

    Returns:
        Tuple of (data, errors) in the format expected by visualization notebooks
    """
    data = {}
    errors = {}

    for metric in metrics_of_interest:
        if metric in all_means:
            data[metric] = all_means[metric]
            errors[metric] = all_stds[metric]
        else:
            print(f"Warning: Metric '{metric}' not found in aggregated data")

    return data, errors


def aggregate_b3lyp_checkpoint_metrics(
    checkpoint_root: str,
    model_name: str,
    loss_term: str,
    checkpoint_id: int | Tuple[int, ...],
    use_csv: bool = True,
) -> Tuple[Dict[str, float], Dict[str, float]] | Dict[str, float]:
    """
    Aggregate metrics for B3LYP models with nested loss term structure.

    Args:
        checkpoint_root: Root directory containing evaluations (e.g., '../evaluations/1000')
        model_name: Name of the model (e.g., 'EGXC_b3lyp_qm7')
        loss_term: Loss term subdirectory (e.g., 'baseline', 'gradient', 'grad_and_hessian')
        checkpoint_id: Single checkpoint integer or tuple of checkpoint integers
        use_csv: If True, read from results.csv for better precision (default: True)

    Returns:
        If single checkpoint_id: Dict mapping metric names to values
        If tuple of checkpoint_ids: Tuple of (means, stds) where each is a dict mapping metric names to values
    """
    checkpoint_root_path = Path(checkpoint_root)

    # Normalize loss term names (handle variations)
    loss_term_normalized = loss_term
    if loss_term == 'grad':
        loss_term_normalized = 'gradient'
    elif loss_term == 'grad_and_hess':
        loss_term_normalized = 'grad_and_hessian'

    # Handle single checkpoint vs multiple checkpoints
    if isinstance(checkpoint_id, tuple):
        checkpoint_indices = checkpoint_id
        return_multiple = True
    else:
        checkpoint_indices = (checkpoint_id,)
        return_multiple = False

    # Collect metrics from all checkpoints
    all_metrics = []
    for idx in checkpoint_indices:
        # For baseline, results are directly in the loss_term folder
        if loss_term_normalized == 'baseline':
            base_path = checkpoint_root_path / model_name / loss_term_normalized
        else:
            # For other loss terms, results are in subdirectories
            # Handle naming variations: EGXC_b3lyp_qm7_58 vs NNmGGA2_b3lyp_qm7_58
            base_model_name = model_name.replace('_b3lyp_qm7', '')
            # Try different naming patterns
            checkpoint_folder1 = f'{model_name}_{idx}'
            checkpoint_folder2 = f'{base_model_name}2_b3lyp_qm7_{idx}'
            checkpoint_folder3 = f'{base_model_name}2_b3lyp_qm7_correction_{idx}'

            base_path = None
            for folder_name in [
                checkpoint_folder1,
                checkpoint_folder2,
                checkpoint_folder3,
            ]:
                candidate_path = (
                    checkpoint_root_path / model_name / loss_term_normalized / folder_name
                )
                if candidate_path.exists():
                    base_path = candidate_path
                    break

            if base_path is None:
                # Try alternative loss term names
                for alt_loss_term in [
                    loss_term,
                    'gradient' if loss_term == 'grad' else 'grad',
                    'grad_and_hessian'
                    if loss_term == 'grad_and_hess'
                    else 'grad_and_hess',
                ]:
                    for folder_name in [
                        checkpoint_folder1,
                        checkpoint_folder2,
                        checkpoint_folder3,
                    ]:
                        candidate_path = (
                            checkpoint_root_path
                            / model_name
                            / alt_loss_term
                            / folder_name
                        )
                        if candidate_path.exists():
                            base_path = candidate_path
                            break
                    if base_path:
                        break

            if base_path is None:
                print(
                    f'Warning: Could not find checkpoint directory for {model_name}/{loss_term_normalized} checkpoint {idx}'
                )
                continue

        # Try to read from CSV first for better precision
        if use_csv:
            csv_path = base_path / 'results.csv'
            if csv_path.exists():
                metrics = parse_csv_file(csv_path)
            else:
                # Fall back to summary.txt
                summary_path = base_path / 'summary.txt'
                if summary_path.exists():
                    metrics = parse_summary_file(summary_path)
                else:
                    print(
                        f'Warning: No results.csv or summary.txt found for {model_name}/{loss_term_normalized} checkpoint {idx}'
                    )
                    continue
        else:
            summary_path = base_path / 'summary.txt'
            if summary_path.exists():
                metrics = parse_summary_file(summary_path)
            else:
                print(
                    f'Warning: No summary.txt found for {model_name}/{loss_term_normalized} checkpoint {idx}'
                )
                continue

        all_metrics.append(metrics)

    if not all_metrics:
        raise ValueError(
            f'No valid checkpoints found for {model_name}/{loss_term_normalized} with checkpoint_id {checkpoint_id}'
        )

    # If single checkpoint, return metrics dict directly
    if not return_multiple:
        return all_metrics[0]

    # Compute mean and std for each metric
    metric_names = all_metrics[0].keys()
    means = {}
    stds = {}

    for metric_name in metric_names:
        values = [m[metric_name] for m in all_metrics if metric_name in m]
        means[metric_name] = float(np.mean(values))
        stds[metric_name] = float(np.std(values))

    return means, stds


def aggregate_b3lyp_model_loss_combinations(
    checkpoint_root: str,
    model_batches: List[List[Tuple[str, str, str, int | Tuple[int, ...]]]],
    use_csv: bool = True,
) -> Tuple[Dict[str, Dict[str, List[float]]], Dict[str, Dict[str, List[float]]]]:
    """
    Aggregate metrics for B3LYP model + loss term combinations.

    Args:
        checkpoint_root: Root directory containing evaluations
        model_batches: List of model batches, where each batch is a list of tuples:
                       (model_name, loss_term, label, checkpoint_id)
                       - model_name: e.g., 'EGXC_b3lyp_qm7'
                       - loss_term: e.g., 'baseline', 'gradient', 'grad_and_hessian'
                       - label: Display label for the loss term (e.g., r'$E_\\mathrm{tot}$')
                       - checkpoint_id: Integer or tuple of checkpoint IDs
        use_csv: If True, read from results.csv for better precision (default: True)

    Returns:
        Tuple of (all_means, all_stds) where:
        - all_means: Dict[metric_name][model_display_name] = [loss_0_value, loss_1_value, ...]
        - all_stds: Dict[metric_name][model_display_name] = [loss_0_std, loss_1_std, ...]
    """
    # Extract model display names from first entry in each batch
    # e.g., 'EGXC_b3lyp_qm7' -> 'EGXC', 'NNmGGA_b3lyp_qm7' -> 'NNmGGA'
    model_display_names = []
    for batch in model_batches:
        model_name = batch[0][0]
        # Extract base model name (before _b3lyp)
        if '_b3lyp' in model_name:
            display_name = model_name.split('_b3lyp')[0]
        else:
            # Fallback: take first part
            display_name = model_name.split('_')[0]
        model_display_names.append(display_name)

    # Dictionary to store aggregated results
    # Structure: all_means[metric_name][model_display_name] = [value_for_loss_0, value_for_loss_1, ...]
    all_means = {}
    all_stds = {}

    for model_idx, batch in enumerate(model_batches):
        model_display_name = model_display_names[model_idx]

        for loss_idx, (model_name, loss_term, label, checkpoint_id) in enumerate(batch):
            print(
                f'Processing {model_name}/{loss_term} ({label}), checkpoint {checkpoint_id}'
            )

            result = aggregate_b3lyp_checkpoint_metrics(
                checkpoint_root, model_name, loss_term, checkpoint_id, use_csv=use_csv
            )

            # Handle both single checkpoint (dict) and multiple checkpoints (tuple)
            if isinstance(result, tuple):
                means, stds = result
            else:
                # Single checkpoint - use as mean, std is zero
                means = result
                stds = {k: 0.0 for k in means.keys()}

            # Add to aggregated results
            for metric_name, mean_value in means.items():
                if metric_name not in all_means:
                    all_means[metric_name] = {mn: [] for mn in model_display_names}
                    all_stds[metric_name] = {mn: [] for mn in model_display_names}

                all_means[metric_name][model_display_name].append(mean_value)
                all_stds[metric_name][model_display_name].append(stds[metric_name])

    return all_means, all_stds
