# %%
import pickle

import sys

# why?! might not be necessary for you guys
sys.path.insert(0, "/mnt/volume/latent-diffusion")

# TODO further clean this up:
# rename result files, update save path, find better names for plots
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from einops import rearrange
import matplotlib

matplotlib.rc_file("../matplotlibrc")

from ntldm.utils.eval_utils import (
    compute_spike_stats_per_neuron,
    counts_to_spike_trains,
    get_temp_corr_summary,
    group_neurons_temp_corr,
)
from ntldm.utils.plotting_utils import (
    cm2inch,
    plot_cross_corr_summary,
    plot_rate_comparisons,
    plot_spiketrain_stats,
    get_group_colors,
)

# %%
data_path = "/mnt/volume/"
save_path = "/mnt/volume/"

data_path = "/home/anonauthor/anonloc1/results/projects/latent-diffusion/DATA_SHARE/"
with open(data_path + "test_data_dict.pkl", "rb") as f:
    rec_dict_numpy_test = pickle.load(f)

with open(data_path + "train_data_dict.pkl", "rb") as f:
    rec_dict_numpy_train = pickle.load(f)

save_path = "/home/anonauthor/anonloc1/results/projects/latent-diffusion/MONKEY/figures/spike-history-5max/"
os.makedirs(save_path, exist_ok=True)


# %%
# core functions
def extract_timewise_histories(spikes, rates, history_length):
    """Extracts the spike history and rates for each time step in the data."""
    _num_trials, _num_neurons, num_timesteps = spikes.shape

    target_ls = []
    spike_history_ls = []
    rates_ls = []
    for i in range(history_length, num_timesteps):
        target_ls.append(spikes[:, :, [i]])
        spike_history_ls.append(spikes[:, :, i - history_length : i])
        rates_ls.append(rates[:, :, [i]])

    # keep singleton dimension here, can refactor to make this clearer
    target = torch.stack(target_ls, dim=2)
    spike_history = torch.stack(spike_history_ls, dim=2)
    rates = torch.stack(rates_ls, dim=2)
    offset = torch.ones_like(rates)

    combine = torch.cat([spike_history, rates, offset], dim=3)
    target_flat = rearrange(target, "b n t d -> (b t d) n")
    combine_flat = rearrange(combine, "b n t d -> (b t) n d")

    return combine_flat, target_flat


class CombinedLogisiticRegression(nn.Module):
    def __init__(self, num_regs, dim, nonlin="softplus"):
        super(CombinedLogisiticRegression, self).__init__()
        self.linear = nn.Parameter(torch.randn(num_regs, dim))
        if nonlin == "softplus":
            self.nonlin = F.softplus
        elif nonlin == "sigmoid":
            self.nonlin = F.sigmoid

    def forward(self, x):
        return self.nonlin(torch.sum(torch.mul(x, self.linear), dim=-1))


def sample_from_rates_zero_start(log_rates, model, history_length, max_spikes=3):
    num_trials, num_neurons, num_timesteps = log_rates.shape

    spikes = torch.zeros(num_trials, num_neurons, num_timesteps + history_length)

    for i in range(history_length, num_timesteps + history_length):
        reg_input = torch.cat(
            [
                spikes[:, :, i - history_length : i],
                log_rates[:, :, [i - history_length]],
                torch.ones_like(log_rates[:, :, [i - history_length]]),
            ],
            dim=2,
        )
        with torch.no_grad():
            p = model.forward(reg_input)
        if max_spikes == 1:
            spikes[:, :, i] = torch.bernoulli(p)
        else:
            spikes[:, :, i] = torch.min(
                torch.poisson(p), torch.ones_like(p) * max_spikes
            )

    return spikes[:, :, history_length:]


######
num_neurons = 182
ts_length = 140
history_length = 20
max_spikes = 5
######

# train and test spikes are the same, later the diffusion sampled rated will be used to compare against test spikes
train_spikes = torch.from_numpy(np.float32(rec_dict_numpy_train["gt_spikes"]))
train_spikes = torch.min(train_spikes, max_spikes * torch.ones_like(train_spikes))
test_spikes = torch.from_numpy(np.float32(rec_dict_numpy_train["gt_spikes"]))
test_spikes = torch.min(test_spikes, max_spikes * torch.ones_like(test_spikes))

