
#%%
from heapq import nsmallest
from tkinter import font
from turtle import pos
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import defaultdict
import re
import sys
import os
import seaborn as sns
import pandas as pd


import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

sys.path.append("../modules")

from toolbox import *
from model_def import RNN
from data_set import *
from data_set import *

#%%


nruns = 10

losses = defaultdict(list)
t_losses = defaultdict(list)
local_alignments = defaultdict(list)

local_grads = defaultdict(list)

dws = defaultdict(list)

update_cost = defaultdict(list)
update_cost_stepwise = defaultdict(list)

experiment_timestamp = "2024May08-205259" # EXAMPLE!

experiment_paths = (
    f"results/{experiment_timestamp}_no_delay_2k_b2b_adaptation_lrs/",
)

plots_path = "plots/"

algorithms = ["fed","sam","bp","rflo"]

experiment_keys = defaultdict(list)

for experiment_path in experiment_paths:
    os.chdir(experiment_path)
    experiments = os.listdir()
    for exp in experiments:
        if (exp[0] != "." and exp[-3:] != "png"):
            experiment_keys[exp[5:]] = exp



#%%

alg1 = 'rflo'
alg2 = 'bp'

update_cost = defaultdict(list)
update_cost_stepwise = defaultdict(list)
for rot in [30,40,50,60]:
    # for alg in ['bp','rflo']:
    for fb_in in [0,1]:
        experiment1 = f'rot_phi_{rot}__learning_algorithm_{alg1}__fb_in_{fb_in}__fb_density_1'
        experiment1_fullname = experiment_keys[experiment1]
        experiment2 = f'rot_phi_{rot}__learning_algorithm_{alg2}__fb_in_{fb_in}__fb_density_1'
        experiment2_fullname = experiment_keys[experiment2]
        print(experiment1_fullname, experiment2_fullname)
        for run in range(nruns):
                grads1 = torch.load(experiment_path + experiment1_fullname + "/" + str(run) + "/AD_" + alg1 + "_grads", map_location="cpu")
                grads2 = torch.load(experiment_path + experiment2_fullname + "/" + str(run) + "/AD_" + alg2 + "_grads", map_location="cpu")
                uc1 = torch.stack(grads1["dws"]).sum(axis=(0)).norm()
                uc2 = torch.stack(grads2["dws"]).sum(axis=(0)).norm()
                update_cost[experiment1].append(uc1)
                update_cost[experiment2].append(uc2)
                update_cost_stepwise[experiment1].append(torch.stack(grads1["dws"]).norm(dim=(1,2)).sum())
                update_cost_stepwise[experiment2].append(torch.stack(grads2["dws"]).norm(dim=(1,2)).sum())


#%%

for key in update_cost.keys():
    update_cost[key] = torch.stack(update_cost[key]).numpy()
    update_cost_stepwise[key] = torch.stack(update_cost_stepwise[key]).numpy()

#%%


plt.figure(figsize=(4,4))

rot = 30

plt.errorbar([1-.1], [np.mean(update_cost[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1']/update_cost_stepwise[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1']))], yerr=[np.std(update_cost[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1']/update_cost_stepwise[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1'])], fmt='o',capsize=5,label='RFLO+c',color='#3f88c5')  


plt.errorbar([1+.1], [np.mean(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], yerr=[np.std(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], fmt='o',capsize=5,label='RFLO',color='#ffba08')  


rot = 40
plt.errorbar([2+.1], [np.mean(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], yerr=[np.std(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], fmt='o',capsize=5,color='#ffba08')  
plt.errorbar([2-.1], [np.mean(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1'])], yerr=[np.std(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1'])], fmt='o',capsize=5,color='#3f88c5')  

rot = 50
plt.errorbar([3+.1], [np.mean(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], yerr=[np.std(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], fmt='o',capsize=5,color='#ffba08')  
plt.errorbar([3-.1], [np.mean(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1'])], yerr=[np.std(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1'])], fmt='o',capsize=5,color='#3f88c5')  

rot = 60
plt.errorbar([4+.1], [np.mean(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], yerr=[np.std(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_0__fb_density_1'])], fmt='o',capsize=5,color='#ffba08')  
plt.errorbar([4-.1], [np.mean(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1'])], yerr=[np.std(global_alignments[f'rot_phi_{rot}__learning_algorithm_rflo__fb_in_1__fb_density_1'])], fmt='o',capsize=5,color='#3f88c5') 

plt.grid()
plt.xlim(0.5,4.5)
plt.ylim(0.2,0.8)

# remove top and right spines
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)

plt.legend(fontsize=15)
plt.xticks([1,2,3,4],['30','40','50','60'],fontsize=15)
plt.yticks(fontsize=15)

plt.ylabel('cosine similarity (global)',fontsize=15)
plt.xlabel( '$\degree$ perturbation', fontsize=15)

plt.tight_layout()

plt.savefig(f"{plots_path}final_efficiency_rflo.pdf", dpi=300)


