"""Training dynamics: track how spatial representation emerges over training epochs."""

import os
import sys
import json
import argparse
import time

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets
from src.gru_policy import GRUPolicy
from src.analysis import collect_hidden_states
from src.evaluation import compute_log_likelihood_gru

N_STATES = 127
N_ACTIONS = 3
HIDDEN_DIM = 128
BATCH_SIZE = 64
MAX_SEQ_LEN = 200
CHECKPOINT_EPOCHS = [0, 1, 2, 5, 10, 20, 50, 100, 150, 200]
SEEDS = [0, 1, 2, 3, 4]

RAW_OBS_BASELINE = 15.0     # % probe accuracy with raw observation only
CHANCE_BASELINE = 100.0 / N_STATES  # 0.79%
MLP_LL_BASELINE = -1.281    # bits/dec
RANDOM_LL_BASELINE = -np.log2(3)  # -1.585 bits/dec
UNTRAINED_PC1_RHO = 0.865   # PC1-depth rho at epoch 0

OBS_CLASS_LABELS = {0: "Root", 1: "Depth 1", 2: "Depths 2-4", 3: "Depth 5", 4: "Leaves"}
OBS_CLASS_COLORS = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00"]


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    degree = 2.0 if node == 0 else (1.0 if node >= first_leaf else 3.0)
    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0
    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        d_degree = 2.0 if dest == 0 else (1.0 if dest >= first_leaf else 3.0)
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def get_node_depth(node):
    """Depth of a node in a binary tree (root=0)."""
    d, n = 0, node
    while n > 0:
        n = (n - 1) // 2
        d += 1
    return d


def get_obs_class(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    if node == 0:
        return 0
    elif node <= 2:
        return 1
    elif node <= 30:
        return 2
    elif node < first_leaf:
        return 3
    else:
        return 4


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq, "targets": target_seq,
                "actions": action_seq, "states": state_seq,
            })
    return dataset


def train_gru_with_checkpoints(policy, dataset, checkpoint_epochs, n_epochs=200,
                               lr=3e-4, batch_size=64):
    """Train GRU and save state_dict snapshots at specified epochs.

    Returns:
        checkpoints: dict mapping epoch -> state_dict copy
        loss_history: list of per-epoch average CE loss
    """
    from torch.nn.utils.rnn import pad_sequence

    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    checkpoints = {}
    loss_history = []

    if 0 in checkpoint_epochs:
        checkpoints[0] = {k: v.clone() for k, v in policy.state_dict().items()}

    for epoch in range(1, n_epochs + 1):
        policy.train()
        indices = np.random.permutation(len(dataset))
        epoch_loss, n_steps = 0.0, 0

        for i in range(0, len(indices), batch_size):
            batch_idx = indices[i:i + batch_size]
            batch = [dataset[j] for j in batch_idx]

            obs = pad_sequence([b["obs"] for b in batch], batch_first=True)
            targets = pad_sequence([b["targets"] for b in batch], batch_first=True)
            lengths = [len(b["obs"]) for b in batch]

            logits, _ = policy(obs)
            log_probs = torch.log_softmax(logits, dim=-1)

            loss = 0.0
            count = 0
            for j, L in enumerate(lengths):
                ce = -(targets[j, :L] * log_probs[j, :L]).sum(dim=-1)
                loss = loss + ce.sum()
                count += L
            loss = loss / count

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
            optimizer.step()

            epoch_loss += loss.item() * count
            n_steps += count

        scheduler.step()
        avg_loss = epoch_loss / n_steps
        loss_history.append(avg_loss)

        if epoch in checkpoint_epochs:
            checkpoints[epoch] = {k: v.clone() for k, v in policy.state_dict().items()}

    return checkpoints, loss_history


