#
# Copyright 2022- IBM Inc. All rights reserved
# SPDX-License-Identifier: Apache2.0
#

# ==================================================================================================
# IMPORTS
# ==================================================================================================
import csv
import datetime
import time
from copy import copy
from operator import itemgetter
import os
import shutil

import torch
import torch as t
import torch.nn as nn

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

import numpy as np
from dotmap import DotMap
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import progressbar
import tqdm
from .util import Logger
from .model import *
# from lib.util import csv2dict, loadmat
from .torch_blocks import *
from .confusion_support import plot_confusion_support, avg_sim_confusion
import os.path
import pdb
from torchvision import transforms
from .PCA import *
from .RandMix import RandMix
from .sinkhorn_distance import SinkhornDistance, SinkhornDistance_given_cost, SinkhornDistance_another_one_to_multi
from sklearn.linear_model import LogisticRegression

from .dataloader.FSCIL.data_utils import *

import torch.nn.functional as F


class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img



def mixup_data(x, y, device_idx, alpha=1.0, use_cuda=True, ):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda(device_idx)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def cutmix_data(input_x,input_y,device_idx,beta=1.0):


    lam = np.random.beta(beta, beta)
    rand_index = torch.randperm(input_x.size()[0]).cuda(device_idx)

    target_a = input_y
    target_b = input_y[rand_index]

    bbx1, bby1, bbx2, bby2 = rand_bbox(input_x.size(), lam)
    input_x[:, :, bbx1:bbx2, bby1:bby2] = input_x[rand_index, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input_x.size()[-1] * input_x.size()[-2]))

    return input_x, target_a, target_b, lam




# def pretrain_baseFSCIL(verbose, **parameters):
#     '''
#     Pre-training on base session
#     '''
#     args = DotMap(parameters)
#     # args.gpu = 4
#     writer = SummaryWriter(args.log_dir)
#
#     # Initialize the dataset generator and the model
#     args = set_up_datasets(args)
#     trainset, train_loader, val_loader = get_base_dataloader(args)
#
#     model = KeyValueNetwork(args)
#
#     model.mode = 'pretrain'
#     # Store all parameters in a variable
#     parameters_list, parameters_table = process_dictionary(parameters)
#     logs_dir = os.path.join(args.log_dir + '/' + 'train_log.txt')
#     print(model)
#     # Print all parameters
#     if verbose:
#         print("Parameters:")
#         for key, value in parameters_list:
#             print("\t{}".format(key).ljust(40) + "{}".format(value))
#             with open(logs_dir, 'a', encoding='utf-8') as f1:
#                 f1.write("\t{}".format(key).ljust(40) + "{}".format(value)+'\n')
#
#
#     criterion = nn.CrossEntropyLoss()
#
#     if args.gpu is not None:
#         t.cuda.set_device(args.gpu)
#         model = model.cuda(args.gpu)
#         criterion = criterion.cuda(args.gpu)
#
#     # for param in model.embedding.parameters():
#     #     param.requires_grad = False
#     # model.classifier.requires_grad=True
#     #
#     # optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
#     #                         lr=args.learning_rate,nesterov=args.SGDnesterov,
#     #                         weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
#     optimizer = t.optim.SGD(model.parameters(), lr=args.learning_rate, nesterov=args.SGDnesterov,
#                             weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
#     scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)
#
#     start_train_iter = 0
#     best_acc1 = 0
#
#
#     for epoch in tqdm.tqdm(range(1, args.max_train_iter), desc='Epoch'):
#         global_count = 0
#
#         losses = AverageMeter('Loss')
#         acc = AverageMeter('Acc@1')
#
#         model.train(True)
#
#         for i, batch in enumerate(train_loader):
#             global_count = global_count + 1
#             data, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
#             # data, data_aug1, data_aug2, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
#             # forward pass
#             optimizer.zero_grad()
#             aug_p = torch.rand(1)
#
#             output = model(data)
#             loss_cls = criterion(output, train_label)
#             proxy = model.classifier
#             features = model.fea_rep
#             loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)
#
#             loss = loss_cls + args.pcl_weight *  loss_pcl
#
#             # Backpropagation
#             loss.backward()
#             optimizer.step()
#
#             losses.update(loss.item(), data.size(0))
#
#
#         scheduler.step()
#
#         # write to tensorboard
#         writer.add_scalar('training_loss/pretrain_CEL', losses.avg, epoch)
#         writer.add_scalar('accuracy/pretrain_train', acc.avg, epoch)
#
#         val_loss, val_acc_mean, _ = validation(model, criterion, val_loader, args)
#         writer.add_scalar('validation_loss/pretrain_CEL', val_loss, epoch)
#         writer.add_scalar('accuracy/pretrain_val', val_acc_mean, epoch)
#
#         is_best = val_acc_mean > best_acc1
#         best_acc1 = max(val_acc_mean, best_acc1)
#
#         with open(logs_dir, 'a', encoding='utf-8') as f1:
#             f1.write(f'epoch: {epoch} current mean ac: {val_acc_mean} best acc: {best_acc1:0.5f}\n')
#
#         print('epoch:', epoch, 'current mean acc', val_acc_mean, 'best acc:', best_acc1)
#         save_checkpoint({
#             'train_iter': epoch + 1,
#             'arch': args.block_architecture,
#             'state_dict': model.state_dict(),
#             'best_acc1': best_acc1,
#             'optimizer': optimizer.state_dict(),
#         }, is_best, savedir=args.log_dir)
#
#     writer.close()


def pretrain_baseFSCIL(verbose, **parameters):
    '''
    Pre-training on base session
    '''
    args = DotMap(parameters)
    # args.gpu = 4
    writer = SummaryWriter(args.log_dir)

    # Initialize the dataset generator and the model
    args = set_up_datasets(args)
    trainset, train_loader, val_loader = get_base_dataloader(args)

    model = KeyValueNetwork(args)

    model.mode = 'pretrain'
    # Store all parameters in a variable
    parameters_list, parameters_table = process_dictionary(parameters)
    logs_dir = os.path.join(args.log_dir + '/' + 'train_log.txt')
    print(model)
    # Print all parameters
    if verbose:
        print("Parameters:")
        for key, value in parameters_list:
            print("\t{}".format(key).ljust(40) + "{}".format(value))
            with open(logs_dir, 'a', encoding='utf-8') as f1:
                f1.write("\t{}".format(key).ljust(40) + "{}".format(value)+'\n')


    criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        t.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    # for param in model.embedding.parameters():
    #     param.requires_grad = False
    # model.classifier.requires_grad=True
    # optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
    #                         lr=args.learning_rate,nesterov=args.SGDnesterov,
    #                         weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)

    optimizer = t.optim.SGD(model.parameters(), lr=args.learning_rate, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
    scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)

    start_train_iter = 0
    best_acc1 = 0


    for epoch in tqdm.tqdm(range(1, args.max_train_iter), desc='Epoch'):
        global_count = 0

        losses = AverageMeter('Loss')
        acc = AverageMeter('Acc@1')

        model.train(True)

        for i, batch in enumerate(train_loader):
            global_count = global_count + 1
            # data, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            data, data_aug1, data_aug2, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            # forward pass
            optimizer.zero_grad()
            aug_p = torch.rand(1)


            if epoch <20:
                output = model(data)
                loss_cls = criterion(output, train_label)
                proxy = model.classifier
                features = model.fea_rep
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)

            else:

                if aug_p < 0.33:

                    inputs_aug, targets_a_aug, targets_b_aug, lam = mixup_data(data, train_label, device_idx=args.gpu)

                    all_x = torch.cat([inputs_aug, data])
                    all_y = model(all_x)
                    loss_cls = lam * criterion(all_y[:data.size()[0]], targets_a_aug) + (1 - lam) * criterion(
                        all_y[:data.size()[0]], targets_b_aug) + criterion(all_y[data.size()[0]:], train_label)

                    proxy = model.classifier
                    features = model.fea_rep[data.size()[0]:]
                    loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)


                elif 0.33 < aug_p < 0.66:

                    inputs_aug, targets_a_aug, targets_b_aug, lam = cutmix_data(data, train_label, device_idx=args.gpu)
                    all_x = torch.cat([inputs_aug, data])
                    all_y = model(all_x)
                    loss_cls = lam * criterion(all_y[:data.size()[0]], targets_a_aug) + (1 - lam) * criterion(
                        all_y[:data.size()[0]], targets_b_aug) + criterion(all_y[data.size()[0]:], train_label)
                    proxy = model.classifier
                    features = model.fea_rep[data.size()[0]:]
                    loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)

                else:
                    all_x = torch.cat([data, data_aug1, data_aug2])
                    all_y = model(all_x)

                    targets = train_label
                    logits_clean, logits_aug1, logits_aug2 = torch.split(
                        all_y, data.size(0))

                    # Cross-entropy is only computed on clean images
                    loss_cls = criterion(logits_clean, targets)

                    p_clean, p_aug1, p_aug2 = F.softmax(logits_clean, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(
                        logits_aug2, dim=1)

                    # Clamp mixture distribution to avoid exploding KL divergence
                    p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
                    loss_kl = 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                                    F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                                    F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

                    loss_cls += loss_kl
                    proxy = model.classifier
                    features = model.fea_rep[:data.size()[0]]
                    loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)



            loss = loss_cls + args.pcl_weight *  loss_pcl

            # Backpropagation
            loss.backward()
            optimizer.step()

            losses.update(loss.item(), data.size(0))


        scheduler.step()

        # write to tensorboard
        writer.add_scalar('training_loss/pretrain_CEL', losses.avg, epoch)
        writer.add_scalar('accuracy/pretrain_train', acc.avg, epoch)

        val_loss, val_acc_mean, _ = validation(model, criterion, val_loader, args)
        writer.add_scalar('validation_loss/pretrain_CEL', val_loss, epoch)
        writer.add_scalar('accuracy/pretrain_val', val_acc_mean, epoch)

        is_best = val_acc_mean > best_acc1
        best_acc1 = max(val_acc_mean, best_acc1)

        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'epoch: {epoch} current mean ac: {val_acc_mean} best acc: {best_acc1:0.5f}\n')

        print('epoch:', epoch, 'current mean acc', val_acc_mean, 'best acc:', best_acc1)
        save_checkpoint({
            'train_iter': epoch + 1,
            'arch': args.block_architecture,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
        }, is_best, savedir=args.log_dir)

    writer.close()




