"""
visualize.py — Chart generation for the consistency loss experiment.

Produces:
  1. coupling_strength.png  — val_coupling_strength per epoch, all variants
  2. explanation_correctness.png — val_bleu1 + val_rouge_l per epoch, all variants
  3. counterfactual_swap.png — val_swap_influence per epoch, all variants
  4. claim_accuracy.png — val_claim_accuracy per epoch, all variants
  5. losses.png — training loss curves

Design palette (Nexus-inspired, clean technical aesthetic):
  Colors: a high-contrast set with accent blues, warm oranges, teals, and pinks.
  Font: DejaVu Sans (matplotlib default, clean sans-serif)
  Background: #F8F9FB (near-white)
  Grid: subtle, alpha=0.3
"""

import os
import math
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np

# ──────────────────────────────────────────────────────────────────────────────
# 1. Palette (Nexus-inspired)
# ──────────────────────────────────────────────────────────────────────────────

# V1 (original) variants
PALETTE = {
    "consistency_loss":               "#2563EB",  # Nexus blue — primary
    "no_consistency_loss":            "#F97316",  # orange — baseline
    "claim_only_pooling":             "#10B981",  # teal — negative ctrl
    "random_label_consistency":       "#EC4899",  # pink — negative ctrl
    # V2 (stronger ablation) variants
    "no_claim_to_claim_attention":    "#7C3AED",  # violet — structural
    "claims_from_explanation_only":   "#0891B2",  # cyan — strict flow
    "surface_bottleneck_consistency": "#B45309",  # amber — surface
    "surface_bottleneck_no_expl_lm":  "#DC2626",  # red — surface + masked LM
}

LABELS = {
    "consistency_loss":               "Consistency Loss (V1)",
    "no_consistency_loss":            "No Consistency Loss (V1)",
    "claim_only_pooling":             "Claim-Only Pooling (V1)",
    "random_label_consistency":       "Random Label Consistency (V1)",
    # V2 labels
    "no_claim_to_claim_attention":    "No Claim→Claim Attn (V2)",
    "claims_from_explanation_only":   "Claims from Expl Only (V2)",
    "surface_bottleneck_consistency": "Surface Bottleneck (V2)",
    "surface_bottleneck_no_expl_lm":  "Surface + No Expl LM (V2)",
}

MARKERS = {
    "consistency_loss":               "o",
    "no_consistency_loss":            "s",
    "claim_only_pooling":             "^",
    "random_label_consistency":       "D",
    # V2 markers
    "no_claim_to_claim_attention":    "P",
    "claims_from_explanation_only":   "X",
    "surface_bottleneck_consistency": "*",
    "surface_bottleneck_no_expl_lm":  "h",
}

BG_COLOR = "#F8F9FB"
GRID_COLOR = "#C8CDD4"
SPINE_COLOR = "#8A9099"


def _base_style(fig, axes_list):
    fig.patch.set_facecolor(BG_COLOR)
    for ax in axes_list:
        ax.set_facecolor(BG_COLOR)
        ax.grid(True, color=GRID_COLOR, alpha=0.5, linewidth=0.7, linestyle="--")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["left"].set_color(SPINE_COLOR)
        ax.spines["bottom"].set_color(SPINE_COLOR)
        ax.tick_params(colors="#444950", labelsize=9)
        ax.xaxis.label.set_color("#444950")
        ax.yaxis.label.set_color("#444950")
        ax.title.set_color("#1A1D20")


def _save(fig, path):
    fig.tight_layout(pad=1.8)
    fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG_COLOR)
    plt.close(fig)
    print(f"  Saved: {path}")


# ──────────────────────────────────────────────────────────────────────────────
# 2. Individual chart functions
# ──────────────────────────────────────────────────────────────────────────────

