# -*- coding: utf-8 -*-
import os
import argparse
import random
import numbers
import numpy as np
import torch
import math
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from metadatas import *
from metamodels import *

from models.classification_heads import ClassificationHead
from models.classification_heads import ClassificationHead_Mixup
from models.R2D2_embedding import R2D2Embedding
from models.protonet_embedding import ProtoNetEmbedding
from models.ResNet12_embedding import resnet12


from models.WRN_embedding import wrn_28_10

from utils import set_gpu, Timer, count_accuracy, count_accuracy_mixup, check_dir, log

import pdb
import time

def one_hot(indices, depth):
    """
    Returns a one-hot tensor.
    This is a PyTorch equivalent of Tensorflow's tf.one_hot.
        
    Parameters:
      indices:  a (n_batch, m) Tensor or (m) Tensor.
      depth: a scalar. Represents the depth of the one hot dimension.
    Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
    """

    encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
    index = indices.view(indices.size()+torch.Size([1]))
    encoded_indicies = encoded_indicies.scatter_(1,index,1)
    
    return encoded_indicies

def get_model(options):
    # Choose the embedding network
    if options.network == 'ProtoNet':
        network = ProtoNetEmbedding().cuda()
    elif options.network == 'R2D2':
        network = R2D2Embedding().cuda()
        network = torch.nn.DataParallel(network)
    elif options.network == 'R2D2_mixup':
        network = R2D2Embedding_mixup().cuda()
        network = torch.nn.DataParallel(network)
    elif options.network == 'ResNet_mixup':
        network = resnet12_mixup(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
    elif options.network == 'ResNet':
        if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
            network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda()
            network = torch.nn.DataParallel(network)
        else:
            network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
            network = torch.nn.DataParallel(network)
    elif options.network == 'WRN':
        network = wrn_28_10().cuda()
        #network = wrn28_10(64 , loss_type = 'softmax').cuda()
        network = torch.nn.DataParallel(network)
    else:
        print ("Cannot recognize the network type")
        assert(False)
        
    # Choose the classification head
    if options.head == 'ProtoNet':
        cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
    elif options.head == 'Ridge':
        cls_head = ClassificationHead(base_learner='Ridge').cuda()
    elif options.head == 'R2D2':
        #cls_head = ClassificationHead(base_learner='R2D2').cuda()
        cls_head = R2D2Head().cuda() 
    elif options.head == 'SVM':
        cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
    else:
        print ("Cannot recognize the dataset type")
        assert(False)
        
    if options.support_aug and 'mix' in options.support_aug:
        if options.head == 'R2D2':
            cls_head_mixup = R2D2Head_Mixup().cuda()
        elif options.head == 'SVM':
            cls_head_mixup = ClassificationHead_Mixup(base_learner='SVM-CS').cuda()
        else:
            print("Cannot recognize the dataset type")

        return (network, cls_head, cls_head_mixup)
        
    else:
        return (network, cls_head)


def get_datasets(name, phase, args):
    if name == 'miniImageNet':
        dataset = MiniImageNet(phase=phase, augment=args.feat_aug, rot90_p=args.t_p)  
    elif name == 'CIFAR_FS':
        dataset = CIFAR_FS(phase=phase, augment=args.feat_aug, rot90_p=args.t_p)
    elif name == 'FC100':
        dataset = FC100(phase=phase, augment=args.feat_aug, rot90_p=args.t_p)
    else:
        print ("Cannot recognize the dataset type")
        assert(False)
    print(dataset)

    if phase == 'train' or phase == 'final':
        for ta in args.task_aug:
            if ta == 'Rot90':
                dataset = Rot90(dataset, p=args.t_p, batch_size_down=8e4)
                dataset.batch_num.value += args.num_epoch * 0.
            elif ta == 'Mix':
                dataset = TaskAug(dataset, "Mix", p=args.t_p, batch_size_down=8e4)
            elif ta == 'CutMix':
                dataset = TaskAug(dataset, "CutMix", p=args.t_p, batch_size_down=8e4)
            elif ta == 'FMix':
                dataset = TaskAug(dataset, "FMix", p=args.t_p, batch_size_down=8e4)
            elif ta == 'Combine':
                dataset = TaskAug(dataset, "Combine", p=args.t_p, batch_size_down=8e4)
            elif ta == 'DropChannel':
                dataset = DropChannels(dataset, p=args.t_p) 
            elif ta == 'RE':
                dataset = RE(dataset, p=args.t_p) 
            elif ta == 'Solarize':
                dataset = Solarize(dataset, p=args.t_p) 
            else:
                continue
            print(dataset)

    return dataset

def mixup_data(x, y, lam, use_cuda=False):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b

def mixup_criterion(opt, pred, y_a, y_b, lam):
    #loss = 0.
    #for i in range(len(pred)):
    #    loss += lam[i] * criterion(pred[i], y_a[i]) + (1 - lam[i]) * criterion(pred[i], y_b[i])

    #return loss/len(pred)
    n_q = opt.train_way * opt.train_query
    #b = opt.episodes_per_batch
    b = opt.parall_num
    m = opt.m
    pred = pred.reshape(-1,opt.train_way)
    logit = F.log_softmax(pred, dim=-1)
    loss_a = F.nll_loss(logit, y_a.reshape(-1),reduction='none')
    loss_b = F.nll_loss(logit, y_b.reshape(-1),reduction='none')
    loss = loss_a.view(b,m,-1) * lam.view(b,m,1).cuda() + loss_b.view(b,m,-1) * (1 - lam).view(b,m,1).cuda()
    loss, loc = loss.max(dim=1)

    return loss.mean(), loc

def get_worst_case(pred, y_a, y_b, loc, lam, opt):

    loc_p = loc.unsqueeze(-1).repeat(1, 1, opt.train_way)
    pred_t = pred.transpose(0,1).reshape(opt.m, -1)
    ya_t = y_a.transpose(0,1).reshape(opt.m, -1)
    yb_t = y_b.transpose(0,1).reshape(opt.m, -1)
    lam_t = lam.reshape(opt.episodes_per_batch,opt.m)

    pred_worst = pred_t[loc_p.reshape(-1), torch.arange(pred_t.size(1))]
    ya_worst = ya_t[loc.reshape(-1), torch.arange(ya_t.size(1))]
    yb_worst = yb_t[loc.reshape(-1), torch.arange(yb_t.size(1))]
    lam_worst = lam_t[torch.arange(lam_t.size(1)), loc.reshape(-1)]

    pred_worst = pred_worst.reshape(opt.episodes_per_batch,-1,opt.train_way)
    ya_worst = ya_worst.reshape(opt.episodes_per_batch,-1)
    yb_worst = yb_worst.reshape(opt.episodes_per_batch,-1)
    
    return pred_worst, ya_worst, yb_worst, lam_worst.reshape(-1)
     

def self_mix(data):
    size = data.size()
    W = size[-1]
    H = size[-2]
    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    cut_w = W//2
    cut_h = H//2

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    while True:
        bbxn = np.random.randint(0, W-(bbx2-bbx1))
        bbyn = np.random.randint(0, H-(bby2-bby1))

        if bbxn != bbx1 or bbyn != bby1:
            break
    if (bbx2 - bbx1) == (bby2 - bby1):
        k = random.sample([0, 1, 2, 3], 1)[0]
    else:
        k = 0
    data[:, :, bbx1:bbx2, bby1:bby2] = torch.rot90(data[:, :, bbxn:bbxn + (bbx2-bbx1), bbyn:bbyn + (bby2-bby1)], k, [2,3])

    return data


def rand_bbox(size, lam=0.5):
    W = size[-1]
    H = size[-2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
                                dtype=torch.long, device=x.device)
    return x[tuple(indices)]

def build_grid(source_size,target_size):
    k = float(target_size)/float(source_size)
    direct = torch.linspace(-k,k,target_size).unsqueeze(0).repeat(target_size,1).unsqueeze(-1)
    full = torch.cat([direct,direct.transpose(1,0)],dim=2).unsqueeze(0)

    return full.cuda()

def random_crop_grid(x,grid):
    delta = x.size(-1)-grid.size(1)
    grid = grid.repeat(x.size(0),1,1,1).cuda()
    #Add random shifts by x
    grid[:,:,:,0] = grid[:,:,:,0]+ torch.FloatTensor(x.size(0)).cuda().random_(0, delta).unsqueeze(-1).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) /x.size(2)
    #Add random shifts by y
    grid[:,:,:,1] = grid[:,:,:,1]+ torch.FloatTensor(x.size(0)).cuda().random_(0, delta).unsqueeze(-1).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) /x.size(2)

    return grid

