"""Hidden dimensionality sweep to find critical GRU capacity for spatial map emergence."""

import os
import sys
import argparse
import json
import glob
import time
import subprocess

import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
FIGURE_DIR = os.path.join(PROJECT_ROOT, "figures")
WORKER_DIR = os.path.join(CHECKPOINT_DIR, "dim_sweep")
os.makedirs(FIGURE_DIR, exist_ok=True)
os.makedirs(WORKER_DIR, exist_ok=True)

DIMS_FULL = [4, 8, 16, 32, 64, 128, 256]
DIMS_QUICK = [4, 8, 16, 32, 64, 128]



def run_worker(dim, seed, n_epochs):
    from scipy.stats import spearmanr
    from sklearn.linear_model import LogisticRegression
    from sklearn.decomposition import PCA

    from src.rosenberg_data import load_rosenberg_everything, build_bc_targets, _ROS_TO_DIRL
    from src.gru_policy import GRUPolicy, train_gru_policy
    from src.evaluation import compute_log_likelihood_gru
    from src.analysis import collect_hidden_states

    N_STATES = 127
    N_ACTIONS = 3
    BATCH_SIZE = 64
    MAX_SEQ_LEN = 200

    def build_structural_obs(node, n_states=127):
        first_leaf = (n_states - 1) // 2
        is_root = 1.0 if node == 0 else 0.0
        is_leaf = 1.0 if node >= first_leaf else 0.0
        degree = 2.0 if node == 0 else (1.0 if node >= first_leaf else 3.0)
        left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
        right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
        reverse_dest = (node - 1) // 2 if node > 0 else 0
        features = [degree / 3.0, is_root, is_leaf]
        for dest in [left_dest, right_dest, reverse_dest]:
            d_is_root = 1.0 if dest == 0 else 0.0
            d_is_leaf = 1.0 if dest >= first_leaf else 0.0
            d_degree = 2.0 if dest == 0 else (1.0 if dest >= first_leaf else 3.0)
            features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
        return torch.tensor(features, dtype=torch.float32)

    def get_node_depth(node):
        d, n = 0, node
        while n > 0:
            n = (n - 1) // 2
            d += 1
        return d

    def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
        dataset = []
        for traj in trajs:
            states = traj["states"]
            actions = traj["actions"]
            T = len(actions)
            for start in range(0, T, max_len):
                end = min(start + max_len, T)
                chunk_states = states[start:end + 1]
                chunk_actions = actions[start:end]
                obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
                target_seq = torch.stack([bc_policy[s] for s in chunk_states])
                action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
                state_seq = [int(s) for s in chunk_states]
                dataset.append({
                    "obs": obs_seq, "targets": target_seq,
                    "actions": action_seq, "states": state_seq,
                })
        return dataset

    d = load_rosenberg_everything()
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)
    structural_obs = {s: build_structural_obs(s) for s in range(N_STATES)}
    obs_dim = len(structural_obs[0])
    train_data = build_obs_dataset_structural(d["train_trajs"], structural_obs, bc_policy)
    val_data = build_obs_dataset_structural(d["val_trajs"], structural_obs, bc_policy)

    torch.manual_seed(seed)
    np.random.seed(seed)
    policy = GRUPolicy(obs_dim=obs_dim, hidden_dim=dim, n_actions=N_ACTIONS)
    policy, _ = train_gru_policy(
        policy, train_data, n_epochs=n_epochs,
        lr=3e-4, batch_size=BATCH_SIZE, print_every=999,
    )

    ll = compute_log_likelihood_gru(policy, val_data)

    hidden_states, positions, _, _ = collect_hidden_states(policy, val_data)
    depths = np.array([get_node_depth(p) for p in positions])

    n_unique = len(np.unique(positions))
    if n_unique >= 2:
        probe = LogisticRegression(max_iter=1000, solver="lbfgs")
        probe.fit(hidden_states, positions)
        probe_acc = probe.score(hidden_states, positions)
    else:
        probe_acc = 0.0

    n_components = min(5, dim, hidden_states.shape[0])
    pca = PCA(n_components=n_components)
    hidden_pca = pca.fit_transform(hidden_states)
    rho, _ = spearmanr(hidden_pca[:, 0], depths)

    result = {"dim": dim, "seed": seed, "n_epochs": n_epochs,
              "ll": ll, "probe_acc": probe_acc, "pc1_rho": abs(rho)}

    out_path = os.path.join(WORKER_DIR, f"d{dim}_s{seed}.json")
    with open(out_path, "w") as f:
        json.dump(result, f)

    return result


