#!/usr/bin/env python3
"""Banino wall-distance vs random encoding: untrained baseline."""
import json
import os
import sys
import time
from multiprocessing import Pool, cpu_count

os.environ["PYTHONUNBUFFERED"] = "1"

import numpy as np
import torch
from scipy.stats import spearmanr
from sklearn.decomposition import PCA
from sklearn.linear_model import Ridge

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from src.banino_grid_cells import (
    BaninoGridNetwork, GridScorer, generate_trajectories,
)

N_SEEDS = 10
N_TRAJ = 500
GRIDNESS_THRESHOLD = 0.37
ENV_SIZE = 2.2
NBINS = 32

CKPT_DIR = os.path.join(os.path.dirname(__file__), "..", "checkpoints", "banino_wall_distance")
FIG_DIR = os.path.join(os.path.dirname(__file__), "..", "figures")


def wall_distance_encoding(positions, env_size=2.2):
    x, y = positions[:, 0], positions[:, 1]
    return np.stack([y, env_size - y, x, env_size - x], axis=1) / env_size


def random_spatial_encoding(positions, env_size=2.2, seed=42, n_tiles=8, n_features=4):
    rng = np.random.RandomState(seed)
    tile_values = rng.randn(n_tiles, n_tiles, n_features)
    bins = np.linspace(0, env_size, n_tiles + 1)
    ix = np.clip(np.digitize(positions[:, 0], bins) - 1, 0, n_tiles - 1)
    iy = np.clip(np.digitize(positions[:, 1], bins) - 1, 0, n_tiles - 1)
    return tile_values[ix, iy]


def run_single_job(args):
    enc_name, seed, pos_list_shared_path = args

    data = np.load(pos_list_shared_path)
    all_pos_concat = data["positions"]
    lengths = data["lengths"]
    pos_list = np.split(all_pos_concat, np.cumsum(lengths)[:-1])
    all_pos = all_pos_concat

    if enc_name == "wall_distance":
        obs_list = [wall_distance_encoding(p) for p in pos_list]
    else:
        obs_list = [random_spatial_encoding(p) for p in pos_list]

    torch.manual_seed(seed)
    model = BaninoGridNetwork()
    model.lstm = torch.nn.LSTMCell(input_size=4, hidden_size=128)
    torch.manual_seed(seed)  # re-seed so LSTM weights are deterministic
    torch.nn.init.xavier_uniform_(model.lstm.weight_ih)
    torch.nn.init.xavier_uniform_(model.lstm.weight_hh)
    torch.nn.init.zeros_(model.lstm.bias_ih)
    torch.nn.init.zeros_(model.lstm.bias_hh)
    model.eval()

    all_acts_list = []
    with torch.no_grad():
        for obs in obs_list:
            obs_t = torch.tensor(obs, dtype=torch.float32)
            h = torch.zeros(1, 128)
            c = torch.zeros(1, 128)
            acts = []
            for t in range(obs_t.shape[0]):
                h, c = model.lstm(obs_t[t:t + 1], (h, c))
                g = model.bottleneck(h)
                acts.append(g)
            all_acts_list.append(torch.cat(acts, dim=0).numpy())
    all_acts = np.concatenate(all_acts_list, axis=0)

    pca = PCA(n_components=5)
    pcs = pca.fit_transform(all_acts)
    rho_x = float(spearmanr(pcs[:, 0], all_pos[:, 0])[0])
    rho_y = float(spearmanr(pcs[:, 0], all_pos[:, 1])[0])
    min_wall = np.minimum(
        np.minimum(all_pos[:, 0], ENV_SIZE - all_pos[:, 0]),
        np.minimum(all_pos[:, 1], ENV_SIZE - all_pos[:, 1]),
    )
    rho_wall = float(spearmanr(pcs[:, 0], min_wall)[0])

    n = all_acts.shape[0]
    idx = np.random.RandomState(seed).permutation(n)
    split = int(0.8 * n)
    reg = Ridge(alpha=1.0)
    reg.fit(all_acts[idx[:split]], all_pos[idx[:split]])
    r2 = float(reg.score(all_acts[idx[split:]], all_pos[idx[split:]]))

    scorer = GridScorer(nbins=NBINS, env_size=ENV_SIZE)
    gridness_scores, _ = scorer.compute_all_gridness(all_pos, all_acts)

    n_grid = int(np.sum(gridness_scores > GRIDNESS_THRESHOLD))
    result = {
        "encoding": enc_name,
        "seed": seed,
        "pc1_var_explained": round(float(pca.explained_variance_ratio_[0]) * 100, 1),
        "pc1_x_rho": round(rho_x, 4),
        "pc1_y_rho": round(rho_y, 4),
        "pc1_wall_rho": round(rho_wall, 4),
        "position_r2": round(r2, 4),
        "n_grid_like": n_grid,
        "pct_grid_like": round(n_grid / 512 * 100, 1),
        "mean_gridness": round(float(np.mean(gridness_scores)), 4),
        "max_gridness": round(float(np.max(gridness_scores)), 4),
        "pctl_95": round(float(np.percentile(gridness_scores, 95)), 4),
    }

    print(f"[{enc_name} s{seed}] grid={n_grid}/512 ({n_grid/512*100:.1f}%), "
          f"PC1-x={rho_x:.3f}, R2={r2:.4f}, max_g={np.max(gridness_scores):.3f}",
          flush=True)

    return result