def metatrain_baseFSCIL(verbose, **parameters):
    '''
    Meta-training on base session
    '''

    # Argument Preparation
    args = DotMap(parameters)
    # args.gpu = 5
    writer = SummaryWriter(args.log_dir)

    # Initialize the dataset generator and the model
    args = set_up_datasets(args)
    trainset, train_loader, val_loader = get_base_dataloader2(args)
    model = KeyValueNetwork(args)

    model.mode = 'pretrain'
    # Store all parameters in a variable
    parameters_list, parameters_table = process_dictionary(parameters)

    logs_dir = os.path.join(args.log_dir + '/' + 'metatrain_log.txt')
    # Print all parameters
    if verbose:
        print("Parameters:")
        for key, value in parameters_list:
            print("\t{}".format(key).ljust(40) + "{}".format(value))
            with open(logs_dir, 'a', encoding='utf-8') as f1:
                f1.write("\t{}".format(key).ljust(40) + "{}".format(value) + '\n')



    # Take start time
    start_time = time.time()

    criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        t.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)


    # for param in model.classifier:
    #     param.requires_grad = False

    # for param in model.embedding.parameters():
    #     param.requires_grad = False
    #
    # model.classifier.requires_grad=True

    optimizer = t.optim.SGD(model.parameters(),
                            lr=args.learning_rate, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)

    # optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
    #                         lr=args.learning_rate,nesterov=args.SGDnesterov,
    #                         weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)



    scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)

    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80, 130 ,180],
    #                                                  gamma=0.1)

    model, optimizer, scheduler, start_train_iter, best_acc1= load_checkpoint(model, optimizer, scheduler, args)
    best_acc1 = 0

    # plot_confusion_support(model.classifier.data.cpu(),
    #                        savepath="{:}/session{:}".format(args.log_dir, 'cos_i_p'))
    # train_iterator = iter(train_loader)


    ## RandMix

    # aug_tran = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # data_aug = RandMix(1).cuda(args.gpu)

    for i in tqdm.tqdm(range(1, args.max_train_iter), desc='Epoch'):
        global_count = 0

        losses = AverageMeter('Loss')
        acc = AverageMeter('Acc@1')

        model.train(True)


        #
        for j, batch in enumerate(train_loader):
            global_count = global_count + 1
            # data, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            data, data_aug1,data_aug2, train_label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
            optimizer.zero_grad()
            aug_p = torch.rand(1)


            if aug_p<args.mixup_decision:

                inputs_aug, targets_a_aug, targets_b_aug, lam = mixup_data(data, train_label, device_idx=args.gpu)

                all_x = torch.cat([inputs_aug, data])
                all_y = model(all_x)
                loss_cls = lam * criterion(all_y[:data.size()[0]], targets_a_aug) + (1 - lam) * criterion(
                    all_y[:data.size()[0]], targets_b_aug) + criterion(all_y[data.size()[0]:], train_label)

                proxy = model.classifier
                features = model.fea_rep[data.size()[0]:]
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)


            elif args.mixup_decision < aug_p < args.cutmix_decision:


                inputs_aug, targets_a_aug, targets_b_aug, lam = cutmix_data(data, train_label, device_idx=args.gpu)
                all_x = torch.cat([inputs_aug, data])
                all_y = model(all_x)
                loss_cls = lam * criterion(all_y[:data.size()[0]], targets_a_aug) + (1 - lam) * criterion(
                    all_y[:data.size()[0]], targets_b_aug) + criterion(all_y[data.size()[0]:], train_label)
                proxy = model.classifier
                features = model.fea_rep[data.size()[0]:]
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)


            elif args.cutmix_decision < aug_p < 0.80:
                output = model(data)
                loss_cls = criterion(output, train_label)
                proxy = model.classifier
                features = model.fea_rep
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)

            else:
                all_x = torch.cat([data, data_aug1, data_aug2])
                all_y = model(all_x)

                targets = train_label
                logits_clean, logits_aug1, logits_aug2 = torch.split(
                    all_y, data.size(0))

                # Cross-entropy is only computed on clean images
                loss_cls = criterion(logits_clean, targets)

                p_clean, p_aug1, p_aug2 = F.softmax(logits_clean, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(logits_aug2, dim=1)

                # Clamp mixture distribution to avoid exploding KL divergence
                p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
                loss_kl= 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                              F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                              F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.


                loss_cls += loss_kl
                proxy = model.classifier
                features = model.fea_rep[:data.size()[0]]
                loss_pcl = PCLoss(num_classes=args.base_class, scale=12)(features, train_label, proxy)


            loss = loss_cls + args.pcl_weight * loss_pcl



            # Backpropagation
            loss.backward()
            optimizer.step()

            # accuracy = top1accuracy(output.argmax(dim=1),train_label)

            losses.update(loss.item(), data.size(0))
            # acc.update(accuracy.item(), data.size(0))

        scheduler.step()

        val_loss, val_acc_mean, _ = validation(model, criterion, val_loader, args)
        writer.add_scalar('validation_loss/log_loss', val_loss, i)
        writer.add_scalar('accuracy/validation', val_acc_mean, i)

        is_best = val_acc_mean > best_acc1
        best_acc1 = max(val_acc_mean, best_acc1)

        if is_best:
            model.eval()
            trainset_clean, train_loader_clean, val_loader_clean = get_clean_dataloader(args)
            prototype, cov, classlabel = model.protoSave(train_loader_clean)
        else:
            prototype, cov, classlabel = None,None,None



        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'epoch: {i} current mean ac: {val_acc_mean} best acc: {best_acc1:0.5f} loss: {losses.avg}\n')
        print('epoch:', i, 'current mean acc', val_acc_mean, 'best acc:', best_acc1,'loss', losses.avg)

        save_checkpoint({
            'train_iter': i + 1,
            'arch': args.block_architecture,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'prototype' : prototype,
            'cov': cov,
            'classlabel': classlabel,


        }, is_best, savedir=args.log_dir)

    writer.close()