# train rates need to fit, so we use the ones from the autoencoder. For eval, the diffusion rates.
# This assumes that the diffusion rate are sufficiently different/not memorized
train_log_rates = torch.from_numpy(np.log(np.float32(rec_dict_numpy_train["ae_rates"])))
test_log_rates = torch.from_numpy(
    np.log(np.float32(rec_dict_numpy_train["diffusion_rates"]))
)

# transpose last two dims
train_spikes = torch.transpose(train_spikes, 2, 1)
test_spikes = torch.transpose(test_spikes, 2, 1)
train_log_rates = torch.transpose(train_log_rates, 2, 1)
test_log_rates = torch.transpose(test_log_rates, 2, 1)

print(train_spikes.shape)
print(train_log_rates.shape)
print(test_spikes.shape)
print(test_log_rates.shape)

# %%
X_train, y_train = extract_timewise_histories(
    train_spikes, train_log_rates, history_length
)

# add + 2 columns to matrix for bias and rate term
model = CombinedLogisiticRegression(
    num_neurons,
    history_length + 2,
    "sigmoid" if max_spikes == 1 else "softplus",
)
criterion = nn.PoissonNLLLoss(log_input=False, full=True)
# good/decent results with lr=0.1, weight_decay=0.01 or 0.1 for Poisson
optimizer = optim.AdamW(model.parameters(), lr=0.1, weight_decay=0.01)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
X_train = X_train.to(device)
y_train = y_train.to(device)

# %%
# not 100% sure if this is sensible to train jointly
for epoch in range(1000):
    model.train()
    optimizer.zero_grad()

    y_pred = model(X_train)
    loss = criterion(y_pred.squeeze(), y_train.float())

    print("Epoch {}: train loss: {}".format(epoch, loss.item()))

    loss.backward()
    optimizer.step()

# %%
# torch.save(model.state_dict(), save_path + "model_v2.pt")

# %%
plt.plot(model.linear.detach().cpu().T, color="red", alpha=0.3)
plt.plot(np.median(model.linear.detach().cpu(), axis=0), color="black", linewidth=2)
plt.show()

# Sample on cpu()
model = model.cpu()
hist_spikes = sample_from_rates_zero_start(
    test_log_rates,
    model,
    history_length,
    max_spikes=max_spikes,
)

test_spikes_numpy = test_spikes.numpy()
pois_spikes_numpy = torch.min(
    torch.poisson(torch.exp(test_log_rates)),
    torch.ones_like(test_log_rates) * max_spikes,
).numpy()
hist_spikes_numpy = hist_spikes.numpy()

print(np.max(test_spikes_numpy))
print(np.max(pois_spikes_numpy))
print(np.max(hist_spikes_numpy))

plt.imshow(test_spikes_numpy[0])
plt.show()
plt.imshow(pois_spikes_numpy[0])
plt.show()
plt.imshow(hist_spikes_numpy[0])
plt.show()

test_spikes_numpy = test_spikes_numpy.transpose(0, 2, 1)
pois_spikes_numpy = pois_spikes_numpy.transpose(0, 2, 1)
hist_spikes_numpy = hist_spikes_numpy.transpose(0, 2, 1)

print(test_spikes_numpy.shape)
print(pois_spikes_numpy.shape)
print(hist_spikes_numpy.shape)

# %%
fps = 1000 / 5
spike_trains_test = counts_to_spike_trains(test_spikes_numpy, fps=fps)
spike_trains_pois = counts_to_spike_trains(pois_spikes_numpy, fps=fps)
spike_trains_hist = counts_to_spike_trains(hist_spikes_numpy, fps=fps)

# %%
spike_stats_test = compute_spike_stats_per_neuron(
    spike_trains_test,
    n_samples=test_spikes_numpy.shape[0],
    n_neurons=test_spikes_numpy.shape[2],
    mean_output=False,
)
spike_stats_pois = compute_spike_stats_per_neuron(
    spike_trains_pois,
    n_samples=pois_spikes_numpy.shape[0],
    n_neurons=pois_spikes_numpy.shape[2],
    mean_output=False,
)
spike_stats_hist = compute_spike_stats_per_neuron(
    spike_trains_hist,
    n_samples=hist_spikes_numpy.shape[0],
    n_neurons=hist_spikes_numpy.shape[2],
    mean_output=False,
)

plot_spiketrain_stats(
    spike_stats_test,
    spike_stats_pois,
    figsize=cm2inch(30, 10),
    color="red",
    labels=["gt", "pois"],
    save=True,
    save_path=save_path + "compute_spike_stats_per_neuron_gt_ae",
)

