from __future__ import absolute_import, print_function
from copy import deepcopy
import argparse
import torch.utils.data
from torch.backends import cudnn
from torch.autograd import Variable
import torchvision.models as models
import losses
from utils import RandomIdentitySampler, mkdir_if_missing, logging
from torch.optim.lr_scheduler import StepLR
from ImageFolder import *
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
import collections
import random
import pandas as pd
import warnings

# Disable PyTorch warnings
warnings.filterwarnings("ignore")

cudnn.benchmark = True


class DeepInversionFeatureHook():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        # hook co compute deepinversion's feature distribution regularization
        nch = input[0].shape[1]

        mean = input[0].mean([0, 2, 3])
        var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)

        # forcing mean and variance to match between two distributions
        r_feature = torch.norm(module.running_var.data.type(var.type()) - var, 2) + torch.norm(
            module.running_mean.data.type(var.type()) - mean, 2)

        self.r_feature = r_feature
        # must have no output

    def close(self):
        self.hook.remove()


def get_image_prior_losses(inputs_jit):
    # COMPUTE total variation regularization loss
    diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
    diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
    diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
    diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]

    loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
    loss_var_l1 = (diff1.abs() / 255.0).mean() + (diff2.abs() / 255.0).mean() + (
            diff3.abs() / 255.0).mean() + (diff4.abs() / 255.0).mean()
    loss_var_l1 = loss_var_l1 * 255.0
    return loss_var_l1, loss_var_l2


def scatter_loss(inputs, num_instances, bs):
    num_class_in_minibatch = int(bs / num_instances)

    # Compute pairwise distance1
    A_idx = torch.LongTensor(range(0, num_instances)).cuda()
    input1 = inputs.index_select(0, A_idx)
    mean1 = torch.mean(input1, dim=0)
    input1_centered = input1 - mean1
    dist1 = input1_centered.T @ input1_centered

    # Compute pairwise distance2
    A_idx = torch.LongTensor(range(num_instances, (2 * num_instances))).cuda()
    input2 = inputs.index_select(0, A_idx)
    mean2 = torch.mean(input2, dim=0)
    input2_centered = input2 - mean2
    dist2 = input2_centered.T @ input2_centered

    means = torch.stack([mean1, mean2], dim=0)
    n5 = means.size(0)
    dist_mean = torch.pow(means, 2).sum(dim=1, keepdim=True).expand(n5, n5)
    dist_mean = dist_mean + dist_mean.t()
    dist_mean.addmm_(means, means.t(), beta=1, alpha=-2)
    dist_mean = dist_mean.clamp(min=1e-12).sqrt()  # for numerical stability
    dist3 = []
    dist4 = []
    input3 = []
    input4 = []
    mean3 = []
    mean4 = []
    means = torch.stack([mean1, mean2], dim=0)
    mean_t = torch.mean(means, dim=0)
    if num_class_in_minibatch == 4:
        # Compute pairwise distance3
        A_idx = torch.LongTensor(range(2 * num_instances, (3 * num_instances))).cuda()
        input3 = inputs.index_select(0, A_idx)
        mean3 = torch.mean(input3, dim=0)
        input3_centered = input3 - mean3
        dist3 = input3_centered.T @ input3_centered

        # Compute pairwise distance4
        A_idx = torch.LongTensor(range(3 * num_instances, (4 * num_instances))).cuda()
        input4 = inputs.index_select(0, A_idx)
        mean4 = torch.mean(input4, dim=0)
        input4_centered = input4 - mean4
        dist4 = input4_centered.T @ input4_centered

        means = torch.stack([mean1, mean2, mean3, mean4], dim=0)
        mean_t = torch.mean(means, dim=0)
        n5 = means.size(0)
        dist_mean = torch.pow(means, 2).sum(dim=1, keepdim=True).expand(n5, n5)
        dist_mean = dist_mean + dist_mean.t()
        dist_mean.addmm_(means, means.t(), beta=1, alpha=-2)
        dist_mean = dist_mean.clamp(min=1e-12).sqrt()  # for numerical stability

    elif num_class_in_minibatch != 2:
        print("error: number of classes in each mini-batch must  be 2 or 4")

    return dist1, dist2, dist3, dist4, dist_mean, input1, input2, input3, input4, mean1, mean2, mean3, mean4, mean_t


