import argparse
import os
import torch

import numpy as np
import torch
import random
import pdb

import re
import yaml

import shutil

import torch.backends.cudnn as cudnn
import torch.distributed as dist
from configs.get_configuration import get_config

from datetime import datetime

import torch.cuda.amp as amp

class Namespace(object):
    def __init__(self, somedict):
        for key, value in somedict.items():
            assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
            if isinstance(value, dict):
                self.__dict__[key] = Namespace(value)
            else:
                self.__dict__[key] = value

    def __getattr__(self, attribute):
        raise AttributeError(f'Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!')


def set_deterministic(seed):
    # seed by default is None
    if seed != None:
        print(f"Deterministic with seed = {seed}")
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def set_distributed(config):
    if config.AMP_OPT_LEVEL != "O0":
        assert amp != None, "amp not installed!"

    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1
    torch.cuda.set_device(config.LOCAL_RANK)
    dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
    dist.barrier()

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

def get_args():
    #parser = argparse.ArgumentParser()
    parser = argparse.ArgumentParser('LAPS pre-training script', add_help=False)
    # basics
    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--tag', type=str, default='', help='tag of experiment')

    # distributed training
    #parser.add_argument("--local-rank", type=int, required=True, help='local rank for DistributedDataParallel')
    parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel')
    
    parser.add_argument('-c', '--config-file', required=True, type=str, metavar="FILE", help="path to yaml file")
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--debug_subset_size', type=int, default=8)
    parser.add_argument('--data_dir', type=str, default='../data')
    parser.add_argument('--log_dir', type=str, default='../logs')
    parser.add_argument('--ckpt_dir', type=str, default='../cache/')
    parser.add_argument('--device', type=str, default='cuda'  if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--eval_from', type=str, default=None)
    parser.add_argument('--hide_progress', action='store_true')
    parser.add_argument('--validation', action='store_true', help='Test on the validation set')

    #parser.add_argument('--reg_lambda', type=float, default=0.0, help='regularization hyperparameter')
    parser.add_argument('--run', type=int, default=0, help='run')
    parser.add_argument('--few_shot', action='store_true')
    parser.add_argument('--data_size_per_task', type=int, default=0)
    parser.add_argument('--backbone', type=str, default='resnet18', 
                                choices = [ 
                                            # 'resnet18',
                                            # 'resnet50',
                                            # 'vit-t4',
                                            # 'vit-t8',
                                            # 'mobilenetv2',
                                            # 'swin-t4',
                                            # 'laps_swin',
                                            'simmim',
                                            'simsiam'])
    
    parser.add_argument('--cl_model', type=str, default='finetune',
                                choices = [ 'finetune',
                                            'SI',
                                            'DER',
                                            'PNN',
                                            'UnsupNaive'])
    
    parser.add_argument('-lr', '--init_lr', type=float, default=0.1)
    parser.add_argument('-lrd', '--final_lr_decay', type=float, default=1e-4)
    parser.add_argument('-eps', '--num_epochs', type=int, default=30)
    parser.add_argument('-bs', '--batch_size', type=int, default=128)
    parser.add_argument('-opt', '--optimizer', type=str, default='sgd')
    parser.add_argument('-wd', '--weight_decay', type=float, default=0.0005)
    parser.add_argument('-mm', '--momentum', type=float, default=0.9)
    parser.add_argument('-ltype', '--laps_type', type=str)
    #parser.add_argument('--ret_hyp', type=float, default=100.)
    #parser.add_argument('--additional_anal', action='store_true')
    parser.add_argument('-ropt', '--reinit_opt_per_task', action='store_true')
    # parser.add_argument('--output', type=str)
    parser.add_argument('--resume_path', type=str)
    parser.add_argument('--pretrained', type=str)
    parser.add_argument('--eval_task', type=int)
    parser.add_argument('-lmps', '--laps_mask_patch_size', type=int)
    parser.add_argument('-lhyps', '--laps_hyp', type=float)
    parser.add_argument('-lins', '--laps_inner_steps', type=int)
    parser.add_argument('--wandb', action='store_true')
    args = parser.parse_args()
    
    #config = get_config(args)
    #set_distributed(config)
    
    with open(args.config_file, 'r') as f:
        for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
            vars(args)[key] = value

    if args.debug:
        if args.train:
            args.batch_size = 2
            args.num_epochs = 1
        if args.eval:
            args.eval.batch_size = 2
        args.dataset.num_workers = 0

    assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]
    
    if hasattr(args, 'resume_path'):
        args.model.resume = args.resume_path   
    if hasattr(args, 'eval_task'):
        args.eval.target_task = args.eval_task
    if hasattr(args.train, 'base_lr'):
        args.init_lr = float(args.train.base_lr)
    
    if hasattr(args.train, 'optimizer'):
        args.optimizer = args.train.optimizer
    if hasattr(args.train, 'weight_decay'):
        args.weight_decay = float(args.train.weight_decay)
    if hasattr(args.train, 'momentum'):
        args.momentum = float(args.train.momentum)
        
    if hasattr(args.train, 'warmup_lr'):
        args.train.warmup_lr = float(args.train.warmup_lr)
    if hasattr(args.train, 'min_lr'):
        args.train.min_lr = float(args.train.min_lr)
        lrd = '_mlr%s'%str(args.train.min_lr).replace('.', 'p')
    else:
        lrd = '_lrd%s'%str(args.final_lr_decay).replace('.', 'p')

    if hasattr(args, 'laps_type'):
        args.hyperparameters.LAPS.type = args.laps_type if args.laps_type is not None else args.hyperparameters.LAPS.type
    if hasattr(args, 'laps_mask_patch_size'):
        args.hyperparameters.LAPS.mask_patch_size = args.laps_mask_patch_size if args.laps_mask_patch_size is not None else args.hyperparameters.LAPS.mask_patch_size    
    if hasattr(args, 'laps_hyp'):
        args.hyperparameters.LAPS.att_hyp = args.laps_hyp if args.laps_hyp is not None else args.hyperparameters.LAPS.att_hyp
    if hasattr(args, 'laps_inner_steps'):
        args.hyperparameters.LAPS.inner_steps = args.laps_inner_steps if args.laps_inner_steps is not None else args.hyperparameters.LAPS.inner_steps
    

    if args.cl_model == 'PNN' and args.pnn_base_widths != 64:
        cl_model = '_PNN%d'%args.hyperparameters.PNN.exp_width
    elif args.cl_model == 'SI':
        cl_model = '_SI%s'%str(args.hyperparameters.SI.reg_hyp).replace('.', 'p')    
    elif args.cl_model == 'APD':
        cl_model = '_APD_sp%s_ret%s'%(str(args.hyperparameters.APD.l1_hyp).replace('.', 'p'), str(args.hyperparameters.APD.ret_hyp).replace('.', 'p'))
    elif args.cl_model == 'LAPS':
        cl_model = '_LAPS_%s'%args.hyperparameters.LAPS.type        
    else:
        cl_model = '_%s'%args.cl_model

    if args.tag != '':
        tag = '_'+args.tag
    else:
        tag = args.tag
        
    if args.model.type == 'vit':
        args.dataset.image_size = 224

    if args.wandb:
        args.logger.wandb = True


    # if args.batch_size != 128:
    #     args.name += '_bs%s'%args.batch_size
    #     print('training batch size is changed.')

    # if args.num_epochs != 30:
    #     args.name += '_ep%s'%args.num_epochs
    #     print('the number of training epochs is changed.')

    bs_info = '_lr%s'%str(args.init_lr).replace('.', 'p')
    bs_info += lrd
    opt_info = '_%s_wd%s'%(args.optimizer, str(args.weight_decay).replace('.', 'p'))

    if args.reinit_opt_per_task:
        args.name += '_ropt'

    args.name += cl_model+tag+bs_info+opt_info
    args.name += 'warm%d_ep%d_bs%d'%(args.train.warmup_epochs, args.num_epochs, args.batch_size)

    if args.few_shot:
        args.name += '_fs%s'%args.data_size_per_task
    
    args.name +='_run_'+str(args.run)
    
    args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name)
    args.ckpt_dir = os.path.join(args.ckpt_dir, datetime.now().strftime('%m%d%H%M%S_')+args.name)
    
    os.makedirs(args.log_dir, exist_ok=True)
    if args.local_rank == 0:
        print(f'creating file {args.log_dir}')
    os.makedirs(args.ckpt_dir, exist_ok=True)
    shutil.copy2(args.config_file, args.log_dir)
    set_deterministic(args.seed)

    vars(args)['aug_kwargs'] = {
        'image_size': args.dataset.image_size
    }
    vars(args)['dataset_kwargs'] = {
        'dataset':args.dataset.name,
        'data_dir': args.data_dir,
        'download': True,
        'debug_subset_size': args.debug_subset_size if args.debug else None,
    }
    vars(args)['dataloader_kwargs'] = {
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.dataset.num_workers,
    }
    if args.local_rank == 0:
        print('##########################################################################################################################################')
        print('args-name', args.name)
        print('args-log_dir', args.log_dir)
        print('args-ckpt_dir', args.ckpt_dir)
        print('##########################################################################################################################################')
    return args