plot_spiketrain_stats(
    spike_stats_test,
    spike_stats_hist,
    figsize=cm2inch(30, 10),
    color="green",
    labels=["gt", "hist"],
    save=True,
    save_path=save_path + "compute_spike_stats_per_neuron_gt_ae",
)


# %%
figsize = cm2inch((15, 15))
plot_rate_comparisons(
    test_spikes_numpy,
    [pois_spikes_numpy, hist_spikes_numpy],
    fps=1000 / 5,
    mode="neur",
    figsize=figsize,
    colors=["red", "green"],
    save=True,
    labels=["pois", "hist"],
    xlabel="firing rate (Hz)",
    save_path=save_path + "mean_spike_comparison_gt_ae_diffusion",
)

# %%
groups = group_neurons_temp_corr(test_spikes_numpy.transpose(1, 0, 2), num_groups=4)
cross_corr_groups_test, auto_corr_groups_test = get_temp_corr_summary(
    test_spikes_numpy.transpose(1, 0, 2),
    groups,
    nlags=30,
    mode="biased",
    batch_first=False,
)
cross_corr_groups_pois, auto_corr_groups_pois = get_temp_corr_summary(
    pois_spikes_numpy.transpose(1, 0, 2),
    groups,
    nlags=30,
    mode="biased",
    batch_first=False,
)
cross_corr_groups_hist, auto_corr_groups_hist = get_temp_corr_summary(
    hist_spikes_numpy.transpose(1, 0, 2),
    groups,
    nlags=30,
    mode="biased",
    batch_first=False,
)


# %%
fig, ax = plt.subplots(1, 1, figsize=cm2inch(12, 10))
save = True
plot_cross_corr_summary(
    cross_corr_groups_test,
    name="gt",
    figsize=cm2inch(6, 4),
    cmap="Greys",
    ax_corr=ax,
    labels="gt",
    ncol=2,
)
plot_cross_corr_summary(
    cross_corr_groups_pois,
    name="pois",
    figsize=cm2inch(6, 4),
    cmap="Reds",
    ax_corr=ax,
    labels="pois",
    ncol=2,
)
if save and save_path is not None:
    fig.savefig(save_path + "cross_corr_all.png")
    fig.savefig(save_path + "cross_corr_all.pdf")

fig, ax = plt.subplots(1, 1, figsize=cm2inch(12, 10))
save = True
plot_cross_corr_summary(
    cross_corr_groups_test,
    name="gt",
    figsize=cm2inch(6, 4),
    cmap="Greys",
    ax_corr=ax,
    labels="gt",
    ncol=2,
)
plot_cross_corr_summary(
    cross_corr_groups_hist,
    name="hist",
    figsize=cm2inch(6, 4),
    cmap="Greens",
    ax_corr=ax,
    labels="hist",
    ncol=2,
)
if save and save_path is not None:
    fig.savefig(save_path + "cross_corr_gt_diff.png")
    fig.savefig(save_path + "cross_corr_gt_diff.pdf")


fig, ax = plt.subplots(1, 1, figsize=cm2inch(12, 10))
plot_cross_corr_summary(
    auto_corr_groups_test,
    name="gt",
    figsize=cm2inch(6, 4),
    cmap="Greys",
    ax_corr=ax,
    labels="gt",
    ncol=2,
    title="auto-corr",
    ylabel="auto-corr",
)
plot_cross_corr_summary(
    auto_corr_groups_pois,
    name="pois",
    figsize=cm2inch(6, 4),
    cmap="Reds",
    ax_corr=ax,
    labels="pois",
    ncol=2,
    title="auto-corr",
    ylabel="auto-corr",
)
if save and save_path is not None:
    fig.savefig(save_path + "auto_corr_all.png")
    fig.savefig(save_path + "auto_corr_all.pdf")


fig, ax = plt.subplots(1, 1, figsize=cm2inch(12, 10))
save = True
plot_cross_corr_summary(
    auto_corr_groups_test,
    name="gt",
    figsize=cm2inch(6, 4),
    cmap="Greys",
    ax_corr=ax,
    labels="gt",
    ncol=2,
)
plot_cross_corr_summary(
    auto_corr_groups_hist,
    name="hist",
    figsize=cm2inch(6, 4),
    cmap="Greens",
    ax_corr=ax,
    labels="hist",
    ncol=2,
    title="auto-corr",
    ylabel="auto-corr",
)
if save and save_path is not None:
    fig.savefig(save_path + "auto_corr_gt_diff.png")
    fig.savefig(save_path + "auto_corr_gt_diff.pdf")


