#!/usr/bin/env python3
"""Banino et al. (2018) untrained baseline experiment."""
import json
import os
import sys
import time

os.environ["PYTHONUNBUFFERED"] = "1"

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from src.banino_grid_cells import (
    PlaceCellEnsemble,
    HeadDirectionEnsemble,
    BaninoGridNetwork,
    GridScorer,
    generate_trajectories,
    collect_bottleneck_activations,
)


N_SEEDS = 10
N_TRAJECTORIES = 1000
GRIDNESS_THRESHOLD = 0.37
N_PLACE_CELLS = 256
N_HD_CELLS = 12
PC_SIGMA = 0.01
HD_KAPPA = 20
ENV_SIZE = 2.2
LSTM_HIDDEN = 128
N_BOTTLENECK = 512
DURATION = 15.0
DT = 0.02
NBINS = 32

CKPT_DIR = os.path.join(os.path.dirname(__file__), "..", "checkpoints", "banino_untrained")
FIG_DIR = os.path.join(os.path.dirname(__file__), "..", "figures")


def ensure_dirs():
    os.makedirs(CKPT_DIR, exist_ok=True)
    os.makedirs(FIG_DIR, exist_ok=True)


def run_single_seed(seed, positions, velocities, head_dirs,
                    pc_ensemble, hd_ensemble, scorer):
    results = {"seed": seed}

    for condition, use_init in [("position_init", True), ("zero_init", False)]:
        print(f"\n  Seed {seed}, condition: {condition}")
        torch.manual_seed(seed)
        model = BaninoGridNetwork(
            n_place_cells=N_PLACE_CELLS, n_hd_cells=N_HD_CELLS,
            n_lstm_hidden=LSTM_HIDDEN, n_bottleneck=N_BOTTLENECK,
            dropout_rate=0.5,
        )
        model.eval()

        all_pos, all_acts = collect_bottleneck_activations(
            model, velocities, positions, head_dirs,
            pc_ensemble, hd_ensemble,
            use_position_init=use_init,
        )

        print(f"Computing gridness for {all_acts.shape[1]} units...")
        gridness_scores, ratemaps = scorer.compute_all_gridness(all_pos, all_acts)

        n_grid = int(np.sum(gridness_scores > GRIDNESS_THRESHOLD))
        pct_grid = n_grid / len(gridness_scores) * 100

        top_idx = np.argsort(gridness_scores)[::-1][:10]

        results[condition] = {
            "gridness_scores": gridness_scores.tolist(),
            "n_grid_like": n_grid,
            "pct_grid_like": round(pct_grid, 2),
            "mean_gridness": round(float(np.mean(gridness_scores)), 4),
            "median_gridness": round(float(np.median(gridness_scores)), 4),
            "max_gridness": round(float(np.max(gridness_scores)), 4),
            "percentile_95": round(float(np.percentile(gridness_scores, 95)), 4),
            "top10_indices": top_idx.tolist(),
            "top10_gridness": gridness_scores[top_idx].tolist(),
        }

        top_ratemaps = [ratemaps[i] for i in top_idx]
        np.savez(
            os.path.join(CKPT_DIR, f"ratemaps_s{seed}_{condition}.npz"),
            ratemaps=np.array(top_ratemaps),
            gridness=gridness_scores[top_idx],
            indices=top_idx,
        )

        print(f"-> {n_grid}/{len(gridness_scores)} grid-like "
              f"({pct_grid:.1f}%), max gridness={gridness_scores.max():.3f}")

    return results


def plot_gridness_histogram(all_results):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))

    for ax, condition, title in [
        (axes[0], "position_init", "With position init"),
        (axes[1], "zero_init", "Zero init"),
    ]:
        pooled = []
        for r in all_results:
            pooled.extend(r[condition]["gridness_scores"])
        pooled = np.array(pooled)

        ax.hist(pooled, bins=60, range=(-1.5, 1.5), color="steelblue",
                alpha=0.7, edgecolor="white", linewidth=0.5,
                label=f"Untrained (n={len(pooled)})")
        ax.axvline(GRIDNESS_THRESHOLD, color="red", linestyle="--", linewidth=1.5,
                   label=f"Threshold = {GRIDNESS_THRESHOLD}")
        n_above = np.sum(pooled > GRIDNESS_THRESHOLD)
        pct = n_above / len(pooled) * 100
        ax.set_title(f"{title}\n{n_above}/{len(pooled)} above threshold ({pct:.1f}%)")
        ax.set_xlabel("Gridness score")
        ax.set_ylabel("Count")
        ax.legend(fontsize=8)

    fig.suptitle("Untrained Banino LSTM: gridness score distributions", fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, "banino_gridness_histogram.png"), dpi=200)
    plt.close()
    print("Saved banino_gridness_histogram.png")


