import os
import gc
import sys
import time
import torch
import wandb
import pathlib
import numpy as np
from anal.pc import AnalPC
from anal.ff import AnalFF
from anal.log import LenLog, LenT
from datetime import datetime
from train.mlbase import MLBase
from os import listdir
from os.path import isfile, join
from tool.args import get_general_args
from anal.pmap_th import get_pmaps_theory, get_p0_thm
from anal.util import get_sigmas, get_etas, transpose_np as tr
from plot.maps import line_plots, scatter_plots, heatmap_plots
from tool.util  import init_wandb, set_seed, get_config, load_pickle
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.enabled = True


version = '1.1'


def main(args, sigma_ws, sigma_bs, etas):
    cAnal = AnalPC if args.pc else AnalFF
    # print(sigma_ws, sigma_bs)
    n_r = args.n_runs; n_w = len(sigma_ws); n_b = len(sigma_bs); n_e = len(etas)
    anal = cAnal(MLBase(args))

    args.log_dir = join('log/', version, get_config(args))
    pathlib.Path(args.log_dir).mkdir(parents=True, exist_ok=True)

    stored_path_dt = dict()
    # if args.orthogonal_testing != True:
    #     for T in listdir(args.log_dir):
    #         if int(T) >= args.T:
    #             log_dir_T = join(args.log_dir, T)
    #             for key_f in listdir(log_dir_T):
    #                 log_path = join(log_dir_T, key_f)
    #                 if isfile(log_path) and key_f not in stored_path_dt.keys():
    #                     stored_path_dt[key_f] = log_path

    to_get_result_dt = dict()
    for i_run in range(args.n_runs):
        for sigma_w in sigma_ws:
            for sigma_b in sigma_bs:
                for eta in etas:
                    to_get_result_dt[f"{i_run},{sigma_w},{sigma_b},{eta}"] = None
    stored_result_set = set(stored_path_dt.keys())
    to_get_result_set = set(to_get_result_dt.keys())
    common_set = stored_result_set & to_get_result_set
    for key_f in list(common_set):
        print(f'loading {key_f}')
        to_get_result_dt[key_f] = load_pickle(args, stored_path_dt[key_f])
        gc.collect()

    to_run_lst = list(to_get_result_set - common_set)
    if args.train: to_run_lst = list(to_get_result_set)
    for key in to_run_lst:
        i_run, sigma_w, sigma_b, eta = key.split(',')
        set_seed(args.seed + int(i_run))
        lenlog = anal(int(i_run), float(sigma_w), float(sigma_b), float(eta))
        # anal.save_model(0, args.util_best_acc)
        to_get_result_dt[key] = lenlog.get_len_lst()

    sorted_keys = sorted(list(to_get_result_dt.keys()))
    sorted_keys = sorted(list(map(lambda x: int(x) if x == key.split(',')[0] else float(x), key.split(','))) \
                         for key in to_get_result_dt.keys())
    
    sorted_lst = [to_get_result_dt[str(k)[1:-1].replace(' ','')] \
                     for k in sorted_keys]
    len_lst_lst = tr(sorted_lst)
    len_np_lst = list()
    for len_lst in len_lst_lst:
        len_np = len_lst.reshape(n_r, n_w, n_b, n_e, *len_lst.shape[1:])
        len_np_lst.append(len_np)
    
    lenT = LenT(len_np_lst, args, args.n_runs, sigma_ws, sigma_bs, etas)
    return lenT


def plot(args, lenlog, sigma_ws, sigma_bs, etas):
    plots= args.plots #['scatter', 'layer', 'iter', 'heatmap', 'theory']
    lenthm = None
    if 'theory' in plots:
        lenthm = get_p0_thm(args, sigma_ws, sigma_bs, etas)
    if 'scatter' in plots:
        scatter_plots(lenlog, args, sigma_ws)
    if 'iter' in plots or 'layer' in plots:
        line_plots(args, sigma_ws, sigma_bs, etas, lenlog, lenthm=lenthm)
    if 'heatmap' in plots:
        heatmap_plots(lenlog, args, sigma_ws, sigma_bs)


if __name__ == '__main__':
    print(datetime.now())
    args = get_general_args()
    init_wandb(args)
    sigma_ws = get_sigmas(args, args.min_val_sw, args.step_val_sw)
    # sigma_ws = [0.1, 0.139, 0.193, 0.268, 0.373, 0.518, 0.72, 1.0, 1.389, 1.931, 2.683, 3.728, 5.179, 7.197, 10.0]
    # if not args.train: sigma_ws = [1.0, 3.728, 7.197]
    sigma_bs = get_sigmas(args, args.min_val_sb, args.step_val_sb) \
                if 'heatmap' in args.plots else [args.sigma_b]
    etas = get_etas(args, args.min_val_e, args.step_val_e)
    lenlog = main(args, sigma_ws, sigma_bs, etas)
    plot(args, lenlog, sigma_ws, sigma_bs, etas)