def collect_worker_results(dims, seeds):
    results = {dim: {"ll": [], "probe_acc": [], "pc1_rho": []} for dim in dims}
    found = 0
    for path in sorted(glob.glob(os.path.join(WORKER_DIR, "d*_s*.json"))):
        with open(path) as f:
            r = json.load(f)
        dim = r["dim"]
        if dim not in results:
            results[dim] = {"ll": [], "probe_acc": [], "pc1_rho": []}
        results[dim]["ll"].append(r["ll"])
        results[dim]["probe_acc"].append(r["probe_acc"])
        results[dim]["pc1_rho"].append(r["pc1_rho"])
        found += 1
    return results, found



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true",
                        help="Quick test: 1 seed, 100 epochs, skip dim=256")
    parser.add_argument("--workers", type=int, default=5,
                        help="Max parallel workers")
    parser.add_argument("--collect-only", action="store_true",
                        help="Skip training, just collect existing results and plot")
    parser.add_argument("--worker", action="store_true",
                        help="Internal: run as subprocess worker")
    parser.add_argument("--dim", type=int)
    parser.add_argument("--seed", type=int)
    parser.add_argument("--epochs", type=int, default=200)
    args = parser.parse_args()

    if args.worker:
        result = run_worker(args.dim, args.seed, args.epochs)
        print(json.dumps(result))
        return

    dims = DIMS_QUICK if args.quick else DIMS_FULL
    seeds = [0] if args.quick else [0, 1, 2, 3, 4]
    n_epochs = 100 if args.quick else 200
    max_workers = args.workers

    if args.collect_only:
        print("Collecting existing results \n")
        results, found = collect_worker_results(dims, seeds)
        print(f"Found {found} result files\n")
        print_summary(results, dims)
        plot_dim_sweep(results, dims)
        return

    jobs = []
    skipped = 0
    for dim in dims:
        for seed in seeds:
            out_path = os.path.join(WORKER_DIR, f"d{dim}_s{seed}.json")
            if os.path.exists(out_path):
                skipped += 1
            else:
                jobs.append((dim, seed))

    total = len(jobs)
    print(f"Dims: {dims}")
    print(f"Seeds: {seeds}")
    print(f"Epochs: {n_epochs}")
    print(f"Workers: {max_workers}")
    print(f"Jobs to run: {total} (skipping {skipped} already done)\n")

    if total == 0:
        print("All jobs already complete. Collecting results...\n")
        results, found = collect_worker_results(dims, seeds)
        print_summary(results, dims)
        plot_dim_sweep(results, dims)
        return

    t0 = time.time()
    completed = 0
    script = os.path.abspath(__file__)
    conda_python = "/opt/anaconda3/envs/mouse_irl/bin/python"

    for batch_start in range(0, total, max_workers):
        batch = jobs[batch_start : batch_start + max_workers]
        procs = []
        for dim, seed in batch:
            cmd = [
                conda_python, script,
                "--worker", "--dim", str(dim), "--seed", str(seed),
                "--epochs", str(n_epochs),
            ]
            proc = subprocess.Popen(
                cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                cwd=PROJECT_ROOT,
            )
            procs.append((dim, seed, proc))

        for dim, seed, proc in procs:
            stdout, stderr = proc.communicate()
            completed += 1
            elapsed = time.time() - t0
            rate = elapsed / completed
            remaining = rate * (total - completed)

            if proc.returncode != 0:
                err = stderr.decode()[-300:]
                print(f"FAILED dim={dim} seed={seed}: {err}", flush=True)
                continue

            lines = stdout.decode().strip().split("\n")
            result = json.loads(lines[-1])

            print(f"[{completed}/{total}] dim={dim:3d} seed={seed}  "
                  f"LL={result['ll']:.4f}  probe={result['probe_acc']:.3f}  "
                  f"|rho|={result['pc1_rho']:.3f}  (~{remaining/60:.0f}min left)",
                  flush=True)

    total_time = time.time() - t0
    print(f"\nTraining time: {total_time/60:.1f} min\n")

    results, found = collect_worker_results(dims, seeds)

    save_path = os.path.join(CHECKPOINT_DIR, "dim_sweep_results.pt")
    torch.save({"dims": dims, "seeds": seeds, "results": results}, save_path)
    print(f"Combined results saved: {save_path}\n")

    print_summary(results, dims)
    plot_dim_sweep(results, dims)


