from calendar import c
import os
from tkinter import N
from webbrowser import get
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('ytick', labelsize=20)
matplotlib.rc('xtick', labelsize=20)
matplotlib.rcParams['lines.linewidth'] = 4
matplotlib.rcParams['axes.spines.right'] = False
matplotlib.rcParams['axes.spines.top'] = False
matplotlib.rcParams['axes.linewidth'] = 2
matplotlib.rcParams['xtick.major.width'] = 2
matplotlib.rcParams['xtick.minor.width'] = 2
matplotlib.rcParams['ytick.major.width'] = 2
matplotlib.rcParams['ytick.minor.width'] = 2
matplotlib.rcParams["font.family"] = "DeJavu Serif"
matplotlib.rcParams["font.serif"] = ["Times New Roman"]
import seaborn as sns
import numpy as np
from anal.util import transpose_np
from anal.pmap_th import get_pmaps_theory, get_pmaps_theory_grid
from plot.util import sigma_pcolor, get_pal, get_color_gradient as get_c
from plot.const import BLUE, RED, DARKBLUE, DARKRED, GREEN, \
                       RUN, SIGMA_W, SIGMA_B
from tqdm import tqdm


def line_plots(args, ws, bs, ets, lenT1, lenT2=None, lenthm=None):
    """
    (legend, x)
    stat
        zd lst
        theory
        wb lst
    nw, nb, T, L
    nw, nb, E, L-1
    """
    stat_name = ['z', 'zh', 'h', 'd', 'loss_inf', 'dz_ratio', 'w_grad', 'b_grad', 'w', 'loss_tr']
    stat_dt={'z':0, 'd':1, 'w_grad':2, 'b_grad':3}
    stat_tr_dt={'z':0, 'd':1, 'w_grad':2, 'b_grad':3, 'w':4, 'loss_tr':5}
    thm_name = ['p', 'pt', 'ph', 'q']
    Ts = [1,10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000]
    T = args.T
    idx_T = Ts.index(T)
    Ts = Ts[:idx_T+1]
    ls = np.arange(5,args.n_layers+1, 5) if args.act=='linear' and 'fig6' in args.exp else \
        np.array([13,14,15,16,26,27,28,29]) if args.n_layers>10 else np.arange(2,args.n_layers+1) 
    early_l = args.n_layers//6
    mid_l = args.n_layers//2 - 1
    late_l = (args.n_layers//6)*5
    short_ls = [early_l, mid_l-1, mid_l, mid_l+1, late_l] # [i for i in range(args.n_layers)]
    if args.act != 'linear' and not args.train: short_ls = short_ls + [args.n_layers-3, args.n_layers-2]
    zd_stat_lst, wb_stat_lst, wb_inf_stat_lst, zd, wb, wb_inf, zd_tr_stat_lst = lenT1.get_stat_plot()
    stat_tr_lst = zd_tr_stat_lst + wb_stat_lst
    pal_lst = [get_c(BLUE, RED, len(lst)) for lst in \
               [zd_stat_lst[:-1], wb_stat_lst[:-1], ws, bs, ets, Ts, ls]]
    pal_zd, pal_wb, pal_ws, pal_bs, pal_ets, pal_Ts, pal_ls = pal_lst
    plot_dir = 'line_plots_'
    # shape: n_sws, n_sbs, n_ets, n_iters, n_ls(-1)
    if lenT2 is not None:
        zd_stat_lst2, wb_stat_lst2, wb_inf_stat_lst2, _, _, _, zd_tr_stat_lst2 = lenT2.get_stat_plot()
        stat_lst2 = zd_stat_lst2 + wb_inf_stat_lst2
        stat_tr_lst2 = zd_tr_stat_lst2 + wb_stat_lst2
    for i_stat, stat in enumerate(tqdm(zd_stat_lst + wb_inf_stat_lst)):
        # if stat_name[i_stat] not in ['d', 'w_grad', 'z', 'dz_ratio']: continue # fast track
        pre_name = f'{stat_name[i_stat]}/{stat_name[i_stat]}'
        with_thm = stat_name[i_stat] in ['z', 'd', 'w_grad', 'b_grad'] \
                   and lenthm is not None
        # for et, T, w, b, l
        for i_et, et in enumerate(ets):
            for i_T, cur_T in enumerate(Ts):
                for i_w, w in enumerate(ws): # (sigma_b, layer)
                    plot_name = pre_name + f'_sw:{w}_et:{et}_T:{cur_T}'
                    ms, ss = stat[0][i_w,:,i_et,cur_T-1], stat[1][i_w,:,i_et,cur_T-1]
                    ps = lenthm[stat_dt[stat_name[i_stat]]][i_w,:,i_et,cur_T-1] if with_thm else None
                    plot_xs(args, plot_dir+'lyr_b', plot_name, bs, pal_bs, ms, ss, ps)
                    for i_b, b in enumerate(bs):
                        if i_T == 0: # (layer, iteration)
                            ls = ls-1 if i_stat == 0 else ls
                            plot_name = pre_name + f'_sw:{w}_sb:{b}_et:{et}'
                            ms = stat[0][i_w,i_b,i_et].T; ss = stat[1][i_w,i_b,i_et].T
                            ps = lenthm[stat_dt[stat_name[i_stat]]][i_w,i_b,i_et].T if with_thm else None
                            plot_xs(args, plot_dir+'iter_l', plot_name, ls, pal_ls, ms, ss, ps) # ms[3::5], ss[3::5]
                            ls = ls+1 if i_stat == 0 else ls
                        if i_w == 0: # (sigma_w, layer)
                            plot_name = pre_name + f'_sb:{b}_et:{et}_T:{cur_T}'
                            ms = stat[0][:,i_b,i_et,cur_T-1]; ss = stat[1][:,i_b,i_et,cur_T-1]
                            ps = (lenthm[stat_dt[stat_name[i_stat]]][:,i_b,i_et,cur_T-1] if with_thm else None)
                            ms2 = None; ss2 = None
                            if lenT2 is not None:
                                stat2 = stat_lst2[i_stat]
                                ms2 = stat2[0][:,i_b,i_et,cur_T-1]; ss2 = stat2[1][:,i_b,i_et,cur_T-1]
                            plot_xs(args, plot_dir+'lyr_w', plot_name, ws, pal_ws, ms, ss, ps, ms2, ss2)
                        for i_l, l in enumerate(short_ls):
                            lb = 1 if stat_name[i_stat] == 'z' else 2
                            if i_stat != 0 and l == args.n_layers-1: continue
                            if i_stat == 0 and l != args.n_layers-1: # (zd wb lens, iteration)
                                if i_T == 0:
                                    plot_name = f'whole/whole(zd)_sw:{w}_sb:{b}_et:{et}_l:{l+lb}'
                                    plot_xs(args, plot_dir+'iter_a', plot_name, stat_name[:4],
                                        pal_zd, zd[0][:,i_w,i_b,i_et,:,l], zd[1][:,i_w,i_b,i_et,:,l])
                                    plot_name = f'whole/whole(wb)_sw:{w}_sb:{b}_et:{et}_l:{l+lb}'
                                    plot_xs(args, plot_dir+'iter_a', plot_name, stat_name[6:-1],
                                        pal_wb, wb_inf[0][:,i_w,i_b,i_et,:,l], wb_inf[1][:,i_w,i_b,i_et,:,l])
                                if i_l == 0: # (zd wb lens, layer, last iter)
                                    plot_name = f'whole/whole(zd)_sw:{w}_sb:{b}_et:{et}_T:{cur_T}'
                                    ms = zd[0][:,i_w,i_b,i_et,cur_T-1,:] # pointing z, zh, h, d
                                    ss = zd[1][:,i_w,i_b,i_et,cur_T-1,:]
                                    plot_xs(args, plot_dir+'lyr_a', plot_name, stat_name[:4],
                                                pal_zd, ms, ss)
                                    plot_name = f'whole/whole(wb)_sw:{w}_sb:{b}_et:{et}_T:{cur_T}'
                                    ms = wb[0][:,i_w,i_b,i_et,-1,:]
                                    ss = wb[1][:,i_w,i_b,i_et,-1,:]
                                    plot_xs(args, plot_dir+'lyr_a', plot_name, stat_name[6:-1],
                                                pal_wb, ms, ss)
                                if i_w == 0: # (zd wb lens, sigma_w, last iter)
                                    plot_name = f'whole/whole(zd)_sb:{b}_et:{et}_l:{l+lb}_T:{cur_T}'
                                    ms = np.concatenate((zd[0][:,:,i_b,i_et,cur_T-1,l], wb[0][:,:,i_b,i_et,-1,l]))[[0,1,3,4]]
                                    ss = np.concatenate((zd[1][:,:,i_b,i_et,cur_T-1,l], wb[1][:,:,i_b,i_et,-1,l]))[[0,1,3,4]]
                                    plot_xs(args, plot_dir+'sw_a', plot_name, stat_name[0:2]+stat_name[3:4]+stat_name[6:7],
                                            pal_zd, ms, ss, is_line=False, xl=ws)
                                    plot_name = f'whole/whole(wb)_sb:{b}_et:{et}_l:{l+lb}_T:{cur_T}'
                                    ms = wb[0][:,:,i_b,i_et,-1,l]
                                    ss = wb[1][:,:,i_b,i_et,-1,l]
                                    plot_xs(args, plot_dir+'sw_a', plot_name, stat_name[6:-1],
                                            pal_zd, ms, ss, is_line=False, xl=ws)
                                if i_b == 0: # (zd wb lens, sigma_b, last iter)
                                    plot_name = f'whole/whole(zd)_sw:{w}_et:{et}_l:{l+lb}_T:{cur_T}'
                                    ms = zd[0][:,i_w,:,i_et,cur_T-1,l]
                                    ss = zd[1][:,i_w,:,i_et,cur_T-1,l]
                                    plot_xs(args, plot_dir+'sb_a', plot_name, stat_name[:4],
                                            pal_zd, ms, ss, is_line=False, xl=bs)
                                    plot_name = f'whole/whole(wb)_sw:{w}_et:{et}_l:{l+lb}_T:{cur_T}'
                                    ms = wb[0][:,i_w,:,i_et,-1,l]
                                    ss = wb[1][:,i_w,:,i_et,-1,l]
                                    plot_xs(args, plot_dir+'sb_a', plot_name, stat_name[6:-1],
                                            pal_zd, ms, ss, is_line=False, xl=bs)
                            # l = l-1 if i_stat == 0 else l
                            if i_w == 0: # (sigma_w, iteration) 
                                plot_name = pre_name + f'_sb:{b}_et:{et}_T:{args.T}_l:{l+lb}'
                                ms = stat[0][:,i_b,i_et,:cur_T,l]; ss = stat[1][:,i_b,i_et,:cur_T,l]
                                ps = lenthm[stat_dt[stat_name[i_stat]]][:,i_b,i_et,:cur_T,l] if with_thm else None
                                plot_xs(args, plot_dir+'iter_w', plot_name, ws, pal_ws,
                                        ms, ss, ps)
                                if i_T==0 and args.train and args.len_all and \
                                   stat_name[i_stat] in ['z','d','w','w_grad']:
                                    idx = stat_tr_dt[stat_name[i_stat]]
                                    ms = stat_tr_lst[idx][0][:,i_b,i_et,:,l]
                                    ss = stat_tr_lst[idx][1][:,i_b,i_et,:,l]
                                    ps = None; ms2 = None; ss2 = None
                                    if lenT2 is not None:
                                        ms2 = stat_tr_lst2[idx][0][:,i_b,i_et,:,l]
                                        ss2 = stat_tr_lst2[idx][1][:,i_b,i_et,:,l]
                                    plot_xs(args, plot_dir+'tr_iter_w', plot_name, ws, pal_ws,
                                            ms, ss, ps, ms2, ss2)
                                if i_et == 0: # (eta, sigma_w) : for b, 3l, T
                                    plot_name = pre_name + f'_sb:{b}_l:{l+lb}_T:{cur_T}'
                                    ms = stat[0][:,i_b,:,cur_T-1,l].T; ss = stat[1][:,i_b,:,cur_T-1,l].T
                                    ps = lenthm[stat_dt[stat_name[i_stat]]][:,i_b,:,cur_T-1,l].T if with_thm else None
                                    plot_xs(args, plot_dir+'sw_eta', plot_name, ets, pal_ets,
                                            ms, ss, ps, is_line=False, xl=ws)
                            if i_b == 0:
                                if i_T == 0: # (sigma_b, iteration)
                                    plot_name = pre_name + f'_sw:{w}_et:{et}_l:{l+lb}'
                                    ms = stat[0][i_w,:,i_et,:,l]; ss = stat[1][i_w,:,i_et,:,l]
                                    ps = lenthm[stat_dt[stat_name[i_stat]]][i_w,:,i_et,:,l] if with_thm else None
                                    plot_xs(args, plot_dir+'iter_b', plot_name, bs, pal_bs,
                                            ms, ss, ps)
                                if i_et == 0: # (eta, sigma_b) for w, 3l, T
                                    plot_name = pre_name + f'_sw:{w}_l:{l+lb}_T:{cur_T}'
                                    ms = stat[0][i_w,:,:,cur_T-1,l].T; ss = stat[1][i_w,:,:,cur_T-1,l].T
                                    ps = lenthm[stat_dt[stat_name[i_stat]]][i_w,:,:,cur_T-1,l].T if with_thm else None
                                    plot_xs(args, plot_dir+'sb_eta', plot_name, ets, pal_ets,
                                            ms, ss, ps, is_line=False, xl=bs)
                            if i_et == 0:
                                if i_l == 0: # (eta, l) for w, b, T
                                    plot_name = pre_name + f'_sw:{w}_sb:{b}_T:{cur_T}'
                                    ms = stat[0][i_w,i_b,:,cur_T-1]; ss = stat[1][i_w,i_b,:,cur_T-1]
                                    ps = lenthm[stat_dt[stat_name[i_stat]]][i_w,i_b,:,cur_T-1] if with_thm else None
                                    plot_xs(args, plot_dir+'lyr_eta', plot_name, ets, pal_ets,
                                            ms, ss, ps)
                                if i_T == 0: # (eta, iter) for w, b, 3l
                                    plot_name = pre_name + f'_sw:{w}_sb:{b}_l:{l+lb}'
                                    ms = stat[0][i_w,i_b,:,:,l]; ss = stat[1][i_w,i_b,:,:,l]
                                    ps = lenthm[stat_dt[stat_name[i_stat]]][i_w,i_b,:,:,l] if with_thm else None
                                    plot_xs(args, plot_dir+'iter_eta', plot_name, ets, pal_ets,
                                            ms, ss, ps)
                            if i_l == 0 and i_T == 0: #(T, l) for w, b, eta
                                Tpt = [T-1 for T in Ts]
                                plot_name = pre_name + f'_sw:{w}_sb:{b}_et:{et}'
                                ms = stat[0][i_w,i_b,i_et,Tpt,:]; ss = stat[1][i_w,i_b,i_et,Tpt,:]
                                plot_xs(args, plot_dir+'lyr_iter', plot_name, Ts, pal_Ts,
                                        ms, ss, None)
                                if args.z_init == 'ff' and stat_name[i_stat] in ['d', 'w_grad']:
                                    Tpt = [1, 10, 20, 30, 40, 50]
                                    pal_Tpt = get_c(BLUE, RED, len(Tpt))
                                    ms = stat[0][i_w,i_b,i_et,Tpt,:]; ss = stat[1][i_w,i_b,i_et,Tpt,:]
                                    plot_xs(args, plot_dir+'_ff_lyr_iter', plot_name, Tpt, pal_Tpt,
                                            ms, ss, None)


def lineplot(plot_dir, legends, x, ms, ss, pal, marker='-'):
    for i, (m, s, c) in enumerate(zip(ms, ss, pal)):
        if np.any(m==0):
            continue
        plt.plot(x, m.T, marker, markersize=3, color=c, alpha=1)
        label=f'${plot_dir[-1]}=$'.replace('$a=$','').replace('r=','T=')\
            .replace('w=','\sigma_w=').replace('b=','\sigma_b=')
        try:
            label = label + f'${legends[i]}$'.replace('_','\_')
        except: import ipdb; ipdb.set_trace()
        if ss is not None:
            plt.fill_between(x, m-s, m+s, alpha=0.2, color=c,
                label=label)


def lineplot_ps(x, ps, pal, marker='o'):
    for i, (p, c) in enumerate(zip(ps, pal)):
        plt.plot(x[::10], p.T[::10], 'o', markersize=6, color=c, alpha=0.8)


def plot_xs(args, plot_dir, plot_name, legends, pal, ms, ss, ps=None,
            ms2=None, ss2=None, is_line=True, xl=None):
    """
    ms: n_m, n_iteration (e.g. n_layer, etc.)
    if legend is layer: figsize = (10, 6)
    if plotting dz_ratio: y-axis is percentage
    if x-axis is sigma_w or sigma_b: bar plot
    """
    if plot_dir[-1] == 'l':
        # plt.figure(figsize=(10, 6))
        pass
    plt.figure(figsize=(5, 3))
    if 'dz_ratio' in plot_name:
        plt.ylim(0, 100)
        ms *= 100; ss *= 100
    elif plot_dir[-1] == 'r': pass
    else: plt.yscale('log')
    x = np.arange(1, ms.shape[-1]+1)

    def barplot(ps):
        n_bars=len(ms); width=0.8/n_bars
        x_offset = x - width * (n_bars - 1) / 2
        for i, (m, s, c) in enumerate(zip(ms, ss, pal)):
            if np.any(m==0):
                continue
            plt.bar(x_offset, m.T, width=width, color=c)
            label=f'${plot_dir[-1]}=$'.replace('$a=$','').replace('r=','T=')\
                .replace('w=','\sigma_w=').replace('b=','\sigma_b=').replace('t=', '\eta=')
            label = label + f'${legends[i]}$'.replace('_','\_')
            if ss is not None:
                plt.errorbar(x_offset, m, s, fmt='o', alpha=0.5, color=c,
                    label=label)
            x_offset += width
        if ps is not None:
            if ps.shape[-1] != ms.shape[-1]: ps = ps[:,1:]
            for i, (p, c) in enumerate(zip(ps, pal)):
                plt.bar(x_offset, p.T, width=width, color=c, alpha=0.6)
                x_offset += width
        plt.xticks(x, xl)

    legends1 = [str(lgnd) + '_pcn' for lgnd in legends] if ms2 is not None else legends # flag
    pal1 = pal if ms2 is None else get_c(RED, RED, len(legends))
    lineplot(plot_dir, legends1, x, ms, ss, pal1, marker='-') if is_line else barplot(ps)
    if ms2 is not None:
        legends2 = [str(lgnd) + '_spcn' for lgnd in legends] # flag
        pal2 = get_c(BLUE, BLUE, len(legends))
        lineplot(plot_dir, legends2, x, ms2, ss2, pal2, marker='-') if is_line else barplot(ps)
    if ps is not None:
        if ps.shape[-1] != ms.shape[-1]: 
            ps = ps[:,1:]
        lineplot_ps(x, ps, pal) if is_line else barplot(ps)
    if 'dz_ratio' in plot_name:
        ms /= 100; ss /= 100
    plt.tick_params(axis='x', labelsize=20)
    plt.tick_params(axis='y', labelsize=20)
    # plt.tick_params(top='off', right='off')
    if is_line: plt.xlim(1, ms.shape[-1])
    else: plt.xlim(0, ms.shape[-1]+1)
    # optionally label x-axis, y-axis, title, legend
    if args.label:
        plt.xlabel(plot_dir.split('_')[-2], fontsize=20)
        prefix = 'len of ' if 'loss' not in plot_name else ''
        yl = prefix + plot_name.split('/')[1].split('_')[0] \
            if not 'dz_ratio' in plot_name else 'd/z percentage (%)'
        plt.ylabel(yl, fontsize=20)
        title = plot_name.split('/')[1].split('_')[-3:] if 'whole' in plot_name \
            else plot_name.split('/')[1].split('_')[-2:]
        title = ' '.join(title).replace('sw:', '$\sigma_w$:')\
                               .replace('sb:', '$\sigma_b$:')\
                               .replace('et:', '$\eta$:')
        plt.title(title, fontsize=20)
        if len(legends) < 7: plt.legend()
        else: plt.legend(bbox_to_anchor=(1.1, 1))
    # plt.tight_layout()
    p = f'{args.fig_dir}/{plot_dir}/'
    if not os.path.exists(p+plot_name.split('/')[0]):
        os.makedirs(p+plot_name.split('/')[0])
    plt.savefig(os.path.join(p, f'{plot_name}.png'), bbox_inches='tight')
    plt.cla()
    plt.clf()
    plt.close()



def old_plot_xs(args, plot_dir, plot_name, legends, pal, ms, ss, ps=None,
            ms2=None, ss2=None, is_line=True, xl=None):
    """
    ms: n_m, n_iteration (e.g. n_layer, etc.)
    if legend is layer: figsize = (10, 6)
    if plotting dz_ratio: y-axis is percentage
    if x-axis is sigma_w or sigma_b: bar plot
    """
    if plot_dir[-1] == 'l':
        plt.figure(figsize=(10, 6))
    else: plt.figure(figsize=(5, 3))
    if 'dz_ratio' in plot_name:
        plt.ylim(0, 100)
        ms *= 100; ss *= 100
    else: plt.yscale('log')
    x = np.arange(1, ms.shape[-1]+1)

    def lineplot(ps):
        for i, (m, s, c) in enumerate(zip(ms, ss, pal)):
            if np.any(m==0):
                continue
            plt.plot(x, m.T, '-', markersize=3, color=c, alpha=1)
            label=f'${plot_dir[-1]}=$'.replace('$a=$','')\
                .replace('w=','\sigma_w=').replace('b=','\sigma_b=')
            label = label + f'${legends[i]}$'.replace('_','\_')
            if ss is not None:
                plt.fill_between(x, m-s, m+s, alpha=0.2, color=c,
                    label=label)
        if ps is not None:
            if ps.shape[-1] != ms.shape[-1]: ps = ps[:,1:]
            for i, (p, c) in enumerate(zip(ps, pal)):
                plt.plot(x, p.T, 'o', markersize=5, color=c, alpha=0.8)
        # plt.xticks(x)

    def barplot(ps):
        n_bars=len(ms); width=0.8/n_bars
        x_offset = x - width * (n_bars - 1) / 2
        for i, (m, s, c) in enumerate(zip(ms, ss, pal)):
            if np.any(m==0):
                continue
            plt.bar(x_offset, m.T, width=width, color=c)
            label=f'${plot_dir[-1]}=$'.replace('$a=$','')\
                .replace('w=','\sigma_w=').replace('b=','\sigma_b=').replace('et=', '\eta=')
            label = label + f'${legends[i]}$'.replace('_','\_')
            if ss is not None:
                plt.errorbar(x_offset, m, s, fmt='o', alpha=0.5, color=c,
                    label=label)
            x_offset += width
        if ps is not None:
            if ps.shape[-1] != ms.shape[-1]: ps = ps[:,1:]
            for i, (p, c) in enumerate(zip(ps, pal)):
                plt.bar(x_offset, p.T, width=width, color=c, alpha=0.6)
                x_offset += width
        plt.xticks(x, xl)

    lineplot(ps) if is_line else barplot(ps)
    if 'dz_ratio' in plot_name:
        ms /= 100; ss /= 100
    plt.tick_params(axis='x', labelsize=20)
    plt.tick_params(axis='y', labelsize=20)
    # plt.tick_params(top='off', right='off')
    if is_line: plt.xlim(1, ms.shape[-1])
    else: plt.xlim(0, ms.shape[-1]+1)
    # optionally label x-axis, y-axis, title, legend
    if args.label:
        plt.xlabel(plot_dir.split('_')[-2], fontsize=20)
        prefix = 'len of ' if 'loss' not in plot_name else ''
        yl = prefix + plot_name.split('/')[1].split('_')[0] \
            if not 'dz_ratio' in plot_name else 'd/z percentage (%)'
        plt.ylabel(yl, fontsize=20)
        title = plot_name.split('/')[1].split('_')[-3:] if 'whole' in plot_name \
            else plot_name.split('/')[1].split('_')[-2:]
        title = ' '.join(title).replace('sw:', '$\sigma_w$:')\
                               .replace('sb:', '$\sigma_b$:')\
                               .replace('et:', '$\eta$:')
        plt.title(title, fontsize=20)
        if len(legends) < 7: plt.legend()
        else: plt.legend(bbox_to_anchor=(1.1, 1))
    # plt.tight_layout()
    p = f'{args.fig_dir}/{plot_dir}/'
    if not os.path.exists(p+plot_name.split('/')[0]):
        os.makedirs(p+plot_name.split('/')[0])
    plt.savefig(os.path.join(p, f'{plot_name}.png'), bbox_inches='tight')
    plt.cla()
    plt.clf()
    plt.close()


def scatter_plots(lenlog, args, run_cond_vs):
    '''
    example of the shapes
    lenlog.z_len, lenlog.d_len: (n_runs, nw, nb, net, T, L, bsz) = (2, 3, 2, 100, 15, 64)
    p_iter: (nw, T, L, n_runs*bsz) = (3, 100, 15, 2*64)
    '''
    ps, qs = lenlog.z_len, lenlog.d_len # (n_runs, nw, nb, net, T, L, bsz) -> (nw, T, L, nb*net*n_runs*bsz)
    p_iter, q_iter = ps.transpose(1,4,5,2,3,0,6), qs.transpose(1,4,5,2,3,0,6)
    pconds = ['z_', 'd_']
    # ps_iter = [ps.transpose(0,2,3,1,4) for ps in ps_result]
    # p_iter, ph_iter, pt_iter, pc_iter, q_iter = ps_iter
    p_iter = np.reshape(p_iter, (p_iter.shape[0], p_iter.shape[1], p_iter.shape[2],
                                p_iter.shape[3]*p_iter.shape[4]*p_iter.shape[5]*p_iter.shape[6]))
    q_iter = np.reshape(q_iter, (q_iter.shape[0], q_iter.shape[1], q_iter.shape[2],
                                q_iter.shape[3]*q_iter.shape[4]*q_iter.shape[5]*q_iter.shape[6]))
    args.n_conds = len(run_cond_vs)
    pal = get_c(BLUE, RED, args.n_conds)
    for l in range(1, args.n_layers-1):
        plt.figure(figsize=(5, 3))
        for i_w in range(args.n_conds):
            # plot_qmaps(qrange, qmaps[..., 1], widxs, bidxs, lw=2) # from theory
            plt.scatter([], [], label=f'$\sigma_w$={run_cond_vs[i_w]}', color=pal[i_w])
            for t in range(args.T-161, args.T-1, 8):
                alpha = 0.002 + 0.048*(1.02**t/1.02**args.T) if i_w != 1 else 0.05
                plt.scatter(p_iter[i_w, t, l, :128], p_iter[i_w, t+8, l, :128],
                    s=20, color=pal[i_w], alpha=alpha)
        plt.plot((1e-6, 1e+8), (1e-6, 1e+8), '--', color='k', zorder=900, linewidth=4)
        plt.xlim(1e-6, 1e+8)
        plt.ylim(1e-6, 1e+8)
        plt.tick_params(axis='x')
        plt.tick_params(axis='y')
        plt.xscale('log')
        plt.yscale('log')
        if args.label:
            plt.legend()
            plt.xlabel(f'${{length}}^t$', fontsize=20)
            plt.ylabel(f'${{length}}^{{t+1}}$', fontsize=20)
            plt.title(f'$\eta$: {args.eta} $\sigma_b$: {args.sigma_b} l:{l+1}', fontsize=20)
        # plt.tight_layout()
        if not os.path.exists(f'{args.fig_dir}/scatter_plots_len_w'):
            os.makedirs(f'{args.fig_dir}/scatter_plots_len_w')
        plt.savefig(os.path.join(f'{args.fig_dir}/scatter_plots_len_w',
                    f'{pconds[0]}l:{l+1}.png'), bbox_inches='tight')
        plt.cla()
        plt.clf()
        plt.close()
        plt.figure(figsize=(5, 3))
        for i_w in range(args.n_conds):
            # plot_qmaps(qrange, qmaps[..., 1], widxs, bidxs, lw=2) # from theory
            plt.scatter([], [], label=f'$\sigma_w$={run_cond_vs[i_w]}', color=pal[i_w])
            for t in range(args.T-161, args.T-1, 8):
                alpha = 0.001 + 0.004*(t/args.T)
                plt.scatter(q_iter[i_w, t, l, :32], q_iter[i_w, t+8, l, :32],
                    s=20, color=pal[i_w], alpha=alpha)
        plt.plot((1e-6, 1e+8), (1e-6, 1e+8), '--', color='k', zorder=900, linewidth=4)
        plt.xlim(1e-6, 1e+8)
        plt.ylim(1e-6, 1e+8)
        plt.tick_params(axis='x')
        plt.tick_params(axis='y')
        plt.xscale('log')
        plt.yscale('log')
        if args.label:
            plt.legend()
            plt.xlabel(f'${{length}}^t$', fontsize=20)
            plt.ylabel(f'${{length}}^{{t+1}}$', fontsize=20)
            plt.title(f'$\eta$: {args.eta} $\sigma_b$: {args.sigma_b} l:{l+1}', fontsize=20)
        # plt.tight_layout()
        if not os.path.exists(f'{args.fig_dir}/scatter_plots_len_w'):
            os.makedirs(f'{args.fig_dir}/scatter_plots_len_w')
        plt.savefig(os.path.join(f'{args.fig_dir}/scatter_plots_len_w',
                    f'{pconds[1]}l:{l+1}.png'), bbox_inches='tight')
        plt.cla()
        plt.clf()
        plt.close()


def heatmap_plots(lenlog, args, sigma_ws, sigma_bs):
    '''
    example of the shapes
    pmaps_grid: (n_ps, nw, nb, L)
    qmaps_grid: (nw, nb, L)
    T is deleted since got the last time step
    ps_result: list
        z, d: (nw, n_runs, nb, L, bsz)
        w: (nw, n_runs, nb, L)
    '''
    zd_stat, wb_stat, _, _ = lenlog.get_stat()
    ps_result = [zd_stat[0][0][:,:,0,-1], zd_stat[3][0][:,:,0,-1],
                 wb_stat[0][0][:,:,0,-1], wb_stat[1][0][:,:,0,-1]]
    # import ipdb; ipdb.set_trace()
    # TODO: check dimensions and mean over runs and bsz
    pconds = ['z', 'd', 'w_grad', 'b_grad']
    # ps_result = [np.mean(p_result,(-1,-4)) if p_result.ndim == 5 else \
    #              np.mean(p_result,(-3)) for p_result in ps_result]
    n_ws = len(sigma_ws); n_bs = len(sigma_bs)
    max_sw = max(sigma_ws); max_sb = max(sigma_bs)
    xticks = np.array([0, 0.25, 0.5, 0.75, 1.0]) * max_sb
    yticks = np.array([0, 0.25, 0.5, 0.75, 1.0]) * max_sw
    xticks = np.round(xticks, 2); yticks = np.round(yticks, 2)
    if not args.skip_theory:
        pmaps_grid, qmaps_grid = get_pmaps_theory_grid(args, sigma_ws, sigma_bs)
    for i, p_result in enumerate(ps_result):
        pcond = pconds[i]
        for l in range(1, args.n_layers-1):
            plt.figure(figsize=(10, 6))
            if not args.skip_theory:
                sigma_pcolor(pmaps_grid[i,:,:,l], sigma_ws, sigma_bs)
            else:
                sigma_pcolor(p_result[:,:,l], sigma_ws, sigma_bs)
            if args.label:
                plt.title(f'{pcond}, T:{args.T}, $\eta$:{args.min_val_e}, l:{l+1}', fontsize=40)
                plt.xlabel(f'$\sigma_b$', fontsize=40)
                plt.ylabel(f'$\sigma_w$', fontsize=40)
            # TODO: change ticks according to the range of sigmas
            plt.xticks(xticks, fontsize=40)
            plt.yticks(yticks, fontsize=40)
            plt.xlim(0, max_sb)
            plt.ylim(0, max_sw)
            # plt.tight_layout()
            if not os.path.exists(f'{args.fig_dir}/heatmaps_sbsw/{pcond}'):
                os.makedirs(f'{args.fig_dir}/heatmaps_sbsw/{pcond}')
            plt.savefig(os.path.join(f'{args.fig_dir}/heatmaps_sbsw/{pcond}',
                        f'{pcond}_l:{l+1}.png'), bbox_inches='tight')
            plt.cla()
            plt.clf()
            plt.close()
