"""Banino et al. (2018) untrained baseline: architecture, place/HD cell ensembles, trajectory generation, and gridness scoring."""

import numpy as np
import torch
import torch.nn as nn
from scipy.ndimage import rotate as ndimage_rotate
from scipy.signal import fftconvolve
from scipy.special import softmax



class PlaceCellEnsemble:
    """256 place cells with Gaussian tuning (Banino eq. 1)."""

    def __init__(self, n_cells=256, sigma=0.01, env_size=2.2, seed=42):
        rng = np.random.RandomState(seed)
        self.centers = rng.uniform(0, env_size, size=(n_cells, 2))
        self.sigma = sigma
        self.n_cells = n_cells

    def encode(self, positions):
        """positions: (N, 2) -> activations: (N, n_cells), softmax-normalised."""
        dists_sq = np.sum(
            (positions[:, None, :] - self.centers[None, :, :]) ** 2, axis=-1
        )
        logits = -dists_sq / (2 * self.sigma ** 2)
        return softmax(logits, axis=-1)


class HeadDirectionEnsemble:
    """12 head-direction cells with von Mises tuning (Banino eq. 2)."""

    def __init__(self, n_cells=12, kappa=20):
        self.centers = np.linspace(-np.pi, np.pi, n_cells, endpoint=False)
        self.kappa = kappa
        self.n_cells = n_cells

    def encode(self, angles):
        """angles: (N,) -> activations: (N, n_cells), softmax-normalised."""
        logits = self.kappa * np.cos(angles[:, None] - self.centers[None, :])
        return softmax(logits, axis=-1)



def generate_trajectory(env_size=2.2, duration=15.0, dt=0.02,
                        sigma_v=0.13, mu_phi=0.0, sigma_phi_deg=330.0,
                        perimeter_dist=0.03, velocity_reduction=0.25,
                        angle_change_deg=90.0, seed=None):
    """Generate a single rat-like trajectory in a square arena."""
    rng = np.random.RandomState(seed)
    n_steps = int(duration / dt)
    sigma_phi = np.deg2rad(sigma_phi_deg)
    angle_change = np.deg2rad(angle_change_deg)

    positions = np.zeros((n_steps, 2))
    velocities = np.zeros((n_steps, 3))
    head_dirs = np.zeros(n_steps)

    pos = rng.uniform(perimeter_dist, env_size - perimeter_dist, size=2)
    hd = rng.uniform(-np.pi, np.pi)

    for t in range(n_steps):
        v = rng.rayleigh(sigma_v)
        dphi = rng.normal(mu_phi, sigma_phi)

        near_wall = (
            pos[0] < perimeter_dist or pos[0] > env_size - perimeter_dist or
            pos[1] < perimeter_dist or pos[1] > env_size - perimeter_dist
        )
        if near_wall:
            v *= velocity_reduction
            centre = np.array([env_size / 2, env_size / 2])
            to_centre = np.arctan2(centre[1] - pos[1], centre[0] - pos[0])
            angle_diff = to_centre - hd
            angle_diff = (angle_diff + np.pi) % (2 * np.pi) - np.pi
            dphi = np.sign(angle_diff) * angle_change

        hd = hd + dphi * dt
        hd = (hd + np.pi) % (2 * np.pi) - np.pi

        dx = v * np.cos(hd) * dt
        dy = v * np.sin(hd) * dt
        new_pos = pos + np.array([dx, dy])

        new_pos = np.clip(new_pos, 0.0, env_size)

        positions[t] = new_pos
        velocities[t] = [v, np.sin(dphi), np.cos(dphi)]
        head_dirs[t] = hd
        pos = new_pos

    return positions, velocities, head_dirs


def generate_trajectories(n_trajectories, seed=0, **kwargs):
    all_pos, all_vel, all_hd = [], [], []
    for i in range(n_trajectories):
        pos, vel, hd = generate_trajectory(seed=seed + i, **kwargs)
        all_pos.append(pos)
        all_vel.append(vel)
        all_hd.append(hd)
    return all_pos, all_vel, all_hd