def train_FSCIL(verbose=False, **parameters):
    '''
    Main FSCIL evaluation on all sessions
    '''
    args = DotMap(parameters)
    args = set_up_datasets(args)
    # args.gpu = 6

    model = KeyValueNetwork(args, mode="pretrain")

    # Store all parameters in a variable
    parameters_list, parameters_table = process_dictionary(parameters)

    # Print all parameters
    if verbose:
        print("Parameters:")
        for key, value in parameters_list:
            print("\t{}".format(key).ljust(40) + "{}".format(value))

    # Write parameters to file
    if not args.inference_only:
        filename = args.log_dir + '/parameters.csv'
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # retrain
        with open(filename, 'w') as csv_file:
            writer = csv.writer(csv_file)
            keys, values = zip(*parameters_list)
            writer.writerow(keys)
            writer.writerow(values)

    writer = SummaryWriter(args.log_dir)

    criterion = nn.CrossEntropyLoss()

    if args.gpu is not None:
        t.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    # set all parameters except FC to trainable false
    for param in model.parameters():
        param.requires_grad = False
    for param in model.embedding.fc.parameters():
        param.requires_grad = True

    model.classifier.requires_grad = False

    optimizer = t.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=args.learning_rate, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)




    scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)

    model, optimizer, scheduler, start_train_iter, best_acc1, prototype, cov, classlabel= load_checkpoint2(model, optimizer, scheduler, args)

    # model, optimizer, scheduler, start_train_iter, best_acc1 = load_checkpoint(model, optimizer, scheduler, args)
    logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
    all_acc = []
    all_acc_new = []
    classifier_weights = model.classifier.data

    # prototype = model.classifier.data.cpu().numpy()
    for session in range(args.sessions):
        nways_session = args.base_class + session * args.way

        if session > 0:
            model.mode ='meta'

        train_set, train_loader, test_loader, test_loader_new= get_dataloader(args, session)
        batch = next(iter(train_loader))

        # _, clean_loader, _ = get_clean_dataloader(args)
        #
        # data_list = []  # 存储所有样本的特征向量
        # labels = []  # 存储所有样本的标签
        #
        # model.eval()
        # with t.no_grad():
        #     for x, target in clean_loader:
        #         x = x.cuda(args.gpu, non_blocking=True)
        #         all_feature = model.embedding(x)
        #         target = target.cuda(args.gpu, non_blocking=True)
        #
        #         # 将特征向量和标签添加到数据列表中
        #         data_list.append(all_feature.cpu().numpy())
        #         labels.append(target.cpu().numpy())
        #
        # # 将数据列表转换为 NumPy 数组
        # data_array = np.concatenate(data_list, axis=0)
        # labels_array = np.concatenate(labels, axis=0)
        #
        # # 使用 t-SNE 进行降维
        # tsne = TSNE(n_components=2, random_state=42)
        # tsne_result = tsne.fit_transform(data_array)
        # print(tsne_result.shape)
        # # 可视化 t-SNE 结果
        # plt.figure(figsize=(8, 6),dpi=500)
        # plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=labels_array, cmap=plt.cm.get_cmap("jet", 10), marker='o')
        # # plt.legend()
        # # plt.colorbar()
        # plt.xticks([])
        # plt.yticks([])
        # plt.title('t-SNE Visualization')
        #
        # save_path = args.log_dir + 'tsne_plot_cifar.pdf'
        # # 保存图为 PDF 文件
        # plt.savefig(save_path, format='pdf')


        if args.retrain_iter == 0:
            original_align(model, batch, optimizer, args, writer, session, nways_session, prototype, cov)
        else:
            # proto_align_final(model, batch, optimizer, args, writer, session, nways_session, prototype, cov)


            base_proto, base_cov, acc_each_session = proto_align_v5(model, batch, optimizer, args, writer, session, nways_session, prototype, cov, test_loader, best_acc1, test_loader_new)
            cov = base_cov
            prototype = base_proto
            all_acc.append(acc_each_session)

            # proto_align_v6(model, batch, optimizer, args, writer, session, nways_session, prototype, cov)
            # base_cov = proto_align_v4(model, batch, optimizer, args, writer, session, nways_session, prototype, cov)
            # cov = base_cov

        # loss, acc, conf_fig = validation(model, criterion, test_loader, args, nways_session)


        # if test_loader_new is not None:
        #     model.eval()
        #     acc_new = AverageMeter('Acc@1', ':6.2f')
        #     with t.no_grad():
        #         for i, batch in enumerate(test_loader_new):
        #             data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
        #             output = model(data)
        #             accuracy = top1accuracy(output.argmax(dim=1), label)
        #
        #             acc_new.update(accuracy.item(), data.size(0))
        #     all_acc_new.append(acc_new.avg)

        #
        # print("Session {:}: {:.2f}%".format(session, acc))
        # all_acc.append(acc)
        # writer.add_scalar('accuracy/cont', acc, session)

        # acc_up2now = []
        #
        # for i in range(session + 1):
        #     if i == 0:
        #         classes = np.arange(args.num_classes)[:args.base_class]
        #     else:
        #         classes = np.arange(args.num_classes)[(args.base_class + (i - 1) * args.way):(args.base_class + i * args.way)]
        #     if args.dataset == 'cifar100':
        #         test_for_each = args.Dataset.CIFAR100(root=args.data_folder, train=False, index=classes,
        #                                               base_sess=False)
        #     elif args.dataset == 'mini_imagenet':
        #
        #         test_for_each = args.Dataset.MiniImageNet(root=args.data_folder, train=False,
        #                                index=classes)
        #     else:
        #         test_for_each =  args.Dataset.CUB200(root=args.data_folder, train=False, index=classes)
        #
        #     testloader2 = torch.utils.data.DataLoader(dataset=test_for_each, batch_size=args.batch_size_inference, shuffle=False, pin_memory=True)
        #     model.eval()
        #     acc2 = AverageMeter('Acc@1', ':6.2f')
        #     with t.no_grad():
        #         for i, batch in enumerate(testloader2):
        #             data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
        #             output = model(data)
        #             accuracy = top1accuracy(output.argmax(dim=1), label)
        #
        #             acc2.update(accuracy.item(), data.size(0))
        #     acc_up2now.append(acc2.avg)
        # print(acc_up2now, all_acc_new)




        # if session == 0:
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'\nSource Model:{args.resume}\n')
        #         f1.write(f'{acc_up2now}\t{acc}\n')
        # else:
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'{acc_up2now}\Avg Acc:{acc}\t Novel classes Acc :{all_acc_new}\n')

        if session == args.sessions - 1:
            mean_acc = np.mean(all_acc)
            with open(logs_dir, 'a', encoding='utf-8') as f1:
                f1.write(f'Mean Acc for this run is: {mean_acc}\t Each Session Acc is{all_acc} \n')
                print((f'Mean Acc for this run is: {mean_acc}\n Each Session Acc is{all_acc} \n'))
    writer.close()


def original_align(model, data, optimizer, args, writer, session, nways_session, prototype, cov):
    '''
    Alignment of FC using MSE Loss and feature replay
    '''

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)

    # Stage 1: Compute feature representation of new data
    all_features = []
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            all_feature = model.embedding(x)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)
            # print(all_feature.shape, target)

    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # nways_session = args.base_class + session * args.way
    # oways_session = args.base_class + (session - 1) * args.way
    # # simaara = cosine_similarity_multi(feat[oways_session :nways_session], feat[:oways_session ])

    # plot_confusion_support(feat[:nways_session].cpu(),
    #                         savepath="{:}/relu2{:}".format(args.log_dir, str(session)))

    # if session ==1 :
    #     np.random.seed(42)
    #     X = model.key_mem.data[:nways_session].cpu().numpy()  # 示例数据，50 个样本，每个样本有 10 个特征
    #
    #     # 使用 t-SNE 进行降维
    #     tsne = TSNE(n_components=2, random_state=42)
    #     X_tsne = tsne.fit_transform(X)
    #
    #     # 绘制 t-SNE 图
    #     num_samples = X_tsne.shape[0]
    #     colors = plt.cm.get_cmap('Pastel1', num_samples)
    #     color_indices = np.arange(60)  # 前 60 个样本的索引
    #
    #     # 绘制 t-SNE 图
    #     plt.figure(figsize=(8, 6))
    #
    #     # 根据颜色映射绘制前 60 个样本的散点图
    #     plt.scatter(X_tsne[color_indices, 0], X_tsne[color_indices, 1], c=color_indices, cmap=colors, marker='o', s=50)
    #
    #     # 绘制剩下的样本为蓝色
    #     plt.scatter(X_tsne[60:, 0], X_tsne[60:, 1], c='b', marker='o', s=50, label='Class 2')
    #
    #     plt.title('t-SNE Visualization')
    #     plt.xlabel('t-SNE Dimension 1')
    #     plt.ylabel('t-SNE Dimension 2')
    #     plt.colorbar()  # 添加颜色映射的颜色条
    #
    #     save_path = args.log_dir + 'tsne_plot_cifar.pdf'
    #     # 保存图为 PDF 文件
    #     plt.savefig(save_path, format='pdf')

    # 显示图
    # plt.show()


    before_caliber = model.key_mem.data[:nways_session]

    # Stage 3: Nuddging
    model.nudge_prototypes(nways_session, writer, session, args.gpu)




    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        # nways_session = args.base_class + session * args.way
        # oways_session = args.base_class + (session - 1) * args.way
        model.bipolarize_prototypes()


    # if session>0:
    #     nways_session = args.base_class + session * args.way
    #     oways_session = args.base_class + (session - 1) * args.way
    #     base_torch = torch.from_numpy(prototype).cuda(args.gpu)
    #
    #     new_prototypes = model.key_mem.data[oways_session:nways_session]
    #
    #     cos = t.nn.CosineSimilarity()
    #     similar =cos(new_prototypes, base_torch)
    #     print(similar)



    # plot_confusion_support(model.key_mem.data[:nways_session].cpu(),
    #                        savepath="{:}/relu2{:}".format(args.log_dir, str(session)))



    # Stage 4: Update Retraining the FC
    model.embedding.fc.train()

    # model.key_mem.data = Tanh10x()(model.key_mem.data)

    # if session == 1:
    #     np.random.seed(42)
    #
    #     X = model.key_mem.data[:nways_session].cpu().numpy()
    #
    #     features = all_feature.cpu().numpy()
    #
    #     X2 = before_caliber.cpu().numpy()[60:65]  # 示例数据，50 个样本，每个样本有 10 个特征
    #
    #     X = np.concatenate((X,X2, features), axis=0)
    #
    #     # 使用 t-SNE 进行降维
    #     tsne = TSNE(n_components=2, random_state=42)
    #     X_tsne = tsne.fit_transform(X)
    #
    #     # 绘制 t-SNE 图
    #     num_samples = X_tsne.shape[0]
    #     colors = plt.cm.get_cmap('Pastel1', num_samples)
    #     color_indices = np.arange(60)  # 前 60 个样本的索引
    #
    #     # 绘制 t-SNE 图
    #     plt.figure(figsize=(8, 6))
    #
    #     # 根据颜色映射绘制前 60 个样本的散点图
    #     plt.scatter(X_tsne[color_indices, 0], X_tsne[color_indices, 1], c=color_indices, cmap=colors, marker='o', s=50,
    #                 label='Base Prototypes')
    #
    #     plt.scatter(X_tsne[60:65, 0], X_tsne[60:65, 1], c='r', marker='o', s=50, label='original new Prototypes')
    #
    #     # 绘制剩下的样本为蓝色
    #     plt.scatter(X_tsne[65:70, 0], X_tsne[65:70, 1], c='b', marker='o', s=50, label='calibrated new Prototypes')
    #
    #     plt.scatter(X_tsne[70:75, 0], X_tsne[70:75, 1], c='g', marker="^", s=50, label='samples')
    #
    #     plt.scatter(X_tsne[75:80, 0], X_tsne[75:80, 1], c='g', marker="s", s=50, label='samples')
    #
    #     plt.scatter(X_tsne[80:85, 0], X_tsne[80:85, 1], c='g', marker="p", s=50, label='samples')
    #
    #     plt.scatter(X_tsne[85:90, 0], X_tsne[85:90, 1], c='g', marker="D", s=50, label='samples')
    #
    #     plt.scatter(X_tsne[90:95, 0], X_tsne[90:95, 1], c='g', marker="x", s=50, label='samples')
    #
    #     # for i in range(5):
    #     #     my_list = ["p"]
    #     #     selected_elements = np.random.choice(my_list, size=3, replace=False)
    #     #     plt.scatter(X_tsne[70:70+(i+1), 0], X_tsne[70:, 1], c='g', marker= "^", s=50, label='samples')
    #
    #     plt.title('t-SNE Visualization')
    #     plt.xlabel('t-SNE Dimension 1')
    #     plt.ylabel('t-SNE Dimension 2')
    #     # plt.legend()
    #
    #     save_path = args.log_dir + 'tsne_plot_cifar_all_v13.pdf'
    #     # 保存图为 PDF 文件
    #     plt.savefig(save_path, format='pdf')



    if session>0:
        for epoch in range(args.retrain_iter):
            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss = criterion(support[:nways_session], model.key_mem.data[:nways_session])

            # Backpropagation
            loss.backward()
            optimizer.step()

            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)






