#%%
from heapq import nsmallest
from tkinter import font
from turtle import pos
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import defaultdict
import sys


sys.path.append("../modules")

from toolbox import *
from model_def import RNN
from data_set import *
from data_set import *



#%%

plots_path = "/plots/"

alpha = 0.01 / 0.05


def get_all_inputs(values, ind, model):
    if ind != 0:
        all_inputs = [
            model.alpha * values["r_inputs1"][ind],
            model.alpha * values["r_inputs2"][ind],
            model.alpha * values["r_inputs3"][ind],
            model.alpha
            * model.feedback.bias.unsqueeze(1)
            .expand(values["r_inputs1"][ind].shape[1], values["r_inputs1"][ind].shape[0])
            .T,
            model.alpha * -values["hidden_raw"][ind - 1],
            values["hidden_raw"][ind - 1],
        ]
    else:
        all_inputs = [
            torch.zeros_like(values["r_inputs1"][ind]),
            values["r_inputs1"][ind],
            torch.zeros_like(values["r_inputs1"][ind]),
            torch.zeros_like(values["r_inputs1"][ind]),
            torch.zeros_like(values["r_inputs1"][ind]),
            torch.zeros_like(values["r_inputs1"][ind]),
        ]
    all_inputs = [x.detach() for x in all_inputs]
    all_inputs = torch.stack(all_inputs)
    return all_inputs


def get_betas(values, model_fb):

    x_mask = torch.tensor([1, 0, 0, 0, 0, 0]).cuda()
    r_mask = torch.tensor([0, 1, 0, 0, 0, 0]).cuda()
    r2_mask = torch.tensor([0, 0, 0, 0, 1, 1]).cuda()
    fb_mask = torch.tensor([0, 0, 1, 0, 0, 0]).cuda()
    b_mask = torch.tensor([0, 0, 0, 1, 0, 0]).cuda()

    fbs, rs, xs, bs, rs2 = [], [], [], [], []

    for i in range(126):
        all_inputs = get_all_inputs(values, i, model_fb)
        all_inputs_sum = torch.sum(all_inputs, axis=0)
        fb = all_inputs.T * fb_mask
        r = all_inputs.T * r_mask
        x = all_inputs.T * x_mask
        b = all_inputs.T * b_mask
        r2 = all_inputs.T * r2_mask

        fb, r, x, b, r2 = [
            val.T.sum(axis=0)
            / all_inputs_sum
            * torch.relu(all_inputs_sum)
            @ model_fb.output.weight.T
            for val in [fb, r, x, b, r2]
        ]
        fbs.append(fb)
        rs.append(r)
        xs.append(x)
        bs.append(b)
        rs2.append(r2)
    fbs, rs, xs, bs, rs2 = [torch.stack(val) for val in (fbs, rs, xs, bs, rs2)]
    return xs, rs, fbs, bs, rs2

#%%

experiment_timestamp = "2024May08-205259" # EXAMPLE!

inputs = []