def analyze_checkpoint(policy, val_data, train_data, obs_classes_map, depths_map):
    """Compute all metrics for a single model checkpoint.

    Returns dict with probe_acc, pc1_rho, per_class_acc, train_ll, val_ll,
    distance_ratio.
    """
    policy.eval()

    val_ll = compute_log_likelihood_gru(policy, val_data)
    train_ll = compute_log_likelihood_gru(policy, train_data)

    hidden_states, positions, timesteps, traj_ids = collect_hidden_states(policy, val_data)
    depths = np.array([depths_map[p] for p in positions])
    obs_classes = np.array([obs_classes_map[p] for p in positions])

    unique_traj_ids = np.unique(traj_ids)
    n_probe_train = int(0.8 * len(unique_traj_ids))
    probe_train_tids = set(unique_traj_ids[:n_probe_train])
    train_mask = np.array([tid in probe_train_tids for tid in traj_ids])
    test_mask = ~train_mask

    n_unique_train = len(np.unique(positions[train_mask]))
    n_unique_test = len(np.unique(positions[test_mask]))

    if n_unique_train >= 2 and n_unique_test >= 2:
        clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
        clf.fit(hidden_states[train_mask], positions[train_mask])
        probe_acc = clf.score(hidden_states[test_mask], positions[test_mask])

        per_class_acc = {}
        for cls_id in range(5):
            cls_test_mask = obs_classes[test_mask] == cls_id
            if cls_test_mask.sum() > 0:
                cls_acc = clf.score(
                    hidden_states[test_mask][cls_test_mask],
                    positions[test_mask][cls_test_mask],
                )
                per_class_acc[cls_id] = cls_acc
            else:
                per_class_acc[cls_id] = np.nan
    else:
        probe_acc = 0.0
        per_class_acc = {c: 0.0 for c in range(5)}

    n_components = min(10, hidden_states.shape[1], hidden_states.shape[0])
    pca = PCA(n_components=n_components)
    hidden_pca = pca.fit_transform(hidden_states)
    rho, _ = spearmanr(hidden_pca[:, 0], depths)
    pc1_rho = abs(rho)

    distance_ratio = compute_distance_ratio(hidden_states, obs_classes)

    return {
        "probe_acc": probe_acc,
        "pc1_rho": pc1_rho,
        "per_class_acc": per_class_acc,
        "train_ll": train_ll,
        "val_ll": val_ll,
        "distance_ratio": distance_ratio,
        "hidden_states": hidden_states,
        "positions": positions,
        "obs_classes": obs_classes,
        "depths": depths,
        "hidden_pca": hidden_pca,
    }


def compute_distance_ratio(hidden_states, obs_classes):
    """Compute within-class / between-class mean Euclidean distance ratio.

    Low ratio = good cluster separation (within-class points closer than
    between-class points). Uses subsampling for efficiency.
    """
    rng = np.random.RandomState(42)
    n_samples = min(2000, len(hidden_states))
    idx = rng.choice(len(hidden_states), n_samples, replace=False)
    h_sub = hidden_states[idx]
    c_sub = obs_classes[idx]

    within_dists = []
    between_dists = []

    n_pairs = 5000
    for _ in range(n_pairs):
        i, j = rng.choice(n_samples, 2, replace=False)
        dist = np.linalg.norm(h_sub[i] - h_sub[j])
        if c_sub[i] == c_sub[j]:
            within_dists.append(dist)
        else:
            between_dists.append(dist)

    if len(within_dists) == 0 or len(between_dists) == 0:
        return np.nan

    return np.mean(within_dists) / np.mean(between_dists)


