import os
import gc
import sys
import time
import torch
import wandb
import pathlib
import numpy as np
from anal.pc import AnalPC
from anal.ff import AnalFF
from anal.log import LenLog, LenT
from datetime import datetime
from train.mlbase import MLBase
from os import listdir
from os.path import isfile, join
from tool.args import get_general_args
from anal.pmap_th import get_pmaps_theory, get_p0_thm
from anal.util import get_sigmas, get_etas, transpose_np as tr
from plot.maps import line_plots, scatter_plots, heatmap_plots
from tool.util  import init_wandb, set_seed, get_config, load_pickle
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.enabled = True


version = '1.1'


def main(args, sigma_ws, sigma_bs, etas):
    cAnal = AnalPC if args.pc else AnalFF
    # print(sigma_ws, sigma_bs)
    n_r = args.n_runs; n_w = len(sigma_ws); n_b = len(sigma_bs); n_e = len(etas)
    anal = cAnal(MLBase(args))
    args.log_dir = join('log/', version, get_config(args))
    pathlib.Path(args.log_dir).mkdir(parents=True, exist_ok=True)

    stored_path_dt = dict()
    for T in listdir(args.log_dir):
        if (int(T) >= args.T and not args.train) or int(T) == args.T:
            log_dir_T = join(args.log_dir, T)
            for key_f in listdir(log_dir_T):
                log_path = join(log_dir_T, key_f)
                if isfile(log_path) and key_f not in stored_path_dt.keys():
                    stored_path_dt[key_f] = log_path

    to_get_result_dt = dict()
    for i_run in range(args.n_runs):
        for sigma_w in sigma_ws:
            for sigma_b in sigma_bs:
                for eta in etas:
                    to_get_result_dt[f"{args.seed + i_run},{sigma_w},{sigma_b},{eta}"] = None
    print('to get result dict:', to_get_result_dt.keys())
    stored_result_set = set(stored_path_dt.keys())
    to_get_result_set = set(to_get_result_dt.keys())
    common_set = stored_result_set & to_get_result_set
    to_remove_list = list()
    for key_f in list(common_set):
        print(f'loading {key_f}')
        # try:
        to_get_result_dt[key_f] = load_pickle(args, stored_path_dt[key_f])
        gc.collect()
        # except:
        #     print(f'Error loading {key_f}')
        #     # delete pickle file using stored_path_dt[key_f]
        #     os.remove(stored_path_dt[key_f])
        #     # delete key_f from common_set
        #     to_remove_list.append(key_f)
    for key_f in to_remove_list:
        common_set.remove(key_f)            

    to_run_lst = list(to_get_result_set - common_set)
    # if args.train: to_run_lst = list(to_get_result_set)
    for key in to_run_lst:
        i_run, sigma_w, sigma_b, eta = key.split(',')
        set_seed(args.seed + int(i_run))
        lenlog = anal(int(i_run), float(sigma_w), float(sigma_b), float(eta))
        to_get_result_dt[key] = lenlog.get_len_lst()

    sorted_keys = sorted(list(to_get_result_dt.keys()))
    sorted_keys = sorted(list(map(lambda x: int(x) if x == key.split(',')[0] else float(x), key.split(','))) \
                         for key in to_get_result_dt.keys())
    sorted_lst = [to_get_result_dt[str(k)[1:-1].replace(' ','')] \
                     for k in sorted_keys]
    del to_get_result_dt
    gc.collect()
    print('sorting list completed')

    len_lst_lst = tr(sorted_lst)
    del sorted_lst
    gc.collect()
    print('transposing list completed')
    len_np_lst = list()
    for len_lst in len_lst_lst:
        len_np = len_lst.reshape(n_r, n_w, n_b, n_e, *len_lst.shape[1:])
        len_np_lst.append(len_np)
    del len_lst_lst
    gc.collect()
    print('reshaping list completed')
    lenT = LenT(len_np_lst, args, args.n_runs, sigma_ws, sigma_bs, etas)
    del len_np_lst
    gc.collect()
    return lenT


def plot(args, lenlog, sigma_ws, sigma_bs, etas, lenlog2=None):
    plots= args.plots #['scatter', 'layer', 'iter', 'heatmap', 'theory']
    lenthm = None
    if args.debug: import ipdb; ipdb.set_trace()
    if 'theory' in plots:
        lenthm = get_p0_thm(args, sigma_ws, sigma_bs, etas)
    if 'scatter' in plots:
        scatter_plots(lenlog, args, sigma_ws)
    if 'iter' in plots or 'layer' in plots:
        line_plots(args, sigma_ws, sigma_bs, etas, lenlog, lenlog2, lenthm)
    if 'heatmap' in plots:
        heatmap_plots(lenlog, args, sigma_ws, sigma_bs)


if __name__ == '__main__':
    print(datetime.now())
    args = get_general_args()
    init_wandb(args)
    sigma_ws = get_sigmas(args, args.min_val_sw, args.step_val_sw)
    sigma_bs = get_sigmas(args, args.min_val_sb, args.step_val_sb) \
                if 'heatmap' in args.plots else [args.sigma_b]
    etas = get_etas(args, args.min_val_e, args.step_val_e)
    if args.pos: 
        if 'iter_sb' in args.exp: sigma_ws = [1.0]; sigma_bs = [0.185, 1.0, 5.4]
        elif 'fig2' in args.exp or 'fig5' in args.exp: sigma_ws = [0.185, 1.0, 5.4]; etas = [0.05]
        elif 'fig3' in args.exp or 'fig5' in args.exp: sigma_ws = [0.185, 0.43, 1.0]; etas = [0.05]
        elif 'explode' in args.exp: sigma_ws = [1.0, 2.0, 4.0, 8.0]; etas=[0.01, 0.05, 0.2, 0.45]
    lenlog = main(args, sigma_ws, sigma_bs, etas)
    # double plot
    lenlog2 = None
    # args.method = 'pcd' # flag
    # args.w_reg = True; args.b_reg = True; args.z_reg = True; args.reg_coef = 10.0
    # lenlog2 = main(args, sigma_ws, sigma_bs, etas)
    plot(args, lenlog, sigma_ws, sigma_bs, etas, lenlog2) 