def proto_align_v6(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()


    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way



        base_proxy = model.key_mem.data[:model.args.base_class]


        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
        # c_proto = Tanh10x()(model.key_mem.data)
        c_proto = torch.sign(model.key_mem.data)
        # c_proto = model.key_mem.data

        base_torch = model.key_mem.data[:args.base_class]

        cost, Pi, C = sinkhorn(base_torch, c_proto[args.base_class:nways_session])

        c_proto = c_proto.cpu().numpy()

        base_prototype = c_proto[:args.base_class]

        # cov_saver = base_cov
        # prototype_saver = base_prototype
        for i in range(args.base_class,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-args.base_class], base_prototype, base_cov[:args.base_class],
                                                     n_lsamples=args.way )

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=300)


            proto_temp2 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            similar = cosine_similarity_multi(proto_temp2, base_proxy, rep=args.representation)


            similar  = torch.argmin(-similar.sum(dim=1))
            c_proto[i] = proto_temp[similar.cpu().numpy()]

            # cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            # prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])



        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto
        model.nudge_prototypes(nways_session, writer, session, args.gpu)



        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
            loss = loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    # if session==0:
    #     all_cov = base_cov
    #     # all_proto = base_prototype
    # else:
    #     all_cov = cov_saver
    #     # all_proto = prototype_saver
    #
    # return  all_cov




def proto_align_v5(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov, test_loader, best_acc1, test_loader_new):
    ######use saved mean and covariance perform clustering##########
    losses = AverageMeter('Loss')
    criterion = myCosineLoss()
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)



    classifier = MultiClassLogisticRegression(input_dim=args.dim_features, output_dim=nways_session).cuda(args.gpu)
    acc = AverageMeter('Acc@1', ':6.2f')
    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            x_features = model.embedding(x)

            model.update_feat_replay(x, target)



    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)
    logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
    if session == 0:
        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'\nSource Model:{args.resume}\n')
            f1.write(f'{best_acc1}\t{best_acc1}\n')

    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way

        c_proto = model.key_mem.data

        # base_torch = model.key_mem.data[:oways_session]

        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        cost, Pi, C = sinkhorn(base_torch[:args.base_class], c_proto[args.base_class:nways_session])


        c_proto = c_proto.cpu().numpy()

        # base_prototype = c_proto[:oways_session]

        cov_saver = base_cov
        prototype_saver = base_prototype

        sampled_feature_old,sampled_label_old  = [], []
        sample_num_old =args.sample_num

        for idx in range(args.base_class):
            sampled_feature_old.append(np.random.multivariate_normal(mean=base_prototype[idx], cov=base_cov[idx], size=sample_num_old))
            sampled_label_old.extend([idx] *sample_num_old)


        sampled_feature_old = np.array(sampled_feature_old).reshape(args.base_class*sample_num_old, -1)
        sampled_label_old = np.array(sampled_label_old)

        sampled_feature_new = []
        sampled_label_new = []
        sample_num = args.sample_num
        for i in range(args.base_class,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-args.base_class], base_prototype[:args.base_class], base_cov[:args.base_class],
                                                     n_lsamples=args.way)

            sampled_feature_new.append(np.random.multivariate_normal(mean=mean, cov=cov, size=sample_num))

            sampled_label_new.extend([i] *sample_num)
            cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])

        sampled_feature_new = np.array(sampled_feature_new).reshape((nways_session-args.base_class)*sample_num, -1)
        sampled_label_new = np.array(sampled_label_new)

        sampled_feature_all = np.concatenate([sampled_feature_old, sampled_feature_new], axis=0)
        sampled_label_all = np.concatenate([sampled_label_old, sampled_label_new], axis=0)

        # sampled_feature_all = sampled_feature_new
        # sampled_label_all = sampled_label_new


        sampled_feature_all = torch.from_numpy(sampled_feature_all).cuda(args.gpu).float()
        sampled_feature_all = torch.concat([sampled_feature_all,x_features],dim=0)

        sampled_label_all = torch.from_numpy(sampled_label_all).cuda(args.gpu)
        sampled_label_all = torch.concat([sampled_label_all,target], dim=0)

        # sampled_feature_new = sampled_feature_new.reshape(5 * sample_num, -1)
        # sampled_feature_new = torch.from_numpy(sampled_feature_new).cuda(args.gpu).float()
        # sampled_label_new = torch.from_numpy(sampled_label_new).cuda(args.gpu)
        # sampled_feature_new = torch.concat([sampled_feature_new,x_features],dim=0)
        # sampled_label_new = torch.concat([sampled_label_new, target], dim=0)

        num_epochs = 1000
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
        classifier.train()

        all_training_num = sample_num_old*args.base_class + sample_num * (nways_session-args.base_class)
        old_class_weight = (sample_num_old*args.base_class)/ all_training_num
        new_class_weight =  (sample_num * (nways_session-args.base_class)) /all_training_num



        for epoch in range(num_epochs):


            # classifier.linear.data[:args.base_class] = model.classifier.data
            outputs = classifier(sampled_feature_all)

            old_class_weights = 0.6*torch.ones(args.base_class).cuda(args.gpu)
            novel_class_weights = torch.ones(nways_session-args.base_class).cuda(args.gpu)

            weights = torch.cat([old_class_weights,novel_class_weights])

            loss = nn.CrossEntropyLoss(weight=weights)(outputs, sampled_label_all)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        classifier.linear.data[:args.base_class] =  0.4*classifier.linear.data[:args.base_class]+ 0.6*model.classifier.data


        model.eval()
        classifier.eval()


        logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
        all_acc,acc_up2now,all_acc_new = [], [],[]

        with t.no_grad():
            for i, batch in enumerate(test_loader):
                data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]

                features = model.embedding(data)

                # query_features = features.cpu().numpy()
                predicts = classifier(features)

                # predicts = torch.from_numpy(predicts).cuda(args.gpu)
                accuracy = top1accuracy(predicts.argmax(dim=1), label)
                # losses.update(loss.item(), data.size(0))
                acc.update(accuracy.item(), data.size(0))

        acc_each_session = acc.avg
        print("Session {:} Testing Acc: {:.2f}%".format(session, acc_each_session))
        all_acc.append(acc_each_session)


        # if test_loader_new is not None:
        #     model.eval()
        #     acc_new = AverageMeter('Acc@1', ':6.2f')
        #     with t.no_grad():
        #         for i, batch in enumerate(test_loader_new):
        #             data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
        #
        #             features = model.embedding(data)
        #             predicts = classifier(features)
        #             accuracy = top1accuracy(predicts.argmax(dim=1), label)
        #
        #             acc_new.update(accuracy.item(), data.size(0))
        #     all_acc_new.append(acc_new.avg)
        # print(all_acc_new)

        # for i in range(session + 1):
        #     if i == 0:
        #         classes = np.arange(args.num_classes)[:args.base_class]
        #     else:
        #         classes = np.arange(args.num_classes)[
        #                   (args.base_class + (i - 1) * args.way):(args.base_class + i * args.way)]
        #     if args.dataset == 'cifar100':
        #         test_for_each = args.Dataset.CIFAR100(root=args.data_folder, train=False, index=classes,
        #                                               base_sess=False)
        #     elif args.dataset == 'mini_imagenet':
        #
        #         test_for_each = args.Dataset.MiniImageNet(root=args.data_folder, train=False,
        #                                                   index=classes)
        #     else:
        #         test_for_each = args.Dataset.CUB200(root=args.data_folder, train=False, index=classes)
        #
        #     testloader2 = torch.utils.data.DataLoader(dataset=test_for_each, batch_size=args.batch_size_inference,
        #                                               shuffle=False, pin_memory=True)
        #     model.eval()
        #     acc2 = AverageMeter('Acc@1', ':6.2f')
        #     with t.no_grad():
        #         for i, batch in enumerate(testloader2):
        #             data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
        #             features = model.embedding(data)
        #
        #             # query_features = features.cpu().numpy()
        #             predicts = classifier(features)
        #             accuracy = top1accuracy(predicts.argmax(dim=1), label)
        #
        #             acc2.update(accuracy.item(), data.size(0))
        #     acc_up2now.append(acc2.avg)
        # print(acc_up2now)
        #
        # with open(logs_dir, 'a', encoding='utf-8') as f1:
        #     f1.write(f'{acc_up2now} Acc each session:\t{acc_each_session}\t Novel classes Acc:{all_acc_new}\n')

    # Stage 5: Fill up prototypes again
    model.eval()

    # model.reset_prototypes(args)
    # model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    if session == 0:
        all_cov = base_cov
        all_proto = base_prototype
        acc_each_session = best_acc1
    else:
        all_cov = cov_saver
        all_proto = prototype_saver
        acc_each_session = acc.avg
    return all_proto, all_cov, acc_each_session


