import os
import numpy as np
from tqdm import tqdm

from utils import get_dir, is_float, load_dict
from utils import save_dict

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="darkgrid")
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'
WIDTH=12
HEIGHT=10
FONT_SIZE=55
plt.rcParams.update({'font.size': FONT_SIZE})


################################################################

def collect_errors_exp(path:str, 
                       params_var_str:str,
                       error_type:str="MSE", 
                       info_to_scale_weight_decay:dict={},
                       save:bool=True)->dict:
    dir = get_dir(path)
    params_var = np.sort([float(d) for d in dir if is_float(d)])
    num_params_var=params_var.shape[0]
    errors=dict(mean=np.zeros((num_params_var,)), median=np.zeros((num_params_var,)), std=np.zeros((num_params_var,)))
    errors[params_var_str]=params_var
    if params_var_str=="weight_decay":
        errors[params_var_str]=errors[params_var_str]/(info_to_scale_weight_decay["num_transitions"]*info_to_scale_weight_decay["num_important_states_visited"])
    for i in range(num_params_var):
        data=load_dict(path+str(params_var[i])+"/errors.json")[error_type]
        errors["mean"][i]=np.mean(data)
        errors["median"][i]=np.median(data)
        errors["std"][i]=np.std(data)
    if save:
        save_dict(path=path+error_type+".json", data=errors)
    return errors

################################################################

def collect_delta_dataset_size_params(path:str,
                                      verbose:bool=True)->dict:
    delta={}
    dir = get_dir(path)
    dir_dataset_size=[d for d in dir if d.rfind("dataset_")!=-1]
    dir_params=get_dir(path+str(dir_dataset_size[0]))
    pbar = tqdm(total=len(dir_params)*len(dir_dataset_size), disable=not verbose, desc="Collecting data")
    filename_delta="delta.json"
    for d in dir_dataset_size:
        d_size=int(d[d.rfind("_")+1:])
        delta[d_size]={}
        for p in dir_params:
            delta[d_size][float(p)]=load_dict(path=path+d+"/"+p+"/"+filename_delta)
            pbar.update(1)
    pbar.close()
    return delta

################################################################

def collect_msbe_dataset_size_params(path:str,
                                     params_var_str:str,
                                     msbe_exp:bool=True,
                                     training:bool=False,
                                     num_states_visited_for_weight_decay:dict=None,
                                     verbose:bool=True)->dict:
    msbe={}
    dir = get_dir(path)
    dir_dataset_size=[d for d in dir if d.rfind("dataset_")!=-1]
    dir_params=get_dir(path+str(dir_dataset_size[0]))
    pbar = tqdm(total=len(dir_params)*len(dir_dataset_size), disable=not verbose, desc="Collecting data")
    info_to_scale_weight_decay=None
    if not(num_states_visited_for_weight_decay is None):
        info_to_scale_weight_decay={}
    msbe_exp_type="MSBE"
    filename_msbe_th="msbe_test_th.json"
    if training:
        msbe_exp_type="MSBE_train"
        filename_msbe_th="msbe_training_th.json"
    for d in dir_dataset_size:
        d_size=int(d[d.rfind("_")+1:])
        msbe[d_size]={}
        for p in dir_params:
            if msbe_exp:
                if not(num_states_visited_for_weight_decay is None):
                    info_to_scale_weight_decay["num_important_states_visited"]=num_states_visited_for_weight_decay[d_size]
                    info_to_scale_weight_decay["num_transitions"]=d_size  
                msbe[d_size][float(p)]=collect_errors_exp(path=path+d+"/"+p+"/",
                                                          params_var_str=params_var_str, 
                                                          error_type=msbe_exp_type, 
                                                          info_to_scale_weight_decay=info_to_scale_weight_decay)
            else:
                msbe[d_size][float(p)]=load_dict(path=path+d+"/"+p+"/"+filename_msbe_th)
            pbar.update(1)
    pbar.close()
    return msbe

    
################################################################
    