def extract_features(model, data_loader):
    model = model.cuda()
    model.eval()

    features = []
    labels = []

    for i, data in enumerate(data_loader, 0):
        imgs, pids = data

        inputs = imgs.cuda()
        with torch.no_grad():
            outputs = model(inputs)
            outputs = torch.squeeze(outputs)
            outputs = F.normalize(outputs, p=2, dim=1)
            outputs = outputs.cpu().numpy()

        if features == []:
            features = outputs
            labels = pids
        else:
            features = np.vstack((features, outputs))
            labels = np.hstack((labels, pids))

    return features, labels


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model


def initial_train_fun(args, trainloader, dataset_sizes_train, num_class, dictlist):
    model = models.resnet18(pretrained=True)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_class)
    model = model.cuda()
    best_model_wts = deepcopy(model.state_dict())
    log_dir = os.path.join('checkpoints', args.log_dir)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_0, momentum=args.momentum_0,
                                weight_decay=args.weight_decay_0)
    exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.lr_step)

    for epoch in range(args.epochs_0):
        print(f'Epoch {epoch}/{args.epochs_0 - 1}')
        print('-' * 50)

        model.train()  # Set model to training mode
        dataloaders = deepcopy(trainloader)
        dataset_sizes = dataset_sizes_train
        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        for inputs, labels in dataloaders:
            inputs = inputs
            labels = labels
            labels_np = labels.numpy()

            for ii in range(len(labels_np)):
                labels_np[ii] = dictlist[labels_np[ii]]

            labels = labels.type(torch.LongTensor)
            inputs = Variable(inputs.cuda())
            labels = Variable(labels).cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # track history if only in train
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                exp_lr_scheduler.step()

    best_model_wts = deepcopy(model.state_dict())
    model.load_state_dict(best_model_wts)
    model = nn.Sequential(*list(model.children())[:-1])

    state = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(state,
               os.path.join(log_dir, args.method + '_task_' + str(0) + '_%d_model.pt' % args.epochs))


