#%%
from heapq import nsmallest
from tkinter import font
from turtle import pos
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import defaultdict
import re
import sys
import os
import seaborn as sns
import pandas as pd


import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

sys.path.append("../modules")

from toolbox import *
from model_def import RNN
from data_set import *
from data_set import *

#%%


nruns = 10

losses = defaultdict(list)
t_losses = defaultdict(list)
local_alignments = defaultdict(list)

local_grads = defaultdict(list)

dws = defaultdict(list)

update_cost = defaultdict(list)
update_cost_stepwise = defaultdict(list)

experiment_timestamp = "2024May08-205259" # EXAMPLE!

experiment_paths = (
    f"results/{experiment_timestamp}_no_delay_2k_b2b_adaptation_lrs/",
)

algorithms = ["fed","sam","bp","rflo"]


for experiment_path in experiment_paths:
    os.chdir(experiment_path)
    experiments = os.listdir()
    for exp in experiments:
        if (
            exp[0] != "."
            and exp[-3:] != "png"
            and '90' not in exp
            and '_fed__' not in exp
        ):
            print(exp)
            learn_alg = [x for x in algorithms if x in exp]
            assert len(learn_alg) == 1
            learn_alg = learn_alg[0]
            if "fb_density_0" in exp and learn_alg == "fed":
                learn_alg = "fed_t"
            for run in range(nruns):
                try:
                    lc = torch.load(
                        experiment_path + exp + "/" + str(run) + "/AD_" + learn_alg
                    )
                    losses[exp[5:]].append(lc["lc"])
                    t_losses[exp[5:]].append(lc["lcT"])
                    grads = torch.load(
                        experiment_path
                        + exp
                        + "/"
                        + str(run)
                        + "/AD_"
                        + learn_alg
                        + "_grads", map_location="cpu"
                    )
                    update_cost[exp[5:]].append(torch.stack(grads["update_cost"]))
                    update_cost_stepwise[exp[5:]].append(torch.stack(grads["update_cost_stepwise"]))
                    # OK, what do I want from dws? 

                    dws[exp[5:]].append(torch.stack(grads["dws"]))
                    local_grads[exp[5:]].append(grads["local_grads"][0])
                    if learn_alg in ["fed","rflo"]:
                        try:
                            local_alignments[exp[5:]].append(torch.stack([torch.stack(g) for g in grads["local_alignments"]]))
                        except:
                            print('local alignments missing for ', exp[5:])
                except:
                    print(f'{exp} [{run}] not done yet!')



#%%
                

for key in losses.keys():
    print(key)
    if isinstance(losses[key],list):
        losses[key] = np.stack(losses[key])
    if isinstance(t_losses[key],list):
        t_losses[key] = np.stack(t_losses[key])
    if isinstance(update_cost[key],list):
        update_cost[key] = np.stack(update_cost[key])
    if isinstance(update_cost_stepwise[key],list):
        update_cost_stepwise[key] = np.stack(update_cost_stepwise[key])
    if isinstance(dws[key],list):
        dws[key] = np.stack(dws[key])
    if isinstance(local_alignments[key],list) and len(local_alignments[key]) > 0:
        local_alignments[key] = torch.stack(local_alignments[key])
    if isinstance(local_grads[key],list) and len(local_grads[key]) > 0:
        local_grads[key] = torch.stack(local_grads[key])

#%%


plots_path = "plots/"

#%%

###### FIGURE3A ###########

rot = 30
xlim = 400
step = 50

learn_alg = 'rflo'

color1 = "#264653"
color2 = "#219ebc"


train_mean = losses[f'rot_phi_{rot}__learning_algorithm_{learn_alg}__fb_in_1__fb_density_1'].mean(axis=0)
train_std = losses[f'rot_phi_{rot}__learning_algorithm_{learn_alg}__fb_in_1__fb_density_1'].std(axis=0)

train_mean2 = losses[f'rot_phi_{rot}__learning_algorithm_{learn_alg}__fb_in_0__fb_density_1'].mean(axis=0)
train_std2 = losses[f'rot_phi_{rot}__learning_algorithm_{learn_alg}__fb_in_0__fb_density_1'].std(axis=0)

test_mean = t_losses[f'rot_phi_{rot}__learning_algorithm_{learn_alg}__fb_in_1__fb_density_1'].mean(axis=0)
test_std = t_losses[f'rot_phi_{rot}__learning_algorithm_{learn_alg}__fb_in_1__fb_density_1'].std(axis=0)

x = range(xlim)
train_final = train_mean[-1]
train_final_std = train_std[-1]
test_final = test_mean[-1]
test_final_std = test_std[-1]
train_final2 = train_mean2[-1]
train_final_std2 = train_std2[-1]

# Create split axis plot with training and initial testing period
plt.plot(x, train_mean[:xlim], color=color1, lw=2.5, label='feedback ON')
plt.fill_between(x, train_mean[:xlim] - train_std[:xlim], train_mean[:xlim] + train_std[:xlim], color=color1, alpha=0.25)

plt.plot(x,test_mean[:xlim], color=color2, lw=2.5, label='feedback OFF',linestyle='--')
plt.fill_between(x, test_mean[:xlim] - test_std[:xlim], test_mean[:xlim] + test_std[:xlim], color=color2, alpha=0.25)



plt.ylabel("$\mathcal{L}_{MSE}$", fontsize=20)
plt.legend(fontsize=15, loc='upper right')

plt.semilogy()

plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)


# Adjust x-axis limits appropriately
plt.xlim(-0.1, xlim )  # Extend to show final value markers

# Labels for x-axis segments
plt.xticks(list(range(0,xlim+1,step)))
plt.xlabel("# trials",fontsize=15)

plt.xticks(fontsize=15)
plt.yticks(fontsize=15)

plt.grid()
plt.tight_layout()
plt.savefig(f"{plots_path}final_adaptation_loss_{rot}_{learn_alg}.pdf", dpi=300)
plt.show()