"""
Common evaluation utilities for forward/inverse performance metrics.

This module provides shared functions used across DTLZ, Unifoil, and Gas Turbine
evaluation scripts. It consolidates:
- Forward and inverse pass logic for Diag-CFM and vanilla CFM models
- Metric computation (MSE, round-trip error, design diversity)
- Summary statistics and result saving
"""

import torch
import numpy as np
from pathlib import Path
import json
from typing import Dict, List, Optional, Callable

from uq_diagcfm.solvers import euler_method, euler_method_with_conditioning


# =============================================================================
# INN-specific forward and inverse passes
# =============================================================================


def forward_pass_inn(
    model: torch.nn.Module,
    x: torch.Tensor,
    device: torch.device,
) -> torch.Tensor:
    """Forward pass for INN: x -> y_hat.

    Args:
        model: INN model with forward(x) -> (y, z) method.
        x: Design parameters tensor (batch_size, P).
        device: Computation device.

    Returns:
        Predicted labels tensor (batch_size, L).
    """
    y_hat, z = model.forward(x)
    return y_hat


def inverse_pass_inn(
    model: torch.nn.Module,
    y: torch.Tensor,
    num_design_params: int,
    device: torch.device,
    num_samples: int = 10,
) -> torch.Tensor:
    """Inverse pass for INN: y -> x_gen.

    Sample z from prior and use model.inverse(y, z) to generate designs.

    Args:
        model: INN model with inverse(y, z) -> x method.
        y: Target labels tensor (batch_size, L).
        num_design_params: Design dimension P.
        device: Computation device.
        num_samples: Number of designs to generate per target.

    Returns:
        Generated designs tensor (num_samples, batch_size, P).
    """
    batch_size = y.shape[0]
    latent_dim = model.latent_dim

    all_samples = []
    for _ in range(num_samples):
        # Sample z from standard Gaussian prior (INN uses Gaussian latent)
        z = torch.randn(batch_size, latent_dim, device=device)
        # Inverse pass
        x_gen = model.inverse(y, z)
        all_samples.append(x_gen)

    # Stack samples: (num_samples, batch_size, P)
    return torch.stack(all_samples, dim=0)


def forward_pass_conditional_inn(
    model: torch.nn.Module,
    x: torch.Tensor,
    c: torch.Tensor,
    device: torch.device,
) -> torch.Tensor:
    """Forward pass for conditional INN: (x, c) -> y_hat.

    Args:
        model: Conditional INN model with forward(x, c) -> (y, z) method.
        x: Design parameters tensor (batch_size, P).
        c: Conditioning tensor (batch_size, cond_dim).
        device: Computation device.

    Returns:
        Predicted labels tensor (batch_size, L).
    """
    y_hat, z = model.forward(x, c)
    return y_hat


def inverse_pass_conditional_inn(
    model: torch.nn.Module,
    y: torch.Tensor,
    c: torch.Tensor,
    num_design_params: int,
    device: torch.device,
    num_samples: int = 10,
) -> torch.Tensor:
    """Inverse pass for conditional INN: (y, c) -> x_gen.

    Sample z from prior and use model.inverse(y, z, c) to generate designs.

    Args:
        model: Conditional INN model with inverse(y, z, c) -> x method.
        y: Target labels tensor (batch_size, L).
        c: Conditioning tensor (batch_size, cond_dim).
        num_design_params: Design dimension P.
        device: Computation device.
        num_samples: Number of designs to generate per target.

    Returns:
        Generated designs tensor (num_samples, batch_size, P).
    """
    batch_size = y.shape[0]
    latent_dim = model.latent_dim

    all_samples = []
    for _ in range(num_samples):
        # Sample z from standard Gaussian prior
        z = torch.randn(batch_size, latent_dim, device=device)
        # Inverse pass with conditioning
        x_gen = model.inverse(y, z, c)
        all_samples.append(x_gen)

    # Stack samples: (num_samples, batch_size, P)
    return torch.stack(all_samples, dim=0)