def train_fun(args, train_loader, current_task, old_labels, current_labels):
    log_dir = os.path.join('checkpoints', args.log_dir)
    mkdir_if_missing(log_dir)
    sys.stdout = logging.Logger(os.path.join(log_dir, 'log.txt'))
    num_classes_in_minibatch = int(args.BatchSize / args.num_instances)

    if current_task == 0:

        if args.method != 'MLRPTM':
            if args.data == 'cifar100':
                model_res = models.resnet18(pretrained=True)
            else:
                model_res = models.resnet18(pretrained=False)

            model = nn.Sequential(*list(model_res.children())[:-1])

        else:
            model_res = models.resnet18(pretrained=False)
            model = nn.Sequential(*list(model_res.children())[:-1])
            state = torch.load(os.path.join(log_dir, args.method +
                                            '_task_' + str(current_task) + '_%d_model.pt' % int(args.epochs)))
            model.load_state_dict(state['state_dict'])

    if current_task > 0:

        model_res = models.resnet18(pretrained=False)
        model = nn.Sequential(*list(model_res.children())[:-1])
        state1 = torch.load(os.path.join(log_dir, args.method +
                                         '_task_' + str(current_task - 1) + '_%d_model.pt' % int(args.epochs)))
        model.load_state_dict(state1['state_dict'])
        if args.method != 'Fine_tuning' and args.method != 'NoisyFine_tuning':
            model_old = deepcopy(model)
            model_old.eval()
            model_old = freeze_model(model_old)
            model_old = model_old.cuda()
            model_gen = deepcopy(model)
            model_gen.eval()

    model = model.cuda()
    model.eval()

    criterion = losses.create(args.loss_m, margin=args.margin, num_instances=args.num_instances).cuda()
    criterion_task = losses.create(args.loss_confusion).cuda()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=200, gamma=0.1)

    if current_task > 0 and args.method == 'MLRPTM':

        filename_csv = f'means_{args.data}.csv'

        # Read means CSV file
        means_df = pd.read_csv(filename_csv)

        # Randomly choose ten rows from means_df
        random_rows = np.random.choice(len(means_df), size=min(10, len(means_df)), replace=False)
        selected_means = means_df.iloc[random_rows].values
        means_df = means_df.drop(random_rows)
        # Save the modified dataframes back to the CSV files
        means_df.to_csv(filename_csv, index=False)

        loss_r_feature_layers = []
        for module in model_gen.modules():
            if isinstance(module, nn.BatchNorm2d):
                loss_r_feature_layers.append(DeepInversionFeatureHook(module))
    all_gen_input = []
    all_gen_labels = []
    if args.method != 'MLRPTM' or current_task == 0:
        inputs_dummy = torch.randn((len(train_loader) * args.batch_size_gen, 3, 8, 8),
                                   requires_grad=True,
                                   device='cuda', dtype=torch.float)
        targets_dummy = torch.LongTensor([0] * len(train_loader) * args.batch_size_gen).to('cuda')

        trainset_dummy = torch.utils.data.TensorDataset(inputs_dummy, targets_dummy)
        train_loader_synth = torch.utils.data.DataLoader(trainset_dummy, batch_size=args.batch_size_gen, drop_last=True,
                                                         num_workers=args.nThreads)
    for epoch in range(args.start, args.epochs + 1):

        running_loss = 0.0

        if epoch == 0 and current_task > 0 and args.method == 'MLRPTM':
            print(50 * '#')
            print("Synthetic data generating...")

            for ii in range(len(train_loader)):

                data_type = torch.float
                inputs_d = torch.randn((args.batch_size_gen, 3, args.resolution, args.resolution),
                                       requires_grad=True,
                                       device='cuda', dtype=data_type)
                optimizer_gen = torch.optim.Adam([inputs_d], lr=args.lr_gen)
                optimizer_gen.state = collections.defaultdict(dict)
                lim_0, lim_1 = 6, 6
                prev_label = old_labels[current_task - 1]
                prev_label = list(set(prev_label))
                if num_classes_in_minibatch == 2:
                    targets = torch.LongTensor([prev_label[0]] * int(args.batch_size_gen / 2)
                                               + [prev_label[1]] * int(args.batch_size_gen / 2)).to('cuda')
                if num_classes_in_minibatch == 4:
                    prev_label = sorted(list(set(prev_label)))
                    tr1 = random.sample(prev_label, k=4)
                    targets = torch.LongTensor(
                        [tr1[0]] * int(args.batch_size_gen / 4) + [tr1[1]] * int(args.batch_size_gen / 4)
                        + [tr1[2]] * int(args.batch_size_gen / 4) + [tr1[3]] * int(args.batch_size_gen / 4)).to(
                        'cuda')
                elif num_classes_in_minibatch != 2:
                    print("error: number of classes in each mini-batch must  be 2 or 4")

                for epoch2 in range(args.epoch_gen):
                    off1 = random.randint(-lim_0, lim_0)
                    off2 = random.randint(-lim_1, lim_1)
                    inputs_jit = torch.roll(inputs_d, shifts=(off1, off2), dims=(2, 3))

                    # foward with jit images
                    optimizer_gen.zero_grad()
                    model_gen.zero_grad()
                    model_gen = model_gen.cuda()
                    model_gen.eval()

                    embed_feat_gen = model_gen(inputs_jit)
                    embed_feat_gen = torch.squeeze(embed_feat_gen)
                    embed_feat_normal_gen = F.normalize(embed_feat_gen, p=2, dim=1)

                    # Triplet loss
                    loss_gen, _, _, _ = criterion(embed_feat_normal_gen, targets)

                    # R_prior losses
                    loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit)

                    # R_feature loss
                    rescale = [args.first_bn_mul] + [1. for _ in range(len(loss_r_feature_layers) - 1)]
                    loss_r_feature = sum(
                        [mod.r_feature * rescale[idx] for (idx, mod) in enumerate(loss_r_feature_layers)])

                    # l2 loss on images
                    loss_l2 = torch.norm(inputs_jit.view(args.BatchSize, -1), dim=1).mean()

                    # combining losses
                    loss_aux = args.tv_l2 * loss_var_l2 + \
                               args.tv_l1 * loss_var_l1 + \
                               args.bn_reg_scale * loss_r_feature + \
                               args.l2 * loss_l2

                    loss_gen = args.main_mul * loss_gen + loss_aux

                    loss_gen.backward()
                    optimizer_gen.step()

                all_gen_input.append(inputs_d)
                all_gen_labels.append(targets)

            ge_input = torch.cat(all_gen_input)
            ge_label = torch.cat(all_gen_labels)
            trainset = torch.utils.data.TensorDataset(ge_input, ge_label)
            train_loader_synth = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_gen, drop_last=True,
                                                             num_workers=args.nThreads)

        for jj, (img, img1) in enumerate(zip(train_loader, train_loader_synth), 0):

            if current_task > 0 and args.method == 'MLRPTM':
                generated_negative_samples = []
                for iii in range(len(selected_means)):
                    mean = selected_means[iii]
                    generated_negative_samples.append(mean)
            else:
                generated_negative_samples = []

            inputs, labels = img
            inputs_synth, labels_synth = img1
            inputs_synth = Variable(inputs_synth.cuda())

            # wrap them in Variable
            inputs = Variable(inputs.cuda())
            labels = Variable(labels).cuda()
            optimizer.zero_grad()
            embed_feat = model(inputs)
            embed_feat = torch.squeeze(embed_feat)
            embed_feat_normalize = F.normalize(embed_feat.clone(), p=2, dim=1)
            if args.method == 'MLRPTM' or args.method == 'NoisyFine_tuning':
                embed_feat_noisy = embed_feat_normalize.clone()
                random_perm2 = np.random.permutation(args.BatchSize)
                sample_index = random_perm2[:int((3 * args.BatchSize) / 4)]
                for j in range(0, int((3 * args.BatchSize) / 4)):
                    embed_feat_noisy[sample_index[j], :] = embed_feat_noisy[sample_index[j], :] \
                                                           + args.Noise_Power * torch.randn(1, args.dim, device='cuda')

            if current_task == 0:
                loss_aug = 0 * torch.sum(embed_feat)
            elif current_task > 0:

                if args.method == 'Fine_tuning':
                    loss_aug = 0 * torch.sum(embed_feat)

                elif args.method == 'MLRPTM':
                    loss_aug = 0 * torch.sum(embed_feat)
                    loss_inter_class = 0.0
                    loss_intra_class = 0.0
                    loss_intra_cluster1 = 0.0
                    loss_intra_cluster2 = 0.0
                    loss_intra_cluster3 = 0.0
                    loss_intra_cluster4 = 0.0
                    loss_inter_class1 = 0.0
                    loss_inter_class2 = 0.0
                    loss_inter_class3 = 0.0
                    loss_inter_class4 = 0.0
                    loss_task = 0.0
                    de_num = 2

                    embed_feat_old = model_old(inputs_synth)
                    embed_feat_old = torch.squeeze(embed_feat_old)
                    embed_feat_old_normal = F.normalize(embed_feat_old, p=2, dim=1)

                    embed_feat_synth = model(inputs_synth)
                    embed_feat_synth = torch.squeeze(embed_feat_synth)
                    embed_feat_normal_synth = F.normalize(embed_feat_synth, p=2, dim=1)

                    dis1, dis2, dis3, dis4, dis_mean, sample_c1, sample_c2, sample_c3, sample_c4, m1, m2, m3, m4, mt = scatter_loss(
                        embed_feat_normal_synth, args.num_instances_gen, args.batch_size_gen)

                    dis1_T, dis2_T, dis3_T, dis4_T, dis_mean_T, sample_c1_T, sample_c2_T, sample_c3_T, sample_c4_T, m1_T, m2_T, m3_T, m4_T, mt_T = scatter_loss \
                        (embed_feat_old_normal, args.num_instances_gen, args.batch_size_gen)

                    loss_intra_cluster1 += torch.mean(torch.norm(sample_c1 - sample_c1_T, p=2, dim=1))
                    loss_intra_cluster2 += torch.mean(torch.norm(sample_c2 - sample_c2_T, p=2, dim=1))

                    Sc1 = sample_c1 - mt
                    Sc1_T = sample_c1_T - mt_T
                    Sc2 = sample_c2 - mt
                    Sc2_T = sample_c2_T - mt_T

                    loss_inter_class1 += torch.mean(torch.norm(Sc1 - Sc1_T, p=2, dim=1))
                    loss_inter_class2 += torch.mean(torch.norm(Sc2 - Sc2_T, p=2, dim=1))

                    if dis4 != [] and dis4_T != []:
                        Sc3 = sample_c3 - mt
                        Sc3_T = sample_c3_T - mt_T
                        Sc4 = sample_c4 - mt
                        Sc4_T = sample_c4_T - mt_T

                        loss_intra_cluster3 += torch.mean(torch.norm(sample_c3 - sample_c3_T, p=2, dim=1))
                        loss_intra_cluster4 += torch.mean(torch.norm(sample_c4 - sample_c4_T, p=2, dim=1))

                        loss_inter_class3 += torch.mean(torch.norm(Sc3 - Sc3_T, p=2, dim=1))
                        loss_inter_class4 += torch.mean(torch.norm(Sc4 - Sc4_T, p=2, dim=1))

                        de_num = 4

                    mse_loss = F.mse_loss(embed_feat_synth, embed_feat_old)

                    loss_intra_class += ((loss_intra_cluster1 + loss_intra_cluster2 + loss_intra_cluster3
                                          + loss_intra_cluster4) / de_num)

                    loss_inter_class += (
                            (loss_inter_class1 + loss_inter_class2 + loss_inter_class3 + loss_inter_class4) / de_num)
                    if args.data != 'cifar10':
                        loss_task = criterion_task(embed_feat_normalize, labels, generated_negative_samples, args.method)

                    loss_aug += args.lambda_scatter * (loss_inter_class + loss_intra_class) + (
                            args.lambda_mse * mse_loss) + (args.lambda_task * loss_task)

            if args.method == 'MLRPTM' or args.method == 'NoisyFine_tuning':

                loss, inter_, dist_ap, dist_an = criterion(embed_feat_noisy, labels)
            else:
                loss, inter_, dist_ap, dist_an = criterion(embed_feat_normalize, labels)

            loss += loss_aug

            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.data
            if epoch == 0 and jj == 0:
                print(50 * '#')
                print('Training...')

        print('[Epoch %05d]\t Total Loss: %.3f' % (epoch + 1, running_loss / len(train_loader)))

        if epoch % args.save_step == 0:
            state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state,
                       os.path.join(log_dir, args.method + '_task_' + str(current_task) + '_%d_model.pt' % epoch))

        if epoch == args.epochs:
            train_embeddings_cl, train_labels_cl = extract_features(
                model, train_loader)

            # Test for each task
            mean_data_csv = []
            for i in current_labels:
                ind_cl = np.where(i == train_labels_cl)[0]
                embeddings_tmp = train_embeddings_cl[ind_cl]
                mean_data_csv.append(np.mean(embeddings_tmp, axis=0).flatten())

            if current_task == 0:
                df = pd.DataFrame(mean_data_csv,
                                  columns=['Feature {}'.format(i + 1) for i in range(len(mean_data_csv[0]))])
                filename_csv = f'means_{args.data}.csv'
                df.to_csv(filename_csv, index=False)

            elif current_task > 0 and args.method == 'MLRPTM':
                # Add mean_data_csv under the final row of means_df
                dfm = pd.DataFrame(mean_data_csv,
                                   columns=['Feature {}'.format(i + 1) for i in range(len(mean_data_csv[0]))])
                means_df = pd.concat([means_df, dfm], ignore_index=True)
                # Save the modified DataFrames with the same names as CSV files
                means_df.to_csv(filename_csv, index=False)