class BaninoGridNetwork(nn.Module):
    """Banino et al. supervised LSTM with place/HD cell outputs."""

    def __init__(self, n_place_cells=256, n_hd_cells=12,
                 n_lstm_hidden=128, n_bottleneck=512, dropout_rate=0.5):
        super().__init__()
        self.n_lstm_hidden = n_lstm_hidden
        self.n_bottleneck = n_bottleneck

        self.lstm = nn.LSTMCell(input_size=3, hidden_size=n_lstm_hidden)

        self.init_cell_from_pc = nn.Linear(n_place_cells, n_lstm_hidden, bias=False)
        self.init_cell_from_hd = nn.Linear(n_hd_cells, n_lstm_hidden, bias=False)
        self.init_hidden_from_pc = nn.Linear(n_place_cells, n_lstm_hidden, bias=False)
        self.init_hidden_from_hd = nn.Linear(n_hd_cells, n_lstm_hidden, bias=False)

        self.bottleneck = nn.Linear(n_lstm_hidden, n_bottleneck, bias=True)
        self.dropout = nn.Dropout(dropout_rate)

        self.pc_output = nn.Linear(n_bottleneck, n_place_cells)
        self.hd_output = nn.Linear(n_bottleneck, n_hd_cells)

    def get_initial_state(self, pc_init, hd_init):
        """Compute initial LSTM (h_0, c_0) from place/HD cell activations."""
        h_0 = self.init_hidden_from_pc(pc_init) + self.init_hidden_from_hd(hd_init)
        c_0 = self.init_cell_from_pc(pc_init) + self.init_cell_from_hd(hd_init)
        return h_0, c_0

    def forward_sequence(self, velocity_seq, h_0, c_0):
        """Run LSTM over a velocity sequence, return bottleneck activations."""
        T = velocity_seq.shape[0]
        h, c = h_0, c_0
        bottleneck_acts = []

        for t in range(T):
            h, c = self.lstm(velocity_seq[t:t+1], (h, c))
            g = self.bottleneck(h)  # no dropout in eval mode
            bottleneck_acts.append(g)

        return torch.cat(bottleneck_acts, dim=0)  # (T, n_bottleneck)