# =============================================================================
# Flow matching forward and inverse passes
# =============================================================================


def forward_pass(
    model: torch.nn.Module,
    x: torch.Tensor,
    num_labels: int,
    device: torch.device,
    diag_cfm: bool = True,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
) -> torch.Tensor:
    """
    Forward pass: design x -> predicted label y_hat.

    For Diag-CFM, the flow goes from [0_L, x] at t=0 to [y, z] at t=1.
    We extract the first L coordinates as the predicted labels.

    For vanilla CFM, we extract the first L coordinates of the output state.

    Args:
        model: Trained model.
        x: Design parameters tensor (batch_size, P).
        num_labels: Number of labels L.
        device: Computation device.
        diag_cfm: Whether this is a Diag-CFM model.
        conditioning: Optional conditioning tensor (batch_size, C) for models
                     that require conditioning (e.g., Unifoil physical params).
        steps: Number of integration steps.

    Returns:
        Predicted labels tensor (batch_size, L).
    """
    batch_size = x.shape[0]
    L = num_labels

    if diag_cfm:
        # For Diag-CFM: augment x to [0_L, x] at t=0
        x_aug = torch.cat([torch.zeros(batch_size, L, device=device), x], dim=1)
    else:
        x_aug = x

    # Integrate forward from t=0 to t=1
    if conditioning is not None:
        result = euler_method_with_conditioning(
            model, x_aug, conditioning, start_t=0, end_t=1, steps=steps
        )
    else:
        result = euler_method(model, x_aug, start_t=0, end_t=1, steps=steps)

    # Extract labels from first L coordinates
    y_hat = result[:, :L]

    return y_hat


def inverse_pass(
    model: torch.nn.Module,
    y: torch.Tensor,
    num_design_params: int,
    num_labels: int,
    device: torch.device,
    num_samples: int = 10,
    diag_cfm: bool = True,
    conditioning: Optional[torch.Tensor] = None,
    steps: int = 30,
    noise_distribution: str = "uniform",
) -> torch.Tensor:
    """
    Inverse pass: target label y -> generated designs x_gen.

    For Diag-CFM, the flow goes from [0_L, x] at t=0 to [y, z] at t=1.
    Starting with [y, z] at t=1, we integrate backward to get [0_L, x] at t=0.

    For vanilla CFM, the state at t=1 is [y, z[:P-L]].

    Args:
        model: Trained model.
        y: Target labels tensor (batch_size, L).
        num_design_params: Design dimension P.
        num_labels: Number of labels L.
        device: Computation device.
        num_samples: Number of designs to generate per target.
        diag_cfm: Whether this is a Diag-CFM model.
        conditioning: Optional conditioning tensor (batch_size, C).
        steps: Number of integration steps.
        noise_distribution: "uniform" for rand (default), "normal" for randn.

    Returns:
        Generated designs tensor (num_samples, batch_size, P).
    """
    batch_size = y.shape[0]
    P = num_design_params
    L = num_labels

    all_samples = []
    for _ in range(num_samples):
        # Sample noise
        if noise_distribution == "uniform":
            z = torch.rand(batch_size, P, device=device)
        else:
            z = torch.randn(batch_size, P, device=device)

        if diag_cfm:
            # For Diag-CFM: initial state is [y, z] at t=1
            s1 = torch.cat([y, z], dim=1)
        else:
            # For vanilla CFM: initial state is [y, z[:P-L]] at t=1
            s1 = torch.cat([y, z[:, : P - L]], dim=1)

        # Integrate backward from t=1 to t=0
        if conditioning is not None:
            result = euler_method_with_conditioning(
                model, s1, conditioning, start_t=1, end_t=0, steps=steps
            )
        else:
            result = euler_method(model, s1, start_t=1, end_t=0, steps=steps)

        if diag_cfm:
            # Extract design from last P coordinates ([0_L, x] -> x)
            x_gen = result[:, L:]
        else:
            x_gen = result

        all_samples.append(x_gen)

    # Stack: (num_samples, batch_size, P)
    return torch.stack(all_samples, dim=0)


