"""Generate figures for the radial arm maze experiment."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from scipy.stats import spearmanr

import src.radial_arm_env as ram
from src.gru_policy import GRUPolicy, train_gru_policy

FIGDIR = os.path.join(PROJECT_ROOT, "figures")
os.makedirs(FIGDIR, exist_ok=True)

PANEL_W, PANEL_H = 5.0, 5.0
AXES_RECT = [0.22, 0.08, 0.68, 0.68]
SCATTER_S = 8
SCATTER_ALPHA_OBS = 0.6
SCATTER_ALPHA_CONT = 0.7
COLOR_RED = "#d6604d"
COLOR_BLUE = "#4393c3"
COLOR_LIGHT_BLUE = "#a6bddb"
COLOR_DARK_BLUE = "#2c7fb8"

obs_class_labels = {0: "Hub", 1: "Proximal", 2: "Medial", 3: "Tip"}
class_colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"]


def collect_with_visit_info(policy, val_trajs, obs_enc):
    policy.eval()
    all_h, all_nodes, all_visited, all_arms = [], [], [], []
    with torch.no_grad():
        for traj in val_trajs:
            states = traj['states']
            visited_seq = traj['visited_seq']
            T = len(traj['actions'])
            obs_seq = torch.stack([obs_enc[s] for s in states[:T]]).unsqueeze(0)
            h = policy.get_recurrent_output(obs_seq).squeeze(0).numpy()
            for t in range(T):
                all_h.append(h[t])
                all_nodes.append(states[t])
                all_visited.append(visited_seq[t])
                all_arms.append(ram.arm_of(states[t]))
    return np.array(all_h), np.array(all_nodes), np.array(all_visited), np.array(all_arms)


def obs_class(node):
    if ram.is_hub(node):
        return 0  # hub
    elif ram.pos_in_arm(node) == 0:
        return 1  # proximal
    elif ram.is_tip(node):
        return 3  # tip
    else:
        return 2  # medial


def square_limits(h2d, pad_frac=0.04):
    xmin, xmax = h2d[:, 0].min(), h2d[:, 0].max()
    ymin, ymax = h2d[:, 1].min(), h2d[:, 1].max()
    max_range = max(xmax - xmin, ymax - ymin) * (1 + 2 * pad_frac)
    cx, cy = (xmin + xmax) / 2, (ymin + ymax) / 2
    return (cx - max_range / 2, cx + max_range / 2), \
           (cy - max_range / 2, cy + max_range / 2)


def _save_obs_panel(emb, obs_classes, fname, xlim, ylim, show_legend=False):
    fig = plt.figure(figsize=(PANEL_W, PANEL_H))
    ax = fig.add_axes(AXES_RECT)
    for c in range(4):
        mask = obs_classes == c
        ax.scatter(emb[mask, 0], emb[mask, 1], c=class_colors[c],
                   s=SCATTER_S, alpha=SCATTER_ALPHA_OBS,
                   label=obs_class_labels[c], edgecolors="none")
    if show_legend:
        ax.legend(fontsize=9, markerscale=3, loc="center right",
                  bbox_to_anchor=(-0.02, 0.5), frameon=True, handletextpad=0.5)
    ax.set_xlim(xlim); ax.set_ylim(ylim)
    ax.set_xticks([]); ax.set_yticks([])
    fig.savefig(fname, dpi=200, transparent=True)
    plt.close(fig)
    print(f"Saved {fname}")


def _save_radial_panel(emb, radial, fname, xlim, ylim, show_colorbar=False):
    fig = plt.figure(figsize=(PANEL_W, PANEL_H))
    ax = fig.add_axes(AXES_RECT)
    sc = ax.scatter(emb[:, 0], emb[:, 1], c=radial, cmap="viridis",
                    s=SCATTER_S, alpha=SCATTER_ALPHA_CONT, edgecolors="none",
                    vmin=0, vmax=3)
    if show_colorbar:
        cbar_ax = fig.add_axes([0.12, 0.12, 0.03, 0.60])
        fig.colorbar(sc, cax=cbar_ax, label="Distance from hub")
        cbar_ax.yaxis.set_ticks_position("left")
        cbar_ax.yaxis.set_label_position("left")
    ax.set_xlim(xlim); ax.set_ylim(ylim)
    ax.set_xticks([]); ax.set_yticks([])
    fig.savefig(fname, dpi=200, transparent=True)
    plt.close(fig)
    print(f"Saved {fname}")


def _save_visited_panel(emb, n_vis, fname, xlim, ylim, show_colorbar=False):
    fig = plt.figure(figsize=(PANEL_W, PANEL_H))
    ax = fig.add_axes(AXES_RECT)
    sc = ax.scatter(emb[:, 0], emb[:, 1], c=n_vis, cmap="plasma",
                    s=SCATTER_S, alpha=SCATTER_ALPHA_CONT, edgecolors="none")
    if show_colorbar:
        cbar_ax = fig.add_axes([0.12, 0.12, 0.03, 0.60])
        fig.colorbar(sc, cax=cbar_ax, label="N arms visited")
        cbar_ax.yaxis.set_ticks_position("left")
        cbar_ax.yaxis.set_label_position("left")
    ax.set_xlim(xlim); ax.set_ylim(ylim)
    ax.set_xticks([]); ax.set_yticks([])
    fig.savefig(fname, dpi=200, transparent=True)
    plt.close(fig)
    print(f"Saved {fname}")


def _embed_tsne(h):
    """Run PCA(30) + t-SNE on hidden states."""
    pca = PCA(n_components=min(30, h.shape[1])).fit(h)
    h_pca = pca.transform(h)
    tsne = TSNE(n_components=2, perplexity=40, random_state=42)
    return tsne.fit_transform(h_pca)


def plot_tsne_comparison(h_untrained, h_trained, nodes, visited):
    """Save untrained vs trained t-SNE panels (paper Figure 3 style)."""
    radial = np.array([ram.radial_distance(n) for n in nodes])
    obs_classes = np.array([obs_class(n) for n in nodes])
    n_vis = np.array([bin(v).count('1') for v in visited])

    print("Computing untrained t-SNE...")
    u_emb = _embed_tsne(h_untrained)
    print("Computing trained t-SNE...")
    t_emb = _embed_tsne(h_trained)

    u_xlim, u_ylim = square_limits(u_emb)
    t_xlim, t_ylim = square_limits(t_emb)

    _save_obs_panel(u_emb, obs_classes,
                    os.path.join(FIGDIR, "radial_arm_untrained_obs.png"),
                    u_xlim, u_ylim, show_legend=True)
    _save_obs_panel(t_emb, obs_classes,
                    os.path.join(FIGDIR, "radial_arm_trained_obs.png"),
                    t_xlim, t_ylim)

    _save_radial_panel(u_emb, radial,
                       os.path.join(FIGDIR, "radial_arm_untrained_radial.png"),
                       u_xlim, u_ylim, show_colorbar=True)
    _save_radial_panel(t_emb, radial,
                       os.path.join(FIGDIR, "radial_arm_trained_radial.png"),
                       t_xlim, t_ylim)

    _save_visited_panel(u_emb, n_vis,
                        os.path.join(FIGDIR, "radial_arm_untrained_visited.png"),
                        u_xlim, u_ylim)
    _save_visited_panel(t_emb, n_vis,
                        os.path.join(FIGDIR, "radial_arm_trained_visited.png"),
                        t_xlim, t_ylim, show_colorbar=True)
def plot_probes(save_path):
    metrics = ['|PC1\u2013radial \u03c1|', 'Node probe\n(25-way)', 'Visit bit\nprobe',
               'N-visited\n(9-way)']
    untrained = [0.82, 0.272, 0.618, 0.245]
    trained = [0.862, 0.274, 0.729, 0.529]
    trained_std = [0.002, 0.002, 0.011, 0.065]
    untrained_std = [0.01, 0.003, 0.004, 0.004]

    x = np.arange(len(metrics))
    width = 0.35

    fig, ax = plt.subplots(figsize=(7, 5))
    ax.bar(x - width/2, untrained, width, yerr=untrained_std,
           label='Untrained GRU', color=COLOR_LIGHT_BLUE, capsize=3)
    ax.bar(x + width/2, trained, width, yerr=trained_std,
           label='Trained GRU', color=COLOR_DARK_BLUE, capsize=3)

    ax.set_ylabel('Accuracy / Correlation', fontsize=12)
    ax.set_xticks(x)
    ax.set_xticklabels(metrics, fontsize=10)
    ax.legend(fontsize=10)
    ax.set_ylim(0, 1.0)
    ax.axhline(y=1/25, color='gray', linestyle='--', alpha=0.5)

    for i in range(len(metrics)):
        gap = trained[i] - untrained[i]
        if abs(gap) > 0.05:
            y_top = max(trained[i], untrained[i]) + max(trained_std[i], untrained_std[i]) + 0.02
            ax.annotate(f'+{gap:.3f}' if gap > 0 else f'{gap:.3f}',
                       xy=(x[i], y_top), ha='center', fontsize=9, color=COLOR_RED)

    plt.tight_layout()
    fig.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {save_path}")


def plot_ablation(save_prefix):
    """PCA ablation: LL and revisit rate as separate panels."""
    labels = ['None', 'PC1', 'PCs 2\u20134', 'PCs 2\u20138', 'PCs 1\u20134']
    lls = [-1.020, -2.336, -1.188, -1.187, -2.840]
    revisit_rates = [0.717, 0.921, 0.604, 0.607, 0.801]

    x = np.arange(len(labels))
    colors = [COLOR_BLUE, COLOR_RED, COLOR_BLUE, COLOR_BLUE, COLOR_RED]

    fig, ax = plt.subplots(figsize=(5, 4.5))
    ax.bar(x, lls, color=colors, width=0.65)
    ax.set_ylabel('Log-likelihood (bits/dec)', fontsize=12)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=10)
    ax.axhline(y=-1.020, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel('Ablation', fontsize=12)
    plt.tight_layout()
    out = f"{save_prefix}_ll.png"
    fig.savefig(out, dpi=200, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {out}")

    fig, ax = plt.subplots(figsize=(5, 4.5))
    ax.bar(x, revisit_rates, color=colors, width=0.65)
    ax.set_ylabel('Revisit rate', fontsize=12)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=10)
    ax.axhline(y=0.717, color='gray', linestyle='--', alpha=0.5)
    ax.set_ylim(0, 1.0)
    ax.set_xlabel('Ablation', fontsize=12)
    plt.tight_layout()
    out = f"{save_prefix}_revisit.png"
    fig.savefig(out, dpi=200, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {out}")


def plot_2x2(save_path):
    """2x2 encoding swap table as a grouped bar chart."""
    conditions = ['Structural\nUntrained', 'Structural\nTrained',
                  'Random\nUntrained', 'Random\nTrained']
    pc1_rho = [0.82, 0.862, 0.353, 0.474]
    node_probe = [0.272, 0.274, 0.566, 0.761]
    visit_bit = [0.618, 0.729, 0.645, 0.699]

    x = np.arange(len(conditions))
    width = 0.25

    fig, ax = plt.subplots(figsize=(7, 5))
    ax.bar(x - width, pc1_rho, width, label='|PC1\u2013radial \u03c1|',
           color=COLOR_BLUE)
    ax.bar(x, node_probe, width, label='Node probe (25-way)',
           color=COLOR_RED)
    ax.bar(x + width, visit_bit, width, label='Visit bit probe',
           color=COLOR_DARK_BLUE)

    ax.set_ylabel('Accuracy / Correlation', fontsize=12)
    ax.set_xticks(x)
    ax.set_xticklabels(conditions, fontsize=10)
    ax.legend(fontsize=9)
    ax.set_ylim(0, 1.0)

    plt.tight_layout()
    fig.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved {save_path}")


CACHE_PATH = os.path.join(PROJECT_ROOT, "checkpoints", "radial_arm_plot_cache.pt")


def load_or_compute():
    if os.path.exists(CACHE_PATH):
        print(f"Loading cached data from {CACHE_PATH}...")
        cache = torch.load(CACHE_PATH, weights_only=False)
        return (cache["h_untrained"], cache["h_trained"],
                cache["nodes"], cache["visited"], cache["arms"])

    print("No cache found, generating data and training...")
    trajs = ram.generate_foraging_trajectories(
        n_trajs=1500, max_steps=300, optimal_prob=0.8, seed=42)
    train_trajs, val_trajs = ram.split_trajectories(trajs)
    obs_enc = ram.structural_obs_encoding()
    train_data = ram.build_obs_dataset(train_trajs, obs_enc)

    torch.manual_seed(0); np.random.seed(0)
    untrained_policy = GRUPolicy(obs_dim=12, hidden_dim=128, n_actions=ram.N_ACTIONS)
    h_untrained, nodes, visited, arms = collect_with_visit_info(
        untrained_policy, val_trajs, obs_enc)

    torch.manual_seed(0); np.random.seed(0)
    policy = GRUPolicy(obs_dim=12, hidden_dim=128, n_actions=ram.N_ACTIONS)
    policy, _ = train_gru_policy(policy, train_data, n_epochs=50,
                                  print_every=50, batch_size=64)
    h_trained, nodes_t, visited_t, arms_t = collect_with_visit_info(
        policy, val_trajs, obs_enc)

    os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
    torch.save({
        "h_untrained": h_untrained,
        "h_trained": h_trained,
        "nodes": nodes,
        "visited": visited,
        "arms": arms,
    }, CACHE_PATH)
    print(f"Saved cache to {CACHE_PATH}")

    return h_untrained, h_trained, nodes, visited, arms


def main():
    print("Generating radial arm maze figures...", flush=True)

    h_untrained, h_trained, nodes, visited, arms = load_or_compute()

    print("\nUntrained vs trained t-SNE panels...")
    plot_tsne_comparison(h_untrained, h_trained, nodes, visited)

    print("Probe bar chart...")
    plot_probes(os.path.join(FIGDIR, "radial_arm_probes.png"))

    print("PCA ablation...")
    plot_ablation(os.path.join(FIGDIR, "radial_arm_ablation"))

    print("2x2 encoding swap...")
    plot_2x2(os.path.join(FIGDIR, "radial_arm_2x2.png"))

    print("\nDone!")


if __name__ == "__main__":
    main()
