"""Random vs structural encoding: inherited vs learned depth geometry in GRU hidden states."""

import os
import sys
import json
import argparse
import time

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from sklearn.metrics import balanced_accuracy_score

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets
from src.gru_policy import GRUPolicy, train_gru_policy
from src.analysis import collect_hidden_states_generic as collect_hidden_states

N_ACTIONS  = 3
N_CLASSES  = 5           # mirrors structural obs (5 unique classes)
OBS_SEED   = 0           # fixed seed for class assignment — do not change
HIDDEN_DIM = 128
N_EPOCHS   = 200
BATCH_SIZE = 64
MAX_SEQ_LEN = 200

CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
FIGURES_DIR    = os.path.join(PROJECT_ROOT, "figures")

STRUCTURAL = {
    "pc1_rho_untrained":  0.865,
    "pc1_rho_trained":    0.890,
    "probe_raw_plain":    0.150,
    "probe_untrained_plain": 0.158,
    "probe_trained_plain":   0.255,
    "ll_trained":        -1.270,
}


def build_random_class_obs(n_classes=N_CLASSES, seed=OBS_SEED):
    """Assign 127 tree nodes to n_classes classes; return one-hot obs tensors.

    Assignment is balanced (~25 nodes per class) with no correlation to tree
    depth by construction.

    Returns:
        obs_dict:     {node: Tensor(n_classes)} one-hot per node
        class_labels: (N_NODES,) int array of class index per node
    """
    rng = np.random.RandomState(seed)
    base  = N_NODES // n_classes        # 25
    extra = N_NODES % n_classes         # 2  (classes 0 and 1 get 26)
    labels = []
    for c in range(n_classes):
        labels.extend([c] * (base + (1 if c < extra else 0)))
    class_labels = rng.permutation(labels).astype(int)

    obs_dict = {}
    for node in range(N_NODES):
        v = torch.zeros(n_classes)
        v[class_labels[node]] = 1.0
        obs_dict[node] = v
    return obs_dict, class_labels


def build_obs_dataset(trajs, obs_dict, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states  = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states  = states[start:end + 1]
            chunk_actions = actions[start:end]
            dataset.append({
                "obs":     torch.stack([obs_dict[s] for s in chunk_states]),
                "targets": torch.stack([bc_policy[s] for s in chunk_states]),
                "actions": torch.tensor([int(a) for a in chunk_actions],
                                         dtype=torch.long),
                "states":  [int(s) for s in chunk_states],
            })
    return dataset


def probe_train_test(X, y, traj_ids):
    """80/20 split by trajectory; fit LogReg (C=1.0); return balanced accuracy."""
    unique_tids = np.unique(traj_ids)
    n_train     = int(0.8 * len(unique_tids))
    train_tids  = set(unique_tids[:n_train])
    train_mask  = np.array([tid in train_tids for tid in traj_ids])
    test_mask   = ~train_mask
    clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
    clf.fit(X[train_mask], y[train_mask])
    return balanced_accuracy_score(y[test_mask], clf.predict(X[test_mask]))


def make_pca_scatter(hidden_states, depths, title, save_path, n_max=8000):
    if len(hidden_states) > n_max:
        idx = np.random.RandomState(0).choice(len(hidden_states), n_max,
                                               replace=False)
        hidden_states = hidden_states[idx]
        depths        = depths[idx]

    pca    = PCA(n_components=2)
    coords = pca.fit_transform(hidden_states)

    fig, ax = plt.subplots(figsize=(6, 5))
    sc = ax.scatter(coords[:, 0], coords[:, 1], c=depths, cmap="plasma",
                    s=4, alpha=0.5, vmin=0, vmax=6, rasterized=True)
    plt.colorbar(sc, ax=ax, label="Tree depth")
    ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% var)")
    ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% var)")

    fig.tight_layout()
    fig.savefig(save_path, dpi=120)
    plt.close(fig)
    print(f"Saved: {save_path}", flush=True)