def compute_mse_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """
    Compute MSE metrics between true and predicted labels.

    Args:
        y_true: True labels (N, L).
        y_pred: Predicted labels (N, L).

    Returns:
        Dictionary with overall MSE and per-label MSE.
    """
    mse = float(np.mean((y_true - y_pred) ** 2))
    per_label_mse = np.mean((y_true - y_pred) ** 2, axis=0).tolist()

    return {
        "forward_mse": mse,
        "forward_per_label_mse": per_label_mse,
    }


def compute_design_diversity(
    x_samples: torch.Tensor,
    roundtrip_errors: Optional[torch.Tensor] = None,
    epsilon: Optional[float] = None,
) -> Dict[str, np.ndarray]:
    """
    Compute design diversity (variance) across multiple samples.

    Optionally filters designs by round-trip error threshold before computing
    diversity, which prevents inflated diversity scores from inaccurate designs.

    Args:
        x_samples: Generated designs (num_samples, batch_size, P).
        roundtrip_errors: Optional per-sample round-trip errors (num_samples, batch_size).
                         If provided with epsilon, only designs with error < epsilon
                         are included in diversity computation.
        epsilon: Optional error threshold. Designs with roundtrip_error >= epsilon
                are excluded from diversity computation.

    Returns:
        Dictionary with variance-based diversity as numpy array of shape
        (batch_size,). Also includes 'num_valid_samples' (batch_size,)
        indicating how many samples passed the epsilon filter for each
        target (all samples if no filtering).
    """
    num_samples, batch_size, P = x_samples.shape

    # Determine if we should filter by round-trip error
    use_filtering = roundtrip_errors is not None and epsilon is not None

    if use_filtering:
        # Compute diversity with filtering
        diversity_var = []
        num_valid = []

        for b in range(batch_size):
            # Get errors for this target
            errors_b = roundtrip_errors[:, b]  # (num_samples,)
            # Mask for valid samples (error < epsilon)
            valid_mask = errors_b < epsilon
            valid_count = valid_mask.sum().item()
            num_valid.append(valid_count)

            if valid_count < 2:
                # Not enough valid samples for diversity computation
                diversity_var.append(0.0)
            else:
                # Get valid samples
                valid_samples = x_samples[valid_mask, b, :]  # (valid_count, P)

                # Variance across valid samples
                var_val = torch.var(valid_samples, dim=0).mean().item()
                diversity_var.append(var_val)

        return {
            "diversity_var": np.array(diversity_var),
            "num_valid_samples": np.array(num_valid),
        }
    else:
        # No filtering: variance across samples
        diversity_var = torch.var(x_samples, dim=0).mean(dim=1)  # (batch_size,)

        return {
            "diversity_var": diversity_var.cpu().numpy(),
            "num_valid_samples": np.full(batch_size, num_samples),
        }