def plot_example_ratemaps(all_results):
    """Show top-5 ratemaps + autocorrelograms for best seed, both conditions."""
    scorer = GridScorer(nbins=NBINS, env_size=ENV_SIZE)

    for condition in ["position_init", "zero_init"]:
        best_seed = max(range(len(all_results)),
                        key=lambda i: all_results[i][condition]["max_gridness"])
        seed = all_results[best_seed]["seed"]

        data = np.load(os.path.join(CKPT_DIR, f"ratemaps_s{seed}_{condition}.npz"))
        ratemaps = data["ratemaps"][:5]
        gridness_vals = data["gridness"][:5]

        fig, axes = plt.subplots(2, 5, figsize=(14, 6))
        for col in range(min(5, len(ratemaps))):
            rm = ratemaps[col]

            ax = axes[0, col]
            rm_plot = rm.copy()
            rm_plot[np.isnan(rm_plot)] = 0
            ax.imshow(rm_plot.T, origin="lower", cmap="hot", interpolation="nearest")
            ax.set_title(f"g={gridness_vals[col]:.3f}", fontsize=9)
            ax.set_xticks([])
            ax.set_yticks([])

            ax2 = axes[1, col]
            sac = scorer.compute_sac(rm)
            ax2.imshow(sac, origin="lower", cmap="jet", interpolation="nearest",
                       vmin=-1, vmax=1)
            ax2.set_xticks([])
            ax2.set_yticks([])

        axes[0, 0].set_ylabel("Ratemap", fontsize=10)
        axes[1, 0].set_ylabel("Autocorrelogram", fontsize=10)
        cond_label = "position init" if condition == "position_init" else "zero init"
        fig.suptitle(
            f"Top-5 untrained units ({cond_label}, seed {seed})", fontsize=12
        )
        plt.tight_layout()
        fname = f"banino_ratemaps_{condition}.png"
        plt.savefig(os.path.join(FIG_DIR, fname), dpi=200)
        plt.close()
        print(f"Saved {fname}")


def plot_bar_comparison(summary):
    labels = ["Untrained\n(position init)", "Untrained\n(zero init)",
              "Trained\n(Banino 2018)"]
    means = [
        summary["position_init"]["mean_pct"],
        summary["zero_init"]["mean_pct"],
        23.0,
    ]
    stds = [
        summary["position_init"]["std_pct"],
        summary["zero_init"]["std_pct"],
        2.8,
    ]
    colors = ["#4C72B0", "#55A868", "#C44E52"]

    fig, ax = plt.subplots(figsize=(5, 4.5))
    x = np.arange(len(labels))
    bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors, edgecolor="white",
                  width=0.6)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_ylabel("% units with gridness > 0.37")
    ax.set_title("Grid-like units: untrained vs trained")
    ax.set_ylim(0, max(means) * 1.5 + 5)

    for bar, m, s in zip(bars, means, stds):
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + s + 0.5,
                f"{m:.1f}%", ha="center", va="bottom", fontsize=10)

    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, "banino_bar_comparison.png"), dpi=200)
    plt.close()
    print("Saved banino_bar_comparison.png")


