#%%

from tkinter import font
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import defaultdict
import sys

import os


sys.path.append("../modules")

from toolbox import *
from model_def import RNN
from data_set import *
from toolbox import *
from data_set import *


# %%

all_alignments = []
all_outputs = []

plots_path = "plots/"
experiment_timestamp = "2024May08-205259" # EXAMPLE!

for i in range(10):
    print('----')
    print(i)
    print('----')
    tm_savname = f"results/{experiment_timestamp}_no_delay_2k_b2b_large_batch/0003_fb_freeze_False__delay_0__go_to_peak_50__vel_10.0__error_detach_False__nlayers_1__dataset_name_Reaching__task_random_pushed__fb_density_1/{i}"
    dataT = torch.load(tm_savname + "/phase0_training")
    params = dataT["params"]

    if i == 0:
        dataset = Reaching()
        np.random.seed(0)
        params["model"]["rot_phi"] = 30 / 180 * np.pi
        target_30, stimulus_30, pert_30, tids_30, stim_ref_30 = dataset.prepare_pytorch(
            params, "center-out-reach_rotated", test_set=True
        )

    model_fb = RNN(
        params["model"]["input_dim"],
        params["model"]["output_dim"],
        params["model"]["n"],
        torch.cuda.FloatTensor,
        params["model"]["dt"],
        params["model"]["tau"],
        fb_delay=params["model"]["fb_delay"],
        fb_density=params["model"]["fb_density"],
    )
    model_fb = model_fb.cuda()
    model_fb.load_state_dict(dataT["model_state_dict"])
    if params["data"]["dataset_name"] == "Reaching":
        model_fb.pos_err = True
    model_fb.error_detach=True

    alignments_j = []
    alignments_2_j = []
    g1_g2_alignment_j = []
    hessian_eigenvalues_j = []
    outputs_j = []


    for j in range(10): # examples
        print('____')
        print(j)
        print('____')

        output_30, hidden_30, extras_30 = model_fb(
            stimulus_30[:,j:(j+1),:], pert_30[:,j:(j+1),:], stim_ref_30[:,j:(j+1),:], analysis=True,fb_in=True
        )
        outputs_j.append(output_30.cpu().detach().numpy())

        hidden_30 = extras_30["hidden_raw"]
        criterion = torch.nn.MSELoss(reduction="none")
        total_loss = criterion(output_30, output_30 * 0).mean()

        grad_h = []

        c = 0
        for h in hidden_30[1:]:
            print(c)
            c+=1
            grad, = torch.autograd.grad(total_loss,
                                    h,
                                    retain_graph=True,
                                    create_graph=True,
                                    )
            grad_h.append(grad)

        m = 0
        alignments = []

        for i in range(len(grad_h)-1-m):
            true_g = - grad_h[i] # important to mention this!
            fb = extras_30["r_inputs3"][1:][i+m]
            a = torch.nn.functional.cosine_similarity(true_g.flatten(),fb.flatten(),dim=0)
            alignments.append(a.item())

        alignments_j.append(np.stack(alignments))

    all_alignments.append(np.stack(alignments_j))
    all_outputs.append(np.stack(outputs_j))

#%%

all_alignments = np.stack(all_alignments)
all_outputs = np.stack(all_outputs)


# %%

fig, ax1 = plt.subplots(1, 1, sharex=True, figsize=(5, 4))


ax1.plot(np.degrees(np.arccos(all_alignments)).mean(axis=(0,1)),lw=2,color='#006494')
ax1.fill_between(np.arange(0, 124, 1), np.degrees(np.arccos(all_alignments)).mean(axis=(0,1)) - np.degrees(np.arccos(all_alignments)).std(axis=(0,1)) * 1/np.sqrt(10), np.degrees(np.arccos(all_alignments)).mean(axis=(0,1)) + np.degrees(np.arccos(all_alignments)).std(axis=(0,1)) * 1/np.sqrt(10), alpha=0.2,color='#006494')

ax1.axhline(90,color='k',linestyle='--')
ax1.set_ylim(0,120)


ax1.grid()
ax1.set_ylabel(r'$ \frac{\partial \mathcal{L}}{\partial a^t} \measuredangle W^{fb} \epsilon^{t-1}$', fontsize=20)
ax1.set_xlabel(r'timesteps $(t)$', fontsize=15)

plt.xticks(fontsize=15)
plt.yticks(fontsize=15)

# remove spines on top and right
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)


plt.tight_layout()
plt.savefig(f"{plots_path}final_first_grad.pdf", dpi=300)

# %%
