import torch
import numpy as np
import random
import visual.resnet as resnet
import torch
import numpy as np
import random
import os
import time
import argparse

import visual.resnet as resnet
import scipy as sp
import scipy.stats
import torch.optim as optim

model_dict = dict(
    ResNet12=resnet.ResNet12
)

def prepare_optimizer(model, args):
    if args.optim == "Adam":
        optimizer = torch.optim.Adam(params=model.parameters(),
                                     lr=args.lr)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum= 0.9 ,
                              weight_decay=1e-4)

    if args.lr_scheduler == 'step':
        lr_scheduler = optim.lr_scheduler.StepLR(
                            optimizer,
                            step_size=int(args.step_size),
                            gamma=0.1
                        )
    elif args.lr_scheduler == 'multistep':
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
                            optimizer,
                            milestones=[int(_) for _ in args.step_size.split(',')],
                            gamma=0.1,
                        )
    elif args.lr_scheduler == 'cosine':
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer,
                            T_max=args.epoch,
                            eta_min=0 ,
                            last_epoch=-1,  
                        )
    else:
        raise ValueError('No Such Scheduler')

    return optimizer, lr_scheduler


def mean_confidence_interval(data, confidence=0.95):
    """

    :param data:
    :param confidence:
    :return:
    """
    a = [1.0 * np.array(data[i]) for i in range(len(data))]
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * sp.stats.t._ppf((1 + confidence) / 2.0, n - 1)
    return m, h

def set_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def load_pretrain_model(model, pretrain_model_file):
    ckpt = torch.load(pretrain_model_file)['state']
    model_dict = {k[8:]: v for k, v in ckpt.items() if 'feature' in k}
    model.feature.load_state_dict(model_dict)

def load_pretrain_model_tie(model, pretrain_model_file):
    ckpt = torch.load(pretrain_model_file)['params']
    model_dict = {k[8:]: v for k, v in ckpt.items() if 'encoder' in k}
    model.feature.load_state_dict(model_dict)

def one_hot(y, num_class):
    return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1)

def load_model(model, model_file):
    ckpt = torch.load(model_file)['state']
    model_dict = {k: v for k, v in ckpt.items()}
    model.load_state_dict(model_dict)

def setup_run():
    params = parse_parms()
    if params.model_dir == '':
        params.model_dir  = os.path.join(params.checkpoint_dir, time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(time.time())))
    if not os.path.isdir(params.model_dir):
        os.makedirs(params.model_dir)
    print('save_dir'+ str(params.model_dir))

    torch.set_printoptions(linewidth=100)
    params.num_gpu = set_gpu(params)
    params.device_ids = None if params.gpu == '-1' else list(range(params.num_gpu))
    
    if params.dataset == 'miniImageNet':
        params.num_class = 64
    elif params.dataset == 'CUB':
        params.num_class = 100
    elif params.dataset == 'tieredImageNet':
        params.num_class = 351
    elif params.dataset == 'FC100':
        params.num_class = 60
    else:
        ValueError('dataset error')

    return params

def set_gpu(args):
    if args.gpu == '-1':
        gpu_list = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
    else:
        gpu_list = [int(x) for x in args.gpu.split(',')]
        print('use gpu:', gpu_list)
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    return gpu_list.__len__()

def parse_parms():
    parser = argparse.ArgumentParser()
    parser.add_argument('-backbone', type=str, default='ResNet12')

    # dataset
    parser.add_argument('-image_size', type=int, default=84)
    parser.add_argument('-dataset', default='miniImageNet', choices=['miniImageNet', 'tieredImageNet', 'CUB'])
    parser.add_argument('-data_dir', type=str, default='../datasets/miniImageNet')
    parser.add_argument('-seed', type=int, default=1)
    parser.add_argument('-epoch', type=int, default=80)
   
    # few shot setting
    parser.add_argument('-train_n_episode', type=int, default=300)
    parser.add_argument('-val_n_episode', type=int, default=200)
    parser.add_argument('-n_shot', type=int, default=1)
    parser.add_argument('-n_way', type=int, default=5)
    parser.add_argument('-n_query', type=int, default=16)
    parser.add_argument('-save_freq', type=int, default=50)

    # MAML  
    parser.add_argument('-update_num', type=int, default=25)
    parser.add_argument('-update_lr', type=float, default=0.5)
    parser.add_argument('-temperature_inner', type=float, default=1)
    parser.add_argument('-temperature_outer', type=float, default=0.1)

    # optimization parameters
    parser.add_argument('-optim', type=str, choices=["Adam", "SGD"], default="SGD")
    parser.add_argument('-lr_scheduler', type=str, default='cosine', choices=['multistep', 'step', 'cosine'])
    parser.add_argument('-gamma', type=float, default=0.05)
    parser.add_argument('-lr', type=float, default=1e-3)
    parser.add_argument('-lr_mul', type=float, default=10)    
    parser.add_argument('-step_size', type=str, default='20')
    parser.add_argument('-train_aug', action='store_true')
    # dir
    parser.add_argument('-checkpoint_dir', type=str, default='./checkpoint')
    parser.add_argument('-pretrain_model_dir', type=str, default='./model_pt')
    parser.add_argument('--model_name', type=str, default='val_best_model.tar')
    parser.add_argument('-model_dir', type=str, default='')
    # gpu
    parser.add_argument('-gpu', default='0')
    # margin
    parser.add_argument('-margin', type=float, default=0)
    params = parser.parse_args()

    return params