for i in range(20): # across all 20 seeds
    print(i)
    tm_savname = f"results/{experiment_timestamp}_no_delay_2k_b2b_large_batch/0003_fb_freeze_False__delay_0__go_to_peak_50__vel_10.0__error_detach_False__nlayers_1__dataset_name_Reaching__task_random_pushed__fb_density_1/{i}"
    dataT = torch.load(tm_savname + "/phase0_training")
    params = dataT["params"]

    if i == 0:
        dataset = Reaching()
        np.random.seed(0)
        params["model"]["rot_phi"] = 0 / 180 * np.pi
        target_0, stimulus_0, pert_0, tids_0, stim_ref_0 = dataset.prepare_pytorch(
            params, "center-out-reach_rotated", test_set=True
        )

        np.random.seed(0)
        params["model"]["rot_phi"] = 30 / 180 * np.pi
        target_30, stimulus_30, pert_30, tids_30, stim_ref_30 = dataset.prepare_pytorch(
            params, "center-out-reach_rotated", test_set=True
        )

        np.random.seed(0)
        params["model"]["rot_phi"] = 40 / 180 * np.pi
        target_60, stimulus_60, pert_60, tids_60, stim_ref_60 = dataset.prepare_pytorch(
            params, "center-out-reach_rotated", test_set=True
        )

        np.random.seed(0)
        params["model"]["rot_phi"] = 50 / 180 * np.pi
        (
            target_90,
            stimulus_90,
            pert_90,
            tids_90,
            stim_ref_90,
        ) = dataset.prepare_pytorch(params, "center-out-reach_rotated", test_set=True)

        np.random.seed(0)
        params["model"]["rot_phi"] = 60 / 180 * np.pi
        (
            target_120,
            stimulus_120,
            pert_120,
            tids_120,
            stim_ref_120,
        ) = dataset.prepare_pytorch(params, "center-out-reach_rotated", test_set=True)


    model_fb = RNN(
        params["model"]["input_dim"],
        params["model"]["output_dim"],
        params["model"]["n"],
        torch.cuda.FloatTensor,
        params["model"]["dt"],
        params["model"]["tau"],
        fb_delay=params["model"]["fb_delay"],
        fb_density=params["model"]["fb_density"],
    )
    model_fb = model_fb.cuda()
    model_fb.load_state_dict(dataT["model_state_dict"])
    if params["data"]["dataset_name"] == "Reaching":
        model_fb.pos_err = True
    model_fb.error_detach = True


    output_0, hidden_0, extras_0 = model_fb(
        stimulus_0, pert_0, stim_ref_0, analysis=True,fb_in=True, 
    )

    output_30, hidden_30, extras_30 = model_fb(
        stimulus_30, pert_30, stim_ref_30, analysis=True,fb_in=True
    )

    output_30_nofb, hidden_30_nofb, extras_30_nofb = model_fb(
        stimulus_30, pert_30, stim_ref_30, analysis=True, fb_in=False
    )

    output_60, hidden_60, extras_60 = model_fb(
        stimulus_60, pert_60, stim_ref_60, analysis=True
    )
    output_90, hidden_90, extras_90 = model_fb(
        stimulus_90, pert_90, stim_ref_90, analysis=True
    )
    output_120, hidden_120, extras_120 = model_fb(
        stimulus_120, pert_120, stim_ref_120, analysis=True
    )


    inputs_0 = get_betas(extras_0, model_fb)
    inputs_45 = get_betas(extras_30, model_fb)
    inputs_90 = get_betas(extras_60, model_fb)
    inputs_135 = get_betas(extras_90, model_fb)
    inputs_180 = get_betas(extras_120, model_fb)

    inputs.append(
        torch.stack(
            [
                torch.stack(inputs_0),
                torch.stack(inputs_45),
                torch.stack(inputs_90),
                torch.stack(inputs_135),
                torch.stack(inputs_180),
            ]
        ).detach()
    )
    
inputs = torch.stack(inputs).detach().cpu()


#%%

############## FIGURE 2A ################


x1 = abs(inputs[:, :5, 0, :, 0]).mean(axis=(3))
x2 = abs(inputs[:, :5, 1, :, 0]).mean(axis=(3))
x3 = abs(inputs[:, :5, 2, :, 0]).mean(axis=(3))
x4 = abs(inputs[:, :5, 3, :, 0]).mean(axis=(3))
x5 = abs(inputs[:, :5, 4, :, 0]).mean(axis=(3))


pert = 0

fig, ax1 = plt.subplots(1, 1, sharex=True, figsize=(5, 4))

plt.plot(torch.nanmean((x3[:,pert] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert] * 100, axis=0), label='0$\degree$', lw=2, color='#00a6fb')
# fill between
plt.fill_between(
    range(126),
    torch.nanmean((x3[:,pert] ) / (x1 + x2 + x3 + x4  + x5 )[:,pert] * 100, axis=0)
    - torch.std((x3[:,pert] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert] * 1/np.sqrt(10) * 100, axis=0),
    torch.nanmean((x3[:,pert] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert] * 100, axis=0)
    + torch.std((x3[:,pert] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert] * 1/np.sqrt(10) * 100, axis=0),
    alpha=0.25,
    color='#00a6fb'
)