def proto_align_v4(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way



        base_proxy = model.key_mem.data[:oways_session]


        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        # c_proto = Tanh10x()(model.key_mem.data)
        c_proto = torch.sign(model.key_mem.data)
        # c_proto = model.key_mem.data

        base_torch = model.key_mem.data[:oways_session]

        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])

        c_proto = c_proto.cpu().numpy()

        base_prototype = c_proto[:oways_session]

        cov_saver = base_cov
        # prototype_saver = base_prototype
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=300)


            proto_temp2 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            similar = cosine_similarity_multi(proto_temp2, base_proxy, rep=args.representation)


            similar  = torch.argmin(-similar.sum(dim=1))
            c_proto[i] = proto_temp[similar.cpu().numpy()]

            cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            # prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])



        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto
        model.nudge_prototypes(nways_session, writer, session, args.gpu)

        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
            loss = loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    if session==0:
        all_cov = base_cov
        # all_proto = base_prototype
    else:
        all_cov = cov_saver
        # all_proto = prototype_saver

    return  all_cov



def proto_align_v3(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):
    ######use saved mean and covariance##########
    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way



        base_proxy = model.key_mem.data[:model.args.base_class]


        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        # c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        c_proto = model.key_mem.data

        # base_torch = model.key_mem.data[:oways_session]

        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])


        c_proto = c_proto.cpu().numpy()

        # base_prototype = c_proto[:oways_session]

        cov_saver = base_cov
        prototype_saver = base_prototype
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way * args.shot)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=300)


            proto_temp2 = torch.from_numpy(proto_temp).float().cuda(args.gpu)
            similar = cosine_similarity_multi(proto_temp2, base_proxy, rep=args.representation)
            similar_most  = torch.argmin(-similar.sum(dim=1))
            c_proto[i] = proto_temp[similar_most.cpu().numpy()]

            cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])



        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto
        model.nudge_prototypes(nways_session, writer, session, args.gpu)

        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])

            loss = 1.0 * loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    if session==0:
        all_cov = base_cov
        all_proto = base_prototype
    else:
        all_cov = cov_saver
        all_proto = prototype_saver

    return all_proto, all_cov





def proto_align_v2(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    sinkhorn_multi = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)
    # sinkhorn_multi2 = SinkhornDistance_given_cost(eps=0.01, max_iter=200, reduction=None).cuda(args.gpu)
    # sinkhorn_multi3 = SinkhornDistance_another_one_to_multi(eps=0.01, max_iter=200, reduction=None).cuda(args.gpu)
    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            all_feature = model.embedding(x)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way


        # c_proto = model.key_mem.data

        old_proxy = model.key_mem.data[model.args.base_class:model.args.base_class + (session) * args.way]

        base_proxy = model.key_mem.data[:model.args.base_class]


        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
        # base_torch = model.key_mem.data[:args.base_class]

        c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        # c_proto = model.key_mem.data

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])


        c_proto = c_proto.cpu().numpy()
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=args.sample_num)

            proto_temp3 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            # probablity = model.pseduo_feature_inference(proto_temp3)

            # cost3, Pi3, C3 = sinkhorn_multi2(base_torch, proto_temp3, probablity.transpose(1,0))

            # cost4, Pi4, C4 = sinkhorn_multi3(base_torch, proto_temp3, probablity)
            # print(Pi4.shape)

            cost2, Pi2, C2 = sinkhorn_multi(base_torch,proto_temp3)

            new_temp = torch.matmul(Pi[:, i-oways_session], torch.matmul(Pi2, proto_temp3))
            # new_temp = torch.mean(torch.matmul(Pi2, proto_temp3),dim=0)

            c_proto[i] = new_temp.cpu().numpy()

            # c_proto[i] = mean

        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto

        model.nudge_prototypes(nways_session, writer, session, args.gpu)

        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])

            loss = 1.0 * loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # plot_confusion_support(model.key_mem.data.cpu(),
    #                     savepath="{:}/relu2{:}".format(args.log_dir, str(session)))

    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)





def proto_align_final(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    sinkhorn_multi = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)
    # sinkhorn_multi2 = SinkhornDistance_given_cost(eps=0.01, max_iter=200, reduction=None).cuda(args.gpu)
    # sinkhorn_multi3 = SinkhornDistance_another_one_to_multi(eps=0.01, max_iter=200, reduction=None).cuda(args.gpu)
    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            all_feature = model.embedding(x)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way


        # c_proto = model.key_mem.data

        old_proxy = model.key_mem.data[model.args.base_class:model.args.base_class + (session) * args.way]

        base_proxy = model.key_mem.data[:model.args.base_class]


        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
        base_torch = Tanh10x()(model.key_mem.data[:args.base_class])
        base_prototype = base_torch.cpu().numpy()

        c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        # c_proto = model.key_mem.data

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])


        c_proto = c_proto.cpu().numpy()
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=args.sample_num)

            proto_temp3 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            # probablity = model.pseduo_feature_inference(proto_temp3)

            # cost3, Pi3, C3 = sinkhorn_multi2(base_torch, proto_temp3, probablity.transpose(1,0))

            # cost4, Pi4, C4 = sinkhorn_multi3(base_torch, proto_temp3, probablity)
            # print(Pi4.shape)

            cost2, Pi2, C2 = sinkhorn_multi(base_torch,proto_temp3)

            new_temp = torch.matmul(Pi[:, i-oways_session], torch.matmul(Pi2, proto_temp3))
            # new_temp = torch.mean(torch.matmul(Pi2, proto_temp3),dim=0)

            c_proto[i] = new_temp.cpu().numpy()

            # c_proto[i] = mean

        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto

        model.nudge_prototypes(nways_session, writer, session, args.gpu)

        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])

            loss = 1.0 * loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # plot_confusion_support(model.key_mem.data.cpu(),
    #                     savepath="{:}/relu2{:}".format(args.log_dir, str(session)))

    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)



def feat_replay(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov, test_loader, best_acc1):
    ######use saved mean and covariance perform clustering##########
    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)



    classifier = MultiClassLogisticRegression(input_dim=args.dim_features, output_dim=nways_session).cuda(args.gpu)
    acc = AverageMeter('Acc@1', ':6.2f')
    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            x_features = model.embedding(x)

            model.update_feat_replay(x, target)



    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)
    logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
    if session == 0:
        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'\nSource Model:{args.resume}\n')
            f1.write(f'{best_acc1}\t{best_acc1}\n')

    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way

        c_proto = model.key_mem.data

        base_torch = model.key_mem.data[:args.base_class]

        base_prototype = model.key_mem.data[:args.base_class].cpu().numpy()
        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        cost, Pi, C = sinkhorn(base_torch[:args.base_class], c_proto[args.base_class:nways_session])


        c_proto = c_proto.cpu().numpy()

        # base_prototype = c_proto[:oways_session]

        cov_saver = base_cov
        prototype_saver = base_prototype



        sampled_feature_old = []
        sampled_label_old = []


        sample_num_old =400

        for idx in range(args.base_class):
            sampled_feature_old.append(np.random.multivariate_normal(mean=base_prototype[idx], cov=base_cov[idx], size=sample_num_old))
            sampled_label_old.extend([idx] *sample_num_old)


        sampled_feature_old = np.array(sampled_feature_old).reshape(args.base_class*sample_num_old, -1)
        sampled_label_old = np.array(sampled_label_old)





        sampled_feature_new = []
        sampled_label_new = []
        sample_num = 1000
        for i in range(args.base_class,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-args.base_class], base_prototype[:args.base_class], base_cov[:args.base_class],
                                                     n_lsamples=args.way)

            sampled_feature_new.append(np.random.multivariate_normal(mean=mean, cov=cov, size=sample_num))

            sampled_label_new.extend([i] *sample_num)
            cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])

        sampled_feature_new = np.array(sampled_feature_new).reshape((nways_session-args.base_class)*sample_num, -1)
        sampled_label_new = np.array(sampled_label_new)

        sampled_feature_all = np.concatenate([sampled_feature_old, sampled_feature_new], axis=0)
        sampled_label_all = np.concatenate([sampled_label_old, sampled_label_new], axis=0)

        sampled_feature_all = torch.from_numpy(sampled_feature_all).cuda(args.gpu).float()
        sampled_feature_all = torch.concat([sampled_feature_all,x_features],dim=0)

        sampled_label_all = torch.from_numpy(sampled_label_all).cuda(args.gpu)
        sampled_label_all = torch.concat([sampled_label_all,target], dim=0)

        # sampled_feature_new = sampled_feature_new.reshape(5 * sample_num, -1)
        # sampled_feature_new = torch.from_numpy(sampled_feature_new).cuda(args.gpu).float()
        # sampled_label_new = torch.from_numpy(sampled_label_new).cuda(args.gpu)
        # sampled_feature_new = torch.concat([sampled_feature_new,x_features],dim=0)
        # sampled_label_new = torch.concat([sampled_label_new, target], dim=0)

        num_epochs = 10000
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, nesterov=args.SGDnesterov,
                            weight_decay=args.SGDweight_decay, momentum=args.SGDmomentum)
        classifier.train()

        for epoch in range(num_epochs):

            # outputs = classifier(sampled_feature_all)
            # loss = nn.CrossEntropyLoss()(outputs, sampled_label_all)
            outputs = classifier(sampled_feature_all)

            # proxy = classifier.linear
            # features = sampled_feature_all
            # loss_pcl = PCLoss(num_classes=nways_session, scale=12)(features, sampled_label_all, proxy)
            # old_class_weights = torch.ones(args.base_class).cuda(args.gpu)
            # novel_class_weights = 3.0*torch.ones(nways_session-args.base_class).cuda(args.gpu)
            #
            # weights = torch.cat([old_class_weights,novel_class_weights])
            loss = nn.CrossEntropyLoss()(outputs, sampled_label_all)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        model.eval()
        classifier.eval()


        logs_dir = os.path.join(args.log_dir + '/' + 'test_log.txt')
        all_acc = []
        with t.no_grad():
            for i, batch in enumerate(test_loader):
                data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]

                features = model.embedding(data)

                # query_features = features.cpu().numpy()
                predicts = classifier(features)

                # predicts = torch.from_numpy(predicts).cuda(args.gpu)
                accuracy = top1accuracy(predicts.argmax(dim=1), label)
                # losses.update(loss.item(), data.size(0))
                acc.update(accuracy.item(), data.size(0))

        acc_each_session = acc.avg
        print("Session {:} Testing Acc: {:.2f}%".format(session, acc_each_session))
        all_acc.append(acc_each_session)
        acc_up2now = []

        for i in range(session + 1):
            if i == 0:
                classes = np.arange(args.num_classes)[:args.base_class]
            else:
                classes = np.arange(args.num_classes)[
                          (args.base_class + (i - 1) * args.way):(args.base_class + i * args.way)]
            if args.dataset == 'cifar100':
                test_for_each = args.Dataset.CIFAR100(root=args.data_folder, train=False, index=classes,
                                                      base_sess=False)
            elif args.dataset == 'mini_imagenet':

                test_for_each = args.Dataset.MiniImageNet(root=args.data_folder, train=False,
                                                          index=classes)
            else:
                test_for_each = args.Dataset.CUB200(root=args.data_folder, train=False, index=classes)

            testloader2 = torch.utils.data.DataLoader(dataset=test_for_each, batch_size=args.batch_size_inference,
                                                      shuffle=False, pin_memory=True)
            model.eval()
            acc2 = AverageMeter('Acc@1', ':6.2f')
            with t.no_grad():
                for i, batch in enumerate(testloader2):
                    data, label = [_.cuda(args.gpu, non_blocking=True) for _ in batch]
                    features = model.embedding(data)

                    # query_features = features.cpu().numpy()
                    predicts = classifier(features)
                    accuracy = top1accuracy(predicts.argmax(dim=1), label)

                    acc2.update(accuracy.item(), data.size(0))
            acc_up2now.append(acc2.avg)
        print(acc_up2now)

        with open(logs_dir, 'a', encoding='utf-8') as f1:
            f1.write(f'{acc_up2now}\t{acc_each_session}\n')

        # if session == 0:
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'\nSource Model:{args.resume}\n')
        #         f1.write(f'{acc_up2now}\t{acc_each_session}\n')


        # if session == args.sessions - 1:
        #     mean_acc = np.mean(all_acc)
        #     with open(logs_dir, 'a', encoding='utf-8') as f1:
        #         f1.write(f'Mean Acc for this run is: {mean_acc}\t Each Session Acc is{all_acc} \n')
        #         print((f'Mean Acc for this run is: {mean_acc}\n Each Session Acc is{all_acc} \n'))
        #


    # Stage 5: Fill up prototypes again
    model.eval()

    # model.reset_prototypes(args)
    # model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    if session == 0:
        all_cov = base_cov
        all_proto = base_prototype
        acc_each_session = best_acc1
    else:
        all_cov = cov_saver
        all_proto = prototype_saver
        acc_each_session = acc.avg
    return all_proto, all_cov, acc_each_session