def print_summary(results, dims):
    print(f"{'dim':>5s}  {'LL (bits/dec)':>22s}  {'Probe Acc':>18s}  {'|PC1-depth|':>18s}")
    for dim in dims:
        if dim not in results or len(results[dim]["ll"]) == 0:
            print(f"{dim:5d}  (no data)")
            continue
        ll_arr = np.array(results[dim]["ll"])
        pa_arr = np.array(results[dim]["probe_acc"])
        pc_arr = np.array(results[dim]["pc1_rho"])
        n = len(ll_arr)
        if n == 1:
            print(f"{dim:5d}  {ll_arr[0]:22.4f}  {pa_arr[0]:18.3f}  {pc_arr[0]:18.3f}")
        else:
            print(f"{dim:5d}  {ll_arr.mean():7.4f} +/- {ll_arr.std():.4f}  "
                  f"{pa_arr.mean():7.3f} +/- {pa_arr.std():.3f}  "
                  f"{pc_arr.mean():7.3f} +/- {pc_arr.std():.3f}")
    print()


def plot_dim_sweep(results, dims):
    """Three-panel figure: LL, probe accuracy, PC1-depth correlation vs dim."""
    plot_dims = [d for d in dims if d in results and len(results[d]["ll"]) > 0]
    if not plot_dims:
        print("No data to plot.")
        return

    metrics = [
        ("ll", "LL (bits/dec)", "dim_sweep_ll.png"),
        ("probe_acc", "127-way Probe Accuracy", "dim_sweep_probe.png"),
        ("pc1_rho", "|PC1-depth| Spearman $\\rho$", "dim_sweep_pc1.png"),
    ]

    x = np.array(plot_dims)

    for key, ylabel, fname in metrics:
        fig, ax = plt.subplots(figsize=(5, 5))
        means = np.array([np.mean(results[d][key]) for d in plot_dims])
        stds = np.array([
            np.std(results[d][key]) if len(results[d][key]) > 1 else 0
            for d in plot_dims
        ])

        ax.errorbar(x, means, yerr=stds, fmt="o-", color="#2c7fb8",
                     capsize=4, markersize=7, linewidth=2, elinewidth=1.5)
        if any(s > 0 for s in stds):
            ax.fill_between(x, means - stds, means + stds, alpha=0.15, color="#2c7fb8")

        ax.set_xscale("log", base=2)
        ax.set_xticks(plot_dims)
        ax.set_xticklabels([str(d) for d in plot_dims])
        ax.set_xlabel("Hidden Dimension", fontsize=12)
        ax.set_ylabel(ylabel, fontsize=12)
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        save_path = os.path.join(FIGURE_DIR, fname)
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
        plt.close()
        print(f"Figure saved: {save_path}")


if __name__ == "__main__":
    main()