class GridScorer:
    """Compute spatial ratemaps, autocorrelograms, and gridness scores."""

    def __init__(self, nbins=32, env_size=2.2):
        self.nbins = nbins
        self.env_size = env_size

    def compute_ratemap(self, positions, activations):
        bins = np.linspace(0, self.env_size, self.nbins + 1)
        sum_map = np.zeros((self.nbins, self.nbins))
        count_map = np.zeros((self.nbins, self.nbins))

        ix = np.clip(np.digitize(positions[:, 0], bins) - 1, 0, self.nbins - 1)
        iy = np.clip(np.digitize(positions[:, 1], bins) - 1, 0, self.nbins - 1)

        np.add.at(sum_map, (ix, iy), activations)
        np.add.at(count_map, (ix, iy), 1)

        ratemap = np.full((self.nbins, self.nbins), np.nan)
        visited = count_map > 0
        ratemap[visited] = sum_map[visited] / count_map[visited]
        return ratemap

    def compute_sac(self, ratemap):
        """Compute spatial autocorrelogram via FFT-based correlation."""
        rm = ratemap.copy()
        mask = ~np.isnan(rm)
        rm[~mask] = 0.0

        if mask.sum() > 0:
            rm[mask] -= rm[mask].mean()

        n_corr = fftconvolve(rm, rm[::-1, ::-1], mode='full')
        n_overlap = fftconvolve(mask.astype(float),
                                mask.astype(float)[::-1, ::-1], mode='full')
        n_overlap = np.maximum(n_overlap, 1)

        sac = n_corr / n_overlap

        peak = sac.max()
        if peak > 0:
            sac /= peak

        return sac

    def compute_gridness(self, sac):
        """Expanding-annulus gridness score (Sargolini 2006 / Banino)."""
        center = np.array(sac.shape) // 2
        y_idx, x_idx = np.ogrid[:sac.shape[0], :sac.shape[1]]
        dist_from_center = np.sqrt((x_idx - center[1]) ** 2 +
                                   (y_idx - center[0]) ** 2)

        rotated_sacs = {}
        for angle in [30, 60, 90, 120, 150]:
            rot = ndimage_rotate(sac, angle, reshape=False, order=1,
                                 mode='constant', cval=np.nan)
            rotated_sacs[angle] = rot

        best_gridness = -2.0
        min_inner = 3
        max_outer = min(center[0], center[1])

        for inner_r in range(min_inner, max(min_inner + 1, max_outer - 4), 2):
            for outer_r in range(inner_r + 4, max_outer, 2):
                annulus_mask = (dist_from_center >= inner_r) & \
                               (dist_from_center <= outer_r)
                if annulus_mask.sum() < 30:
                    continue

                annulus_vals = sac[annulus_mask]

                corrs = {}
                for angle in [30, 60, 90, 120, 150]:
                    rot_vals = rotated_sacs[angle][annulus_mask]
                    valid = ~np.isnan(annulus_vals) & ~np.isnan(rot_vals)
                    if valid.sum() < 20:
                        corrs[angle] = -1.0
                        continue
                    a = annulus_vals[valid]
                    b = rot_vals[valid]
                    if a.std() == 0 or b.std() == 0:
                        corrs[angle] = 0.0
                        continue
                    corrs[angle] = np.corrcoef(a, b)[0, 1]

                gridness = (min(corrs[60], corrs[120]) -
                            max(corrs[30], corrs[90], corrs[150]))

                if gridness > best_gridness:
                    best_gridness = gridness

        return best_gridness

    def compute_all_gridness(self, positions, activations_matrix):
        n_units = activations_matrix.shape[1]
        scores = np.zeros(n_units)
        ratemaps = []

        for u in range(n_units):
            rm = self.compute_ratemap(positions, activations_matrix[:, u])
            ratemaps.append(rm)
            sac = self.compute_sac(rm)
            scores[u] = self.compute_gridness(sac)
            if (u + 1) % 100 == 0:
                print(f"scored {u + 1}/{n_units} units", flush=True)

        return scores, ratemaps



@torch.no_grad()
def collect_bottleneck_activations(model, velocities_list, positions_list,
                                   head_dirs_list, pc_ensemble, hd_ensemble,
                                   use_position_init=True):
    """Run untrained model on trajectories, collect bottleneck activations."""
    model.eval()
    device = next(model.parameters()).device

    all_pos = []
    all_acts = []

    for i, (vel, pos, hd) in enumerate(
        zip(velocities_list, positions_list, head_dirs_list)
    ):
        vel_t = torch.tensor(vel, dtype=torch.float32, device=device)

        if use_position_init:
            pc_act = pc_ensemble.encode(pos[0:1])  # (1, 256)
            hd_act = hd_ensemble.encode(hd[0:1])   # (1, 12)
            pc_t = torch.tensor(pc_act, dtype=torch.float32, device=device)
            hd_t = torch.tensor(hd_act, dtype=torch.float32, device=device)
            h_0, c_0 = model.get_initial_state(pc_t, hd_t)
        else:
            h_0 = torch.zeros(1, model.n_lstm_hidden, device=device)
            c_0 = torch.zeros(1, model.n_lstm_hidden, device=device)

        acts = model.forward_sequence(vel_t, h_0, c_0)  # (T, 512)
        all_pos.append(pos)
        all_acts.append(acts.cpu().numpy())

        if (i + 1) % 200 == 0:
            print(f"processed {i + 1}/{len(velocities_list)} trajectories",
                      flush=True)

    return np.concatenate(all_pos, axis=0), np.concatenate(all_acts, axis=0)
