import numpy as np
import argparse
import random
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import TwoSlopeNorm
import src.utils as ut
import numpy as np
from matplotlib.cm import ScalarMappable
from cmcrameri import cm
from src.utils import calculate_gain, str2bool, none_or_type

plt.subplots_adjust(right=0.8)
seed = 42
np.seterr(under="warn")

import matplotlib
#matplotlib.font_manager.findfont("Symbol")
matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')[:10]
matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath} \usepackage{amsfonts}')

# Say, "the default sans-serif font is COMIC SANS"
plt.rcParams['font.serif'] = "Times New Roman"
# Then, "ALWAYS use sans-serif fonts"
plt.rcParams['font.family'] = "serif"
plt.rcParams['mathtext.fontset'] = 'dejavuserif'
#plt.rcParams['font.weight'] = 'bold'


def parse_list(arg):
    # Strip brackets and split by commas
    return [int(x) for x in arg.strip('[]').split(',')]


def get_args():
    """Parse input arguments

    Returns
    -------
    dict
        Dictionary containing the run config.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="MLP", choices=["MLP", "RESMLP", "CNN", "VIT", "LargeVIT"], help="Model to use, 0 for MLP.")
    parser.add_argument('--max_depth', type=none_or_type(int), default=100, help="Number MLP hidden layers, i.e. depth of the network.")
    parser.add_argument('--widths', type=none_or_type(parse_list), default=1000, help="Widths to plot")
    parser.add_argument('--act_func', type=none_or_type(str), default="ReLU", help="Activation function.")
    parser.add_argument('--act_func2', type=none_or_type(str), default="Linear", help="Activation function.")
    parser.add_argument('--data', type=str, default="bfmnist", choices=["random", "bfmnist", "bcifar","cifar10", "cifar100", "tinyimagenet"])
    parser.add_argument('--use_faster_attention', type=str2bool, default=True, help="Use faster attention for VIT.")
    parser.add_argument('--plot_mode', type=str, default="semilogy", help="Plot type from pyplot, e.g. plot, semilog etc")
    cfg = vars(parser.parse_args())
    return cfg

if __name__=="__main__":
    cfg = get_args()
    model = cfg["model"]
    max_depth = cfg["max_depth"]
    widths = [cfg["widths"]]
    n_widths = len(widths) 
    act_func = cfg["act_func"]
    act_func2 = cfg["act_func2"]
    plot_mode = cfg["plot_mode"]
    random.seed(seed)
    np.random.seed(seed)
    
    # load experimental values from differente wiodths
    if act_func2=="Linear":
        run_cfg, buf_prime = ut.read_data(f"data/init/{cfg['model']}_{cfg['data']}_{act_func}_D{max_depth}_W{widths[0]}.h5")
    else:
        run_cfg, buf_prime = ut.read_data(f"data/init/{cfg['model']}_{cfg['data']}_{act_func}+{act_func2}_D{max_depth}_W{widths[0]}.h5")
 
    n_classes = buf_prime.shape[1] - 1
    n_w, n_b, Vw_range, Vb_range = run_cfg["n_w"], run_cfg["n_b"], [run_cfg["Vw_min"], run_cfg["Vw_max"]], [run_cfg["Vb_min"], run_cfg["Vb_max"]]
    
    effective_depth = buf_prime.shape[2]
    cmap = cm.roma_r
    # select effective depth of the transfoormer
    buf = np.zeros((8, n_classes+1, effective_depth, n_widths, n_b,n_w))
    buf[:,:,:, 0, :, :] =  buf_prime
    for ww, width in enumerate(widths[1:]):
        if act_func2 == "Linear":
            _, buf_prime  = ut.read_data(f"data/init/{cfg['model']}_{cfg['data']}_{act_func}_D{max_depth}_W{width}.h5")
        else:
            _, buf_prime  = ut.read_data(f"data/init/{cfg['model']}_{cfg['data']}_{act_func}+{act_func2}_D{max_depth}_W{width}.h5")
        buf[:,:,:,ww+1,:,:] = buf_prime


    cmap = cm.roma_r
    gain = 1.0
    if cfg["model"] in ["MLP", "RESMLP", "CNN"]:
        gain = calculate_gain(act_func, act_func2, Vb=0.1)
        Vw_range = gain*np.array(Vw_range)

    # define hyper-params range
    Vw_vec = np.linspace(Vw_range[0],Vw_range[1],n_w)
    Vb_vec = np.linspace(Vb_range[0],Vb_range[1],n_b)
    Vw_mean = (Vw_range[0] + Vw_range[1])/2
 
    norm = TwoSlopeNorm(vmin=Vw_range[0], vcenter=Vw_mean, vmax=Vw_range[1])
    # set up figures
    fig_diag, ax_diag = plt.subplots(1,3,figsize=(7*3,7), sharey=True, sharex=True)
    n_widths = 10
   
    plot_depth = effective_depth
    depths = np.arange(1,effective_depth+1)
    pairs = [[0,ii] for ii in range(n_w)] # sigma_b^2 = 0.1

    fig_diag.supxlabel("$\mathrm{layer}\; l$", fontsize=20)
    
    # set titles
    if act_func is not None:
        if act_func2=="Linear":
            fig_diag.suptitle(f"{act_func} {model}", fontsize=25, fontweight='bold')
        else:
            if act_func2.startswith("Avg"):
                t2 = "AveragePool"
            elif act_func2.startswith("Max"):
                t2 = "MaxPool"
            fig_diag.suptitle(f"{act_func}+{t2} {model}", fontsize=25, fontweight='bold')

    for ind_c, pp in enumerate(pairs):
        Vb_ind, Vw_ind = pp
        Vw, Vb = Vw_vec[Vw_ind], Vb_vec[Vb_ind]
        color = cmap(norm(Vw))
    
        # Plot experimental
        for oo, width in enumerate(widths):
            lw = 1.0 +  5* (oo +1)/ (len(widths)+1)
            # plot only last width
            if oo == len(widths)-1:
                ### PLOT DIAG
                # trace sgd
                    ax_diag[0].set_title("Global", fontsize=20)
                    ax_diag[1].set_title(f"Favoured class", fontsize=20)
                    ax_diag[2].set_title(f"Unfavoured class", fontsize=20)
                    ax_diag[0].grid()
                    ax_diag[1].grid()
                    ax_diag[2].grid()
                    # means
                    #print(buf[5,1,:plot_depth,oo, Vb_ind,Vw_ind])
                    getattr(ax_diag[0], plot_mode)(depths, buf[2,0,:plot_depth,oo, Vb_ind,Vw_ind], ls="-", c=color, lw=2) 
                    getattr(ax_diag[1], plot_mode)(depths, buf[2,1,:plot_depth,oo, Vb_ind,Vw_ind], ls="-", c=color, lw=2) 
                    getattr(ax_diag[2], plot_mode)(depths, buf[2,-1,:plot_depth,oo, Vb_ind,Vw_ind], ls="-", c=color, lw=2) 
                    # quantiles
                    #ax_diag[0].fill_between(depths, buf[1,0,:plot_depth,oo, Vb_ind,Vw_ind], buf[3,0,:plot_depth,oo, Vb_ind,Vw_ind], color=color, alpha=0.3) # 5, 95 % range
                    #ax_diag[1].fill_between(depths, buf[1,1,:plot_depth,oo, Vb_ind,Vw_ind], buf[3,1,:plot_depth,oo, Vb_ind,Vw_ind], color=color, alpha=0.3) # 5, 95 % range
                    #ax_diag[2].fill_between(depths, buf[1,-1,:plot_depth,oo, Vb_ind,Vw_ind], buf[3,-1,:plot_depth,oo, Vb_ind,Vw_ind], color=color, alpha=0.3) # 5, 95 % range
                #print(Vw, Vb, buf[0,:plot_depth,oo, Vb_ind,Vw_ind])

    # grid and colrobar
    divider = make_axes_locatable(ax_diag[-1])
    cax = divider.append_axes("right", size="5%", pad=0.1)
    sm = ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])
    cb = fig_diag.colorbar(sm, cax=cax)
    cb.ax.set_title("$\\sigma_w^2$", fontsize=15)
    #cb.ax.set_yticklabels([f'{t:.2f}' for t in ticks])  # optionally format tick labels
    # save diag figure
    #handles, labels = ax_diag.get_legend_handles_labels()
    #fig_diag.legend(handles, labels, loc='center right', handletextpad=0.5, bbox_to_anchor=(0.0, 0.0, 0.85, 0.7), ncol=1,frameon=True, fontsize=15)
    fig_diag.supylabel("$\\tilde{q}_{ab}^{(l)}$", fontsize=20, rotation=0, x=0)
    fig_diag.tight_layout()
    
    if act_func2 == "Linear":
        fig_diag.savefig(f"figures/init/grads_{cfg['model']}_{cfg['data']}_{act_func}_D{max_depth}_W{max(widths)}.png", dpi=300)
    else: 
        fig_diag.savefig(f"figures/init/grads_{cfg['model']}_{cfg['data']}_{act_func}+{act_func2}_D{max_depth}_W{max(widths)}.png", dpi=300)

    