def proto_align_v2_for_plot(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):
    '''
    Alignment of FC using MSE Loss and feature replay
    '''

    losses = AverageMeter('Loss')


    criterion = myCosineLoss(args.retrain_act)


    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)



    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            all_feature = model.embedding(x)

            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]

    before_caliber = model.key_mem.data[:nways_session]

    # if session == 1:
    #     np.random.seed(42)
    #     X =  before_caliber.cpu().numpy()  # 示例数据，50 个样本，每个样本有 10 个特征
    #
    #
    #     # 使用 t-SNE 进行降维
    #     tsne = TSNE(n_components=2, random_state=42)
    #     X_tsne = tsne.fit_transform(X)
    #
    #     # 绘制 t-SNE 图
    #     num_samples = X_tsne.shape[0]
    #     colors = plt.cm.get_cmap('Pastel1', num_samples)
    #     color_indices = np.arange(60)  # 前 60 个样本的索引
    #
    #     # 绘制 t-SNE 图
    #     plt.figure(figsize=(8, 6))
    #
    #     # 根据颜色映射绘制前 60 个样本的散点图
    #     plt.scatter(X_tsne[color_indices, 0], X_tsne[color_indices, 1], c=color_indices, cmap=colors, marker='o', s=50)
    #
    #     # 绘制剩下的样本为蓝色
    #     plt.scatter(X_tsne[60:, 0], X_tsne[60:, 1], c='b', marker='o', s=50, label='Class 2')
    #
    #     plt.title('t-SNE Visualization')
    #     plt.xlabel('t-SNE Dimension 1')
    #     plt.ylabel('t-SNE Dimension 2')
    #     plt.colorbar()  # 添加颜色映射的颜色条
    #
    #     save_path = args.log_dir + 'tsne_plot_cifar_before.pdf'
    #     # 保存图为 PDF 文件
    #     plt.savefig(save_path, format='pdf')



    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way


        # c_proto = model.key_mem.data

        old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]

        base_proxy = model.key_mem.data[:model.args.base_class]


        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
        # base_torch = model.key_mem.data[:args.base_class]

        # c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        c_proto = model.key_mem.data

        # base_torch = Tanh10x()(base_torch)
        # c_proto = Tanh10x()(model.key_mem.data)
        # c_proto[oways_session:nways_session] = Tanh10x()(c_proto[oways_session:nways_session])
        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])

        c_proto = c_proto.cpu().numpy()
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way * args.shot)


            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=300)





            proto_temp2 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            # reference = torch.from_numpy(mean).float().cuda(args.gpu)
            similar = cosine_similarity_multi(proto_temp2, base_proxy, rep=args.representation)


            similar_most  = torch.argmin(-similar.sum(dim=1))

            c_proto[i] = proto_temp[similar_most.cpu().numpy()]
            # c_proto[i] = mean
        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto

        model.nudge_prototypes(nways_session, writer, session, args.gpu)


        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)


            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])

            loss = 1.0 * loss_cls

            # Backpropagation
            loss.backward()
            # hook_handle.remove()
            optimizer.step()

            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)

        # support2 = model.get_support_feat(feat)
        # print(type(support2))
        # plot_confusion_support(support2[:nways_session].cpu().detach,
        #                    savepath="{:}/relu2{:}".format(args.log_dir, str(session)))
    # print(cosine_similarity_multi(support[:nways_session],support[:nways_session]))
    #     for param in model.embedding.fc.parameters():
    #         print(param)




    # if session == 1:
    #     np.random.seed(42)
    #
    #
    #     X = model.key_mem.data[:nways_session].cpu().numpy()
    #
    #     features = all_feature.cpu().numpy()
    #
    #     X2 =  before_caliber.cpu().numpy()[60:65]  # 示例数据，50 个样本，每个样本有 10 个特征
    #
    #     # X = np.concatenate((X2,X,features), axis=0)
    #     X = np.concatenate((X, features,X2), axis=0)
    #
    #     # 使用 t-SNE 进行降维
    #     tsne = TSNE(n_components=2, random_state=42, perplexity=20)
    #     X_tsne = tsne.fit_transform(X)
    #
    #     # 绘制 t-SNE 图
    #     num_samples = X_tsne.shape[0]
    #     colors = plt.cm.get_cmap('Pastel1', num_samples)
    #     color_indices = np.arange(0,30)  # 前 60 个样本的索引
    #
    #     # 绘制 t-SNE 图
    #     plt.figure(figsize=(16, 9))
    #
    #     # 根据颜色映射绘制前 60 个样本的散点图
    #     plt.scatter(X_tsne[color_indices, 0], X_tsne[color_indices, 1], c=color_indices, cmap=colors, marker='o', s=50, label='Base Prototypes')
    #
    #
    #     plt.scatter(X_tsne[60:65, 0], X_tsne[60:65, 1], c='b', marker='o', s=50, label='calibrated new Prototypes')
    #
    #
    #     # 绘制剩下的样本为蓝色
    #     # plt.scatter(X_tsne[65:70, 0], X_tsne[65:70, 1], c='b', marker='o', s=50, label='calibrated new Prototypes')
    #
    #     plt.scatter(X_tsne[65:70, 0], X_tsne[65:70, 1], c='g', marker="x", s=50, label='new class1')
    #
    #     plt.scatter(X_tsne[70:75, 0], X_tsne[70:75, 1], c='g', marker="^", s=50, label='new class2')
    #
    #     plt.scatter(X_tsne[75:80, 0], X_tsne[75:80, 1], c='g', marker="s", s=50, label='new class3')
    #
    #     plt.scatter(X_tsne[80:85, 0], X_tsne[80:85, 1], c='g', marker="p", s=50, label='new class4')
    #
    #     plt.scatter(X_tsne[85:90, 0], X_tsne[85:90, 1], c='g', marker="D", s=50, label='new class5')
    #
    #     plt.scatter(X_tsne[90:95, 0], X_tsne[90:95, 1], c='r', marker="o", s=50, label='og new samples')
    #
    #     # for i in range(5):
    #     #     my_list = ["p"]
    #     #     selected_elements = np.random.choice(my_list, size=3, replace=False)
    #     #     plt.scatter(X_tsne[70:70+(i+1), 0], X_tsne[70:, 1], c='g', marker= "^", s=50, label='samples')
    #
    #
    #     plt.title('t-SNE Visualization')
    #     plt.xlabel('t-SNE Dimension 1')
    #     plt.ylabel('t-SNE Dimension 2')
    #     plt.legend()
    #
    #     save_path = args.log_dir + 'tsne_plot_cifar_after_calibration2.pdf'
    #     # 保存图为 PDF 文件
    #     plt.savefig(save_path, format='pdf')


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # plot_confusion_support(model.key_mem.data[:nways_session].cpu(),
    #                        savepath="{:}/relu2{:}".format(args.log_dir, str(session)))

    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)