if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='MLRPTM Training')

    # hyper-parameters
    parser.add_argument('-lr', type=float, default=1e-6, help="learning rate")
    parser.add_argument('-lambda_scatter', type=float, default=0.4, help="Scatter loss Coefficient")
    parser.add_argument('-lambda_mse', type=float, default=0.2, help="MSE loss Coefficient")
    parser.add_argument('-lambda_task', type=float, default=0.2, help="inter-task confusion loss Coefficient")
    parser.add_argument('-margin', type=float, default=0.0, help="margin for metric loss")
    parser.add_argument('-BatchSize', '-b', default=64, type=int, metavar='N', help='mini-batch size Default: 64')
    parser.add_argument('-num_instances', default=16, type=int, metavar='n',
                        help=' number of samples from one class in mini-batch')
    parser.add_argument('-Noise_Power', type=float, default=0.01, help="Variance of noise")
    parser.add_argument('-dim', default=512, type=int, metavar='n', help='dimension of embedding space')

    # generator hyper-parameters
    parser.add_argument('--jitter', default=30, type=int, help='jittering factor')
    parser.add_argument('--bn_reg_scale', type=float, default=0.05,
                        help='coefficient for feature distribution regularization')
    parser.add_argument('--first_bn_mul', type=float, default=10.0,
                        help='additional multiplier on first bn layer of R_feature')
    parser.add_argument('--tv_l1', type=float, default=0.0, help='coefficient for total variation L1 loss')
    parser.add_argument('--tv_l2', type=float, default=0.001, help='coefficient for total variation L2 loss')
    parser.add_argument('--lr_gen', type=float, default=0.1, help='learning rate for optimization')
    parser.add_argument('--l2', type=float, default=0.00001, help='l2 loss on the image')
    parser.add_argument('--main_mul', type=float, default=1.0, help='coefficient for the main loss in optimization')
    parser.add_argument('--resolution', type=int, default=224, help='resolution of image')
    parser.add_argument('--batch_size_gen', type=int, default=16, metavar='N', help='generator batch size')
    parser.add_argument('--epoch_gen', type=int, default=80, help='epochs for generating synthetic images')
    parser.add_argument('-num_instances_gen', default=4, type=int, metavar='n',
                        help=' number of samples from one class in generated mini-batch')

    # data & network
    parser.add_argument('-data', default='tiny-imagenet-200', help='path to Data Set')
    parser.add_argument('-loss_m', default='triplet', help='loss for training network')
    parser.add_argument('-loss_confusion', default='NPairLoss', help='task-confusion regulizer')
    parser.add_argument('-epochs', default=50, type=int, metavar='N', help='epochs for training process')
    parser.add_argument('-seed', default=1993, type=int, metavar='N', help='seeds for training process')
    parser.add_argument('-save_step', default=50, type=int, metavar='N', help='number of epochs to save model')
    parser.add_argument('-lr_step', default=200, type=int, metavar='N', help='scheduler step')
    parser.add_argument('-start', default=0, type=int, help='resume epoch')

    # basic parameter
    parser.add_argument('-log_dir', default='Tiny-imagenet-200', help='path that the trained models save')
    parser.add_argument('--nThreads', '-j', default=0, type=int, metavar='N',
                        help='number of data loading threads (default: 0)')
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight-decay', type=float, default=2e-4)
    parser.add_argument("-gpu", type=str, default='0', help='which gpu to choose')
    parser.add_argument("-method", type=str, default='MLRPTM',
                        help='MLRPTM (our method) or Fine_tuning')
    parser.add_argument('-task', default=20, type=int, help='number of tasks')
    parser.add_argument('-base', default=100, type=int, help='number of classes in non_incremental_state')

    # Non-incremental train parameters
    parser.add_argument('--momentum_0', type=float, default=0.9)
    parser.add_argument('--weight-decay_0', type=float, default=5e-4)
    parser.add_argument('-lr_0', type=float, default=0.001,
                        help="learning rate of non_incremental_state")
    parser.add_argument('-BatchSize_0', default=256, type=int, metavar='N',
                        help='mini-batch size Default: 256')
    parser.add_argument('-epochs_0', default=100, type=int, metavar='N', help='epochs for non_incremental_state'
                                                                              'training process')

    args = parser.parse_args()

    if args.data == "cifar100":
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        root = 'DataSet' + '/cifar100'
        traindir = os.path.join(root, 'train')
        testdir = os.path.join(root, 'test')
        num_classes = 100
        label_map = list(range(0, num_classes))

    if args.data == "cifar10":
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                 (0.24703233, 0.24348505, 0.26158768)),
        ])

        root = 'DataSet' + '/cifar10'
        traindir = os.path.join(root, 'train')
        testdir = os.path.join(root, 'test')
        num_classes = 10
        label_map = list(range(0, num_classes))

    if args.data == 'mini-imagenet-100':
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225)),
        ])

        root = 'DataSet' + '/mini-imagenet-100'
        traindir = os.path.join(root, 'train')
        testdir = os.path.join(root, 'test')
        num_classes = 100
        label_map = list(range(0, num_classes))

    if args.data == 'tiny-imagenet-200':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225)),

        ])
        root = 'DataSet' + '/tiny-imagenet-200'
        traindir = os.path.join(root, 'train')
        testdir = os.path.join(root, 'test')
        num_classes = 200
        label_map = list(range(0, num_classes))

    num_task = args.task
    num_class_per_task = int((num_classes - args.base) / (num_task - 1))

    np.random.seed(args.seed)
    random_perm = np.random.permutation(num_classes)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    class_indexes = []
    for i in range(num_task):
        if i == 0:
            class_index = random_perm[:args.base]
            trainfolder = ImageFolder(
                traindir, transform_train, index=class_index)

            train_loader = torch.utils.data.DataLoader(
                trainfolder, batch_size=args.BatchSize,
                sampler=RandomIdentitySampler(
                    trainfolder, batch_size=args.BatchSize, num_instances=args.num_instances),
                drop_last=True, num_workers=args.nThreads)
            classes_number = len(train_loader.dataset.classes)
        else:
            class_index = random_perm[args.base +
                                      (i - 1) * num_class_per_task:args.base + i * num_class_per_task]
            trainfolder = ImageFolder(
                traindir, transform_train, index=class_index)

            train_loader = torch.utils.data.DataLoader(
                trainfolder, batch_size=args.BatchSize,
                sampler=RandomIdentitySampler(
                    trainfolder, batch_size=args.BatchSize, num_instances=args.num_instances),
                drop_last=True, num_workers=args.nThreads)
            classes_number = len(train_loader.dataset.classes)
        class_indexes.append(class_index)

        if i == 0 and args.method == 'MLRPTM':
            num_class = len(class_index)

            trainfolder = ImageFolder(
                traindir, transform_train, index=class_index)
            dataset_sizes_train = len(trainfolder)

            train_loader_0 = torch.utils.data.DataLoader(trainfolder, batch_size=args.BatchSize_0, shuffle=True,
                                                         drop_last=True, num_workers=args.nThreads)

            dictlist = dict(zip(class_index, label_map))
            print(50 * '#')
            print("Start of Non-incremental Task")
            print("base classes number = {}".format(num_class))
            print("Training-set size = " + str(len(train_loader_0.dataset)))
            initial_train_fun(args, train_loader_0, dataset_sizes_train, num_class, dictlist)
            train_fun(args, train_loader, i, class_indexes, class_index)

        else:
            print(50 * '#')
            if i == 0:
                print("Start of Non-incremental Task")
            else:
                print("Start of task: {}".format(i))
            print("new classes number = {}".format(num_class_per_task))
            print("Training-set size = " + str(len(train_loader.dataset)))
            train_fun(args, train_loader, i, class_indexes, class_index)
