"""Passive T-Maze validation: IRL→GRU pipeline on a minimal memory task."""

import os
import sys
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(PROJECT_ROOT, "src"))
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
FIGURE_DIR = os.path.join(PROJECT_ROOT, "figures")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(FIGURE_DIR, exist_ok=True)

from maxent_irl import MaxCausalEntIRL, train_irl, SoftValueIteration
from gru_policy import GRUPolicy, train_gru_policy
from tmaze_env import (
    N_ACTIONS,
    n_states,
    encode_state,
    decode_state,
    build_tmaze_transition_tensor,
    generate_expert_demos,
    trajectories_to_sa_pairs,
    trajectories_to_junction_sa_pairs,
    tmaze_trajectories_to_obs_dataset,
    compute_optimal_soft_policy,
)



def adaptive_irl_params(corridor_length):
    """Choose gamma and n_vi_iters so value propagates across the corridor."""
    gamma = 1.0 - 1.0 / (corridor_length + 10)
    n_iters = min(max(200, corridor_length * 3), 5000)
    return gamma, n_iters


def run_irl(corridor_length, trajs, n_epochs=300, lr=0.01, print_every=1):
    """Run MaxCausalEnt IRL on T-maze demonstrations.

    Only trains on junction (state, action) pairs — corridor steps are
    informationally empty (all actions produce the same transition).
    """
    gamma, n_iters = adaptive_irl_params(corridor_length)
    ns = n_states(corridor_length)
    T = build_tmaze_transition_tensor(corridor_length)
    sa_pairs = trajectories_to_junction_sa_pairs(trajs, corridor_length)

    model = MaxCausalEntIRL(ns, T, gamma=gamma, n_vi_iters=n_iters, l2_reg=0.01)
    history = train_irl(model, sa_pairs, n_epochs=n_epochs, lr=lr,
                        print_every=print_every)

    with torch.no_grad():
        V, Q, pi = model.soft_vi(model.reward_params)

    return model, Q, pi, history