# def proto_align_v2(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):
#     '''
#     Alignment of FC using MSE Loss and feature replay
#     '''
#
#     losses = AverageMeter('Loss')
#
#
#     criterion = myCosineLoss(args.retrain_act)
#
#
#     dataset = myRetrainDataset(data[0], data[1])
#     dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
#     sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)
#
#
#     # Stage 1: Compute feature representation of new data
#     model.eval()
#     with t.no_grad():
#         for x, target in dataloader:
#             x = x.cuda(args.gpu, non_blocking=True)
#             target = target.cuda(args.gpu, non_blocking=True)
#             model.update_feat_replay(x, target)
#
#
#     # Stage 2: Compute prototype based on GAAM
#     feat, label = model.get_feat_replay()
#
#     model.reset_prototypes(args)
#     model.update_prototypes_feat(feat, label, nways_session)
#
#     # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
#
#
#     # Stage 3: Nuddging
#     # model.nudge_prototypes(nways_session, writer, session, args.gpu)
#
#     # Bipolarize prototypes in Mode 2
#     if args.bipolarize_prototypes:
#         model.bipolarize_prototypes()
#
#     # Stage 4: Update Retraining the FC
#
#     model.embedding.fc.train()
#
#
#     if session > 0:
#         nways_session = args.base_class + session * args.way
#         oways_session = args.base_class + (session - 1) * args.way
#
#
#         # c_proto = model.key_mem.data
#
#
#
#         base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
#         # base_torch = model.key_mem.data[:args.base_class]
#
#         # c_proto = Tanh10x()(model.key_mem.data)
#         # c_proto = torch.sign(model.key_mem.data)
#
#         c_proto = model.key_mem.data[:,0,:]
#         # print(c_proto.shape)
#
#         # base_torch = Tanh10x()(base_torch)
#         # c_proto = Tanh10x()(model.key_mem.data)
#         # c_proto[oways_session:nways_session] = Tanh10x()(c_proto[oways_session:nways_session])
#         cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])
#
#
#         c_proto = c_proto.cpu().numpy()
#
#         # c_proto2 = c_proto
#         # for j in range(oways_session):
#         #     proto_temp = np.random.multivariate_normal(mean=base_prototype, cov=base_cov, size=1)
#         #     c_proto2[j] = proto_temp
#
#
#         for i in range(oways_session, nways_session):
#             mean, cov = distribution_calibration_dan(c_proto[i], Pi[:, i - oways_session], base_prototype, base_cov,
#                                                      n_lsamples=args.way * args.shot)
#
#             proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=1)
#
#             c_proto[i] = proto_temp
#
#
#         c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)
#
#         # model.key_mem.data = Tanh10x()(c_proto)
#
#         model.key_mem.data[:args.base_class,1,:] = model.key_mem.data[:args.base_class,0,:]
#         model.key_mem.data[oways_session:nways_session,1,:] = c_proto[oways_session:nways_session]
#         # model.key_mem.data[:oways_session,1,:] = c_proto2[:oways_session]
#
#         # print(model.key_mem.data[oways_session-2:nways_session,:,:1])
#         # for epoch in range(args.retrain_iter):
#         #
#         #     optimizer.zero_grad()
#         #     support = model.get_support_feat(feat)
#         #
#         #
#         #     loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
#         #
#         #     loss = 1.0 * loss_cls
#         #     # loss =  loss_n_cls
#         #
#         #
#         #
#         #     # Backpropagation
#         #     loss.backward()
#         #     optimizer.step()
#         #
#         #     writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)
#
#     # plot_confusion_support(model.key_mem.data[:nways_session].cpu(),
#     #                        savepath="{:}/relu2{:}".format(args.log_dir, str(session)))
#
#
#
#
#     # Stage 5: Fill up prototypes again
#     model.eval()
#     model.reset_prototypes(args)
#     model.update_prototypes_feat(feat, label, nways_session)
#
#
#     # Stage 6: Optional EM compression
#     if args.em_compression == "hrr":
#         model.hrr_superposition(nways_session, args.em_compression_nsup)


def proto_align_v3(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov):
    ######use saved mean and covariance##########
    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()

    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)

    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way



        base_proxy = model.key_mem.data[:model.args.base_class]


        # base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        # c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        c_proto = model.key_mem.data

        # base_torch = model.key_mem.data[:oways_session]

        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])


        c_proto = c_proto.cpu().numpy()

        # base_prototype = c_proto[:oways_session]

        cov_saver = base_cov
        prototype_saver = base_prototype
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way * args.shot)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=300)


            proto_temp2 = torch.from_numpy(proto_temp).float().cuda(args.gpu)
            similar = cosine_similarity_multi(proto_temp2, base_proxy, rep=args.representation)
            similar_most  = torch.argmin(-similar.sum(dim=1))
            c_proto[i] = proto_temp[similar_most.cpu().numpy()]

            cov_saver = np.concatenate([cov_saver, np.expand_dims(cov,0)])
            prototype_saver = np.concatenate([prototype_saver, np.expand_dims(mean,0)])



        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto
        model.nudge_prototypes(nways_session, writer, session, args.gpu)

        for epoch in range(args.retrain_iter):

            optimizer.zero_grad()
            support = model.get_support_feat(feat)
            loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])

            loss = 1.0 * loss_cls
            # Backpropagation
            loss.backward()
            optimizer.step()
            writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again
    model.eval()
    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)


    # Stage 6: Optional EM compression
    if args.em_compression == "hrr":
        model.hrr_superposition(nways_session, args.em_compression_nsup)

    if session==0:
        all_cov = base_cov
        all_proto = base_prototype
    else:
        all_cov = cov_saver
        all_proto = prototype_saver

    return all_proto, all_cov





def proto_sample(model, data, optimizer, args, writer, session, nways_session, base_prototype, base_cov, classifier_weights):

    losses = AverageMeter('Loss')
    criterion = myCosineLoss(args.retrain_act)
    dataset = myRetrainDataset(data[0], data[1])
    dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size_training)
    sinkhorn = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    sinkhorn_multi = SinkhornDistance(eps=0.01, max_iter=200, args=args, reduction=None).cuda(args.gpu)

    # Stage 1: Compute feature representation of new data
    model.eval()
    with t.no_grad():
        for x, target in dataloader:
            x = x.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            model.update_feat_replay(x, target)


    # Stage 2: Compute prototype based on GAAM
    feat, label = model.get_feat_replay()


    model.reset_prototypes(args)
    model.update_prototypes_feat(feat, label, nways_session)



    # old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]
    # Stage 3: Nuddging
    # model.nudge_prototypes(nways_session, writer, session, args.gpu)

    # Bipolarize prototypes in Mode 2
    if args.bipolarize_prototypes:
        model.bipolarize_prototypes()

    # Stage 4: Update Retraining the FC

    # model.embedding.fc.train()


    if session > 0:
        nways_session = args.base_class + session * args.way
        oways_session = args.base_class + (session - 1) * args.way


        model.key_mem.data[:oways_session] = classifier_weights

        # c_proto = model.key_mem.data

        old_proxy = model.key_mem.data[:model.args.base_class + (session - 1) * args.way]

        base_proxy = model.key_mem.data[:model.args.base_class]



        base_torch = torch.from_numpy(base_prototype).cuda(args.gpu)
        # base_torch = model.key_mem.data[:args.base_class]

        # c_proto = Tanh10x()(model.key_mem.data)
        # c_proto = torch.sign(model.key_mem.data)
        c_proto = model.key_mem.data

        cost, Pi, C = sinkhorn(base_torch, c_proto[oways_session:nways_session])


        c_proto = c_proto.cpu().numpy()
        for i in range(oways_session,nways_session):

            mean, cov = distribution_calibration_v2(c_proto[i], Pi[:, i-oways_session], base_prototype, base_cov,
                                                     n_lsamples=args.way)

            proto_temp = np.random.multivariate_normal(mean=mean, cov=cov, size=1000)

            proto_temp3 = torch.from_numpy(proto_temp).float().cuda(args.gpu)

            cost2, Pi2, C2 = sinkhorn_multi(base_torch,proto_temp3)

            new_temp = torch.matmul(Pi[:, i-oways_session], torch.matmul(Pi2, proto_temp3))

            classifier_weights = torch.cat([classifier_weights,new_temp.unsqueeze(0)], dim=0)

            c_proto[i] = new_temp.cpu().numpy()

        c_proto = torch.from_numpy(c_proto).float().cuda(args.gpu)

        # model.key_mem.data = Tanh10x()(c_proto)
        model.key_mem.data = c_proto

        model.nudge_prototypes(nways_session, writer, session, args.gpu)


        # for epoch in range(args.retrain_iter):
        #
        #     optimizer.zero_grad()
        #     support = model.get_support_feat(feat)
        #     loss_cls = criterion(support[:nways_session], model.key_mem.data[:nways_session])
        #
        #     loss = 1.0 * loss_cls
        #     # Backpropagation
        #     loss.backward()
        #     optimizer.step()
        #     writer.add_scalar('retraining/loss_sess{:}'.format(session), loss.item(), epoch)


    # Stage 5: Fill up prototypes again

    # if session == 0:
    #     model.key_mem.data = base_proxy


    model.eval()

    return classifier_weights
    # model.reset_prototypes(args)
    # model.update_prototypes_feat(feat, label, nways_session)