HIDDEN_STATES_PATH = os.path.join(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
    "checkpoints", "random_enc_seed0_hidden.npz"
)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true",
                        help="Run 1 seed only (fast sanity check)")
    parser.add_argument("--plot-only", action="store_true",
                        help="Skip training; regenerate figures from saved hidden states")
    args   = parser.parse_args()

    if args.plot_only:
        if not os.path.exists(HIDDEN_STATES_PATH):
            print(f"ERROR: {HIDDEN_STATES_PATH} not found. Run without --plot-only first.")
            return
        data = np.load(HIDDEN_STATES_PATH)
        os.makedirs(FIGURES_DIR, exist_ok=True)
        print("Regenerating PCA scatter plots from saved hidden states...", flush=True)
        make_pca_scatter(
            data["h_trained"], data["depths"],
            title="",
            save_path=os.path.join(FIGURES_DIR, "random_enc_pca_trained.png"),
        )
        make_pca_scatter(
            data["h_untrained"], data["depths"],
            title="",
            save_path=os.path.join(FIGURES_DIR, "random_enc_pca_untrained.png"),
        )
        return

    seeds  = [0] if args.quick else [0, 1, 2, 3, 4]

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(FIGURES_DIR, exist_ok=True)

    t0 = time.time()
    print("Random Encoding Experiment (127-node binary tree) \n", flush=True)

    d = load_rosenberg_everything()

    depths_raw = d["depths"]
    if isinstance(depths_raw, dict):
        depth_array = np.array([depths_raw[n] for n in range(N_NODES)])
    else:
        depth_array = np.array(depths_raw)

    obs_dim    = N_CLASSES
    n_unique   = len({tuple(v.tolist()) for v in obs_dict.values()})
    rho_class_depth, _ = spearmanr(class_labels, depth_array)

    print(f"obs_dim={obs_dim}, unique obs classes: {n_unique} "
          f"(aliasing ratio {N_NODES/n_unique:.1f}:1)", flush=True)
    print(f"Spearman rho(obs class, depth) = {rho_class_depth:.3f} "
          f"(expected ~0 by construction)", flush=True)
    print(f"Class sizes: { {c: int((class_labels==c).sum()) for c in range(N_CLASSES)} }",
          flush=True)

    train_data = build_obs_dataset(d["train_trajs"], obs_dict, bc_policy)
    val_data   = build_obs_dataset(d["val_trajs"],   obs_dict, bc_policy)

    obs_matrix = np.stack([obs_dict[s].numpy() for s in range(N_NODES)])

    random_ll  = -np.log2(N_ACTIONS)
    chance_bal = 1.0 / N_NODES
    print(f"\nRandom LL baseline: {random_ll:.4f} bits/dec", flush=True)
    print(f"Chance (balanced acc): {chance_bal:.4f} ({chance_bal*100:.2f}%)\n",
          flush=True)

    results = {
        "probe_raw_obs":           [],
        "probe_untrained":         [],
        "probe_trained":           [],
        "pc1_rho_untrained":       [],
        "pc1_rho_trained":         [],
        "var_explained_untrained": [],
        "var_explained_trained":   [],
        "ll_trained":              [],
    }
    pca_plot_data = {}

    for seed in seeds:
        print(f"Seed {seed} ", flush=True)
        torch.manual_seed(seed)
        np.random.seed(seed)

        untrained = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM,
                              n_actions=N_ACTIONS)
        untrained.eval()
        h_u, pos_u, _, tids_u = collect_hidden_states(untrained, val_data)

        torch.manual_seed(seed)
        trained = GRUPolicy(obs_dim=obs_dim, hidden_dim=HIDDEN_DIM,
                            n_actions=N_ACTIONS)
        trained, _ = train_gru_policy(
            trained, train_data,
            n_epochs=N_EPOCHS, lr=3e-4, batch_size=BATCH_SIZE,
            print_every=N_EPOCHS + 1,
        )
        trained.eval()
        h_t, pos_t, _, tids_t = collect_hidden_states(trained, val_data)

        assert np.array_equal(pos_u, pos_t), "Position arrays mismatch"
        positions = pos_t
        traj_ids  = tids_t

        from src.evaluation import compute_log_likelihood_gru
        ll = compute_log_likelihood_gru(trained, val_data)
        results["ll_trained"].append(ll)
        print(f"Trained GRU val LL = {ll:.4f} bits/dec", flush=True)

        n_pca = min(50, HIDDEN_DIM)
        pca_u  = PCA(n_components=n_pca)
        h_pca_u = pca_u.fit_transform(h_u)
        var_u   = pca_u.explained_variance_ratio_[0]

        pca_t  = PCA(n_components=n_pca)
        h_pca_t = pca_t.fit_transform(h_t)
        var_t   = pca_t.explained_variance_ratio_[0]

        results["var_explained_untrained"].append(var_u)
        results["var_explained_trained"].append(var_t)

        node_depths = depth_array[positions]
        rho_u, _ = spearmanr(h_pca_u[:, 0], node_depths)
        rho_t, _ = spearmanr(h_pca_t[:, 0], node_depths)
        results["pc1_rho_untrained"].append(abs(rho_u))
        results["pc1_rho_trained"].append(abs(rho_t))

        print(f"Untrained: |PC1–depth| rho={abs(rho_u):.3f}  "
              f"PC1 var={var_u*100:.1f}%", flush=True)
        print(f"Trained:   |PC1–depth| rho={abs(rho_t):.3f}  "
              f"PC1 var={var_t*100:.1f}%", flush=True)

        obs_vecs = obs_matrix[positions]
        raw_acc  = probe_train_test(obs_vecs, positions, traj_ids)
        unt_acc  = probe_train_test(h_u,      positions, traj_ids)
        trn_acc  = probe_train_test(h_t,      positions, traj_ids)

        results["probe_raw_obs"].append(raw_acc)
        results["probe_untrained"].append(unt_acc)
        results["probe_trained"].append(trn_acc)

        print(f"Probe (bal.) — raw obs: {raw_acc*100:.1f}%  "
              f"untrained: {unt_acc*100:.1f}%  "
              f"trained: {trn_acc*100:.1f}%", flush=True)

        if seed == 0:
            pca_plot_data["h_untrained"] = h_u.copy()
            pca_plot_data["h_trained"]   = h_t.copy()
            pca_plot_data["depths"]      = node_depths.copy()
            np.savez(HIDDEN_STATES_PATH,
                     h_untrained=h_u, h_trained=h_t, depths=node_depths)
            print(f"Saved hidden states: {HIDDEN_STATES_PATH}", flush=True)

    make_pca_scatter(
        pca_plot_data["h_trained"], pca_plot_data["depths"],
        title="Trained GRU — random enc (seed 0)",
        save_path=os.path.join(FIGURES_DIR, "random_enc_pca_trained.png"),
    )
    make_pca_scatter(
        pca_plot_data["h_untrained"], pca_plot_data["depths"],
        title="Untrained GRU — random enc (seed 0)",
        save_path=os.path.join(FIGURES_DIR, "random_enc_pca_untrained.png"),
    )

    def fmt(vals, pct=False):
        a = np.array(vals)
        m = 100 if pct else 1
        if n == 1:
            return f"{a[0]*m:.3f}{'%' if pct else ''}"
        return f"{a.mean()*m:.3f} +/- {a.std()*m:.3f}{'%' if pct else ''}"

    print("\n2x2: Inherited vs Learned (PC1-depth rho) ", flush=True)
    print(f"\n{'':38} {'Structural enc':>16}  {'Random enc':>16}", flush=True)
    rows = [
        ("PC1-depth |rho|  (untrained)",
         f"{STRUCTURAL['pc1_rho_untrained']:.3f}",
         fmt(results["pc1_rho_untrained"])),
        ("PC1-depth |rho|  (trained)",
         f"{STRUCTURAL['pc1_rho_trained']:.3f}",
         fmt(results["pc1_rho_trained"])),
        ("PC1 var explained (untrained)",
         "—",
         fmt(results["var_explained_untrained"], pct=True)),
        ("PC1 var explained (trained)",
         "—",
         fmt(results["var_explained_trained"], pct=True)),
        ("Probe raw obs  (balanced acc)",
         f"{STRUCTURAL['probe_raw_plain']*100:.1f}% (plain)",
         fmt(results["probe_raw_obs"], pct=True)),
        ("Probe untrained GRU  (bal.)",
         f"{STRUCTURAL['probe_untrained_plain']*100:.1f}% (plain)",
         fmt(results["probe_untrained"], pct=True)),
        ("Probe trained GRU  (bal.)",
         f"{STRUCTURAL['probe_trained_plain']*100:.1f}% (plain)",
         fmt(results["probe_trained"], pct=True)),
        ("Learned gain (trained - raw obs)",
         f"{(STRUCTURAL['probe_trained_plain']-STRUCTURAL['probe_raw_plain'])*100:.1f}%",
         f"{(np.mean(results['probe_trained'])-np.mean(results['probe_raw_obs']))*100:.1f}%"),
        ("Val LL trained (bits/dec)",
         f"{STRUCTURAL['ll_trained']:.3f}",
         fmt(results["ll_trained"])),
    ]
    for label, struct_val, rand_val in rows:
        print(f"{label:<36} {struct_val:>16}  {rand_val:>16}", flush=True)

    print(f"rho(obs class, depth)                  "
          f"{'~0.7 (strong)':>16}  {rho_class_depth:>+.3f}", flush=True)
    print(f"Aliasing ratio                         "
          f"{'25.4:1':>16}  {N_NODES/n_unique:.1f}:1", flush=True)
    print("\nNote: structural enc probe uses plain accuracy (paper standard).", flush=True)
    print("Random enc probe uses balanced accuracy.", flush=True)

    save = {
        "seeds": seeds,
        "n_classes": N_CLASSES,
        "obs_seed": OBS_SEED,
        "obs_class_assignment": class_labels.tolist(),
        "rho_obs_class_depth": float(rho_class_depth),
        "results": {k: [float(x) for x in v] for k, v in results.items()},
        "means":   {k: float(np.mean(v)) for k, v in results.items()},
    }
    out_path = os.path.join(CHECKPOINT_DIR, "random_encoding_results.json")
    with open(out_path, "w") as f:
        json.dump(save, f, indent=2)
    print(f"\nSaved: {out_path}", flush=True)
    print(f"Total time: {(time.time()-t0)/60:.1f} min", flush=True)


if __name__ == "__main__":
    main()
