import os
import numpy as np
import torch
import torch.nn as nn
from umap import UMAP
from pathlib import Path

# Set paths
data_dir = "data"
output_dir = "ctis_master_results_0512"
session_ids = [1, 2, 3, 4, 5, 6, 7, 8]


class LowRankRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, rank):
        super().__init__()
        self.U = nn.Parameter(torch.randn(hidden_dim, rank))
        self.V = nn.Parameter(torch.randn(hidden_dim, rank))
        self.W_input = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.Tanh()

    def forward(self, x, h):
        W_r = self.U @ self.V.T
        h_next = self.activation(W_r @ h + self.W_input(x))
        return h_next

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_crcns_spike_train(res_file, clu_file, bin_size=0.1, duration=600, sampling_rate=20000):
    spikes = np.loadtxt(res_file, dtype=int)
    clusters = np.loadtxt(clu_file, dtype=int)
    if len(clusters) == len(spikes) + 1:
        clusters = clusters[1:]
    assert len(spikes) == len(clusters)
    num_bins = int(duration / bin_size)
    n_units = int(clusters.max())
    binned = np.zeros((num_bins, n_units))
    for spike_time, cluster_id in zip(spikes, clusters):
        t = spike_time / sampling_rate
        bin_idx = int(t / bin_size)
        if 0 <= bin_idx < num_bins:
            binned[bin_idx, cluster_id - 1] += 1
    binned = (binned - binned.mean(axis=0)) / (binned.std(axis=0) + 1e-5)
    return binned

def run_rnn(model, input_seq):
    device = next(model.parameters()).device
    h = torch.zeros(model.U.shape[0], device=device)
    hidden_states = []
    with torch.no_grad():
        for v in input_seq:
            v_tensor = torch.tensor(v, dtype=torch.float32, device=device)
            h = model(v_tensor, h)
            hidden_states.append(h.cpu().numpy())
    return np.stack(hidden_states)

def add_jitter(data, scale=1e-4):
    rng = np.random.default_rng(42)
    return data + rng.normal(0, scale, data.shape)

def apply_random_lesion(model, level=0.3):
    with torch.no_grad():
        mask = torch.rand_like(model.U) < level
        model.U[mask] = 0.0

for session_id in session_ids:
    session_path = os.path.join(output_dir, f"session_{session_id}")
    if not os.path.exists(session_path):
        continue

    print(f"Generating umap_lesion.npy for session {session_id}")
    res_file = os.path.join(data_dir, f"ec013.544.res.{session_id}")
    clu_file = os.path.join(data_dir, f"ec013.544.clu.{session_id}")
    binned = load_crcns_spike_train(res_file, clu_file)[::5]

    model = LowRankRNN(input_dim=binned.shape[1], hidden_dim=64, rank=4).to(get_device())
    apply_random_lesion(model, level=0.3)

    hidden_lesion = run_rnn(model, binned[::20])
    embed_lesion = UMAP(n_components=3).fit_transform(add_jitter(hidden_lesion.astype(np.float32)))

    lesion_path = os.path.join(session_path, "umap_lesion.npy")
    np.save(lesion_path, embed_lesion)
    print(f"Saved to {lesion_path}")