def plot_msbe_test_and_train_dataset_size_params(path:str, 
                                                 msbe:dict,
                                                 params_var:dict,
                                                 subfig:bool=False,
                                                 xlim:list=[None, None],
                                                 ylim:list=[None, None],
                                                 median:bool=False,
                                                 verbose:bool=True)->None:
    
    dataset_size=np.sort(list(msbe["test"]["exp"].keys()))
    params_var_subfig=np.sort(list(msbe["test"]["exp"][dataset_size[0]].keys()))
    num_dataset_size=dataset_size.shape[0]
    num_subfig=params_var_subfig.shape[0]
    pbar = tqdm(total=num_dataset_size, disable=not verbose, desc="Creating Figures")
    msbe_test_str=r"$MSBE$"
    msbe_test="msbe_test"
    msbe_train_str=r"$\widehat{MSBE}$"
    msbe_train="msbe_training"
    msbe_exp="mean"
    if median:
        msbe_exp="median" 
    if params_var["axis"]=="ratio":
        xlabel=r'$N\,/\,m$'
    elif params_var["axis"]=="weight_decay":
        xlabel=r'$\lambda$'
    if params_var["axis"]=="gamma":
        xlabel=r'$\gamma$'
    for i in range(num_dataset_size):
        if subfig:
            plt.figure(figsize=(num_subfig*WIDTH, HEIGHT))
        for j in range(num_subfig):
            if subfig:
                plt.subplot(1, num_subfig, j+1)
            else:
                plt.figure(figsize=(10, 6))
            plt.plot(np.array(msbe["train"]["th"][dataset_size[i]][params_var_subfig[j]][params_var["axis"]]), 
                     msbe["train"]["th"][dataset_size[i]][params_var_subfig[j]][msbe_train],
                     "b-",
                     label=msbe_train_str+" (th)")
            plt.plot(np.array(msbe["train"]["exp"][dataset_size[i]][params_var_subfig[j]][params_var["axis"]]), 
                     msbe["train"]["exp"][dataset_size[i]][params_var_subfig[j]][msbe_exp],
                     "bx")
            plt.plot(np.array(msbe["test"]["th"][dataset_size[i]][params_var_subfig[j]][params_var["axis"]]), 
                     msbe["test"]["th"][dataset_size[i]][params_var_subfig[j]][msbe_test],
                     "r-",
                     label=msbe_test_str+" (th)")
            plt.plot(np.array(msbe["test"]["exp"][dataset_size[i]][params_var_subfig[j]][params_var["axis"]]), 
                     msbe["test"]["exp"][dataset_size[i]][params_var_subfig[j]][msbe_exp],
                     "rx")
            plt.axvline(x =1.0, color = 'black', ls='--')
            plt.yscale("log")
            if params_var["axis"]=="weight_decay":
                plt.xscale("log")
            plt.xlabel(xlabel, fontsize=FONT_SIZE-4)
            if subfig and j==0 or not(subfig):
                plt.ylabel(msbe_test_str, fontsize=FONT_SIZE-4)
            # plt.ylim(0.08, 3.9)
            plt.xlim(*xlim)
            if len(ylim)>2:
                plt.ylim(*ylim[j])
            else:
                plt.ylim(*ylim)
            plt.tick_params(axis='both', labelsize=FONT_SIZE-4)
            if params_var["subfig"]=="weight_decay":
                plt.title(r"$\lambda=$"+f"{params_var_subfig[j]:.1e}"+"\n", fontsize=FONT_SIZE-4, style="italic")
            elif params_var["subfig"]=="gamma":
                plt.title(r"$\gamma=$"+f"{params_var_subfig[j]:.2f}"+"\n", fontsize=FONT_SIZE-4, style="italic")
            elif params_var["subfig"]=="ratio":
                plt.title(r"$N~/~m=$"+str(round(params_var_subfig[j], 2))+"\n", fontsize=FONT_SIZE-4, style="italic")
            plt.tight_layout()
            if not(subfig):
                plt.savefig(path+"msbe_datatset_"+str(dataset_size[i])+"_"+params_var["subfig"]+"_"+str(params_var_subfig[j])+".png", dpi=200)
        if subfig:
            plt.savefig(path+"msbe_datatset_"+str(dataset_size[i])+".png", dpi=200)
        pbar.update(1)
    pbar.close()
     
################################################################

def plot_delta_ratio(path:str, 
                     delta:dict)->None:
    
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    dataset_size=np.sort(list(delta.keys()))
    params_var_subfig=np.sort(list(delta[dataset_size[0]].keys()))
    num_dataset_size=dataset_size.shape[0]
    num_subfig=params_var_subfig.shape[0]
    delta_str=r"$\delta$"
    xlabel=r'$N\,/\,m$'
    plt.figure(figsize=(num_subfig*WIDTH, HEIGHT))
    for j in range(num_subfig):
        plt.subplot(1, num_subfig, j+1)
        for i in range(num_dataset_size):
            plt.plot(np.array(delta[dataset_size[i]][params_var_subfig[j]]["ratio"]), 
                        delta[dataset_size[i]][params_var_subfig[j]]["delta"],
                        color=colors[i],
                        label=r"$n = $"+str(dataset_size[i]))
        plt.yscale("log")
        plt.xlabel(xlabel, fontsize=FONT_SIZE-4)
        if j==0:
            plt.ylabel(delta_str, fontsize=FONT_SIZE-4)
        plt.tick_params(axis='both', labelsize=FONT_SIZE-4)
        plt.title(r"$\lambda=$"+f"{params_var_subfig[j]:.1e}"+"\n", fontsize=FONT_SIZE-4, style="italic")
        plt.tight_layout()
    plt.savefig(path+"delta_ratio.png", dpi=200)