plt.plot(torch.nanmean((x3[:,pert+1] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+1] * 100, axis=0), label='30$\degree$', lw=2, color='#0582ca')
# fill between
plt.fill_between(
    range(126),
    torch.nanmean((x3[:,pert+1] ) / (x1 + x2 + x3 + x4  + x5 )[:,pert+1] * 100, axis=0)
    - torch.std((x3[:,pert+1] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+1]* 1/np.sqrt(10) * 100, axis=0),
    torch.nanmean((x3[:,pert+1] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+1] * 100, axis=0)
    + torch.std((x3[:,pert+1] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+1] * 1/np.sqrt(10) * 100, axis=0),
    alpha=0.25,
    color='#0582ca'
)

plt.plot(torch.nanmean((x3[:,pert+3] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+3] * 100, axis=0), label='60$\degree$', lw=2, color='#006494')
# fill between
plt.fill_between(
    range(126),
    torch.nanmean((x3[:,pert+3] ) / (x1 + x2 + x3 + x4  + x5 )[:,pert+3] * 100, axis=0)
    - torch.std((x3[:,pert+3] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+3]* 1/np.sqrt(10) * 100, axis=0),
    torch.nanmean((x3[:,pert+3] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+3] * 100, axis=0)
    + torch.std((x3[:,pert+3] ) / (x1 + x2 + x3 + x4 + x5 )[:,pert+3] * 1/np.sqrt(10) * 100, axis=0),
    alpha=0.25,
    color='#006494'
)


leg = plt.legend(fontsize=15)

# change the line width for the legend
for line in leg.get_lines():
    line.set_linewidth(4.0)

plt.grid()
plt.ylabel('% readout contribution',fontsize=15) 
plt.xlabel('timesteps',fontsize=15)

plt.yticks(fontsize=15)
plt.xticks(fontsize=15)

#% remove top and right spines
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)


plt.tight_layout()
plt.savefig(f"{plots_path}final_inputs.pdf", dpi=300)
plt.show()

#%%

############## FIGURE 2B ################

x1 = torch.nanmean(abs(inputs[:, :5, 0, :, 0]),axis=(2, 3))
x2 = torch.nanmean(abs(inputs[:, :5, 1, :, 0]),axis=(2, 3))
x3 = torch.nanmean(abs(inputs[:, :5, 2, :, 0]),axis=(2, 3))
x4 = torch.nanmean(abs(inputs[:, :5, 3, :, 0]),axis=(2, 3))
x5 = torch.nanmean(abs(inputs[:, :5, 4, :, 0]),axis=(2, 3))


fig, ax1 = plt.subplots(1, 1, sharex=True, figsize=(4, 4))


ax1.errorbar(
    range(5),
    torch.nanmean((x3 ) / (x1 + x2 + x3 + x4 + x5) * 100, axis=0),
    yerr=torch.std((x3 )/ (x1 + x2 + x3 + x4 + x5) * 100, axis=0),
    label=r"feedback",
    color='#006494',
    fmt= 'o',
    markersize=8,
    capsize=5,
    
)
ax1.grid()


plt.ylabel('mean feedback \n  % readout contribution',fontsize=15,labelpad=10)
plt.xlabel(r'$\degree$ perturbation',fontsize=15)

ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)


plt.yticks(fontsize=15)
plt.xticks([0,1,2,3,4], ['0', '30', '40', '50', '60'],fontsize=15)
plt.xlim(-.5,4.5)
plt.ylim(0,5.5)

# plt.legend(fontsize=15)

plt.tight_layout()
plt.savefig(f"{plots_path}final_mean_inputs.pdf", dpi=300)
plt.show()

# %%

 