def random_cropping(batch, t):
    batch = F.pad(batch.view([-1] + list(batch.shape[-3:])), (4,4,4,4))
    #Building central crop of t pixel size
    grid_source = build_grid(batch.size(-1),t)
    #Make radom shift for each batch
    grid_shifted = random_crop_grid(batch,grid_source)
    #Sample using grid sample
    sampled_batch = F.grid_sample(batch, grid_shifted, mode='nearest')

    return sampled_batch

def combine_labels(data, labels, train_way):
    for i, l in enumerate(range(train_way)):
        new_data = torch.cat((data[labels == l], data[labels == train_way + i]))
        new_order = torch.randperm(len(new_data))[:len(new_data)//2]
        data[labels == l] = new_data[new_order]

    return data[labels < train_way], labels[labels < train_way]


def large_rotation(data, labels, opt):
    for i in range(len(data)):
        for j in range(opt.train_way):
            k = random.randint(0, 3)
            data[i][labels[i] == j] = torch.rot90(data[i][labels[i] == j], k, [2, 3]) 
    
    return data

def random_erase(data):

    aaa = data.view([-1] + list(data.shape[-3:]))
    bbb = torch.zeros_like(aaa)
    for i in range(len(aaa)):
        bbb[i] = random_erase_img(aaa[i], p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0)
    return bbb.view(data.shape)

def get_params(img, scale, ratio, value=0):
    img_c, img_h, img_w = img.shape
    area = img_h * img_w

    for _ in range(10):
        erase_area = random.uniform(scale[0], scale[1]) * area
        aspect_ratio = random.uniform(ratio[0], ratio[1])

        h = int(round(math.sqrt(erase_area * aspect_ratio)))
        w = int(round(math.sqrt(erase_area / aspect_ratio)))

        if h < img_h and w < img_w:
            i = random.randint(0, img_h - h)
            j = random.randint(0, img_w - w)
            if isinstance(value, numbers.Number):
                v = value
            elif isinstance(value, torch._six.string_classes):
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            elif isinstance(value, (list, tuple)):
                v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
            return i, j, h, w, v

    # Return original image
    return 0, 0, img_h, img_w, img

def random_erase_img(img, p, scale, ratio, value):
    """
    Args:
        img (Tensor): Tensor image of size (C, H, W) to be erased.

    Returns:
        img (Tensor): Erased Tensor image.
    """
        
    if random.uniform(0, 1) < p:
        x, y, h, w, v = get_params(img, scale=scale, ratio=ratio, value=value)
        return TF.erase(img, x, y, h, w, v, False)
    return img


def drop_channel(data, opt):
    m = torch.nn.Dropout2d(p=0.5)
    out = m(data.view([-1] + list(data.shape[-3:])))
    
    return out.view(data.shape)


def shot_aug(data_support, labels_support, n_support, method, opt):
    size = data_support.shape
    if method == "fliplr":
        n_support = opt.s_du * n_support
        data_shot = flip(data_support, -1)
        data_support = torch.cat((data_support, data_shot), dim = 1)
        labels_support = torch.cat((labels_support, labels_support), dim = 1)
    elif method == "flip_ver":
        n_support = opt.s_du * n_support
        data_shot = flip(data_support, -2)
        data_support = torch.cat((data_support, data_shot), dim = 1)
        labels_support = torch.cat((labels_support, labels_support), dim = 1)
    elif method == "random_crop":
        n_support = opt.s_du * n_support
        data_shot = random_cropping(data_support, 32)
        #data_support = torch.cat((data_support, data_shot.view([size[0], -1] + list(data_support.shape[-3:]))), dim = 1)
        data_shot2 = random_cropping(data_support, 32)
        data_support = torch.cat((data_shot.view([size[0], -1] + list(data_support.shape[-3:])), data_shot2.view([size[0], -1] + list(data_support.shape[-3:]))), dim = 1)
        labels_support = torch.cat((labels_support, labels_support), dim = 1)

    return data_support, labels_support, n_support

def data_aug(data_support, labels_support, data_query, labels_query, r, method, m, b, opt):
    #b = opt.episodes_per_batch
    b = opt.parall_num
    label_a, label_b = torch.zeros_like(labels_query), torch.zeros_like(labels_query)
    new_data_s = torch.zeros_like(data_support)
    new_data_q = torch.zeros_like(data_query)
    ls = []
    p = np.random.rand(1)

    if method == "qcm":
        for ii in range(b):
            lll = np.random.beta(2., 2.)
            rand_index = torch.randperm(data_query[ii].size()[0]).cuda()
            label_a[ii] = labels_query[ii]
            label_b[ii] = labels_query[ii][rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(data_query[ii].size(), lll)
            new_data_q[ii] = data_query[ii]
            new_data_q[ii][:, :, bbx1:bbx2, bby1:bby2] = data_query[ii][rand_index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lll = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data_query[ii].size()[-1] * data_query[ii].size()[-2]))  
            ls.append(lll)
            new_data_s = data_support
        #new_data, label_a, label_b = map(Variable, (new_data, label_a, label_b))
    elif method == "qre":
        new_data_q = random_erase(data_query)
        new_data_s = data_support
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b
    elif method == "sre":
        new_data_s = random_erase(data_support)
        new_data_q = data_query
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b
    elif method == "tlr":
        for ii in range(b):
            for j in range(opt.train_way):
                k = random.sample([0, 0, 0, 1, 2, 3], 1)[0]
                new_data_s[ii][labels_support[ii] == j] = torch.rot90(data_support[ii][labels_support[ii] == j], k, [2, 3])
                new_data_q[ii][labels_query[ii] == j] = torch.rot90(data_query[ii][labels_query[ii] == j], k, [2, 3])
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b
    elif method == "qsm":
        for ii in range(b):
            new_data_q[ii] = self_mix(data_query[ii])
        new_data_s = data_support
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b
    elif method == "ssm":
        for ii in range(b):
            new_data_s[ii] = self_mix(data_support[ii])
        new_data_q = data_query
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b
    elif method == "qcm + tlr":
        for ii in range(b):
            for j in range(opt.train_way):
                k = random.sample([0, 0, 0, 1, 2, 3], 1)[0]
                new_data_s[ii][labels_support[ii] == j] = torch.rot90(data_support[ii][labels_support[ii] == j], k, [2, 3])
                new_data_q[ii][labels_query[ii] == j] = torch.rot90(data_query[ii][labels_query[ii] == j], k, [2, 3])

            lll = np.random.beta(2., 2.)
            rand_index = torch.randperm(new_data_q[ii].size()[0]).cuda()
            label_a[ii] = labels_query[ii]
            label_b[ii] = labels_query[ii][rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(new_data_q[ii].size(), lll)
            new_data_q[ii][:, :, bbx1:bbx2, bby1:bby2] = new_data_q[ii][rand_index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lll = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (new_data_q[ii].size()[-1] * new_data_q[ii].size()[-2]))
            ls.append(lll)
    elif method == "qsm + tlr":
        for ii in range(b):
            for j in range(opt.train_way):
                k = random.sample([0, 0, 0, 1, 2, 3], 1)[0]
                new_data_s[ii][labels_support[ii] == j] = torch.rot90(data_support[ii][labels_support[ii] == j], k, [2, 3]) 
                new_data_q[ii][labels_query[ii] == j] = torch.rot90(data_query[ii][labels_query[ii] == j], k, [2, 3])   

            new_data_q[ii] = self_mix(data_query[ii])
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b
    elif method == "qre + tlr":
        for ii in range(b):
            for j in range(opt.train_way):
                k = random.sample([0, 0, 0, 1, 2, 3], 1)[0]
                new_data_s[ii][labels_support[ii] == j] = torch.rot90(data_support[ii][labels_support[ii] == j], k, [2, 3]) 
                new_data_q[ii][labels_query[ii] == j] = torch.rot90(data_query[ii][labels_query[ii] == j], k, [2, 3])   

        new_data_q = random_erase(new_data_q)
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b
    elif method == "qcm + qre":
        new_data_q = random_erase(data_query)
        new_data_s = data_support
        for ii in range(b):
            lll = np.random.beta(2., 2.)
            rand_index = torch.randperm(new_data_q[ii].size()[0]).cuda()
            label_a[ii] = labels_query[ii]
            label_b[ii] = labels_query[ii][rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(new_data_q[ii].size(), lll)
            new_data_q[ii][:, :, bbx1:bbx2, bby1:bby2] = new_data_q[ii][rand_index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lll = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (new_data_q[ii].size()[-1] * new_data_q[ii].size()[-2]))
            ls.append(lll)
    elif method == "qcm + qsm":
        new_data_s = data_support
        for ii in range(b):
            lll = np.random.beta(2., 2.)
            rand_index = torch.randperm(data_query[ii].size()[0]).cuda()
            label_a[ii] = labels_query[ii]
            label_b[ii] = labels_query[ii][rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(data_query[ii].size(), lll)
            new_data_q[ii][:, :, bbx1:bbx2, bby1:bby2] = data_query[ii][rand_index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lll = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (new_data_q[ii].size()[-1] * new_data_q[ii].size()[-2]))
            ls.append(lll)
            new_data_q[ii] = self_mix(new_data_q[ii])
    elif method == "qcm + ssm":
        for ii in range(b):
            new_data_s[ii] = self_mix(data_support[ii])
            lll = np.random.beta(2., 2.)
            rand_index = torch.randperm(data_query[ii].size()[0]).cuda()
            label_a[ii] = labels_query[ii]
            label_b[ii] = labels_query[ii][rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(data_query[ii].size(), lll)
            new_data_q[ii][:, :, bbx1:bbx2, bby1:bby2] = data_query[ii][rand_index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lll = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (new_data_q[ii].size()[-1] * new_data_q[ii].size()[-2]))
            ls.append(lll)
    #elif method == "sre + qre":
    #    new_data_q = random_erase(data_query)
    #    new_data_s = random_erase(data_support)
    #    label_a, label_b = labels_query, labels_query
    #    ls = [1.] * b
    elif method == "qcm + sre":
        for ii in range(b):
            lll = np.random.beta(2., 2.)
            rand_index = torch.randperm(data_query[ii].size()[0]).cuda()
            label_a[ii] = labels_query[ii]
            label_b[ii] = labels_query[ii][rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(data_query[ii].size(), lll)
            new_data_q[ii] = data_query[ii]
            new_data_q[ii][:, :, bbx1:bbx2, bby1:bby2] = data_query[ii][rand_index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lll = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data_query[ii].size()[-1] * data_query[ii].size()[-2]))
            ls.append(lll)
        new_data_s = random_erase(data_support)
    else:
        new_data_s = data_support
        new_data_q = data_query
        label_a, label_b = labels_query, labels_query
        ls = [1.] * b

    return new_data_s, labels_support, new_data_q, label_a, label_b, torch.tensor(ls)


def data_aug_maxup(data_support, labels_support, data_query, labels_query, r, methods, opt):
    train_n_query = opt.train_way * opt.train_query
    train_n_support = opt.train_way * opt.train_shot * opt.s_du
    b = opt.parall_num
    #b = opt.episodes_per_batch
    n_max = opt.m
    label_a = torch.zeros(n_max, b, train_n_query).long()
    label_b = torch.zeros(n_max, b, train_n_query).long()
    label_s = torch.zeros(n_max, b, train_n_support).long()
    new_data_q = torch.zeros([n_max, b, train_n_query] + list(data_query.shape[-3:]))
    new_data_s = torch.zeros([n_max, b, train_n_support] + list(data_support.shape[-3:]))
    ls = torch.zeros(n_max, b)

    for m in range(n_max):
        method = random.sample(methods, 1)[0]
        new_data_s[m], label_s[m], new_data_q[m], label_a[m], label_b[m], ls[m] = data_aug(data_support, labels_support, data_query, labels_query, r, method, m, b, opt)

    new_data_s.transpose_(0,1)
    new_data_q.transpose_(0,1)
    label_s.transpose_(0,1)
    label_a.transpose_(0,1)
    label_b.transpose_(0,1)
    ls.transpose_(0,1)

    return new_data_s, label_s, new_data_q, label_a, label_b, ls.reshape(-1)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num-epoch', type=int, default=60,
                            help='number of training epochs')
    parser.add_argument('--save-epoch', type=int, default=10,
                            help='frequency of model saving')
    parser.add_argument('--train-shot', type=int, default=5,
                            help='number of support examples per training class')
    parser.add_argument('--val-shot', type=int, default=5,
                            help='number of support examples per validation class')
    parser.add_argument('--train-query', type=int, default=6,
                            help='number of query examples per training class')
    parser.add_argument('--val-episode', type=int, default=2000,
                            help='number of episodes per validation')
    parser.add_argument('--val-query', type=int, default=15,
                            help='number of query examples per validation class')
    parser.add_argument('--train-way', type=int, default=5,
                            help='number of classes in one training episode')
    parser.add_argument('--sample-way', type=int, default=5,
                            help='number of classes in one training episode')
    parser.add_argument('--test-way', type=int, default=5,
                            help='number of classes in one test (or validation) episode')
    parser.add_argument('--save-path', default='./experiments/exp_1')
    parser.add_argument('--gpu', default='0, 1, 2, 3')
    parser.add_argument('--network', type=str, default='ProtoNet',
                            help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
    parser.add_argument('--head', type=str, default='ProtoNet',
                            help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM')
    parser.add_argument('--dataset', type=str, default='miniImageNet',
                            help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
    parser.add_argument('--episodes-per-batch', type=int, default=8,
                            help='number of episodes per batch')
    parser.add_argument('--num-per-batch', type=int, default=1000,
                            help='number of episodes per batch')
    parser.add_argument('--m', type=int, default=4,
                            help='number of episodes per batch')
    parser.add_argument('--parall_num', type=int, default=8,
                            help='number of episodes per batch')
    parser.add_argument('--eps', type=float, default=0.0,
                            help='epsilon of label smoothing')

    parser.add_argument('--load', default=None,
                            help='path of the checkpoint file')
    parser.add_argument('--pretrain', type=str, default=None,
                            help='path of the checkpoint file')

    ## Data Augmentation
    parser.add_argument('--feat_aug', '-faug', default='norm', type=str,
                        help='If use feature level augmentation.')
    parser.add_argument('--task_aug', '-taug', default=[], nargs='+', type=str,
                        help='If use task level data augmentation.')
    parser.add_argument('--support_aug', '-saug', default=None, type=str,
                        help='If use support level data augmentation.')
    parser.add_argument('--shot_aug', '-shotaug', default=[], nargs='+', type=str,
                        help='If use shot level data augmentation.')
    parser.add_argument('--query_aug', '-qaug', default=None, type=str,
                        help='If use query level data augmentation.')
    parser.add_argument('--t_p', '-tp', default=1, type=float,
                        help='The possibility of sampling categories or images with rot90.')
    parser.add_argument('--s_p', '-sp', default=1, type=float,
                        help='The possibility of sampling categories or images with rot90.')
    parser.add_argument('--s_du', '-sdu', default=1, type=int,
                        help='The possibility of sampling categories or images with rot90.')
    parser.add_argument('--q_p', '-qp', default=1, type=float,
                        help='The possibility of sampling categories or images with rot90.')
    parser.add_argument('--rot_degree', default=30, type=int,
                        help='Degree for random rotation.')

    opt = parser.parse_args()
    
    trainset = get_datasets(opt.dataset, 'train', opt)
    valset = get_datasets(opt.dataset, 'val', opt)
  
    epoch_size = opt.episodes_per_batch * opt.num_per_batch

    dloader_train = FewShotDataloader(trainset, kway=opt.sample_way, kshot=opt.train_shot, kquery=opt.train_query,
                                    batch_size=opt.episodes_per_batch, num_workers=4, epoch_size=epoch_size, shuffle=True)
    dloader_val = FewShotDataloader(valset, kway=opt.train_way, kshot=opt.val_shot, kquery=opt.val_query,
                                  batch_size=1, num_workers=1, epoch_size=2000, shuffle=False, fixed=False)

    set_gpu(opt.gpu)
    check_dir('./experiments/')
    check_dir(opt.save_path)
    
    log_file_path = os.path.join(opt.save_path, "train_log.txt")
    log(log_file_path, str(vars(opt)))

    if opt.support_aug and "mix" in opt.support_aug:
        (embedding_net, cls_head, cls_head_mixup) = get_model(opt)
        embedding_net.cuda()
        cls_head.cuda()
        cls_head_mixup.cuda()
    else:
        (embedding_net, cls_head) = get_model(opt)
        embedding_net.cuda()
        cls_head.cuda()

    # Load saved model checkpoints
    if opt.pretrain is not None:
        saved_models = torch.load(opt.pretrain)
        embedding_net.load_state_dict(saved_models['state'])
        #cls_head.load_state_dict(saved_models['head'])

    if opt.load:
        # Load saved model checkpoints
        saved_models = torch.load(opt.load)
        embedding_net.load_state_dict(saved_models['embedding'])
        cls_head.load_state_dict(saved_models['head'])
    
    optimizer = torch.optim.SGD([{'params': embedding_net.parameters()}, 
                                 {'params': cls_head.parameters()}], lr=0.1, momentum=0.9, \
                                          weight_decay=5e-4, nesterov=True)
    
    lambda_epoch = lambda e: 1.0 if e < 20 else (0.06 if e < 40 else 0.012 if e < 50 else (0.0024 if e < 60 else (0.001)))
    #lambda_epoch = lambda e: 0.06 if e < 10 else (0.012 if e < 20 else 0.0024 if e < 30 else (0.0024 if e < 60 else (0.001)))
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch, last_epoch=-1)

    max_val_acc = 0.0

    timer = Timer()
    x_entropy = torch.nn.CrossEntropyLoss()

    da_pool = ['qcm', 'qre', 'sre', 'tlr', 'qcm + tlr', 'qre + tlr']
    #da_pool = ['qcm', 'qre', 'sre', 'tlr', 'qcm + tlr', 'qre + tlr', 'qsm', 'ssm', "qcm + qsm", "qcm + ssm"]
    #da_pool = ['qcm', 'qre', 'sre', 'tlr', 'qcm + tlr', 'qre + tlr', 'qre + sre', 'qcm + qre', 'qcm + sre', 'qsm + tlr', 'qsm', 'ssm', "qcm + qsm", "qcm + ssm"]
    
    for epoch in range(1, opt.num_epoch + 1):
        # Train on the training split
        
        # Fetch the current epoch's learning rate
        epoch_learning_rate = 0.1
        for param_group in optimizer.param_groups:
            epoch_learning_rate = param_group['lr']
            
        log(log_file_path, 'Train Epoch: {}\tLearning Rate: {:.4f}'.format(
                            epoch, epoch_learning_rate))
        
        _, _ = [x.train() for x in (embedding_net, cls_head)]
        
        train_accuracies = []
        train_s_accuracies = []
        train_losses = []

        for i, batch in enumerate(dloader_train(epoch), 1):
            optimizer.zero_grad()
            batch_size = len(batch[0])
            j = 0
            loss_sum = 0.
            while j < batch_size:
                if "random_crop" in opt.shot_aug:
                    data_support, labels_support, _, data_query, labels_query, _ = [x[j:j + opt.parall_num].cuda() for x in batch]
                else:
                    data_support, labels_support, _, data_query, labels_query, _ = [x[j:j + opt.parall_num] for x in batch]

                train_n_support = opt.train_way * opt.train_shot 
                train_n_query = opt.train_way * opt.train_query 
                rs, rq = 0., 0.

                ## data augmentation for shots (increasing num of shots for support)
                for shot_method in opt.shot_aug:
                    data_support, labels_support, train_n_support = shot_aug(data_support, labels_support, train_n_support, shot_method, opt)

                new_data_s, label_s, new_data_q, label_a, label_b, ls = data_aug_maxup(data_support, labels_support, data_query, labels_query, opt.q_p, da_pool, opt)

                new_data_s, new_data_q, label_a, label_b, ls, labels_support = new_data_s.cuda(), new_data_q.cuda(), label_a.cuda(), label_b.cuda(), ls.cuda(), labels_support.cuda()
                label_s = label_s.cuda()


                ## get embedding
                emb_support = embedding_net(new_data_s.reshape([-1] + list(new_data_s.shape[-3:])))
                #emb_support = emb_support.reshape(opt.episodes_per_batch*opt.m, train_n_support, -1)
                emb_support = emb_support.reshape(opt.parall_num*opt.m, train_n_support, -1)
            
                emb_query = embedding_net(new_data_q.reshape([-1] + list(new_data_q.shape[-3:])))
                #emb_query = emb_query.reshape(opt.episodes_per_batch*opt.m, train_n_query, -1)
                emb_query = emb_query.reshape(opt.parall_num*opt.m, train_n_query, -1)
            
                ## get logits for query embedding
                logit_query = cls_head(emb_query, emb_support, label_s, opt.train_way, opt.train_shot)
                ## get loss for the outer loop
                loss, loc = mixup_criterion(opt, logit_query.reshape(opt.parall_num, opt.m*train_n_query, -1), label_a, label_b, ls)
                loss *= float(opt.parall_num) / batch_size
                loss.backward(retain_graph=False)

                ## pick the worst case
                #logit_w, l_a_w, l_b_w, lam_w = get_worst_case(logit_query, label_query_a, label_query_b, loc, lqs, opt)
                acc = count_accuracy_mixup(logit_query, label_a.reshape(opt.parall_num * opt.m, -1), label_b.reshape(opt.parall_num * opt.m, -1), ls)
                #acc = count_accuracy_mixup(logit_query.view(opt.episodes_per_batch * opt.m, train_n_query, -1), label_a.view(opt.episodes_per_batch * opt.m, -1), label_b.view(opt.episodes_per_batch * opt.m, -1), qs)
                j += opt.parall_num
                loss_sum += loss.data
            
                ## get accuracies
                train_accuracies.append(acc.item())
                #train_losses.append(loss.item())

            if (i % 100 == 0):
                train_acc_avg = np.mean(np.array(train_accuracies))
                train_acc_avg_s = np.mean(np.array(train_s_accuracies))
                log(log_file_path, 'Train Epoch: {}\tBatch: [{}/{}]\tLoss: {:.4f}\tAccuracy: {:.2f} % ({:.2f} %)'.format(
                            #epoch, i, len(dloader_train), loss.item(), train_acc_avg, acc))
                            epoch, i, len(dloader_train), loss_sum, train_acc_avg, acc))
            
            #optimizer.zero_grad()
            #loss.backward()
            optimizer.step()

        # Evaluate on the validation split
        _, _ = [x.eval() for x in (embedding_net, cls_head)]

        val_accuracies = []
        val_losses = []
        
        for i, batch in enumerate(dloader_val(epoch), 1):
            #data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
            data_support, labels_support, _, data_query, labels_query, _ = [x.cuda() for x in batch]

            test_n_support = opt.test_way * opt.val_shot
            test_n_query = opt.test_way * opt.val_query
 
            for method in opt.shot_aug:
                data_support, labels_support, test_n_support = shot_aug(data_support, labels_support, test_n_support, method, opt)

            emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
            emb_support = emb_support.reshape(1, test_n_support, -1)
            emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
            emb_query = emb_query.reshape(1, test_n_query, -1)

            logit_query = cls_head(emb_query, emb_support, labels_support, opt.test_way, opt.val_shot)[0]

            loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
            acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))

            val_accuracies.append(acc.item())
            val_losses.append(loss.item())
            
        val_acc_avg = np.mean(np.array(val_accuracies))
        val_acc_ci95 = 1.96 * np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode)

        val_loss_avg = np.mean(np.array(val_losses))

        lr_scheduler.step()

        if val_acc_avg > max_val_acc:
            max_val_acc = val_acc_avg
            torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()},\
                       os.path.join(opt.save_path, 'best_model.pth'))
            log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)'\
                  .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
        else:
            log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %'\
                  .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))

        torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
                   , os.path.join(opt.save_path, 'last_epoch.pth'))

        if epoch % opt.save_epoch == 0:
            torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
                       , os.path.join(opt.save_path, 'epoch_{}.pth'.format(epoch)))

        log(log_file_path, 'Elapsed Time: {}/{}\n'.format(timer.measure(), timer.measure(epoch / float(opt.num_epoch))))