def main():
    os.makedirs(CKPT_DIR, exist_ok=True)
    os.makedirs(FIG_DIR, exist_ok=True)
    t0 = time.time()

    print(f"Generating {N_TRAJ} trajectories...", flush=True)
    pos_list, _, _ = generate_trajectories(N_TRAJ, seed=0)

    shared_path = os.path.join(CKPT_DIR, "_shared_pos.npz")
    all_pos_concat = np.concatenate(pos_list, axis=0)
    traj_lengths = np.array([p.shape[0] for p in pos_list])
    np.savez(shared_path, positions=all_pos_concat, lengths=traj_lengths)
    print(f"saved shared data", flush=True)

    jobs = []
    for enc in ["wall_distance", "random"]:
        for seed in range(N_SEEDS):
            jobs.append((enc, seed, shared_path))

    n_workers = min(5, len(jobs))
    print(f"\nLaunching {len(jobs)} jobs on {n_workers} workers...", flush=True)

    with Pool(n_workers) as pool:
        results = pool.map(run_single_job, jobs)

    os.remove(shared_path)

    summary = {}
    for enc in ["wall_distance", "random"]:
        enc_results = [r for r in results if r["encoding"] == enc]
        pcts = [r["pct_grid_like"] for r in enc_results]
        r2s = [r["position_r2"] for r in enc_results]
        pc1_x = [r["pc1_x_rho"] for r in enc_results]
        max_gs = [r["max_gridness"] for r in enc_results]

        summary[enc] = {
            "mean_pct_grid": round(float(np.mean(pcts)), 2),
            "std_pct_grid": round(float(np.std(pcts)), 2),
            "mean_r2": round(float(np.mean(r2s)), 4),
            "std_r2": round(float(np.std(r2s)), 4),
            "mean_pc1_x_rho": round(float(np.mean(np.abs(pc1_x))), 4),
            "std_pc1_x_rho": round(float(np.std(np.abs(pc1_x))), 4),
            "mean_max_gridness": round(float(np.mean(max_gs)), 4),
            "per_seed": enc_results,
        }

    elapsed = time.time() - t0

    w = summary["wall_distance"]
    r = summary["random"]
    print(f"\nwall_distance: {w['mean_pct_grid']:.1f}% +/- {w['std_pct_grid']:.1f}% grid-like, "
          f"R2={w['mean_r2']:.4f}, |PC1-x|={w['mean_pc1_x_rho']:.4f}, "
          f"max_g={w['mean_max_gridness']:.4f}", flush=True)
    print(f"random: {r['mean_pct_grid']:.1f}% +/- {r['std_pct_grid']:.1f}% grid-like, "
          f"R2={r['mean_r2']:.4f}, |PC1-x|={r['mean_pc1_x_rho']:.4f}, "
          f"max_g={r['mean_max_gridness']:.4f}", flush=True)
    print(f"done in {elapsed/60:.1f} min", flush=True)

    output = {
        "experiment": "banino_wall_distance_encoding_swap",
        "n_seeds": N_SEEDS,
        "n_trajectories": N_TRAJ,
        "gridness_threshold": GRIDNESS_THRESHOLD,
        "summary": summary,
        "elapsed_minutes": round(elapsed / 60, 1),
    }
    json_path = os.path.join(CKPT_DIR, "results.json")
    with open(json_path, "w") as f:
        json.dump(output, f, indent=2)
    print(f"\nSaved to {json_path}", flush=True)


if __name__ == "__main__":
    main()
