"""Plotting utilities for anisotropy experiments."""

import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np

try:
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False


def ensure_matplotlib():
    if not HAS_MATPLOTLIB:
        raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")


def plot_eigenvalue_spectrum(
    model_eigenvalues: Dict[str, np.ndarray],
    output_path: str,
    title: str = "Global Covariance Eigenvalue Spectrum",
    top_k: int = 100,
):
    """
    Plot eigenvalue spectra for multiple models.

    Args:
        model_eigenvalues: Dict mapping model name to eigenvalues array
        output_path: Path to save the plot
        title: Plot title
        top_k: Number of top eigenvalues to plot
    """
    ensure_matplotlib()

    fig, ax = plt.subplots(figsize=(10, 6))

    colors = cm.tab10(np.linspace(0, 1, len(model_eigenvalues)))

    for (model_name, eigenvalues), color in zip(model_eigenvalues.items(), colors):
        # Normalize eigenvalues
        normalized = eigenvalues / eigenvalues.sum()
        k = min(top_k, len(normalized))
        ax.plot(range(1, k + 1), normalized[:k], label=model_name, color=color, linewidth=2)

    ax.set_xlabel("Eigenvalue Index", fontsize=12)
    ax.set_ylabel("Normalized Eigenvalue", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc="upper right")
    ax.set_yscale("log")
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def plot_isotropy_comparison(
    model_metrics: Dict[str, dict],
    output_path: str,
    title: str = "Global Isotropy Comparison",
):
    """
    Bar chart comparing isotropy scores across models.

    Args:
        model_metrics: Dict mapping model name to metrics dict with 'isotropy_score'
        output_path: Path to save the plot
        title: Plot title
    """
    ensure_matplotlib()

    models = list(model_metrics.keys())
    scores = [model_metrics[m]["isotropy_score"] for m in models]

    fig, ax = plt.subplots(figsize=(10, 6))

    colors = cm.tab10(np.linspace(0, 1, len(models)))
    bars = ax.bar(models, scores, color=colors)

    ax.set_ylabel("Isotropy Score (d_eff / d)", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_ylim(0, 1)

    # Add value labels on bars
    for bar, score in zip(bars, scores):
        height = bar.get_height()
        ax.annotate(f"{score:.3f}",
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha="center", va="bottom", fontsize=10)

    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def plot_local_anisotropy_comparison(
    model_metrics: Dict[str, Dict[str, float]],
    output_path: str,
    title: str = "Local Anisotropy by Neighborhood Type",
):
    """
    Grouped bar chart comparing local anisotropy across models and neighborhood types.

    Args:
        model_metrics: Dict mapping model name to dict with neighborhood type -> anisotropy
                      e.g., {"clip": {"random": 0.1, "semantic": 0.3, "adversarial": 0.5}}
        output_path: Path to save the plot
        title: Plot title
    """
    ensure_matplotlib()

    models = list(model_metrics.keys())
    neighborhood_types = list(model_metrics[models[0]].keys())

    x = np.arange(len(models))
    width = 0.25
    n_types = len(neighborhood_types)

    fig, ax = plt.subplots(figsize=(12, 6))

    colors = ["#4C72B0", "#55A868", "#C44E52"]  # Blue, Green, Red

    for i, ntype in enumerate(neighborhood_types):
        values = [model_metrics[m][ntype] for m in models]
        offset = (i - n_types / 2 + 0.5) * width
        bars = ax.bar(x + offset, values, width, label=ntype, color=colors[i % len(colors)])

    ax.set_xlabel("Model", fontsize=12)
    ax.set_ylabel("Local Anisotropy (λ₁ / Σλ)", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(models, rotation=45, ha="right")
    ax.legend(title="Neighborhood Type")
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def plot_isotropy_vs_accuracy(
    model_data: Dict[str, Tuple[float, float]],
    output_path: str,
    title: str = "Global Isotropy vs Linear Probe Accuracy",
    xlabel: str = "Isotropy Score",
    ylabel: str = "Linear Probe Accuracy",
):
    """
    Scatter plot of isotropy vs accuracy.

    Args:
        model_data: Dict mapping model name to (isotropy, accuracy) tuple
        output_path: Path to save the plot
        title: Plot title
    """
    ensure_matplotlib()

    fig, ax = plt.subplots(figsize=(8, 6))

    colors = cm.tab10(np.linspace(0, 1, len(model_data)))

    for (model_name, (iso, acc)), color in zip(model_data.items(), colors):
        ax.scatter(iso, acc, s=150, c=[color], label=model_name, edgecolors="black", linewidth=1)

    ax.set_xlabel(xlabel, fontsize=12)
    ax.set_ylabel(ylabel, fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc="best")
    ax.grid(True, alpha=0.3)

    # Add trend line
    isos = [model_data[m][0] for m in model_data]
    accs = [model_data[m][1] for m in model_data]
    if len(isos) > 2:
        z = np.polyfit(isos, accs, 1)
        p = np.poly1d(z)
        x_line = np.linspace(min(isos), max(isos), 100)
        ax.plot(x_line, p(x_line), "--", color="gray", alpha=0.5, label=f"Trend")

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def plot_anisotropy_vs_gain(
    model_data: Dict[str, Tuple[float, float]],
    output_path: str,
    title: str = "Local Anisotropy vs Conditional Readout Gain",
):
    """
    Scatter plot of local anisotropy vs conditional readout gain.

    Args:
        model_data: Dict mapping model name to (anisotropy, gain) tuple
        output_path: Path to save the plot
        title: Plot title
    """
    ensure_matplotlib()

    fig, ax = plt.subplots(figsize=(8, 6))

    colors = cm.tab10(np.linspace(0, 1, len(model_data)))

    for (model_name, (aniso, gain)), color in zip(model_data.items(), colors):
        ax.scatter(aniso, gain * 100, s=150, c=[color], label=model_name, edgecolors="black", linewidth=1)

    ax.set_xlabel("Local Anisotropy (adversarial)", fontsize=12)
    ax.set_ylabel("Conditional Readout Gain (%)", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(loc="best")
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color="black", linestyle="-", linewidth=0.5)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def plot_per_sample_analysis(
    anisotropy: np.ndarray,
    baseline_correct: np.ndarray,
    conditional_correct: np.ndarray,
    output_path: str,
    title: str = "Per-Sample Analysis: Anisotropy vs Conditional Gain",
    n_bins: int = 10,
):
    """
    Binned analysis of per-sample anisotropy vs conditional readout improvement.

    Args:
        anisotropy: (N,) array of per-sample anisotropy values
        baseline_correct: (N,) binary array of baseline correctness
        conditional_correct: (N,) binary array of conditional correctness
        output_path: Path to save the plot
        title: Plot title
        n_bins: Number of bins for anisotropy
    """
    ensure_matplotlib()

    # Compute flipped samples (baseline wrong, conditional right)
    flipped = (~baseline_correct.astype(bool)) & conditional_correct.astype(bool)

    # Bin by anisotropy
    bins = np.linspace(anisotropy.min(), anisotropy.max(), n_bins + 1)
    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_indices = np.digitize(anisotropy, bins) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)

    flip_rates = []
    for i in range(n_bins):
        mask = bin_indices == i
        if mask.sum() > 0:
            flip_rates.append(flipped[mask].mean())
        else:
            flip_rates.append(0)

    fig, ax = plt.subplots(figsize=(8, 6))

    ax.bar(bin_centers, flip_rates, width=(bins[1] - bins[0]) * 0.8, color="#55A868", edgecolor="black")

    ax.set_xlabel("Local Anisotropy", fontsize=12)
    ax.set_ylabel("Flip Rate (baseline wrong → conditional right)", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def plot_all_results(
    results: dict,
    output_dir: str,
):
    """
    Generate all plots from experiment results.

    Args:
        results: Full results dictionary from run_experiment
        output_dir: Directory to save plots
    """
    ensure_matplotlib()

    os.makedirs(output_dir, exist_ok=True)

    # Plot A: Eigenvalue spectra
    eigenvalues = {
        model: data["global_isotropy"]["eigenvalues"]
        for model, data in results.items()
        if "global_isotropy" in data
    }
    if eigenvalues:
        plot_eigenvalue_spectrum(
            eigenvalues,
            os.path.join(output_dir, "plot_a_eigenvalue_spectrum.png"),
        )

    # Plot B: Isotropy comparison
    isotropy_metrics = {
        model: {"isotropy_score": data["global_isotropy"]["isotropy_score"]}
        for model, data in results.items()
        if "global_isotropy" in data
    }
    if isotropy_metrics:
        plot_isotropy_comparison(
            isotropy_metrics,
            os.path.join(output_dir, "plot_b_isotropy_comparison.png"),
        )

    # Plot C: Local anisotropy comparison
    local_aniso = {}
    for model, data in results.items():
        if "local_anisotropy" in data:
            local_aniso[model] = {
                "random": data["local_anisotropy"].get("random", {}).get("mean_anisotropy", 0),
                "semantic": data["local_anisotropy"].get("semantic", {}).get("mean_anisotropy", 0),
                "adversarial": data["local_anisotropy"].get("adversarial", {}).get("mean_anisotropy", 0),
            }
    if local_aniso:
        plot_local_anisotropy_comparison(
            local_aniso,
            os.path.join(output_dir, "plot_c_local_anisotropy.png"),
        )

    # Plot D: Isotropy vs linear probe accuracy
    iso_vs_acc = {}
    for model, data in results.items():
        if "global_isotropy" in data and "linear_probe" in data:
            iso_vs_acc[model] = (
                data["global_isotropy"]["isotropy_score"],
                data["linear_probe"].get("accuracy", 0),
            )
    if iso_vs_acc:
        plot_isotropy_vs_accuracy(
            iso_vs_acc,
            os.path.join(output_dir, "plot_d_isotropy_vs_accuracy.png"),
        )

    # Plot E: Local anisotropy vs conditional gain
    aniso_vs_gain = {}
    for model, data in results.items():
        if "local_anisotropy" in data and "conditional_readout" in data:
            aniso = data["local_anisotropy"].get("adversarial", {}).get("mean_anisotropy", 0)
            gain = data["conditional_readout"].get("gain", 0)
            aniso_vs_gain[model] = (aniso, gain)
    if aniso_vs_gain:
        plot_anisotropy_vs_gain(
            aniso_vs_gain,
            os.path.join(output_dir, "plot_e_anisotropy_vs_gain.png"),
        )

    print(f"Plots saved to {output_dir}")