def main():
    ensure_dirs()
    t0 = time.time()

    print(f"Generating {N_TRAJECTORIES} trajectories "
          f"({DURATION}s each, dt={DT})...")
    positions_list, velocities_list, hd_list = generate_trajectories(
        N_TRAJECTORIES, seed=0,
        env_size=ENV_SIZE, duration=DURATION, dt=DT,
    )
    print(f"{len(positions_list)} trajectories, "
          f"{positions_list[0].shape[0]} steps each")

    all_pos_flat = np.concatenate(positions_list, axis=0)
    print(f"Position range: x=[{all_pos_flat[:,0].min():.3f}, "
          f"{all_pos_flat[:,0].max():.3f}], "
          f"y=[{all_pos_flat[:,1].min():.3f}, {all_pos_flat[:,1].max():.3f}]")

    pc_ensemble = PlaceCellEnsemble(n_cells=N_PLACE_CELLS, sigma=PC_SIGMA,
                                    env_size=ENV_SIZE, seed=42)
    hd_ensemble = HeadDirectionEnsemble(n_cells=N_HD_CELLS, kappa=HD_KAPPA)

    scorer = GridScorer(nbins=NBINS, env_size=ENV_SIZE)

    all_results = []
    for seed in range(N_SEEDS):
        print(f"\nSeed {seed}/{N_SEEDS - 1}")
        result = run_single_seed(
            seed, positions_list, velocities_list, hd_list,
            pc_ensemble, hd_ensemble, scorer,
        )
        all_results.append(result)

    summary = {}
    for condition in ["position_init", "zero_init"]:
        pcts = [r[condition]["pct_grid_like"] for r in all_results]
        all_gridness = []
        for r in all_results:
            all_gridness.extend(r[condition]["gridness_scores"])
        all_gridness = np.array(all_gridness)

        summary[condition] = {
            "mean_pct": round(float(np.mean(pcts)), 2),
            "std_pct": round(float(np.std(pcts)), 2),
            "per_seed_pct": pcts,
            "pooled_mean_gridness": round(float(np.mean(all_gridness)), 4),
            "pooled_median_gridness": round(float(np.median(all_gridness)), 4),
            "pooled_percentile_95": round(float(np.percentile(all_gridness, 95)), 4),
            "pooled_percentile_99": round(float(np.percentile(all_gridness, 99)), 4),
            "pooled_n_above_threshold": int(np.sum(all_gridness > GRIDNESS_THRESHOLD)),
            "pooled_total": len(all_gridness),
        }

    elapsed = time.time() - t0
    print(f"\nTrained reference (Banino 2018): 23.0% +/- 2.8%, threshold={GRIDNESS_THRESHOLD}")

    for condition in ["position_init", "zero_init"]:
        s = summary[condition]
        label = "Position init" if condition == "position_init" else "Zero init"
        print(f"{label}: {s['mean_pct']:.1f}% +/- {s['std_pct']:.1f}% grid-like, "
              f"pooled mean={s['pooled_mean_gridness']:.4f}, "
              f"p95={s['pooled_percentile_95']:.4f}, "
              f"{s['pooled_n_above_threshold']}/{s['pooled_total']} above threshold")

    output = {
        "experiment": "banino_untrained_baseline",
        "n_seeds": N_SEEDS,
        "architecture": {
            "lstm_hidden": LSTM_HIDDEN,
            "bottleneck": N_BOTTLENECK,
            "n_place_cells": N_PLACE_CELLS,
            "n_hd_cells": N_HD_CELLS,
            "pc_sigma": PC_SIGMA,
            "hd_kappa": HD_KAPPA,
            "env_size": ENV_SIZE,
        },
        "trajectory_params": {
            "n_trajectories": N_TRAJECTORIES,
            "duration_s": DURATION,
            "dt": DT,
            "nbins": NBINS,
        },
        "trained_reference": {
            "source": "Banino et al. 2018 Nature",
            "pct_grid_like": "23% +/- 2.8% (100 retrains)",
            "gridness_threshold": GRIDNESS_THRESHOLD,
            "n_grid_units_single_run": "129/512 (25.2%)",
        },
        "summary": summary,
        "per_seed": [
            {
                "seed": r["seed"],
                "position_init": {
                    k: v for k, v in r["position_init"].items()
                    if k != "gridness_scores"
                },
                "zero_init": {
                    k: v for k, v in r["zero_init"].items()
                    if k != "gridness_scores"
                },
            }
            for r in all_results
        ],
        "elapsed_minutes": round(elapsed / 60, 1),
    }

    json_path = os.path.join(CKPT_DIR, "results.json")
    with open(json_path, "w") as f:
        json.dump(output, f, indent=2)
    print(f"Saved results to {json_path}")

    for condition in ["position_init", "zero_init"]:
        pooled = []
        for r in all_results:
            pooled.extend(r[condition]["gridness_scores"])
        np.save(
            os.path.join(CKPT_DIR, f"gridness_pooled_{condition}.npy"),
            np.array(pooled),
        )

    print("\nGenerating figures...")
    plot_gridness_histogram(all_results)
    plot_example_ratemaps(all_results)
    plot_bar_comparison(summary)

    print("\nDone.")


if __name__ == "__main__":
    main()