def distribution_calibration(query, base_means, base_cov, k,alpha=0.21):
    dist = []
    for i in range(len(base_means)):
        dist.append(np.linalg.norm(query-base_means[i]))
    index = np.argpartition(dist, k)[:k]

    mean = np.concatenate([np.array(base_means)[index], query[np.newaxis, :]])

    calibrated_mean = np.mean(mean, axis=0)
    calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha

    return calibrated_mean, calibrated_cov




def distribution_calibration_v2(prototype, probabi, base_means, base_cov, n_lsamples, alpha=0.21, lambd=0.3, k=10):
    # index = np.argsort(-probabi.numpy())
    dim = base_means[0].shape[0]
    calibrated_mean = 0
    calibrated_cov = 0

    probabi = probabi.cpu()

    proab_reshape = np.repeat(n_lsamples * probabi.numpy(), dim, axis=0).reshape(len(base_means), dim)
    calibrated_mean = (1 - lambd) * np.sum(proab_reshape * np.concatenate([base_means[:]]), axis=0) + lambd * prototype
    #
    proab_reshape_conv = np.repeat(n_lsamples * probabi.numpy(), dim * dim, axis=0).reshape(len(base_means), dim, dim)
    calibrated_cov = np.sum(proab_reshape_conv * np.concatenate([base_cov[:]]), axis=0) + alpha
    return calibrated_mean, calibrated_cov






def validation(model,criterion,dataloader, args,nways_session=None):
    losses = AverageMeter('Loss', ':.4e')
    acc = AverageMeter('Acc@1', ':6.2f')

    sim_conf = avg_sim_confusion(args.num_classes,nways_session)
    model.eval()
    with t.no_grad(): 
        for i, batch in enumerate(dataloader):
            data, label = [_.cuda(args.gpu,non_blocking=True) for _ in batch]

            output = model(data)
            loss = criterion(output,label)

            # print(output)
            # print(output[:5])
            accuracy = top1accuracy(output.argmax(dim=1),label)
            losses.update(loss.item(),data.size(0))
            acc.update(accuracy.item(),data.size(0))
            # if nways_session is not None:
            #     sim_conf.update(model.similarities.detach().cpu(),
            #                 F.one_hot(label.detach().cpu(), num_classes = args.num_classes).float())
    # Plot figure if needed
    fig = sim_conf.plot() if nways_session is (not None) else None
    return losses.avg, acc.avg, fig

def validation_onehot(model,criterion,dataloader, args, num_classes):
    #  

    losses = AverageMeter('Loss', ':.4e')
    acc = AverageMeter('Acc@1', ':6.2f')

    model.eval()

    with t.no_grad(): 
        for i, batch in enumerate(dataloader):
            data, label = [_.cuda(args.gpu,non_blocking=True) for _ in batch]
            label = F.one_hot(label, num_classes = num_classes).float()

            output = model(data)
            loss = criterion(output,label)
            
            _, _, _, _, accuracy = process_result(
                output,label)

            losses.update(loss.item(),data.size(0))
            acc.update(accuracy.item()*100,data.size(0))
    
    return losses.avg, acc.avg

# --------------------------------------------------------------------------------------------------
# Interpretation
# --------------------------------------------------------------------------------------------------
def process_result(predictions, actual):
    predicted_labels = t.argmax(predictions, dim=1)
    actual_labels = t.argmax(actual, dim=1)

    accuracy = predicted_labels.eq(actual_labels).float().mean(0,keepdim=True)
    # TBD implement those uncertainties
    predicted_certainties =0#
    actual_certainties = 0 #
    return predicted_labels, predicted_certainties, actual_labels, actual_certainties, accuracy


def process_dictionary(dict):
    # Convert the dictionary to a sorted list
    dict_list = sorted(list(dict.items()))

    # Convert the dictionary into a table
    keys, values = zip(*dict_list)
    values = [repr(value) for value in values]
    dict_table = np.vstack((np.array(keys), np.array(values))).T

    return dict_list, dict_table

# --------------------------------------------------------------------------------------------------
# Summaries
# --------------------------------------------------------------------------------------------------
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar',savedir=''):
    t.save(state, savedir+'/'+filename)
    if is_best:
        shutil.copyfile(savedir+'/'+filename, savedir+'/'+'model_best.pth.tar')



def load_checkpoint(model,optimizer,scheduler,args):        

    # First priority: load checkpoint from log_dir 
    if os.path.isfile(args.log_dir+ '/checkpoint.pth.tar'):
        resume = args.log_dir+ '/checkpoint.pth.tar'
        print("=> loading checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = int(checkpoint['train_iter'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best_acc1 = checkpoint['best_acc1']
        model.load_state_dict(checkpoint['state_dict'])

        print("=> loaded checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))
        print('previous acc', best_acc1)
        prototype, cov, classlabel = None, None, None


    # Second priority: load from pretrained model
    # No scheduler and no optimizer loading here.  
    elif os.path.isfile(args.resume+'/model_best.pth.tar'):
        resume = args.resume+'/model_best.pth.tar'
        print("=> loading pretrain checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = 0 
        best_acc1 = 0
        model.load_state_dict(checkpoint['state_dict'])
        best_acc2 = checkpoint['best_acc1']
        print('previous best acc',best_acc2)
        print("=> loaded pretrained checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))

        # prototype = checkpoint['prototype']
        # cov = checkpoint['cov']
        # classlabel = checkpoint['classlabel']


    else:
        start_train_iter=0
        best_acc1 = 0
        prototype, cov, classlabel = None, None, None
        print("=> no checkpoint found at '{}'".format(args.log_dir))
        print("=> no pretrain checkpoint found at '{}'".format(args.resume))




    return model, optimizer, scheduler, start_train_iter, best_acc1,
    # return model, optimizer, scheduler, start_train_iter, best_acc1, prototype,cov,classlabel




def load_checkpoint2(model,optimizer,scheduler,args):

    # First priority: load checkpoint from log_dir
    if os.path.isfile(args.log_dir+ '/checkpoint.pth.tar'):
        resume = args.log_dir+ '/checkpoint.pth.tar'
        print("=> loading checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = int(checkpoint['train_iter'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best_acc1 = checkpoint['best_acc1']
        best_acc2 = best_acc1
        model.load_state_dict(checkpoint['state_dict'])

        print("=> loaded checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))
        print('previous acc', best_acc1)
        prototype, cov, classlabel = None, None, None


    # Second priority: load from pretrained model
    # No scheduler and no optimizer loading here.
    elif os.path.isfile(args.resume+'/model_best.pth.tar'):
        resume = args.resume+'/model_best.pth.tar'
        print("=> loading pretrain checkpoint '{}'".format(resume))
        if args.gpu is None:
            checkpoint = t.load(resume)
        else:
            # Map model to be loaded to specified single args.gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = t.load(resume, map_location=loc)
        start_train_iter = 0
        best_acc1 = 0
        model.load_state_dict(checkpoint['state_dict'])
        best_acc2 = checkpoint['best_acc1']
        print('previous best acc',best_acc2)
        print("=> loaded pretrained checkpoint '{}' (train_iter {})"
              .format(args.log_dir, checkpoint['train_iter']))

        prototype = checkpoint['prototype']
        cov = checkpoint['cov']
        classlabel = checkpoint['classlabel']


    else:
        start_train_iter=0
        best_acc1 = 0
        best_acc2 = 0
        prototype, cov, classlabel = None, None, None
        print("=> no checkpoint found at '{}'".format(args.log_dir))
        print("=> no pretrain checkpoint found at '{}'".format(args.resume))




    # return model, optimizer, scheduler, start_train_iter, best_acc1,
    return model, optimizer, scheduler, start_train_iter, best_acc2, prototype,cov,classlabel

# --------------------------------------------------------------------------------------------------
# Some Pytorch helper functions (might be removed from this file at some point)
# --------------------------------------------------------------------------------------------------





def convert_toonehot(label): 
    '''
    Converts index to one-hot. Removes rows with only zeros, such that 
    the tensor has shape (B,num_ways)
    '''
    label_onehot = F.one_hot(label)
    label_onehot = label_onehot[:,label_onehot.sum(dim=0)!=0]
    return label_onehot.type(t.FloatTensor)

def top1accuracy(pred, target):
    """Computes the precision@1"""
    batch_size = target.size(0)

    correct = pred.eq(target).float().sum(0)


    return correct.mul_(100.0 / batch_size)


class myRetrainDataset(Dataset):
    def __init__(self, x,y):
        self.x = x
        self.y = y
       
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class MultiClassLogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MultiClassLogisticRegression, self).__init__()
        # self.linear = nn.Linear(input_dim, output_dim)
        self.linear = nn.Parameter(t.FloatTensor(output_dim, input_dim))
        nn.init.kaiming_uniform_(self.linear, mode='fan_out', a=math.sqrt(5))



    def forward(self, x):
        # return self.linear(x)
        # a_normalized = F.normalize(x, dim=1)
        # b_normalized = F.normalize(self.linear, dim=1)
        similiarity = F.linear(x, self.linear)
        return  similiarity