def _plot_metric(df: pd.DataFrame, metric_col: str, title: str,
                 ylabel: str, output_path: str, ylim=None):
    fig, ax = plt.subplots(figsize=(8, 4.5))

    for variant in df["variant"].unique():
        sub = df[df["variant"] == variant].sort_values("epoch")
        color  = PALETTE.get(variant, "#666666")
        label  = LABELS.get(variant, variant)
        marker = MARKERS.get(variant, "o")
        ax.plot(sub["epoch"], sub[metric_col],
                color=color, label=label, marker=marker,
                linewidth=2, markersize=5, markevery=max(1, len(sub) // 10))

    ax.set_xlabel("Epoch", fontsize=10)
    ax.set_ylabel(ylabel, fontsize=10)
    ax.set_title(title, fontsize=12, fontweight="bold", pad=10)
    ax.legend(fontsize=8.5, framealpha=0.8, edgecolor=GRID_COLOR)
    ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
    if ylim:
        ax.set_ylim(*ylim)

    _base_style(fig, [ax])
    _save(fig, output_path)


def plot_coupling_strength(df: pd.DataFrame, output_dir: str):
    path = os.path.join(output_dir, "coupling_strength.png")
    _plot_metric(
        df, "val_coupling_strength",
        "Coupling Strength — Classifier Accuracy on Explanation Pooling",
        "Mean Classifier Accuracy (time + space + correct)",
        path, ylim=(0.0, 1.05)
    )


def plot_explanation_correctness(df: pd.DataFrame, output_dir: str):
    # Two sub-plots: BLEU-1 and ROUGE-L
    fig, axes = plt.subplots(1, 2, figsize=(13, 4.5))
    for variant in df["variant"].unique():
        sub    = df[df["variant"] == variant].sort_values("epoch")
        color  = PALETTE.get(variant, "#666666")
        label  = LABELS.get(variant, variant)
        marker = MARKERS.get(variant, "o")
        mv = max(1, len(sub) // 10)
        axes[0].plot(sub["epoch"], sub["val_bleu1"],
                     color=color, label=label, marker=marker,
                     linewidth=2, markersize=5, markevery=mv)
        axes[1].plot(sub["epoch"], sub["val_rouge_l"],
                     color=color, label=label, marker=marker,
                     linewidth=2, markersize=5, markevery=mv)

    for ax, title, ylabel in zip(
        axes,
        ["Explanation Correctness — BLEU-1", "Explanation Correctness — ROUGE-L"],
        ["BLEU-1 Score", "ROUGE-L Score"]
    ):
        ax.set_xlabel("Epoch", fontsize=10)
        ax.set_ylabel(ylabel, fontsize=10)
        ax.set_title(title, fontsize=11, fontweight="bold", pad=8)
        ax.legend(fontsize=8, framealpha=0.8, edgecolor=GRID_COLOR)
        ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
        ax.set_ylim(0.0, 1.05)

    _base_style(fig, list(axes))
    _save(fig, os.path.join(output_dir, "explanation_correctness.png"))


def plot_counterfactual_swap(df: pd.DataFrame, output_dir: str):
    path = os.path.join(output_dir, "counterfactual_swap.png")
    _plot_metric(
        df, "val_swap_influence",
        "Counterfactual Swap Influence",
        "Swap Influence Score (own − swapped accuracy)",
        path
    )


def plot_claim_accuracy(df: pd.DataFrame, output_dir: str):
    path = os.path.join(output_dir, "claim_accuracy.png")
    _plot_metric(
        df, "val_claim_accuracy",
        "Claim Emission Accuracy",
        "Fraction of Correct Claim Tokens Generated",
        path, ylim=(0.0, 1.05)
    )


def plot_losses(df: pd.DataFrame, output_dir: str):
    fig, axes = plt.subplots(1, 2, figsize=(13, 4.5))
    for variant in df["variant"].unique():
        sub    = df[df["variant"] == variant].sort_values("epoch")
        color  = PALETTE.get(variant, "#666666")
        label  = LABELS.get(variant, variant)
        marker = MARKERS.get(variant, "o")
        mv = max(1, len(sub) // 10)
        axes[0].plot(sub["epoch"], sub["train_lm_loss"],
                     color=color, label=label, marker=marker,
                     linewidth=2, markersize=5, markevery=mv)
        axes[1].plot(sub["epoch"], sub["train_total_loss"],
                     color=color, label=label, marker=marker,
                     linewidth=2, markersize=5, markevery=mv)

    for ax, title, ylabel in zip(
        axes,
        ["Training LM Loss", "Training Total Loss"],
        ["LM Cross-Entropy", "Total Loss (LM + λ·Consistency)"]
    ):
        ax.set_xlabel("Epoch", fontsize=10)
        ax.set_ylabel(ylabel, fontsize=10)
        ax.set_title(title, fontsize=11, fontweight="bold", pad=8)
        ax.legend(fontsize=8, framealpha=0.8, edgecolor=GRID_COLOR)
        ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))

    _base_style(fig, list(axes))
    _save(fig, os.path.join(output_dir, "losses.png"))


# ──────────────────────────────────────────────────────────────────────────────
# 3. Convenience wrapper
# ──────────────────────────────────────────────────────────────────────────────

def generate_all_charts(df: pd.DataFrame, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    print("\nGenerating charts...")
    plot_coupling_strength(df, output_dir)
    plot_explanation_correctness(df, output_dir)
    plot_counterfactual_swap(df, output_dir)
    plot_claim_accuracy(df, output_dir)
    plot_losses(df, output_dir)
    print("All charts saved.")


if __name__ == "__main__":
    # Test with synthetic data
    import pandas as pd
    rows = []
    for variant in ["consistency_loss", "no_consistency_loss",
                    "claim_only_pooling", "random_label_consistency"]:
        for ep in range(1, 6):
            base = {"consistency_loss": 0.45, "no_consistency_loss": 0.35,
                    "claim_only_pooling": 0.38, "random_label_consistency": 0.33}[variant]
            rows.append({
                "variant": variant, "epoch": ep,
                "val_coupling_strength": base + ep * 0.02 + np.random.randn() * 0.01,
                "val_bleu1":   0.1 + ep * 0.01 + np.random.randn() * 0.005,
                "val_rouge_l": 0.12 + ep * 0.01,
                "val_swap_influence": 0.05 + ep * 0.01,
                "val_claim_accuracy": 0.2 + ep * 0.03,
                "train_lm_loss": 4.0 - ep * 0.2 + np.random.randn() * 0.05,
                "train_total_loss": 4.5 - ep * 0.2,
            })
    df = pd.DataFrame(rows)
    generate_all_charts(df, "outputs")