def plot_main_figure(results, epochs, seeds, save_dir):
    """2x2 panel figure: probe acc, PC1-rho, per-class acc, LL."""
    fig, axes = plt.subplots(2, 2, figsize=(10, 8))

    probe_accs = np.array([results[s]["probe_acc"] for s in seeds])  # (n_seeds, n_epochs)
    pc1_rhos = np.array([results[s]["pc1_rho"] for s in seeds])
    train_lls = np.array([results[s]["train_ll"] for s in seeds])
    val_lls = np.array([results[s]["val_ll"] for s in seeds])

    mean_probe = probe_accs.mean(axis=0) * 100
    std_probe = probe_accs.std(axis=0) * 100
    mean_pc1 = pc1_rhos.mean(axis=0)
    std_pc1 = pc1_rhos.std(axis=0)
    mean_train_ll = train_lls.mean(axis=0)
    std_train_ll = train_lls.std(axis=0)
    mean_val_ll = val_lls.mean(axis=0)
    std_val_ll = val_lls.std(axis=0)

    x = np.array(epochs)
    x_plot = np.where(x == 0, 0.5, x)  # shift epoch 0 slightly for log scale

    ax = axes[0, 0]
    ax.plot(x_plot, mean_probe, "o-", color="#2c7bb6", linewidth=2, markersize=5)
    ax.fill_between(x_plot, mean_probe - std_probe, mean_probe + std_probe,
                    alpha=0.2, color="#2c7bb6")
    ax.axhline(y=RAW_OBS_BASELINE, color="gray", linestyle="--", linewidth=1,
               label=f"Obs. baseline ({RAW_OBS_BASELINE:.0f}%)")
    ax.axhline(y=CHANCE_BASELINE, color="lightgray", linestyle=":", linewidth=1,
               label=f"Chance ({CHANCE_BASELINE:.1f}%)")
    ax.set_xscale("log")
    ax.set_xlabel("Training Epoch", fontsize=10)
    ax.set_ylabel("Node Decoding Accuracy (%)", fontsize=10)
    ax.set_title("(a) Linear Probe Accuracy", fontsize=11, fontweight="bold")
    ax.legend(fontsize=8, loc="lower right")
    ax.grid(True, alpha=0.3)

    ax = axes[0, 1]
    ax.plot(x_plot, mean_pc1, "s-", color="#d7191c", linewidth=2, markersize=5)
    ax.fill_between(x_plot, mean_pc1 - std_pc1, mean_pc1 + std_pc1,
                    alpha=0.2, color="#d7191c")
    ax.axhline(y=UNTRAINED_PC1_RHO, color="gray", linestyle="--", linewidth=1,
               label=f"Untrained baseline ({UNTRAINED_PC1_RHO:.3f})")
    ax.set_xscale("log")
    ax.set_xlabel("Training Epoch", fontsize=10)
    ax.set_ylabel("|Spearman $\\rho$| (PC1 vs Depth)", fontsize=10)
    ax.set_title("(b) PC1-Depth Correlation", fontsize=11, fontweight="bold")
    ax.legend(fontsize=8, loc="lower right")
    ax.grid(True, alpha=0.3)
    lo = max(0, (mean_pc1 - std_pc1).min() - 0.05)
    ax.set_ylim(lo, 1.0)

    ax = axes[1, 0]
    class_keys = ["root", "depth1", "depths2-4", "depth5", "leaves"]
    for cls_id in range(5):
        per_class = np.array([results[s]["per_class_acc"][cls_id] for s in seeds])
        mean_cls = per_class.mean(axis=0) * 100
        ax.plot(x_plot, mean_cls, "o-", color=OBS_CLASS_COLORS[cls_id],
                linewidth=1.5, markersize=4, label=OBS_CLASS_LABELS[cls_id])
    ax.set_xscale("log")
    ax.set_xlabel("Training Epoch", fontsize=10)
    ax.set_ylabel("Node Decoding Accuracy (%)", fontsize=10)
    ax.set_title("(c) Per-Class Probe Accuracy", fontsize=11, fontweight="bold")
    ax.legend(fontsize=7, loc="upper left", ncol=2)
    ax.grid(True, alpha=0.3)

    ax = axes[1, 1]
    ax.plot(x_plot, mean_train_ll, "o-", color="#2166ac", linewidth=2, markersize=5,
            label="Train LL")
    ax.fill_between(x_plot, mean_train_ll - std_train_ll, mean_train_ll + std_train_ll,
                    alpha=0.15, color="#2166ac")
    ax.plot(x_plot, mean_val_ll, "s-", color="#b2182b", linewidth=2, markersize=5,
            label="Val LL")
    ax.fill_between(x_plot, mean_val_ll - std_val_ll, mean_val_ll + std_val_ll,
                    alpha=0.15, color="#b2182b")
    ax.axhline(y=MLP_LL_BASELINE, color="gray", linestyle="--", linewidth=1,
               label=f"MLP baseline ({MLP_LL_BASELINE:.3f})")
    ax.axhline(y=RANDOM_LL_BASELINE, color="lightgray", linestyle=":", linewidth=1,
               label=f"Random ({RANDOM_LL_BASELINE:.3f})")
    ax.set_xscale("log")
    ax.set_xlabel("Training Epoch", fontsize=10)
    ax.set_ylabel("Log-Likelihood (bits/dec)", fontsize=10)
    ax.set_title("(d) Learning Curves", fontsize=11, fontweight="bold")
    ax.legend(fontsize=7, loc="lower right")
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    for ext in ["pdf", "png"]:
        path = os.path.join(save_dir, f"training_dynamics_main.{ext}")
        fig.savefig(path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved training_dynamics_main.pdf/png", flush=True)


def plot_geometry_figure(results, epochs, seeds, save_dir):
    """1x3 PCA scatter at early/mid/late epochs coloured by obs class."""
    preferred = [1, 50, 200]
    snapshot_epochs = [e for e in preferred if e in epochs]
    if len(snapshot_epochs) < 3:
        available = [e for e in epochs if e > 0]
        if len(available) >= 3:
            snapshot_epochs = [available[0], available[len(available) // 2], available[-1]]
        elif len(available) >= 1:
            snapshot_epochs = available
        else:
            print("Skipping geometry figure (no trained checkpoints)", flush=True)
            return

    epoch_indices = [epochs.index(e) for e in snapshot_epochs]
    n_panels = len(snapshot_epochs)

    seed = seeds[0]

    fig, axes = plt.subplots(1, n_panels, figsize=(4.7 * n_panels, 4.5))
    if n_panels == 1:
        axes = [axes]
    for ax_idx, (ep, ep_i) in enumerate(zip(snapshot_epochs, epoch_indices)):
        ax = axes[ax_idx]
        h_pca = results[seed]["hidden_pca_snapshots"][ep_i]
        obs_classes = results[seed]["obs_classes_snapshots"][ep_i]

        for cls_id in range(5):
            mask = obs_classes == cls_id
            ax.scatter(h_pca[mask, 0], h_pca[mask, 1],
                       c=OBS_CLASS_COLORS[cls_id], s=4, alpha=0.5,
                       label=OBS_CLASS_LABELS[cls_id], edgecolors="none")

        ax.set_title(f"Epoch {ep}", fontsize=11, fontweight="bold")
        ax.set_xlabel("PC1", fontsize=9)
        ax.set_ylabel("PC2", fontsize=9)
        ax.tick_params(labelsize=8)
        if ax_idx == n_panels - 1:
            ax.legend(fontsize=7, markerscale=3, loc="best")

    plt.tight_layout()
    for ext in ["pdf", "png"]:
        path = os.path.join(save_dir, f"training_dynamics_geometry.{ext}")
        fig.savefig(path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved training_dynamics_geometry.pdf/png", flush=True)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true",
                        help="Quick test: 1 seed, max 50 epochs")
    args = parser.parse_args()

    seeds = [0] if args.quick else SEEDS
    max_epoch = 50 if args.quick else 200
    checkpoint_epochs = [e for e in CHECKPOINT_EPOCHS if e <= max_epoch]

    print("Training Dynamics Analysis \n", flush=True)
    print(f"Seeds: {seeds}", flush=True)
    print(f"Checkpoint epochs: {checkpoint_epochs}", flush=True)
    print(f"Max epoch: {max_epoch}\n", flush=True)

    t0 = time.time()

    print("Loading Rosenberg data...", flush=True)
    d = load_rosenberg_everything()
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)
    structural_obs = {s: build_structural_obs(s) for s in range(N_STATES)}
    obs_dim = len(structural_obs[0])

    train_data = build_obs_dataset_structural(d["train_trajs"], structural_obs, bc_policy)
    val_data = build_obs_dataset_structural(d["val_trajs"], structural_obs, bc_policy)

    depths_map = {s: get_node_depth(s) for s in range(N_STATES)}
    obs_classes_map = {s: get_obs_class(s) for s in range(N_STATES)}

    print(f"Train: {len(train_data)} chunks, Val: {len(val_data)} chunks", flush=True)
    print(f"Obs dim: {obs_dim}, Hidden dim: {HIDDEN_DIM}\n", flush=True)

    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("figures", exist_ok=True)

    all_results = {}

    for seed in seeds:
        print(f"Seed {seed} ", flush=True)
        torch.manual_seed(seed)
        np.random.seed(seed)

        policy = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)

        print(f"Training GRU ({max_epoch} epochs, saving {len(checkpoint_epochs)} checkpoints)...",
              flush=True)
        ckpts, loss_history = train_gru_with_checkpoints(
            policy, train_data, checkpoint_epochs,
            n_epochs=max_epoch, lr=3e-4, batch_size=BATCH_SIZE,
        )

        seed_results = {
            "probe_acc": [],
            "pc1_rho": [],
            "per_class_acc": {cls_id: [] for cls_id in range(5)},
            "train_ll": [],
            "val_ll": [],
            "distance_ratio": [],
            "hidden_pca_snapshots": [],
            "obs_classes_snapshots": [],
        }

        for ep in checkpoint_epochs:
            print(f"Analyzing epoch {ep}...", flush=True)
            eval_policy = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM,
                                    n_actions=N_ACTIONS)
            eval_policy.load_state_dict(ckpts[ep])

            metrics = analyze_checkpoint(eval_policy, val_data, train_data,
                                         obs_classes_map, depths_map)

            seed_results["probe_acc"].append(metrics["probe_acc"])
            seed_results["pc1_rho"].append(metrics["pc1_rho"])
            seed_results["train_ll"].append(metrics["train_ll"])
            seed_results["val_ll"].append(metrics["val_ll"])
            seed_results["distance_ratio"].append(metrics["distance_ratio"])
            for cls_id in range(5):
                seed_results["per_class_acc"][cls_id].append(
                    metrics["per_class_acc"].get(cls_id, np.nan))

            seed_results["hidden_pca_snapshots"].append(metrics["hidden_pca"][:, :2])
            seed_results["obs_classes_snapshots"].append(metrics["obs_classes"])

            print(f"probe={metrics['probe_acc']*100:.1f}%  "
                  f"|PC1-depth|={metrics['pc1_rho']:.3f}  "
                  f"val_ll={metrics['val_ll']:.4f}  "
                  f"dist_ratio={metrics['distance_ratio']:.3f}", flush=True)

        all_results[seed] = seed_results

        elapsed = time.time() - t0
        print(f"Seed {seed} done ({elapsed/60:.1f} min elapsed)\n", flush=True)

    save_data = {
        "epochs": checkpoint_epochs,
        "seeds": seeds,
        "probe_accuracy": {},
        "pc1_depth_rho": {},
        "per_class_accuracy": {},
        "train_ll": {},
        "val_ll": {},
        "distance_ratio": {},
    }
    for seed in seeds:
        sk = f"seed_{seed}"
        save_data["probe_accuracy"][sk] = all_results[seed]["probe_acc"]
        save_data["pc1_depth_rho"][sk] = all_results[seed]["pc1_rho"]
        save_data["train_ll"][sk] = all_results[seed]["train_ll"]
        save_data["val_ll"][sk] = all_results[seed]["val_ll"]
        save_data["distance_ratio"][sk] = all_results[seed]["distance_ratio"]
        save_data["per_class_accuracy"][sk] = {
            "root": all_results[seed]["per_class_acc"][0],
            "depth1": all_results[seed]["per_class_acc"][1],
            "depths2-4": all_results[seed]["per_class_acc"][2],
            "depth5": all_results[seed]["per_class_acc"][3],
            "leaves": all_results[seed]["per_class_acc"][4],
        }

    def to_native(obj):
        if isinstance(obj, dict):
            return {k: to_native(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [to_native(v) for v in obj]
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        return obj

    save_data = to_native(save_data)

    json_path = "checkpoints/training_dynamics_results.json"
    with open(json_path, "w") as f:
        json.dump(save_data, f, indent=2)
    print(f"Saved numerical results to {json_path}", flush=True)

    plot_main_figure(all_results, checkpoint_epochs, seeds, "figures")
    plot_geometry_figure(all_results, checkpoint_epochs, seeds, "figures")

    total_time = time.time() - t0
    probe_accs = np.array([all_results[s]["probe_acc"] for s in seeds])
    pc1_rhos = np.array([all_results[s]["pc1_rho"] for s in seeds])

    for i, ep in enumerate(checkpoint_epochs):
        pa_mean = probe_accs[:, i].mean() * 100
        pa_std = probe_accs[:, i].std() * 100
        pc_mean = pc1_rhos[:, i].mean()
        pc_std = pc1_rhos[:, i].std()
        print(f"epoch {ep}: probe {pa_mean:.1f} +/- {pa_std:.1f}%, |PC1-depth| {pc_mean:.3f} +/- {pc_std:.3f}",
              flush=True)

    print(f"done in {total_time/60:.1f} min", flush=True)


if __name__ == "__main__":
    main()
