import os
import sys
import git
import math
import yaml
import numpy as np
import torch
import argparse
from datetime import datetime
from torchvision import models as torchvision_models
from data.dl_getter import DATASETS, n_cls, sh, input_range
from tool.util import set_seed, set_exp, set_fig_dir


torchvision_archs = sorted(name for name in torchvision_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(torchvision_models.__dict__[name]))


def get_general_args(parser = None, is_nb=False):
    args = pre_processing_args(parser, is_nb)
    args = post_processing_args(args)
    return args

def pre_processing_args(parser = None, is_nb=False):
    if parser is None :
        parser = argparse.ArgumentParser('Evaluation')
    parser.add_argument('--version', default='1.0', type=str, help='version')
    parser.add_argument('--proj_name', default='bml_pc', type=str, help='experiment name')
    parser.add_argument('--exp', default='', type=str, help='experiment name')
    parser.add_argument('--tags', type=str, default='', metavar='N',
                        help='')
    parser.add_argument('--output_dir', default=os.path.expanduser("~/exp.log/"),
                        help='Path to save logs and chkpts')
    parser.add_argument('--exp_load', default='', type=str, help='experiment name')
    parser.add_argument('--wandb_dir', default=os.path.expanduser("~/wandb.log/"),
                        help='Path to save logs and chkpts')
    parser.add_argument('--wandb_entity', default="bml_pc",
                        help='wandb entity')
    parser.add_argument('--debug',   action='store_true', default=False,
                        help='debug mode or not')
    parser.add_argument('--nosha',  action='store_true', default=False,
                        help='debug mode or not')
    
    # special option for debug
    parser.add_argument('--orthogonal_testing', action='store_true', default=False, help='orthogonal initializing of weight')

    # method
    parser.add_argument('--method', default='pc', type=str, choices=['ce',
                        'at', 'ae', 'acls', 'pc', 'pcd', 'pcl', 'pcls', 'cls'], help="method name")
    parser.add_argument('--arch', default='fc', type=str, choices=['ae',
                        'resnet', 'fc', 'cnn'], help='Architecture')
    parser.add_argument("--chkpt_key", default="teacher", type=str, help='Key to use in the chkpt (example: "teacher")')
    parser.add_argument('--epochs', default=20, type=int, help='Number of epochs of training.')
    parser.add_argument("--optimizer", choices=["adam", "sgd"], default="adam")
    parser.add_argument("--seed", type=int, default=0, help="Random seed.")
    parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate.""")
    parser.add_argument("--lr_decay_epochs", nargs="+", type=int, default=[60, 120, 160],
                        help="decay learning rate by decay_rate at these epochs")
    parser.add_argument("--lr_decay_rate", type=float, default=.2,
                        help="learning rate decay multiplier")
    parser.add_argument("--warmup_iters", type=int, default=1000,
                        help="number of iters to linearly increase learning rate, if -1 then no warmmup")
    parser.add_argument("--warmup_epochs", type=int, default=5,
                        help="number of iters to linearly increase learning rate, if -1 then no warmmup")
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')
    parser.add_argument('--save_freq', default=5, type=int, help='save frequency')
    # fig 1a: 20, fig 1b: 100, fig 2b: 500, fig 4: 256 (unofficial)
    parser.add_argument('--bsz', default=128, type=int, help='batch-size')
    parser.add_argument('--bsz_vl', default=1024, type=int, help='batch-size')
    parser.add_argument('--dataset', default='random', type=str, help='data option for stability analysis',
                            choices=['mnist', 'cifar10', 'fmnist', 'svhn', 'random', 'o2c', 'complex'])

    #ae
    parser.add_argument('--eval', action='store_true', default=False,
                        help='debug mode or not')
    parser.add_argument('--pred_enc_dec',   action='store_true', default=False,
                        help='prediction based on encoding and decoding')
    parser.add_argument('--update_x',   action='store_true', default=False,
                        help='prediction based on encoding and decoding')
    #pc
    parser.add_argument('--cls_mth', default='none', type=str, choices=['z2y', 'y2z', 'emb', 'none'],
                        help='Classification method')
    parser.add_argument('--vl_corrupt_ns', action='store_true', default=False,
                        help='corrupt validation set')

    # stability analysis
    parser.add_argument('--act', default='linear', type=str, help='activation function',
                        choices=('relu','selu','tanh','leaky_relu','silu','linear','logsig'))
    parser.add_argument('--T', default=10, type=int, help='Number of inner steps') # 200, 350 for stability analysis
    parser.add_argument('--T_vl', default=300, type=int, help='Number of inner steps') # 200, 350 for stability analysis
    parser.add_argument('--eta', default=0.1, type=float, help='Step size for inner steps')
    parser.add_argument('--eta_r', default=10., type=float, help='min value of run conditions eta')
    parser.add_argument('--sigma_w', default=1.0, type=float, help='gain for xavier initialization')
    parser.add_argument('--sigma_b', default=0.1, type=float,
                        help='std of bias for initialization, 0.05 for grad anal, 0.3 for expr anal')
    parser.add_argument('--run_cond', default='sigma_w', type=str,
                        choices= ['sigma_w', 'T', 'eta', 'iter'],
                        help='varying condition for stability analysis')
    parser.add_argument('--n_layers', default=4, type=int, help='Number of layers')
    parser.add_argument('--latent_dim', default=100, type=int, help='Dimension of latent space')
    parser.add_argument('--tensorized', action='store_true', default=True, help='use tensorized version to compute length')
    # experimental conditions
    parser.add_argument('--crop', default=0, type=int, help='crop x-axis during plotting')
    parser.add_argument('--label', action='store_true', default=False, help='optionally plot label and title')
    parser.add_argument('--min_val_T', default=50, type=int, help='min value run conditions T')
    parser.add_argument('--min_val_e', default=0.01, type=float, help='min value run conditions eta')
    parser.add_argument('--min_val_sw', default=1.0, type=float, help='min value of run conditions sigma_w')
    parser.add_argument('--min_val_sb', default=0.2, type=float, help='min value of run conditions sigma_b')
    parser.add_argument('--step_val_e', default=0.04, type=float, help='step value of run conditions eta')
    parser.add_argument('--step_val_sw', default=2.0, type=float, help='step value of run conditions sigma_w')
    parser.add_argument('--step_val_sb', default=0.2, type=float, help='step value of run conditions sigma_b')
    parser.add_argument('--step_exp', action='store_true', default=False, help='exponential increase of run conditions')
    parser.add_argument('--n_conds', default=1, type=int, help='number of conditions') #41
    parser.add_argument('--n_conds_e', default=3, type=int, help='number of conditions') #41
    parser.add_argument('--n_runs', default=1, type=int, help='number of runs for stability analysis') # 7
    parser.add_argument('--heatmap', default=False, action='store_true', help='plot heatmap')
    parser.add_argument('--scatter', default=False, action='store_true', help='plot scatter')
    parser.add_argument('--theory', default=False, action='store_true', help='plot theory')
    parser.add_argument('--o2c_idx', default=1, type=int, help='index for number of layers used for generating o2c')
    parser.add_argument('--v_check', action='store_true', default=False, help='check for stability analysis')
    parser.add_argument('--not_iter_log', action='store_true', default=False, help='log during z update iterations')
    parser.add_argument('--flip', action='store_false', default=True, help='activate ahead of forward')
    parser.add_argument('--exp_opt', default='pc', type=str, help='experiment option for stability analysis',
                            choices=['torch', 'keras', 'pc', 'None'])
    parser.add_argument('--train', action='store_true', default=False, help='train model and do stability analysis')
    parser.add_argument('--not_use_y', action='store_true', default=True, help='not using y during validation')
    parser.add_argument('--last_cls', action='store_true', default=False, help='CE loss for last layer when doing cls task')
    parser.add_argument('--verbose', action='store_false', default=True, help='verbose')
    parser.add_argument('--do_drop', action='store_true', default=False, help='do dropout')
    parser.add_argument('--dropout', default=0.1, type=float, help='dropout rate')
    parser.add_argument('--skip_con', action='store_true', default=False, help='use skip connection')
    parser.add_argument('--not_use_db', action='store_true', default=False, help='not using db')
    parser.add_argument('--z_hard_norm', action='store_true', default=False, help='hard normalization of z')
    parser.add_argument('--step_eta', action='store_true', default=False, help='step eta')
    parser.add_argument('--comp_eta', action='store_true', default=False, help='step T')
    parser.add_argument('--step_inf_eta', default=2.0, type=float, help='step value of eta during inference')
    parser.add_argument('--eta_cnt', default=5, type=int, help='count eta')
    parser.add_argument('--inf_tol', default=1e-6, type=float, help='tolerance value of eta during inference')
    parser.add_argument('--step_T', action='store_true', default=False, help='step T')
    parser.add_argument('--loss_sum', action='store_true', default=False, help='sum over the batch of losses')
    parser.add_argument('--z_init', default='gaussian', type=str, help='initialization of z',
                            choices=['use_db', 'gaussian', 'ff'])
    parser.add_argument('--param_init', default='normal', type=str, help='initialization of parameters',
                            choices=['normal', 'uniform'])
    parser.add_argument('--act_first', action='store_true', default=False, help='activation first')
    parser.add_argument('--smooth_label', action='store_true', default=False, help='label smoothing')
    parser.add_argument('--alpha_sm', default=0.005, type=float, help='smoothing ratio')
    parser.add_argument('--gain', default=1.0, type=float, help='gain ratio')
    parser.add_argument('--test_sw', action='store_true', default=False, help='test sigma_w')
    # stability method
    parser.add_argument('--not_prg', action='store_true', default=False, help='not using progress')
    parser.add_argument('--skip_bias', action='store_true', default=False, help='skip bias')
    parser.add_argument('--z_norm', action='store_true', default=False, help='normalization of z') # change to store_true
    parser.add_argument('--vec_min', default=0.9, type=float, help='min clip of z')
    parser.add_argument('--vec_max', default=1.0, type=float, help='max clip of z')
    parser.add_argument('--alpha_mean', action='store_true', default=False, help='mean over batch during normalizations')
    parser.add_argument('--w_norm', action='store_true', default=False, help='normalization of w')
    parser.add_argument('--w_orth', action='store_true', default=False, help='orthogonalization of w grad')
    parser.add_argument('--b_norm', action='store_true', default=False, help='normalization of b')
    parser.add_argument('--z_reg', action='store_true', default=False, help='loss regularized by z')
    parser.add_argument('--w_reg', action='store_true', default=False, help='loss regularized by w')
    parser.add_argument('--b_reg', action='store_true', default=False, help='loss regularized by b')
    parser.add_argument('--reg_coef', default=1.0, type=float, help='regularization coefficient')
    parser.add_argument('--pos', action='store_true', default=False,  help='True: plot all sigmas')
    parser.add_argument('--len_all', action='store_true', default=False,  help='True: logging thru all epoch, False: only last epoch')

    args = parser.parse_args()
    
    return args

def post_processing_args(args):
    if args.train:
        args.not_iter_log = True
        args.step_eta = not args.comp_eta
    args.skip_theory = not args.theory
    if not args.train: args.epochs = 1

    # load config file
    # filename = 'config/{}_{}.yaml'.format(args.method, args.dataset)
    filename_dt = 'config/{}.yaml'.format(args.dataset)
    # with open(filename, 'r') as f:
        # config = yaml.safe_load(f)
    with open(filename_dt, 'r') as f:
        config_dt = yaml.safe_load(f)
    # double dictionary format is only allowed
    # for param_type in config.values():
    #     # try to use the same name as the config file as possible!
    #     for param_name, param_value in param_type.items():
    #         setattr(args, param_name, param_value)
    for common_type in config_dt.values():
        for param_name, param_value in common_type.items():
            setattr(args, param_name, param_value)

    dt = args.dataset
    if dt in ['cifar10', 'svhn', 'stl10', 'cifar10H', 'cifar100', 'celeba', 'imagenet']:
        args.in_ch = 3
    else:
        args.in_ch = 1
    args.reloaded = False

    mth = args.method
    args.pc = True if mth in ['pc', 'pcls', 'pcd', 'pcl'] else False
    args.pcd = True if mth in ['pcd'] else False
    args.pcl = True if mth in ['pcl'] else False
    args.ae = True if mth in ['ae', 'acls'] else False
    args.rec = True if mth in ['ae', 'pc'] else False # acls, pcls might also be added
    args.cls = True if mth in ['cls', 'acls', 'pcls'] else False

    arch = args.arch
    args.cnn = True if arch in ['cnn', 'resnet'] else False

    if arch == 'cnn':
        args.ecs = [args.in_ch] + args.ecs
        args.dcs = args.dcs + [args.in_ch]

    # stability analysis
    if args.dataset in ['mnist', 'cifar10', 'svhn', 'fmnist']:
        args.input_range = input_range[args.dataset]
        args.sh = sh[args.dataset]
        args.cls = True
        args.num_labels = args.n_cls = n_cls[f'{args.dataset}']
        if arch == 'fc':
            args.ds = [int(np.prod(args.sh))] + [args.latent_dim]*(args.n_layers-2) \
                    + [n_cls[args.dataset]]
            args.lyr = ['fc'] * (args.n_layers-1)
        elif arch == 'cnn':
            assert args.n_layers == len(args.lyr) + 1
            assert len(args.eks) == len([l for l in args.lyr if l == 'cnn'])
            ds = [args.sh[1]]
            for i in range(len(args.eks)):
                d_new = (ds[-1] + args.eps[i]*2 - args.eks[i]) // args.ess[i] + 1
                ds.append(d_new)
            ds = [(d**2) * args.ecs[i] for i, d in enumerate(ds)]
            args.ds = ds + [args.latent_dim] * (args.n_layers-len(args.eks)-2) + [n_cls[args.dataset]]
    else:
        assert args.T in [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000]
        # args.z_init = 'gaussian'
        args.sh = [args.latent_dim]
        args.cls = False
        args.ds = [args.latent_dim] * (args.n_layers)
        args.lyr = ['fc'] * (args.n_layers-1)
        args.ds_o = [args.latent_dim] * (args.o2c_idx)
    args.name = 'PC' if args.pc else 'FF'
    args.run_min = args.min_val_sw
    args.step_run = args.step_val_sw
    args.stability = True
    args.prg = not args.not_prg
    # options: 'scatter', 'layer', 'iter', 'heatmap', 'theory'
    if args.heatmap: args.plots = ['heatmap']
    elif args.scatter: args.plots = ['scatter']
    else: args.plots = ['layer', 'iter']
    if args.theory: args.plots.append('theory')
    args.warmup_epochs = int(args.epochs*0.03)
    args.lr_decay_epochs = [int(args.epochs)*0.3, int(args.epochs)*0.6, int(args.epochs)*0.8]

    # added args.device, args.print_freq, args.head_eval_clip, args.at
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.print_freq = 100
    args.head_eval_clip = None
    args.at = True if 'at' in args.method else False

    # repo = git.Repo(search_parent_directories=True)
    # sha = repo.head.object.hexsha
    args.sha = '' #'{}'.format(sha[:4])
    args.st_time = datetime.today().strftime("%y%m%d_%H%M_%S")
    set_exp(args)
    if args.debug:
        args.epochs = 1
    args.load_path = None
    set_seed(args.seed)
    set_fig_dir(args)
    return args