# %%
# SPECIAL PLOT
def history_plot_cross_corr_summary(
    cross_corr_groups,
    binWidth=None,
    cmap="Reds",
    figsize=(6, 4),
    linestyle="-",
    ax_corr=None,
    labels="group",
    title="cross-corr",
    xlabel="lag",
    ms=2,
    lw=1.5,
):
    num_groups = len(cross_corr_groups)
    g_colors = get_group_colors(num_groups, cmap)

    if ax_corr is None:
        fig_corr, ax_corr = plt.subplots(figsize=figsize)
    ax_corr.set_title(title)

    nlag = int((len(cross_corr_groups[0]) - 1) / 2)
    x_ticks = np.arange(-nlag, nlag + 1, 1)

    # set
    if binWidth is not None:
        x_ticks = x_ticks * binWidth

    for ind in range(num_groups):
        cross_corr_groups[ind][nlag] = np.nan
        ax_corr.plot(
            x_ticks,
            cross_corr_groups[ind],
            linestyle,
            lw=lw,
            ms=ms,
            color=g_colors[num_groups - ind - 1],
            label=f"{labels} {ind + 1}",
            solid_capstyle="round",
        )
    ax_corr.set_xlabel(xlabel)

    return ax_corr


# %%
fig, ax = plt.subplots(1, 2, figsize=cm2inch(8.5, 2.5))
save = True
history_plot_cross_corr_summary(
    auto_corr_groups_test[:2],
    figsize=cm2inch(5, 3),
    cmap="Greys",
    ax_corr=ax[1],
    labels="gt ",
)
history_plot_cross_corr_summary(
    auto_corr_groups_hist[:2],
    figsize=cm2inch(5, 3),
    cmap="Greens",
    ax_corr=ax[1],
    labels="diff ",
    title="with spike history",
)
history_plot_cross_corr_summary(
    auto_corr_groups_test[:2],
    figsize=cm2inch(5, 3),
    cmap="Greys",
    ax_corr=ax[0],
    labels="gt ",
)
history_plot_cross_corr_summary(
    auto_corr_groups_pois[:2],
    figsize=cm2inch(5, 3),
    cmap="Reds",
    ax_corr=ax[0],
    labels="diff ",
    title="Poisson",
)
# ax[0].set_yticks([-0.0005, 0.0005])
# ax[0].set_yticklabels(["-0.0005", "0.0005"])
ax[1].set_yticklabels([])
ax[0].set_ylabel("auto-corr")
ax[0].legend(fontsize=4, loc="lower right")
ax[1].legend(fontsize=4, loc="lower right")

if save and save_path is not None:
    fig.savefig(save_path + "Fig_3_auto_corr_gt_diff.png")
    fig.savefig(save_path + "Fig_3_auto_corr_gt_diff.pdf")

# %%
# save rates and spikes and sample multiple times
history_spikes = []

for i in range(5):

    hist_spikes = sample_from_rates_zero_start(
        test_log_rates,
        model,
        history_length,
        max_spikes=max_spikes,
    )

    test_spikes_numpy = test_spikes.numpy()
    pois_spikes_numpy = torch.min(
        torch.poisson(torch.exp(test_log_rates)),
        torch.ones_like(test_log_rates) * max_spikes,
    ).numpy()
    hist_spikes_numpy = hist_spikes.numpy()

    plt.imshow(test_spikes_numpy[0])
    plt.show()
    plt.imshow(pois_spikes_numpy[0])
    plt.show()
    plt.imshow(hist_spikes_numpy[0])
    plt.show()

    test_spikes_numpy = test_spikes_numpy.transpose(0, 2, 1)
    pois_spikes_numpy = pois_spikes_numpy.transpose(0, 2, 1)
    hist_spikes_numpy = hist_spikes_numpy.transpose(0, 2, 1)

    history_spikes.append(hist_spikes_numpy)

history_sampled_dict = {
    "gt": test_spikes_numpy,
    "pois": pois_spikes_numpy,
    "rate": torch.exp(test_log_rates).numpy().transpose(0, 2, 1),
    "hist": history_spikes,
}
# save pickle file
with open(data_path + "history_sampled_dict.pkl", "wb") as f:
    pickle.dump(history_sampled_dict, f)

# %%