def train_tmaze_gru(Q_soft, trajs, corridor_length, hidden_dim=16,
                    n_epochs=200, lr=3e-4, batch_size=64, seed=42,
                    print_every=10, curriculum=True):
    """Train a GRU on T-maze with junction-only loss.

    Feeds the full observation sequence through the GRU (so it builds up
    hidden state encoding the cue), but only computes cross-entropy loss
    at the junction timestep.

    When curriculum=True, starts with short prefix lengths and gradually
    increases to the full corridor, helping BPTT learn to preserve the cue
    across many steps.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)

    dataset = tmaze_trajectories_to_obs_dataset(trajs, Q_soft, corridor_length)
    n_train = int(0.8 * len(dataset))
    train_data = dataset[:n_train]
    val_data = dataset[n_train:]

    policy = GRUPolicy(obs_dim=3, hidden_dim=hidden_dim,
                       n_actions=N_ACTIONS, dropout=0.1)
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    junction_t = corridor_length
    history = {"loss": []}

    if curriculum and corridor_length > 20:
        min_prefix = 5
    else:
        min_prefix = corridor_length

    for epoch in range(n_epochs):
        policy.train()
        indices = np.random.permutation(len(train_data))
        epoch_loss, n_samples = 0.0, 0

        if curriculum and corridor_length > 20:
            warmup = 30
            if epoch < warmup:
                prefix_len = min_prefix
            else:
                ramp_progress = min((epoch - warmup) / (n_epochs * 0.3), 1.0)
                prefix_len = int(min_prefix + ramp_progress * (corridor_length - min_prefix))
        else:
            prefix_len = corridor_length

        for i in range(0, len(indices), batch_size):
            batch_idx = indices[i : i + batch_size]
            batch = [train_data[j] for j in batch_idx]

            trunc_obs_list = []
            trunc_targets_list = []
            for b in batch:
                skip = corridor_length - prefix_len
                if skip > 0:
                    obs_trunc = torch.cat([
                        b["obs"][0:1],           # cue observation
                        b["obs"][skip + 1:],     # corridor tail + junction + terminal
                    ], dim=0)
                    tgt_trunc = torch.cat([
                        b["targets"][0:1],
                        b["targets"][skip + 1:],
                    ], dim=0)
                else:
                    obs_trunc = b["obs"]
                    tgt_trunc = b["targets"]
                trunc_obs_list.append(obs_trunc)
                trunc_targets_list.append(tgt_trunc)

            obs = torch.nn.utils.rnn.pad_sequence(trunc_obs_list, batch_first=True)
            targets = torch.nn.utils.rnn.pad_sequence(trunc_targets_list,
                                                       batch_first=True)

            logits, _ = policy(obs)
            log_probs = torch.log_softmax(logits, dim=-1)

            junc_idx = prefix_len

            loss = 0.0
            count = 0
            for j in range(len(batch)):
                if junc_idx < logits.shape[1]:
                    ce = -(targets[j, junc_idx] * log_probs[j, junc_idx]).sum()
                    loss = loss + ce
                    count += 1
            loss = loss / max(count, 1)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=5.0)
            optimizer.step()

            epoch_loss += loss.item() * count
            n_samples += count

        avg_loss = epoch_loss / max(n_samples, 1)
        history["loss"].append(avg_loss)

        if epoch % print_every == 0:
            extra = f" (prefix={prefix_len})" if curriculum and corridor_length > 20 else ""
            print(f"Epoch {epoch}: junction CE = {avg_loss:.4f}{extra}")

    return policy, history, train_data, val_data


def evaluate_junction_accuracy(policy, dataset, corridor_length):
    """Measure top-1 accuracy at the junction timestep only.

    The junction is at timestep = corridor_length (0-indexed).
    """
    policy.eval()
    correct, total = 0, 0
    junction_t = corridor_length  # timestep index of junction

    with torch.no_grad():
        for d in dataset:
            if junction_t >= len(d["actions"]):
                continue
            obs = d["obs"].unsqueeze(0)
            logits, _ = policy(obs)
            pred = logits[0, junction_t].argmax().item()
            true_action = d["actions"][junction_t].item()
            if pred == true_action:
                correct += 1
            total += 1

    return correct / total if total > 0 else 0.0


def collect_junction_hidden_states(policy, dataset, corridor_length):
    """Extract GRU hidden states at the junction timestep.

    Returns:
        hidden_states: (N, hidden_dim) array
        cues: (N,) array of cue values (0 or 1)
    """
    policy.eval()
    junction_t = corridor_length
    all_hidden, all_cues = [], []

    with torch.no_grad():
        for d in dataset:
            if junction_t >= len(d["states"]):
                continue
            obs = d["obs"].unsqueeze(0)
            gru_out = policy.get_recurrent_output(obs)  # (1, T, hidden_dim)
            h_junction = gru_out[0, junction_t].numpy().copy()
            state_at_junction = d["states"][junction_t]
            _, cue = decode_state(state_at_junction)
            all_hidden.append(h_junction)
            all_cues.append(cue)

    return np.array(all_hidden), np.array(all_cues)



def run_main_experiment(corridor_length=100, n_episodes=1000):

    trajs = generate_expert_demos(corridor_length, n_episodes=n_episodes)
    n_cue0 = sum(1 for t in trajs if t["cue"] == 0)
    n_cue1 = sum(1 for t in trajs if t["cue"] == 1)
    print(f"{n_episodes} episodes, cue=0: {n_cue0}, cue=1: {n_cue1}")
    print(f"Augmented state space: {n_states(corridor_length)} states")
    print(f"Trajectory length: {len(trajs[0]['actions'])} steps\n")

    print("Running MaxCausalEnt IRL...")
    gamma, n_iters = adaptive_irl_params(corridor_length)
    print(f"gamma={gamma:.5f}, n_vi_iters={n_iters}")
    t0 = time.time()
    irl_model, Q_soft, pi_soft, irl_history = run_irl(corridor_length, trajs)
    irl_time = time.time() - t0
    print(f"IRL completed in {irl_time:.1f}s")

    R = irl_model.reward_params.detach().numpy()
    cl = corridor_length
    print("\nRecovered rewards at key states:")
    for cue in range(2):
        s_start = encode_state(0, cue)
        s_junc = encode_state(cl, cue)
        s_left = encode_state(cl + 1, cue)
        s_right = encode_state(cl + 2, cue)
        print(f"cue={cue}: start={R[s_start]:.3f}, junction={R[s_junc]:.3f}, "
              f"left_term={R[s_left]:.3f}, right_term={R[s_right]:.3f}")

    print("\nSoft policy at junction (should favor correct terminal):")
    for cue in range(2):
        s_junc = encode_state(cl, cue)
        probs = pi_soft[s_junc].detach().numpy()
        correct = "left" if cue == 0 else "right"
        print(f"cue={cue} (correct={correct}): "
              f"fwd={probs[0]:.3f}, left={probs[1]:.3f}, right={probs[2]:.3f}")

    t0 = time.time()
    policy, gru_history, train_data, val_data = train_tmaze_gru(
        Q_soft, trajs, corridor_length, hidden_dim=32, n_epochs=200, lr=1e-3,
    )
    gru_time = time.time() - t0
    print(f"GRU training completed in {gru_time:.1f}s")

    train_junc_acc = evaluate_junction_accuracy(policy, train_data, corridor_length)
    val_junc_acc = evaluate_junction_accuracy(policy, val_data, corridor_length)
    print(f"Junction accuracy — train: {train_junc_acc:.1%}, val: {val_junc_acc:.1%}")

    all_data = train_data + val_data
    hidden_states, cues = collect_junction_hidden_states(policy, all_data,
                                                         corridor_length)
    print(f"Collected {len(hidden_states)} junction hidden states")

    pca = PCA(n_components=min(2, hidden_states.shape[1]))
    hidden_pca = pca.fit_transform(hidden_states)

    mask0, mask1 = cues == 0, cues == 1
    centroid0 = hidden_pca[mask0].mean(axis=0)
    centroid1 = hidden_pca[mask1].mean(axis=0)
    centroid_dist = np.linalg.norm(centroid0 - centroid1)
    spread0 = np.std(hidden_pca[mask0, 0])
    spread1 = np.std(hidden_pca[mask1, 0])
    separation = centroid_dist / (spread0 + spread1 + 1e-8)
    print(f"PC1 variance explained: {pca.explained_variance_ratio_[0]:.1%}")
    print(f"Centroid distance: {centroid_dist:.3f}")
    print(f"Separation ratio (centroid_dist / sum_of_spreads): {separation:.2f}")

    return {
        "corridor_length": corridor_length,
        "irl_history": irl_history,
        "gru_history": gru_history,
        "rewards": R,
        "Q_soft": Q_soft.detach(),
        "pi_soft": pi_soft.detach(),
        "val_junction_acc": val_junc_acc,
        "train_junction_acc": train_junc_acc,
        "hidden_pca": hidden_pca,
        "cues": cues,
        "pca_explained_variance": pca.explained_variance_ratio_,
        "separation_ratio": separation,
        "irl_time": irl_time,
        "gru_time": gru_time,
    }



def run_scaling_experiment(corridor_lengths=(10, 50, 100, 500, 1000, 1500),
                           n_episodes=1000, n_seeds=3):
    """Run the pipeline across multiple corridor lengths with multiple seeds.

    Uses analytical soft policy (known rewards + single soft VI pass) instead
    of full IRL to keep runtime tractable for long corridors.  The scaling
    experiment tests GRU memory capacity, not IRL convergence.
    """
    print("(Using analytical soft policy — bypasses IRL for speed)\n")

    results = {}
    for cl in corridor_lengths:
        print(f"\ncorridor_length={cl} ")
        seed_accs = []

        print(f"Computing analytical soft policy (n_states={n_states(cl)})...",
              end=" ", flush=True)
        t0 = time.time()
        Q_opt, pi_opt = compute_optimal_soft_policy(cl)
        print(f"done ({time.time() - t0:.1f}s)")

        for cue in range(2):
            s_junc = encode_state(cl, cue)
            probs = pi_opt[s_junc].numpy()
            correct_a = 1 if cue == 0 else 2
            assert probs[correct_a] > probs[3 - correct_a], \
                f"Soft policy incorrect at junction for cue={cue}: {probs}"

        for seed in range(n_seeds):
            print(f"seed {seed}...", end=" ", flush=True)
            t0 = time.time()

            trajs = generate_expert_demos(cl, n_episodes=n_episodes,
                                          seed=42 + seed * 1000)

            n_epochs = 500
            policy, _, train_data, val_data = train_tmaze_gru(
                Q_opt, trajs, cl, hidden_dim=32, n_epochs=n_epochs,
                lr=1e-3, batch_size=64, seed=seed, print_every=10,
            )

            acc = evaluate_junction_accuracy(policy, val_data, cl)
            seed_accs.append(acc)
            elapsed = time.time() - t0
            print(f"junction_acc={acc:.1%} ({elapsed:.0f}s)")

        results[cl] = {
            "accs": seed_accs,
            "mean": np.mean(seed_accs),
            "std": np.std(seed_accs),
        }
        print(f"to mean={results[cl]['mean']:.1%} +/- {results[cl]['std']:.1%}")

    return results



def plot_pca(hidden_pca, cues, explained_var, save_path):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))

    ax = axes[0]
    for cue_val, label, color in [(0, "Cue 0 (go left)", "#2196F3"),
                                   (1, "Cue 1 (go right)", "#F44336")]:
        mask = cues == cue_val
        ax.scatter(hidden_pca[mask, 0], hidden_pca[mask, 1],
                   c=color, label=label, alpha=0.5, s=20, edgecolors="none")
    ax.set_xlabel(f"PC1 ({explained_var[0]:.0%} var.)")
    ax.set_ylabel(f"PC2 ({explained_var[1]:.0%} var.)" if len(explained_var) > 1
                  else "PC2")
    ax.legend(loc="best", fontsize=9)

    ax = axes[1]
    for cue_val, label, color in [(0, "Cue 0 (go left)", "#2196F3"),
                                   (1, "Cue 1 (go right)", "#F44336")]:
        mask = cues == cue_val
        ax.hist(hidden_pca[mask, 0], bins=30, alpha=0.6, color=color,
                label=label, density=True, edgecolor="white", linewidth=0.5)
    ax.set_xlabel(f"PC1 ({explained_var[0]:.0%} var.)")
    ax.set_ylabel("Density")
    ax.legend(loc="best", fontsize=9)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {save_path}")


def plot_corridor_scaling(scaling_results, save_path):
    cls = sorted(scaling_results.keys())
    means = [scaling_results[cl]["mean"] for cl in cls]
    stds = [scaling_results[cl]["std"] for cl in cls]

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.errorbar(cls, means, yerr=stds, marker="o", capsize=5, linewidth=2,
                markersize=7, color="#1976D2", ecolor="#90CAF9", capthick=1.5)
    ax.fill_between(cls, [m - s for m, s in zip(means, stds)],
                    [m + s for m, s in zip(means, stds)],
                    alpha=0.15, color="#1976D2")

    ax.axhline(0.5, color="gray", linestyle="--", linewidth=1, label="Chance (50%)")
    ax.set_xscale("log")
    ax.set_xlabel("Corridor Length", fontsize=13)
    ax.set_ylabel("Junction Accuracy", fontsize=13)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xticks(cls)
    ax.set_xticklabels([str(c) for c in cls])
    ax.legend(loc="lower left", fontsize=10)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {save_path}")



SEEDS = [0, 1, 2, 3, 4]


def main():
    print(f"Seeds: {SEEDS}, corridor_length=100\n")

    corridor_length = 100
    n_episodes = 1000

    trajs = generate_expert_demos(corridor_length, n_episodes=n_episodes)
    print(f"{n_episodes} episodes")

    print("Running MaxCausalEnt IRL...")
    t0 = time.time()
    irl_model, Q_soft, pi_soft, irl_history = run_irl(corridor_length, trajs)
    print(f"IRL completed in {time.time() - t0:.1f}s")

    seed_accs = []
    primary_pca = None
    primary_cues = None
    primary_explained_var = None

    for seed in SEEDS:
        print(f"\nSeed {seed} ")
        t0 = time.time()

        policy, gru_history, train_data, val_data = train_tmaze_gru(
            Q_soft, trajs, corridor_length, hidden_dim=32,
            n_epochs=200, lr=1e-3, seed=seed, print_every=50,
        )

        val_acc = evaluate_junction_accuracy(policy, val_data, corridor_length)
        seed_accs.append(val_acc)
        print(f"Junction accuracy: {val_acc:.1%} ({time.time() - t0:.0f}s)")

        if seed == 2:
            all_data = train_data + val_data
            hidden_states, cues = collect_junction_hidden_states(
                policy, all_data, corridor_length)
            pca = PCA(n_components=min(2, hidden_states.shape[1]))
            hidden_pca = pca.fit_transform(hidden_states)
            primary_pca = hidden_pca
            primary_cues = cues
            primary_explained_var = pca.explained_variance_ratio_

            mask0, mask1 = cues == 0, cues == 1
            centroid_dist = np.linalg.norm(
                hidden_pca[mask0].mean(axis=0) - hidden_pca[mask1].mean(axis=0))
            print(f"PCA separation: {centroid_dist:.3f}")

    mean_acc = np.mean(seed_accs)
    std_acc = np.std(seed_accs)
    print(f"\nResults ({len(SEEDS)} seeds) ")
    print(f"Junction accuracy: {mean_acc:.1%} +/- {std_acc:.1%}")
    print(f"Per-seed: {[f'{a:.1%}' for a in seed_accs]}")

    print("\nGenerating PCA figure...")
    plot_pca(
        primary_pca, primary_cues, primary_explained_var,
        os.path.join(FIGURE_DIR, "tmaze_pca.png"),
    )

    checkpoint = {
        "corridor_length": corridor_length,
        "seeds": SEEDS,
        "seed_accs": seed_accs,
        "mean_acc": mean_acc,
        "std_acc": std_acc,
        "hidden_pca": primary_pca,
        "cues": primary_cues,
        "pca_explained_variance": primary_explained_var,
    }
    save_path = os.path.join(CHECKPOINT_DIR, "tmaze_results.pt")
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved: {save_path}")



if __name__ == "__main__":
    main()