def compute_roundtrip_errors(
    y_target: torch.Tensor,
    x_samples: torch.Tensor,
    forward_fn: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
    """
    Compute round-trip errors: y -> x_gen -> y_pred, measure ||y - y_pred||.

    Args:
        y_target: Target labels (batch_size, L).
        x_samples: Generated designs (num_samples, batch_size, P).
        forward_fn: Function that maps designs to labels.

    Returns:
        Average round-trip error per sample (batch_size,) as numpy array.
    """
    num_samples = x_samples.shape[0]
    roundtrip_errors = []

    for i in range(num_samples):
        x_gen = x_samples[i]
        y_pred = forward_fn(x_gen)
        error = torch.mean((y_target - y_pred) ** 2, dim=1)
        roundtrip_errors.append(error)

    # Average across samples
    avg_errors = torch.stack(roundtrip_errors, dim=0).mean(dim=0)
    return avg_errors.cpu().numpy()


def aggregate_diversity_stats(
    diversities_var: np.ndarray,
) -> Dict[str, float]:
    """
    Aggregate diversity metrics into summary statistics.

    Args:
        diversities_var: Variance-based diversity for each sample.

    Returns:
        Dictionary with mean and std for diversity.
    """
    return {
        "inverse_design_diversity_var_mean": float(np.mean(diversities_var)),
        "inverse_design_diversity_var_std": float(np.std(diversities_var)),
    }


def get_model_type(run_info: Dict) -> str:
    """
    Determine the model type from run_info.

    Args:
        run_info: Dictionary containing run information.

    Returns:
        Model type string: "INN", "Diag-CFM", or "CFM".
    """
    if run_info.get("model_type", "") == "INN":
        return "INN"
    elif run_info.get("diag_cfm", True):
        return "Diag-CFM"
    else:
        return "CFM"


def validate_ensemble_parameter_counts(run_infos: List[Dict]) -> None:
    """
    Validate that all models of the same type have identical parameter counts.

    This ensures fair comparisons within each model type (INN, CFM, Diag-CFM)
    by verifying that all models in an ensemble have the same architecture.

    Args:
        run_infos: List of run_info dictionaries from loaded models.

    Raises:
        ValueError: If models of the same type have different parameter counts.
    """
    # Group by model type
    groups: Dict[str, List[Dict]] = {"Diag-CFM": [], "CFM": [], "INN": []}
    for run_info in run_infos:
        model_type = get_model_type(run_info)
        groups[model_type].append(run_info)

    # Check each group for consistent parameter counts
    errors = []
    for model_type, infos in groups.items():
        if len(infos) <= 1:
            continue

        param_counts = [info.get("number_of_parameters") for info in infos]

        # Check if any are None
        if any(p is None for p in param_counts):
            errors.append(
                f"{model_type}: Some models missing 'number_of_parameters' in run_info"
            )
            continue

        unique_counts = set(param_counts)
        if len(unique_counts) > 1:
            # Build detailed error message
            count_details = {}
            for info in infos:
                params = info.get("number_of_parameters")
                run_path = info.get("run_path", "unknown")
                if params not in count_details:
                    count_details[params] = []
                count_details[params].append(run_path)

            details_str = "; ".join(
                f"{count:,} params: {len(runs)} model(s)"
                for count, runs in sorted(count_details.items())
            )
            errors.append(
                f"{model_type}: Found models with different parameter counts ({details_str})"
            )

    if errors:
        error_msg = "Ensemble validation failed - models of same type must have identical parameter counts:\n"
        error_msg += "\n".join(f"  - {e}" for e in errors)
        raise ValueError(error_msg)


def compute_summary_statistics(
    results: List[Dict], metric_keys: Optional[List[str]] = None
) -> Dict:
    """
    Compute summary statistics grouped by model type (Diag-CFM, CFM, INN).

    Args:
        results: List of result dictionaries.
        metric_keys: List of metric keys to summarize. If None, uses defaults.

    Returns:
        Summary dictionary with mean/std for each metric per model type.
    """
    if metric_keys is None:
        metric_keys = [
            "forward_mse",
            "inverse_roundtrip_error_mean",
            "inverse_design_diversity_var_mean",
        ]

    # Group results by model type
    groups = {"Diag-CFM": [], "CFM": [], "INN": []}
    for r in results:
        is_inn = r.get("model_type", "") == "INN"
        if is_inn:
            groups["INN"].append(r)
        elif r.get("diag_cfm", True):
            groups["Diag-CFM"].append(r)
        else:
            groups["CFM"].append(r)

    summary = {}
    for group_name, group_results in groups.items():
        if not group_results:
            continue

        summary[group_name] = {}
        for key in metric_keys:
            values = [r[key] for r in group_results if key in r]
            if values:
                summary[group_name][key] = {
                    "mean": float(np.mean(values)),
                    "std": float(np.std(values)),
                    "n": len(values),
                }

    return summary


def save_results(
    results: List[Dict],
    summary: Dict,
    output_path: Path,
    extra_fields: Optional[Dict] = None,
):
    """
    Save evaluation results to JSON file.

    Args:
        results: List of per-model result dictionaries.
        summary: Summary statistics dictionary.
        output_path: Path to save JSON file.
        extra_fields: Optional extra fields to include in output.
    """
    # Remove non-serializable items from results
    clean_results = []
    for r in results:
        clean_r = {}
        for k, v in r.items():
            if isinstance(v, (int, float, str, bool, list, dict, type(None))):
                clean_r[k] = v
        clean_results.append(clean_r)

    output = {
        "all_results": clean_results,
        "summary": summary,
    }
    if extra_fields:
        output.update(extra_fields)

    with open(output_path, "w") as f:
        json.dump(output, f, indent=2)

    print(f"\nResults saved to: {output_path}")


def print_summary(summary: Dict, title: str = "SUMMARY STATISTICS"):
    """
    Print summary statistics in a readable format.

    Args:
        summary: Summary dictionary from compute_summary_statistics.
        title: Title for the summary section.
    """
    print("\n" + "=" * 70)
    print(title)
    print("=" * 70)

    for group_name, metrics in summary.items():
        if group_name == "ensemble":
            if metrics:
                print("\nEnsemble:")
                for k, v in metrics.items():
                    print(f"  {k}: {v:.6f}")
        else:
            print(f"\n{group_name}:")
            print("-" * 40)
            for metric, stats in metrics.items():
                if isinstance(stats, dict) and "mean" in stats:
                    n_str = f" (n={stats['n']})" if "n" in stats else ""
                    print(f"  {metric}:")
                    print(
                        f"    mean: {stats['mean']:.6e} +/- {stats['std']:.6e}{n_str}"
                    )
                else:
                    print(f"  {metric}: {stats}")


# =============================================================================
# Diversity vs Epsilon Analysis
# =============================================================================


def compute_diversity_vs_epsilon(
    x_samples: torch.Tensor,
    roundtrip_errors: torch.Tensor,
    epsilon_values: np.ndarray,
) -> Dict[str, np.ndarray]:
    """
    Compute diversity metrics for a range of epsilon values.

    Args:
        x_samples: Generated designs (num_samples, batch_size, P).
        roundtrip_errors: Per-sample errors (num_samples, batch_size).
        epsilon_values: Array of epsilon thresholds to evaluate.

    Returns:
        Dictionary with arrays of diversity values and valid sample counts.
    """
    results = {
        "epsilon": epsilon_values,
        "diversity_var_mean": [],
        "diversity_var_std": [],
        "num_valid_mean": [],
        "num_valid_std": [],
    }

    for eps in epsilon_values:
        diversity = compute_design_diversity(
            x_samples, roundtrip_errors=roundtrip_errors, epsilon=eps
        )
        results["diversity_var_mean"].append(np.mean(diversity["diversity_var"]))
        results["diversity_var_std"].append(np.std(diversity["diversity_var"]))
        results["num_valid_mean"].append(np.mean(diversity["num_valid_samples"]))
        results["num_valid_std"].append(np.std(diversity["num_valid_samples"]))

    for key in [
        "diversity_var_mean",
        "diversity_var_std",
        "num_valid_mean",
        "num_valid_std",
    ]:
        results[key] = np.array(results[key])

    return results


def plot_diversity_vs_epsilon_combined(
    results: Dict[str, Dict[str, np.ndarray]],
    output_path: Path,
    num_samples: int = 10,
    title: Optional[str] = None,
):
    """
    Create a single combined plot comparing diversity vs epsilon for all model types.

    Uses dual y-axes: left for diversity (solid lines), right for valid samples (dashed).

    Args:
        results: Dictionary mapping model type to diversity vs epsilon results.
        output_path: Full path to save the figure (including .png extension).
        num_samples: Number of samples used (for annotation).
        title: Optional custom title for the plot.
    """
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D

    model_types = ["Diag-CFM", "CFM", "INN"]
    colors = {"Diag-CFM": "#1f77b4", "CFM": "#ff7f0e", "INN": "#2ca02c"}

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax2 = ax1.twinx()

    # Plot diversity (solid lines, left axis)
    diversity_lines = []
    for model_type in model_types:
        if model_type not in results:
            continue

        data = results[model_type]
        eps = data["epsilon"]
        div_mean = data["diversity_var_mean"]
        color = colors[model_type]

        (line,) = ax1.plot(eps, div_mean, color=color, linewidth=2, label=model_type)
        diversity_lines.append(line)

    # Plot valid samples (dashed lines, right axis) - lighter/more transparent
    for model_type in model_types:
        if model_type not in results:
            continue

        data = results[model_type]
        eps = data["epsilon"]
        valid_mean = data["num_valid_mean"]
        color = colors[model_type]

        ax2.plot(eps, valid_mean, "--", color=color, linewidth=1.5, alpha=0.4)

    # Configure left axis (diversity)
    ax1.set_xlabel(r"$\varepsilon$ (round-trip error threshold)", fontsize=12)
    ax1.set_ylabel("Design Diversity (variance)", fontsize=12)
    ax1.set_xscale("log")
    ax1.set_ylim(bottom=0)
    ax1.grid(True, alpha=0.3)

    # Configure right axis (valid samples)
    ax2.set_ylabel(
        f"Avg. valid samples (out of {num_samples})", fontsize=12, color="gray"
    )
    ax2.set_ylim(0, num_samples + 0.5)
    ax2.tick_params(axis="y", labelcolor="gray")

    # Create legend with line style explanation
    legend_elements = diversity_lines + [
        Line2D(
            [0],
            [0],
            color="gray",
            linewidth=2,
            linestyle="-",
            label="Diversity (solid)",
        ),
        Line2D(
            [0],
            [0],
            color="gray",
            linewidth=1.5,
            linestyle="--",
            alpha=0.5,
            label="Valid samples (dashed)",
        ),
    ]
    ax1.legend(handles=legend_elements, loc="lower right", fontsize=10)

    if title:
        ax1.set_title(title, fontsize=14)
    else:
        ax1.set_title("Diversity vs Round-Trip Error Threshold", fontsize=14)

    fig.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"Saved: {output_path}")
    plt.close()


def print_diversity_vs_epsilon_description():
    """Print a description of the diversity vs epsilon plot to stdout."""
    description = """
================================================================================
DIVERSITY VS ROUND-TRIP ERROR THRESHOLD PLOT
================================================================================

This plot shows how design diversity changes as we filter designs by accuracy.

AXES:
- X-axis (log scale): ε (epsilon) - the round-trip error threshold. Only designs
  with round-trip error < ε are included in the diversity computation.
- Left Y-axis: Design Diversity (variance) - mean variance of design parameters
  across generated samples, averaged over all test targets.
- Right Y-axis: Average valid samples - how many of the generated samples (out
  of the total) have round-trip error below ε.

LINES:
- Solid lines: Design diversity for each model type (Diag-CFM, CFM, INN)
- Dashed lines (lighter): Number of valid samples for each model type

INTERPRETATION:
- A model with lower round-trip errors will have its dashed line (valid samples)
  reach the maximum earlier (at lower ε).
- Comparing diversity at a fixed ε gives a fair comparison that accounts for
  accuracy - high diversity from inaccurate designs is not meaningful.
- The plot reveals whether high diversity scores come from genuinely diverse
  accurate designs or from noise in inaccurate designs.
================================================================================
"""
    print(description)
