#%%
from heapq import nsmallest
from tkinter import font
from turtle import pos
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import defaultdict
import sys


sys.path.append("../modules")

from toolbox import *
from model_def import RNN
from data_set import *
from data_set import *



#%%

plots_path = "/plots."


#%%

experiment_timestamp = "2024May08-205259" # EXAMPLE!

jacobians = []

for i in range(20): # across all 20 seeds
    print(i)
    tm_savname = f"results/{experiment_timestamp}_no_delay_2k_b2b_large_batch/0003_fb_freeze_False__delay_0__go_to_peak_50__vel_10.0__error_detach_False__nlayers_1__dataset_name_Reaching__task_random_pushed__fb_density_1/{i}"
    dataT = torch.load(tm_savname + "/phase0_training")
    params = dataT["params"]

    if i == 0:
        dataset = Reaching()
        np.random.seed(0)
        params["model"]["rot_phi"] = 0 / 180 * np.pi
        target_0, stimulus_0, pert_0, tids_0, stim_ref_0 = dataset.prepare_pytorch(
            params, "center-out-reach_rotated", test_set=True
        )

        np.random.seed(0)
        params["model"]["rot_phi"] = 30 / 180 * np.pi
        target_30, stimulus_30, pert_30, tids_30, stim_ref_30 = dataset.prepare_pytorch(
            params, "center-out-reach_rotated", test_set=True
        )

        np.random.seed(0)
        params["model"]["rot_phi"] = 40 / 180 * np.pi
        target_60, stimulus_60, pert_60, tids_60, stim_ref_60 = dataset.prepare_pytorch(
            params, "center-out-reach_rotated", test_set=True
        )

        np.random.seed(0)
        params["model"]["rot_phi"] = 50 / 180 * np.pi
        (
            target_90,
            stimulus_90,
            pert_90,
            tids_90,
            stim_ref_90,
        ) = dataset.prepare_pytorch(params, "center-out-reach_rotated", test_set=True)

        np.random.seed(0)
        params["model"]["rot_phi"] = 60 / 180 * np.pi
        (
            target_120,
            stimulus_120,
            pert_120,
            tids_120,
            stim_ref_120,
        ) = dataset.prepare_pytorch(params, "center-out-reach_rotated", test_set=True)


    model_fb = RNN(
        params["model"]["input_dim"],
        params["model"]["output_dim"],
        params["model"]["n"],
        torch.cuda.FloatTensor,
        params["model"]["dt"],
        params["model"]["tau"],
        fb_delay=params["model"]["fb_delay"],
        fb_density=params["model"]["fb_density"],
    )
    model_fb = model_fb.cuda()
    model_fb.load_state_dict(dataT["model_state_dict"])
    if params["data"]["dataset_name"] == "Reaching":
        model_fb.pos_err = True
    model_fb.error_detach = True


    output_0, hidden_0, extras_0 = model_fb(
        stimulus_0, pert_0, stim_ref_0, analysis=True,fb_in=True, 
    )

    output_30, hidden_30, extras_30 = model_fb(
        stimulus_30, pert_30, stim_ref_30, analysis=True,fb_in=True
    )

    output_30_nofb, hidden_30_nofb, extras_30_nofb = model_fb(
        stimulus_30, pert_30, stim_ref_30, analysis=True, fb_in=False
    )

    output_60, hidden_60, extras_60 = model_fb(
        stimulus_60, pert_60, stim_ref_60, analysis=True
    )
    output_90, hidden_90, extras_90 = model_fb(
        stimulus_90, pert_90, stim_ref_90, analysis=True
    )
    output_120, hidden_120, extras_120 = model_fb(
        stimulus_120, pert_120, stim_ref_120, analysis=True
    )

    # calculate jacobians with and without feedback
    js = []
    for t in range(2,125):
        js_t = []
        for h in range(params['model']['n']):
            js_t.append(torch.autograd.grad(hidden_30[t][0,h], hidden_30[t-1], retain_graph=True)[0][0])
        js_t = torch.stack(js_t)
        js.append(js_t)

    js_30 = torch.stack(js)
    js_30 = torch.norm(js_30, dim=(1,2)).cpu().detach()

    js = []
    for t in range(2,125):
        js_t = []
        for h in range(params['model']['n']):
            js_t.append(torch.autograd.grad(hidden_0[t][0,h], hidden_0[t-1], retain_graph=True)[0][0])
        js_t = torch.stack(js_t)
        js.append(js_t)

    js_0 = torch.stack(js)
    js_0 = torch.norm(js_0, dim=(1,2)).cpu().detach()

    js = []
    for t in range(2,125):
        js_t = []
        for h in range(params['model']['n']):
            js_t.append(torch.autograd.grad(hidden_60[t][0,h], hidden_60[t-1], retain_graph=True)[0][0])
        js_t = torch.stack(js_t)
        js.append(js_t)

    js_60 = torch.stack(js)
    js_60 = torch.norm(js_60, dim=(1,2)).cpu().detach()

    js = []
    for t in range(2,125):
        js_t = []
        for h in range(params['model']['n']):
            js_t.append(torch.autograd.grad(hidden_90[t][0,h], hidden_90[t-1], retain_graph=True)[0][0])
        js_t = torch.stack(js_t)
        js.append(js_t)

    js_90 = torch.stack(js)
    js_90 = torch.norm(js_90, dim=(1,2)).cpu().detach()

    js = []
    for t in range(2,125):
        js_t = []
        for h in range(params['model']['n']):
            js_t.append(torch.autograd.grad(hidden_30_nofb[t][0,h], hidden_30_nofb[t-1], retain_graph=True)[0][0])
        js_t = torch.stack(js_t)
        js.append(js_t)

    js_nofb = torch.stack(js)
    js_nofb = torch.norm(js_nofb, dim=(1,2)).cpu().detach()

    jacobians.append(torch.stack([js_nofb, js_0, js_30, js_60, js_90]))


jacobians = torch.stack(jacobians).detach().cpu()


#%%

############### FIGURE 4A ####################


plt.plot(jacobians[:,0,:].mean(axis=0),lw=2, label='nofb')
# fill between
plt.fill_between(
    range(123),
    jacobians[:,0,:].mean(axis=0) - jacobians[:,0,:].std(axis=0),
    jacobians[:,0,:].mean(axis=0) + jacobians[:,0,:].std(axis=0),
    alpha=0.25,
)

alpha = [1, 0.75, 0.5] 

for pert in [3]:
    plt.plot(jacobians[:,pert,:].mean(axis=0),lw=2, label='fb',color='tab:orange',alpha=alpha[pert-2])
    # fill between
    plt.fill_between(
        range(123),
        jacobians[:,pert,:].mean(axis=0) - jacobians[:,pert,:].std(axis=0),
        jacobians[:,pert,:].mean(axis=0) + jacobians[:,pert,:].std(axis=0),
        alpha=0.25, color="tab:orange"
    )


plt.grid()
plt.ylabel('Jacobian norm',fontsize=15)
plt.xlabel('timesteps',fontsize=15)

plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)


plt.yticks(fontsize=15)
plt.xticks(fontsize=15)

plt.legend(fontsize=15)

plt.tight_layout()
plt.savefig(f"{plots_path}final_jacobians.pdf", dpi=300)
plt.show()


#%%


