import sys, argparse
import numpy as np
import torch
from torch.nn.functional import relu, avg_pool2d
from buffer import Buffer
# import utils
import datetime
from torch.nn.functional import relu
import torch
import torch.nn as nn
import torch.nn.functional as F
from InfoNCE import tao as TL
#from InfoNCE import classifier as C
from InfoNCE.utils import normalize
from utils import feat_cam_normalized
from InfoNCE.contrastive_learning import get_similarity_matrix,Supervised_NT_xent_old_to_new,Supervised_NT_xent_envs,Supervised_NT_xent_pre,Supervised_NT_xent_n,Supervised_NT_xent_uni
import torch.optim.lr_scheduler as lr_scheduler
#from CSL.shedular import GradualWarmupScheduler
import torch
from apex import amp
import torchvision.transforms as transforms
import  torchvision
from torch.cuda.amp import GradScaler,autocast
import torchvision.transforms as transforms
import  torchvision


# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='(default=%(default)d)')
parser.add_argument('--experiment', default='cifar-100', type=str, required=False, help='(default=%(default)s)')
parser.add_argument('--approach', default='IFO', type=str, required=False, help='(default=%(default)s)')
parser.add_argument('--nepochs', default=25, type=int, required=False, help='(default=%(default)d)')
parser.add_argument('--lr', default=0.02, type=float, required=False, help='(default=%(default)f)')
parser.add_argument('--parameter', type=str, default='', help='(default=%(default)s)')
parser.add_argument('--dataset', type=str, default='cifar', help='(default=%(default)s)')
parser.add_argument('--input_size', type=str, default=[3, 32, 32], help='(default=%(default)s)')
parser.add_argument('--buffer_size', type=int, default=2000, help='(default=%(default)s)')
parser.add_argument('--gen', type=str, default=True, help='(default=%(default)s)')
parser.add_argument('--n_classes', type=int, default=512, help='(default=%(default)s)')
parser.add_argument('--buffer_batch_size', type=int, default=64, help='(default=%(default)s)')
args = parser.parse_args()
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # ignore warning
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # use gpu0,1
def rot_inner_all(x):
    num=x.shape[0]

    #print(num)
    R=x.repeat(4,1,1,1)
    a=x.permute(0,1,3,2)
    a = a.view(num,3, 2, 16, 32)
    import pdb
   # pdb.set_trace()
  #  imshow(torchvision.utils.make_grid(a))
    a = a.permute(2,0, 1, 3, 4)
    s1 = a[0]  # .permute(1,0, 2, 3)#, 4)
    s2 = a[1]  # .permute(1,0, 2, 3)
    a= torch.rot90(a, 2, (3, 4))
    s1_1=a[0]#.permute(1,0, 2, 3)#, 4)
    s2_2=a[1]#.permute(1,0, 2, 3)
    #S0 = torch.cat((s1.unsqueeze(2), s2.unsqueeze(2)), dim=2).reshape(num, 1, 28, 28).permute(0, 1, 3, 2)
    R[3*num:] = torch.cat((s1_1.unsqueeze(2), s2.unsqueeze(2)), dim=2).reshape(num,3, 32, 32).permute(0,1,3,2)
    R[num:2*num] = torch.cat((s1.unsqueeze(2), s2_2.unsqueeze(2)), dim=2).reshape(num,3, 32, 32).permute(0,1,3,2)
    R[2*num:3*num] = torch.cat((s1_1.unsqueeze(2), s2_2.unsqueeze(2)), dim=2).reshape(num,3, 32, 32).permute(0,1,3,2)

    return R
def Rotation(x,y):
   # print(x.shape)
  # if r<=-1:
        num=x.shape[0]
        X = rot_inner_all(x)#, 1, 0)
        ori_y=y
        y=y.repeat(16)
        for i in range(1,16):
            y[i*num:(i+1)*num]+=1000*i
        return torch.cat((X,torch.rot90(X,1,(2,3)),torch.rot90(X,2,(2,3)),torch.rot90(X,3,(2,3))),dim=0),y
   #else:
gpus = [0,1, 2, 3,5,6,7]
torch.cuda.set_device('cuda:{}'.format(gpus[0]))

print('=' * 100)
print('Arguments =')
for arg in vars(args):
    print('\t' + arg + ':', getattr(args, arg))
print('=' * 100)
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'GPU  ' + os.environ["CUDA_VISIBLE_DEVICES"])
print('=' * 100)
########################################################################################################################

# Seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
else:
    print('[CUDA unavailable]')
    sys.exit()
import cifar as dataloader
# import owm as approach
# import cnn_owm as network
#from minimodel import net as s_model
from Resnet18 import resnet18 as b_model
from buffer import Buffer as buffer
# imagenet200 import SequentialTinyImagenet as STI
from torch.optim import Adam, SGD  # ,SparseAdam
import torch.nn.functional as F
from copy import deepcopy
import matplotlib.pyplot as plt
def imshow(img):
    img=img/2+0.5
    npimg=img.cpu().numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
def test_model(loder,i,model):
    test_loss = 0
    correct = 0
    num = 0
    for batch_idx, (data, target) in enumerate(loder):

        data, target = data.cuda(), target.cuda()
        # data, target = Variable(data, volatile=True), Variable(target)
        model.eval()
        #import pdb
        #pdb.set_trace()
        #labels=torch.zeros(data.shape[0],10).cuda().scatter_(1,(target-i * 10).unsqueeze(1),1).cuda()

        pred = model.forward(data)#[:, i * 10:(i + 1) * 10]
        #pred[:, i * 10:(i + 1) * 10]=pred[:, i * 10:(i + 1) * 10]*labels
        Pred = pred.data.max(1, keepdim=True)[1]
        num += data.size()[0]
       # target -= i * 10
       # num += data.size()[0]

    #    print("final", Pred, target.data.view_as(Pred))
        # print(target,"True",pred)

        correct += Pred.eq(target.data.view_as(Pred)).cpu().sum()

    test_accuracy = 100. * correct / num  # len(data_loader.dataset)
    print(
        'Test set{}: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'
            .format(i,
            test_loss, correct, num,
            100. * correct / num, ))
    return test_accuracy
########################################################################################################################
# Basic_model=torch.load('basic_model25.pkl')
print('Load data...')
oop=16
data, taskcla, inputsize, Loder, test_loder = dataloader.get_cifar100_10(seed=args.seed)
print('Input size =', inputsize, '\nTask info =', taskcla)
buffero = buffer(args).cuda()
#args.input_size=[3,16,16]
#args.buffer_size=2000#750*4#889#2000#0#8000
#buffer_center=buffer(args).cuda()
Basic_model = b_model(100).cuda()
llabel = {}
Optimizer = Adam(Basic_model.parameters(), lr=0.001, betas=(0.9, 0.99),weight_decay=1e-4)#SGD(Basic_model.parameters(), lr=0.02, momentum=0.9)
hflip = TL.HorizontalFlipLayer().cuda()
Basic_model, Optimizer = amp.initialize(Basic_model, Optimizer,opt_level="O1")
with torch.no_grad():
    resize_scale = (0.3, 1.0)  # resize scaling factor,default [0.08,1]
    # if P.resize_fix: # if resize_fix is True, use same scale
    #    resize_scale = (P.resize_factor, P.resize_factor)

    # Align augmentation
    color_jitter = TL.ColorJitterLayer(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.1).cuda()

    color_gray = TL.RandomColorGrayLayer(p=0.25).cuda()
    resize_crop = TL.RandomResizedCropLayer(scale=resize_scale, size=[32, 32, 3]).cuda()
    simclr_aug = transform = torch.nn.Sequential(
       # color_jitter,  
        hflip,
        color_gray,  
        resize_crop, )
#for n,w in Basic_model.named_parameters():
 #   print(n,w.shape)
import cv2
def add_snow(image):
    image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS
    image_HLS = np.array(image_HLS, dtype = np.float64)
    brightness_coefficient = 2.5
    snow_point=14 ## increase this for more snow
    image_HLS[:,:,1][image_HLS[:,:,1]<snow_point] = image_HLS[:,:,1][image_HLS[:,:,1]<snow_point]*brightness_coefficient ## scale pixel values up for channel 1(Lightness)
    image_HLS[:,:,1][image_HLS[:,:,1]>255]  = 255 ##Sets all values above 255 to 255    i
    image_HLS = np.array(image_HLS, dtype = np.uint8)
    image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion to RGB
    return image_RGB
Max_acc=[]
transformto = transforms.ToPILImage()
print('=' * 100)
#xx=x[0].reshape(32,32,3).numpy()
#bb=torch.tensor(add_snow(xx)).reshape(3,32,32).float()
#imshow(bb)
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'GPU  ' + os.environ["CUDA_VISIBLE_DEVICES"])
print('=' * 100)
class_holder=[]
buffer_per_class =7
num_class_per_task=10
for run in range(1):
   # rank=torch.randperm(len(Loder))
    rank = torch.arange(0, 10)
   # rank = torch.tensor([0,1,2,3,4,5,6,7,8,9])
    for i in range(len(Loder)):

        print(i)
        task_id=i
        mem_x_whole=None
        if buffero.is_empty():
            train_loader = Loder[rank[i].item()]['train']
            #Optimizer = Adam(Basic_model.parameters(), lr=0.02, momentum=0.9)
            for epoch in range(1):
                Basic_model.train()
                num_d=0
                for batch_idx, (x, y) in enumerate(train_loader):
                    num_d+=x.shape[0]
                #    import pdb
                 #   pdb.set_trace()
                    if num_d%5000==0:
                        print(num_d,num_d/10000)
                  #  if batch_idx>3:
                   #     continue
                    llabel[i] = []

                    Y = deepcopy(y)
                    for j in range(len(Y)):
                        if Y[j] not in class_holder:
                            class_holder.append(Y[j].detach())

                    # print("idx",batch_idx)
                    # buffero.add_reservoir(x=x, y=y, logits=None, t=i)
                   # optimizer[i].zero_grad()
                    Optimizer.zero_grad()
                    # if args.cuda:
                    x, y = x.cuda(), y.cuda()
                    # x, y = Variable(x), Variable(y)
                    #imshow(torchvision.utils.make_grid(x[2]))
                    x = x.requires_grad_()
                    if (batch_idx + 1) % 10 == 0:
                        mem_x_whole, mem_y_whole, _ = buffero.onlysample(2000, task_id)
                        y_all_loose = mem_y_whole.contiguous().view(-1, 1)
                        from kmeans_pytorch import kmeans
                        import pdb

                        index = torch.arange(max(mem_y_whole) + 1)
                        y_loose = index.contiguous().view(-1, 1)
                      #  y_all_loose = mem_y_whole.contiguous().view(-1, 1)
                        Mask = (torch.eq(y_loose.cuda(), y_all_loose.cuda().t()).float())
                        with torch.no_grad():
                            hidden_whole = Basic_model.f_train(mem_x_whole)
                        clusters_id = {}
                        clusters = []
                        #  clusters_prev = []
                        pseudo_x = []
                        y_same = []

                        for iv in range(min(mem_y_whole),max(mem_y_whole) + 1):

                            # clusters[iv]
                            num_clusters = 3
                            # num_clusters = 3
                            # pdb.set_trace()
                            hidden_iv = hidden_whole[Mask[iv].bool()]

                            # kmeans
                           # if num_clusters < x.shape[0]:
                            cluster_ids_x, cluster_centers = kmeans(X=hidden_iv, num_clusters=num_clusters,
                                                                        distance='euclidean',
                                                                        device=torch.device('cuda:0'))
                            y_same.append(torch.tensor([iv, iv, iv]).cuda())
                            pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 0])
                            pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 1])
                            pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 2])
                            clusters.append(cluster_centers)
                       # import pdb
                       # pdb.set_trace()
                        clusters = normalize(torch.cat([clusters[i] for i in range(len(clusters))], dim=0))
                        y_same = torch.cat([y_same[i] for i in range(len(y_same))], dim=0)
                        y_same_con = y_same.contiguous().view(-1, 1)
                        print(task_id,batch_idx)

                    if mem_x_whole is not None:
                        y_new=y.contiguous().view(-1, 1)
                        Mask_same = (torch.eq(y_new.cuda(), y_same_con.cuda().t())).float()
                       # _, index_class = torch.topk(matrix_clusters.cuda() * Mask_same.cuda(), dim=1, k=1)
                       # mask=(torch.eq(y_new.cuda(),y_all_loose.t().cuda())).float()
                        matrix=torch.matmul(normalize(Basic_model.f_train(x)),normalize(clusters).t().cuda())
                        _,index=torch.topk(matrix.cuda()*Mask_same.cuda(),1,dim=1)

                        x_conflict=[]#mem_x_whole[index].squeeze(1)
                        y_conflict=[]
                        for tt in range(x.shape[0]):
                            for ii in range(3):
                                if index[tt]==3*y[tt]+ii:
                                    continue
                                else:
                                    idx=torch.randperm(len(pseudo_x[3*y[tt]+ii])).cuda()
                                    x_conflict.append(pseudo_x[3*y[tt]+ii][idx[0]])
                                    y_conflict.append(y[tt])


                        x_conflict=torch.cat([x_conflict[i].unsqueeze(0) for i in range(len(x_conflict))],dim=0)
                        #import pdb

                        #pdb.set_trace()
                        y_conflict = torch.stack(y_conflict).cuda()#[y_conflict[i].item() for i in range(len(y_conflict))])
                    #  import pdb
                        #pdb.set_trace()
                       # import pdb
                       # pdb.set_trace()
                      #  y_conflict=mem_y_whole[index]
                        images1, rot_sim_labels = Rotation(x, y)
                        #images1, rot_sim_labels = Rotation(torch.cat([x,x_conflict],dim=0), torch.cat([y,y_conflict]))
                    else:
                        images1, rot_sim_labels = Rotation(x, y)
                   # import pdb
                    #pdb.set_trace()


                   # imshow(torchvision.utils.make_grid(x[2]))
                    #positive_data2 = x.data
                    #images1, images2 = hflip(positive_data2.repeat(2, 1, 1, 1)).chunk(2)
                   # images1, rot_sim_labels = Rotation(x, y)

                    # images2 = torch.cat([Rotation(x, r) for r in range(4)])#.detach()  # 4B

                    # images1 = torch.cat([Rotation(x, r) for r in range(4)])
                    # images2 = torch.cat([Rotation(x, r) for r in range(4)])

                    images_pair = torch.cat([images1, simclr_aug(images1)], dim=0)

              #  labels1 = y.cuda()
                    # print("LLLL",labels1.shape)
               #     rot_sim_labels = torch.cat([labels1 + 100 * i for i in range(oop)], dim=0)
                #    Rot_sim_labels = torch.cat([labels1 + 0 * i for i in range(oop)], dim=0)
              #      rot_sim_labels1 = torch.cat([torch.arange(0,labels1.shape[0]).cuda()+ 10*i for i in range(4)], dim=0)
                    #print(labels1)
                    #print(Rot_sim_labels)
                    # print("RRRR1",rot_sim_labels)da()  # 这个label其实是用来mask的，4个rotate
                    rot_sim_labels = rot_sim_labels.cuda()
                  #  images_pair = simclr_aug(images_pair)  # simclr augment
                    # print("X",input_x.shape,images_pair.shape)
                  #  for n,w in Basic_model.named_parameters():
                   #     print(n,w.shape)
                    feature_map,outputs_aux = Basic_model(images_pair, is_simclr=True)
                   # outputs_aux1 = Basic_model(images_pair, is_simclr1=True)# , penultimate=True)

                    simclr = normalize(outputs_aux)  # normalize
                    feature_map_out = normalize(feature_map[:images_pair.shape[0]])
                   # num1 = feature_map_out.shape[1] // simclr.shape[1]
                    num1 = feature_map_out.shape[1] - simclr.shape[1]
                    id1 = torch.randperm(num1)[0]
                   # id1_2 = torch.randperm(num1)[1]
                    size = simclr.shape[1]
                 #   sim_matrix = torch.zeros((simclr.shape[0], simclr.shape[0])).cuda()
                    #sim_matrix_r = torch.zeros((simclr_r.shape[0], simclr_r.shape[0])).cuda()
                    #   sim_matrix_r_pre = torch.zeros((pre_r.shape[0], pre_r.shape[0])).cuda()

                    #for index in range(num1):
                        # pdb.set_trace()
                    sim_matrix = torch.matmul(simclr, feature_map_out[:, id1 :id1+ 1 * size].t())

                    sim_matrix += 1 * get_similarity_matrix(simclr)  # *(1-torch.eye(simclr.shape[0]).cuda())#+0.5*get_similarity_matrix(feature_map_out)

                    loss_sim1 = Supervised_NT_xent_n(sim_matrix, labels=rot_sim_labels,
                                                   temperature=0.07)
    #                loss_sim2 = Supervised_NT_xent(sim_matrix, labels=rot_sim_labels1,
     #                                          temperature=0.07)

                    #hid = model.return_hidden(input_x)
                    lo1 = 1 * loss_sim1 #+0*loss_sim2
                    #lo1.backward()
                    #y_pred = Basic_model.forward(simclr_aug(x))
                    #y_pred = S_model[i].forward(out)
                    #                y_pred=S_model[i](x)

                   # if task_id==0:
                    with torch.no_grad():
                        # center_x,center_y,_=buffer_center.sample(int(buffer_batch_size*0.5), exclude_task=None)
                        index_idx = torch.from_numpy(np.random.choice(x.size(0), x.shape[0], replace=True))

                        mix_new = x[index_idx]
                         #mix_x=x[index_new]#[0].unsqueeze(0).repeat(center_x.shape[0],1,1,1)
                        #mix_new[:, :, 4:28, 4:28] = F.interpolate(x, size=24)
                        mix_new[:, :, 4:28, 4:28] = F.interpolate(x, size=24)
                        if mem_x_whole is not None:
                            idy = torch.arange(x.shape[0]).cuda()
                            import copy

                            #idy = torch.randperm(x.shape[0]).cuda()
                            buffer_new_x = deepcopy(x_conflict[2 * idy])
                            buffer_new_x[:,:, 4: 28, 4: 28] = F.interpolate(x, size=24)
                            buffer_new_x_2 = deepcopy(x_conflict[2 * idy + 1])
                            buffer_new_x_2[:, :, 4: 28, 4: 28] = F.interpolate(x, size=24)
                    beta=[]
                    for _ in range(4):

                        beta.append(np.random.beta(1, 1))
                    #import pdb
                    #pdb.set_trace()
                    with torch.no_grad():

                        mask=feat_cam_normalized(Basic_model,x,y)
                       # mask=mask>0.2
                    #def feat_cam_normalized(model, x, y):
                    #import pdb
                    #pdb.set_trace()
                    alpha=0.25
                    if mem_x_whole is not None:
                        num_envs=8
                        flip_x = torch.cat((x, mix_new,buffer_new_x,buffer_new_x_2,
                                            torch.flip((mask > alpha) * (
                                                        beta[3] * x[:, torch.randperm(3)] + (1 - beta[3]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                            torch.flip((mask > alpha) * (
                                                        beta[2] * x[:, torch.randperm(3)] + (1 - beta[2]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                            torch.flip((mask > alpha) * (
                                                        beta[1] * x[:, torch.randperm(3)] + (1 - beta[1]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                            torch.flip((mask > alpha) * (
                                                        beta[0] * x[:, torch.randperm(3)] + (1 - beta[0]) * x) + (
                                                                   mask < alpha) * x, (3,))


                                                                   ),

                                           dim=0)

                    else:
                        num_envs=6

                        flip_x = torch.cat((x, mix_new,

                                        torch.flip((mask>alpha)*(beta[3] * x[:, torch.randperm(3)] + (1 - beta[3]) * x)+(mask<alpha)*x, (3,)),
                                        torch.flip((mask>alpha)*(beta[2] * x[:, torch.randperm(3)] + (1 - beta[2]) * x)+(mask<alpha)*x, (3,)),
                                        torch.flip((mask>alpha)*(beta[1] * x[:, torch.randperm(3)] + (1 - beta[1]) * x)+(mask<alpha)*x, (3,)),
                                        torch.flip((mask>alpha)*(beta[0] * x[:, torch.randperm(3)] + (1 - beta[0]) * x)+(mask<alpha)*x, (3,))),
                                       dim=0)
                    # torch.cat((x, flip3(x), flip3_1(x), flip3_1(flip3(x))), dim=0)
                    y_pred = Basic_model.forward(simclr_aug(flip_x))
                    import pdb#
                    # torch.cat([Basic_model.f_train(flip_x[rr*x.shape[0]:(rr+1)*x.shape[0]]).mean() for rr in range(4)])
              #      pdb.set_trace()
                    #  import pdb
                    # pdb.set_trace()

                    sample_matrix = torch.matmul(normalize(Basic_model.f_train(simclr_aug(flip_x))),
                                                 normalize(Basic_model.f_train(flip_x)).t())


                    loss_envs = Supervised_NT_xent_envs(sample_matrix, labels=torch.arange(y.shape[0]).repeat(num_envs),
                                                        temperature=0.07)

                    # y_pred = Basic_model.linear(hidden_pred)
                    # import pdb
                    # pdb.set_trace()
                    loss = 1 * F.cross_entropy(y_pred, y.repeat(num_envs)) + 1 * lo1 + 1*loss_envs



                    with amp.scale_loss(loss, Optimizer) as scaled_loss:
                        scaled_loss.backward()

#                    print(i, epoch, loss)
                    #optimizer[i].step()
                    Optimizer.step()

                    #if batch_idx%2==0:
                    buffero.add_reservoir(x=x.detach(), y=y.detach(), logits=None, t=i)

            Previous_model = deepcopy(Basic_model)
            for j in range(i + 1):
                print("ori", rank[j].item())
                a = test_model(Loder[rank[j].item()]['test'], j,Basic_model)
                if j == i:
                    Max_acc.append(a)
                if a > Max_acc[j]:
                    Max_acc[j] = a


        else:
            past=[]
            past_y=[]
            train_loader = Loder[rank[i].item()]['train']
            active_one=False
            active_three = False
            active_five= False
            #optimizer.append(Adam(S_model[i].parameters(), lr=0.001, betas=(0.9, 0.99)))  # ,momentum=0.9))
            for epoch in range(1):
            #    S_model[i].train()

                num_d=0
                Basic_model.train()
              #  with torch.no_grad():
               #     mem_x,mem_y,bt=buffero.sample(args.buffer_size, exclude_task=None)
                #    pre_hidden=normalize(Previous_model.f_train(mem_x))
            #    x_o=[]
                #for x,y in Loder[rank[j].item()]['test']:x_o.append(x)
             #   x_0=torch.cat([x_o[i] for i in range(len(x_o))],dim=0)

                for batch_idx, (x, y) in enumerate(train_loader):
                    num_d+=x.shape[0]
                    if num_d%5000==0:
                        print(num_d,num_d/10000)



                    Y = deepcopy(y)
                    for j in range(len(Y)):
                        if Y[j] not in class_holder:
                            class_holder.append(Y[j].detach())
                    task_id=i

                    # print("idx",batch_idx)
                    # buffero.add_reservoir(x=x, y=y, logits=None, t=i)
             #       optimizer[i].zero_grad()
                    Optimizer.zero_grad()
                    # if args.cuda:
                    x, y = x.cuda(), y.cuda()
                    # x, y = Variable(x), Variable(y)
                    x = x.requires_grad_()
                    #if task_id<9:
                        #if ((batch_idx + 1) ==200)|((batch_idx>200)&((batch_idx + 1)%100==0)):
                    if (batch_idx + 1) % 10 == 0:

                  #  import pdb
                   # pdb.set_trace()
                        #print(task_id,batch_idx)
                        #if task_id==9:
                         #   if batch_idx<100:
                          #      continue
                            #import pdb
                            #pdb.set_trace()
                        if (buffero.y_idx()>=10*task_id).sum()<10:
                            mem_x_whole=None
                            continue
                        else:

                            mem_x_whole, mem_y_whole, _ = buffero.onlysample(2000, task_id)
                            y_all_loose = mem_y_whole.contiguous().view(-1, 1)
                            y_check=torch.arange(min(mem_y_whole),max(mem_y_whole)+1)
                            y_check=y_check.contiguous().view(-1, 1)
                            with torch.no_grad():
                                hidden_whole = Basic_model.f_train(mem_x_whole)
                            #import pdb
                            #pdb.set_trace()
                            mask_check=(torch.eq(y_check.cuda(), y_all_loose.cuda().t()).float())
                            mask_check=mask_check.sum(1)
                            if mask_check.shape[0]>(mask_check>2).sum():
                                active_five=False
                                active_three=False
                                if mask_check.shape[0] == (mask_check > 0).sum():
                                    active_one=True
                                else:
                                    mem_x_whole=None
                                    continue
                            else:
                                active_one = False
                                from kmeans_pytorch import kmeans
                                import pdb

                                index = torch.arange(max(mem_y_whole) + 1)
                                y_loose = index.contiguous().view(-1, 1)
                                #  y_all_loose = mem_y_whole.contiguous().view(-1, 1)
                                Mask = (torch.eq(y_loose.cuda(), y_all_loose.cuda().t()).float())
                                with torch.no_grad():
                                    hidden_whole = Basic_model.f_train(mem_x_whole)
                                clusters_id = {}
                                clusters = []
                                #  clusters_prev = []
                                pseudo_x = []
                                y_same = []
                                '''
                                if mask_check.shape[0]==(mask_check>4).sum():
                                    active_five=False
                                    active_three = True
                                    for iv in range(min(mem_y_whole), max(mem_y_whole) + 1):
                                        # clusters[iv]
                                        num_clusters = 3
                                        # num_clusters = 3
                                        # pdb.set_trace()
                                        hidden_iv = hidden_whole[Mask[iv].bool()]

                                        # kmeans
                                        # if num_clusters < x.shape[0]:
                                        cluster_ids_x, cluster_centers = kmeans(X=hidden_iv, num_clusters=num_clusters,
                                                                                distance='euclidean',
                                                                                device=torch.device('cuda:0'))
                                        y_same.append(torch.tensor([iv, iv, iv,iv,iv]).cuda())
                                        pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 0])
                                        pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 1])
                                        pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 2])
                                       # pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 3])
                                       # pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 4])
                                        clusters.append(cluster_centers)
                                    # import pdb
                                    # pdb.set_trace()
                                    clusters = normalize(torch.cat([clusters[i] for i in range(len(clusters))], dim=0))
                                    y_same = torch.cat([y_same[i] for i in range(len(y_same))], dim=0)
                                    y_same_con = y_same.contiguous().view(-1, 1)
                                    print(task_id, batch_idx)
                                else:
                                '''
                                if mask_check.shape[0] == (mask_check > 2).sum():

                                    active_three=True
                                    active_five = False
                                    for iv in range(min(mem_y_whole), max(mem_y_whole) + 1):
                                        # clusters[iv]
                                        num_clusters = 3
                                        # num_clusters = 3
                                        # pdb.set_trace()
                                        hidden_iv = hidden_whole[Mask[iv].bool()]

                                        # kmeans
                                        # if num_clusters < x.shape[0]:
                                        cluster_ids_x, cluster_centers = kmeans(X=hidden_iv, num_clusters=num_clusters,
                                                                                distance='euclidean',
                                                                                device=torch.device('cuda:0'))
                                        y_same.append(torch.tensor([iv, iv, iv]).cuda())
                                        pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 0])
                                        pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 1])
                                        pseudo_x.append(mem_x_whole[Mask[iv].bool()][cluster_ids_x == 2])
                                        clusters.append(cluster_centers)
                                    # import pdb
                                    # pdb.set_trace()
                                    clusters = normalize(torch.cat([clusters[i] for i in range(len(clusters))], dim=0))
                                    y_same = torch.cat([y_same[i] for i in range(len(y_same))], dim=0)
                                    y_same_con = y_same.contiguous().view(-1, 1)
                                    print(task_id, batch_idx)


                        #def check_chain(x):




                    if mem_x_whole is not None:
                        if active_one :
                            with torch.no_grad():
                                y_new = y.contiguous().view(-1, 1)
                                mask = (torch.eq(y_new.cuda(), y_all_loose.t().cuda())).float()
                                matrix = torch.matmul(normalize(Basic_model.f_train(x)), normalize(hidden_whole).t())
                                _, index = torch.topk((1 - matrix) * mask, 1, dim=1)
                            x_conflict = mem_x_whole[index].squeeze(1)
                            images1, rot_sim_labels = Rotation(torch.cat([x, x_conflict], dim=0), y.repeat(2))
                        if active_three:
                            y_new = y.contiguous().view(-1, 1)
                            Mask_same = (torch.eq(y_new.cuda(), y_same_con.cuda().t())).float()
                            # _, index_class = torch.topk(matrix_clusters.cuda() * Mask_same.cuda(), dim=1, k=1)
                            # mask=(torch.eq(y_new.cuda(),y_all_loose.t().cuda())).float()
                            matrix = torch.matmul(normalize(Basic_model.f_train(x)), normalize(clusters).t().cuda())
                            _, index = torch.topk(matrix.cuda() * Mask_same.cuda(), 1, dim=1)

                            x_conflict = []  # mem_x_whole[index].squeeze(1)
                            y_conflict = []
                            for tt in range(x.shape[0]):
                                for ii in range(3):
                                    if index[tt] == 3 * y[tt] + ii:
                                        continue
                                    else:
                                        idx = torch.randperm(len(pseudo_x[3 * (y[tt]-(10*task_id)) + ii])).cuda()
                                        x_conflict.append(pseudo_x[3 * (y[tt]-(10*task_id)) + ii][idx[0]])
                                        y_conflict.append(y[tt])

                            x_conflict = torch.cat([x_conflict[i].unsqueeze(0) for i in range(len(x_conflict))], dim=0)

                            y_conflict = torch.stack(y_conflict).cuda()  # [y_conflict[i].item() for i in range(len(y_conflict))])

                            images1, rot_sim_labels = Rotation(torch.cat([x, x_conflict], dim=0), torch.cat([y,y_conflict]))
                            #images1, rot_sim_labels = Rotation(x, y)
                        if active_five:
                            y_new = y.contiguous().view(-1, 1)
                            Mask_same = (torch.eq(y_new.cuda(), y_same_con.cuda().t())).float()
                            # _, index_class = torch.topk(matrix_clusters.cuda() * Mask_same.cuda(), dim=1, k=1)
                            # mask=(torch.eq(y_new.cuda(),y_all_loose.t().cuda())).float()
                            matrix = torch.matmul(normalize(Basic_model.f_train(x)), normalize(clusters).t().cuda())
                            _, index = torch.topk(matrix.cuda() * Mask_same.cuda(), 1, dim=1)

                            x_conflict = []  # mem_x_whole[index].squeeze(1)
                            y_conflict = []
                            for tt in range(x.shape[0]):
                                for ii in range(5):
                                    if index[tt] == 5 * y[tt] + ii:
                                        continue
                                    else:
                                        idx = torch.randperm(len(pseudo_x[5 * (y[tt] - (10 * task_id)) + ii])).cuda()
                                        x_conflict.append(pseudo_x[5 * (y[tt] - (10 * task_id)) + ii][idx[0]])
                                        y_conflict.append(y[tt])

                            x_conflict = torch.cat([x_conflict[i].unsqueeze(0) for i in range(len(x_conflict))], dim=0)

                            y_conflict = torch.stack(
                                y_conflict).cuda()  # [y_conflict[i].item() for i in range(len(y_conflict))])

                            images1, rot_sim_labels = Rotation(torch.cat([x, x_conflict], dim=0),
                                                               torch.cat([y, y_conflict]))
                            # images1, rot_sim_labels = Rotation(x, y)
                    else:
                            images1, rot_sim_labels = Rotation(x, y)

                    buffer_batch_size = min(64,buffer_per_class*len(class_holder))
                    mem_x, mem_y,_= buffero.sample(int(buffer_batch_size), exclude_task=None)

                  #  import pdb
                   # pdb.set_trace()


                    #mem_x=torch.cat([mem_x,mix_x],dim=0)
                    #mem_y=torch.cat([mem_y,center_y])
                    mem_x = mem_x.requires_grad_()

               #     images1, rot_sim_labels = Rotation(x, y)  # torch.cat([Rotation(x, r) for r in range(4)])
                    # images2 = Rotation(x)#torch.cat([Rotation(x, r) for r in range(4)])#.detach()
                    # images2.requires_grad=False

                    images1_r, rot_sim_labels_r = Rotation(mem_x,
                                                           mem_y)  # torch.cat([Rotation(mem_x, r) for r in range(4)])
                    # images2_r = Rotation(mem_x)#torch.cat([Rotation(mem_x, r) for r in range(4)])#.detach()
                    # images2_r.requires_grad=False
                    images_pair = torch.cat([images1, simclr_aug(images1)], dim=0)
                    images_pair_r = torch.cat([images1_r, simclr_aug(images1_r)], dim=0)

                    t =torch.cat((images_pair,images_pair_r),dim=0)
                    feature_map, u = Basic_model.forward(t, is_simclr=True)
                    pre_u_feature, pre_u = Previous_model.forward(images1_r, is_simclr=True)
                    feature_map_out = normalize(feature_map[:images_pair.shape[0]])
                    feature_map_out_r = normalize(feature_map[images_pair.shape[0]:])
                    pre_feature_map_out_r = pre_u_feature

                    images_out = u[:images_pair.shape[0]]
                    images_out_r = u[images_pair.shape[0]:]
                    pre_u = normalize(pre_u)#torch.cat((images_out_r,pre_u),dim=0)

                    simclr = normalize(images_out)
                    simclr_r = normalize(images_out_r)
                    simclr_pre = normalize(pre_feature_map_out_r)
                   # simclr_now=normalize(u[images_pair.shape[0]:images_pair.shape[0]+images1_r.shape[0]])

#                    rot_sim_labels = torch.cat([y.cuda()+ 100 * i for i in range(oop)],dim=0)
            #        rot_sim_labels1 = torch.cat([torch.arange(0,y.shape[0]).cuda() +10*i for i in range(4)], dim=0)
             #       rot_sim_labels_r1 = torch.cat([torch.arange(0,mem_y.shape[0]).cuda()+10*i for i in range(4)], dim=0)
 #                   rot_sim_labels_r = torch.cat([mem_y.cuda()+ 100 * i for i in range(oop)],dim=0)

                    num1 = feature_map_out.shape[1] - simclr.shape[1]
                    id1 = torch.randperm(num1)[0]
                    id1_2 = torch.randperm(num1)[1]
                    id2=torch.randperm(num1)[0]
                    id2_2 = torch.randperm(num1)[0]
                  #  id3 = torch.randperm(num1)[0]
                    size = simclr.shape[1]

                    sim_matrix = torch.matmul(simclr, feature_map_out[:, id1:id1 + size].t())
                    sim_matrix_r = torch.matmul(simclr_r,
                                                     feature_map_out_r[:, id2:id2 + size].t())

                    # pdb.set_trace()
                    sim_matrix += 1 * get_similarity_matrix(
                        simclr)  # *(1-torch.eye(simclr.shape[0]).cuda())#+0.5*get_similarity_matrix(feature_map_out)
                    sim_matrix_r += 1 * get_similarity_matrix(simclr_r)
                    sim_matrix_r_pre = torch.matmul(simclr_r[:images1_r.shape[0]],pre_u.t())
                  #  sim_matrix /= (num1 + 1)
                 #   sim_matrix_r /= (num1 + 1)
               #
             #       loss_sim_mix1=Supervised_NT_xent(sim_matrix_mix,labels=rot_sim_labels_mix1,temperature=0.07)
                    loss_sim_r =Supervised_NT_xent_uni(sim_matrix_r,labels=rot_sim_labels_r,temperature=0.07)
             #       loss_sim_mix2= Supervised_NT_xent(sim_matrix_mix,labels=rot_sim_labels_mix2,temperature=0.07)
                    loss_sim_pre = Supervised_NT_xent_pre(sim_matrix_r_pre, labels=rot_sim_labels_r, temperature=0.07)
                    loss_sim = Supervised_NT_xent_n(sim_matrix, labels=rot_sim_labels, temperature=0.07)

                    lo1 =2 * loss_sim_r+1*loss_sim+1*loss_sim_pre#+loss_sup1#+0*loss_sim_r1+0*loss_sim1#+0*loss_sim_mix1+0*loss_sim_mix2#+ 1 * loss_sup1#+loss_sim_kd
                   # mem_x = torch.cat((torch.flip(mem_x, (3,)), mem_x), dim=0)
                   # mem_y = mem_y.repeat(2)
                    y_label = Basic_model.forward(simclr_aug(mem_x))  # [:,:10*(task_id+1)]
                    y_label_pre = Previous_model(simclr_aug(mem_x))
                    with torch.no_grad():
                        # center_x,center_y,_=buffer_center.sample(int(buffer_batch_size*0.5), exclude_task=None)
                        index_idx = torch.from_numpy(np.random.choice(mem_x.size(0), x.shape[0], replace=True))
                        index_idx2 = torch.from_numpy(np.random.choice(x.size(0), mem_x.shape[0], replace=True))

                        mix_new = mem_x[index_idx]
                        mix_buf = x[index_idx2]
                        # mix_x=x[index_new]#[0].unsqueeze(0).repeat(center_x.shape[0],1,1,1)
                        mix_new[:, :, 4:28, 4:28] = F.interpolate(x, size=24)
                        mix_buf[:, :, 4:28, 4:28] = F.interpolate(mem_x, size=24)
                        if mem_x_whole is not None:
                            if active_one:
                                buffer_new_x = deepcopy(x_conflict)
                                buffer_new_x[:, :, 4: 28, 4: 28] = F.interpolate(x, size=24)
                            if active_three:
                                idy = torch.arange(x.shape[0]).cuda()
                                #idy=torch.randperm(x.shape[0]).cuda()
                                buffer_new_x=deepcopy(x_conflict[2*idy])
                                buffer_new_x[:,:, 4: 28, 4: 28] = F.interpolate(x, size=24)
                                buffer_new_x_2 = deepcopy(x_conflict[2 * idy+1])
                                buffer_new_x_2[:, :, 4: 28, 4: 28] = F.interpolate(x, size=24)
                            if active_five:
                                idy=torch.randperm(x.shape[0]).cuda()
                                buffer_new_x=deepcopy(x_conflict[4*idy])
                                buffer_new_x[:,:, 4: 28, 4: 28] = F.interpolate(x, size=24)
                                buffer_new_x_2 = deepcopy(x_conflict[4 * idy+1])
                                buffer_new_x_2[:, :, 4: 28, 4: 28] = F.interpolate(x, size=24)
                                buffer_new_x_3 = deepcopy(x_conflict[4 * idy + 2])
                                buffer_new_x_3[:, :, 4: 28, 4: 28] = F.interpolate(x, size=24)
                                buffer_new_x_x = deepcopy(x_conflict[4 * idy + 3])
                                buffer_new_x_x[:, :, 4: 28, 4: 28] = F.interpolate(x, size=24)


                    beta = []
                    for _ in range(4):
                        beta.append(np.random.beta(1, 1))

                    with torch.no_grad():
                        mask = feat_cam_normalized(Basic_model, x, y)
                        mask_mem = feat_cam_normalized(Basic_model, mem_x, mem_y)
                        # mask=mask>0.2
                        # def feat_cam_normalized(model, x, y):
                        # import pdb
                        # pdb.set_trace()
                    alpha = 0.25

                    if mem_x_whole is not None:
                        if active_one:
                            num_envs = 7
                            flip_x = torch.cat((x, mix_new, buffer_new_x,

                                                torch.flip((mask > alpha) * (
                                                        beta[3] * x[:, torch.randperm(3)] + (1 - beta[3]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                                torch.flip((mask > alpha) * (
                                                        beta[2] * x[:, torch.randperm(3)] + (1 - beta[2]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                                torch.flip((mask > alpha) * (
                                                        beta[1] * x[:, torch.randperm(3)] + (1 - beta[1]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                                torch.flip((mask > alpha) * (
                                                        beta[0] * x[:, torch.randperm(3)] + (1 - beta[0]) * x) + (
                                                                   mask < alpha) * x, (3,))),
                                           dim=0)
                        if active_three:
                            num_envs = 8
                            flip_x = torch.cat((x, mix_new, buffer_new_x,buffer_new_x_2,

                                                torch.flip((mask > alpha) * (
                                                        beta[3] * x[:, torch.randperm(3)] + (1 - beta[3]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                                torch.flip((mask > alpha) * (
                                                        beta[2] * x[:, torch.randperm(3)] + (1 - beta[2]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                                torch.flip((mask > alpha) * (
                                                        beta[1] * x[:, torch.randperm(3)] + (1 - beta[1]) * x) + (
                                                                   mask < alpha) * x, (3,)),
                                                torch.flip((mask > alpha) * (
                                                        beta[0] * x[:, torch.randperm(3)] + (1 - beta[0]) * x) + (
                                                                   mask < alpha) * x, (3,))),
                                           dim=0)
                        if active_five:
                            num_envs = 10
                            flip_x = torch.cat((x, mix_new, buffer_new_x,buffer_new_x_2,buffer_new_x_3,buffer_new_x_x,

                                            torch.flip((mask < alpha) * (
                                                    beta[3] * x[:, torch.randperm(3)] + (1 - beta[3]) * x) + (
                                                               mask > alpha) * x, (3,)),
                                            torch.flip((mask < alpha) * (
                                                    beta[2] * x[:, torch.randperm(3)] + (1 - beta[2]) * x) + (
                                                               mask > alpha) * x, (3,)),
                                            torch.flip((mask < alpha) * (
                                                    beta[1] * x[:, torch.randperm(3)] + (1 - beta[1]) * x) + (
                                                               mask > alpha) * x, (3,)),
                                            torch.flip((mask < alpha) * (
                                                    beta[0] * x[:, torch.randperm(3)] + (1 - beta[0]) * x) + (
                                                               mask > alpha) * x, (3,))),
                                           dim=0)

                    else:
                        num_envs = 6

                        flip_x = torch.cat((x, mix_new,

                                            torch.flip((mask > alpha) * (
                                                    beta[3] * x[:, torch.randperm(3)] + (1 - beta[3]) * x) + (
                                                               mask < alpha) * x, (3,)),
                                            torch.flip((mask > alpha) * (
                                                    beta[2] * x[:, torch.randperm(3)] + (1 - beta[2]) * x) + (
                                                               mask < alpha) * x, (3,)),
                                            torch.flip((mask > alpha) * (
                                                    beta[1] * x[:, torch.randperm(3)] + (1 - beta[1]) * x) + (
                                                               mask < alpha) * x, (3,)),
                                            torch.flip((mask > alpha) * (
                                                    beta[0] * x[:, torch.randperm(3)] + (1 - beta[0]) * x) + (
                                                               mask < alpha) * x, (3,))),
                                           dim=0)
                    # torch.cat((x, flip3(x), flip3_1(x), flip3_1(flip3(x))), dim=0)
                    #     flip_mem_x = torch.cat((mem_x, mem_x[:, [0, 2, 1]], mem_x[:, [1, 2, 0]], mem_x[:, [1, 0, 2]]),
                    #                           dim=0)  # torch.cat((mem_x, flip3(mem_x), flip3_1(mem_x), flip3_1(flip3(mem_x))), dim=0)
                    flip_mem_x = torch.cat((mem_x, mix_buf,
                                            torch.flip((mask_mem > alpha) * (beta[3] * mem_x[:, torch.randperm(3)] + (1 - beta[3]) * mem_x)+(mask_mem < alpha)*mem_x,
                                                       (3,)),
                                            torch.flip((mask_mem > alpha) * (beta[0] * mem_x[:, torch.randperm(3)] + (1 - beta[0]) * mem_x)+(mask_mem < alpha)*mem_x,
                                                       (3,)),
                                            (mask_mem > alpha) * (beta[1] * mem_x[:, torch.randperm(3)] + (1 - beta[1]) * mem_x)+(mask_mem < alpha)*mem_x,
                                            torch.flip((mask_mem > alpha) * (beta[2] * mem_x[:, torch.randperm(3)] + (1 - beta[2]) * mem_x)+(mask_mem < alpha)*mem_x,
                                                       (3,))),
                                           dim=0)
                    # envs_matrix = torch.matmul(normalize(Basic_model(flip_x,is_simclr=True)),
                    #                           normalize(Basic_model(simclr_aug(flip_x),is_simclr=True)).t())
                    envs_matrix = torch.matmul(normalize(Basic_model.f_train(simclr_aug(flip_x))),
                                               normalize(Basic_model.f_train(flip_x)).t())
                    envs_matrix_mem = torch.matmul(normalize(Basic_model.f_train(simclr_aug(flip_mem_x))),
                                                   normalize(Basic_model.f_train(flip_mem_x)).t())

                    # envs_matrix_mem = torch.matmul(normalize(Basic_model(flip_mem_x,is_simclr=True)),
                    #                           normalize(Basic_model(simclr_aug(flip_mem_x),is_simclr=True)).t())
                    loss_envs = Supervised_NT_xent_envs(envs_matrix, labels=torch.arange(y.shape[0]).repeat(num_envs),
                                                        temperature=0.07) + 2 * Supervised_NT_xent_envs(envs_matrix_mem,
                                                                                                        labels=torch.arange(
                                                                                                            mem_y.shape[
                                                                                                                0]).repeat(
                                                                                                            6),
                                                                                                        temperature=0.07)

                    y_pred_mem = Basic_model(simclr_aug(flip_mem_x))
                    new_hidden = Basic_model.f_train(flip_x)
                    new_logits = Basic_model(flip_x)
                    pred = new_logits[:, min(y):max(y) + 1]


                    loss_all=0
                    y = y.repeat(num_envs)
                    mem_y = mem_y.repeat(6)
                    min_y = y - min(y)
                    loss_new = F.cross_entropy(pred, min_y)
                 #   if len(past)==2:
                  #      loss_all=F.cross_entropy(Basic_model.linear(past[1]),past_y[1])+F.cross_entropy(Basic_model.linear(past[0])[:,min(past_y[0]):max(past_y[0])+1],past_y[0]-min(past_y[0]))


                  #  pdb.set_trace()
                    loss =2 * F.cross_entropy(y_pred_mem, mem_y)+1*loss_envs+1*loss_new + lo1 + 1 * F.mse_loss(y_label_pre[:, :10 * task_id],
                                                                                       y_label[:,
                                                                                       :10 * task_id])  # +1*F.mse_loss(y_feature, logits)+ 1*F.mse_loss(y_feature, Pre_y_feature)
                    #+1*F.mse_loss(y_feature, logits)+ 1*F.mse_loss(y_feature, Pre_y_feature)
                 #   print(loss_relation)

                    with amp.scale_loss(loss, Optimizer) as scaled_loss:
                        scaled_loss.backward()

                    Optimizer.step()

                       # past[0]=(Basic_model.f_train(x).detach())
                        #past_y[0]=(y[:x.shape[0]])

                    #if batch_idx % 2 == 0:
                    buffero.add_reservoir(x=x.detach(), y=y.detach(), logits=None, t=i)
                    #else:
                     #   buffer_center.add_reservoir(x=F.interpolate(x.detach(), size=16), y=y.detach(), logits=None, t=i)

           # Basic_model.del_pseudo_dim()
            Previous_model = deepcopy(Basic_model)
            sum=0
            for j in range(i + 1):
                print("ori", rank[j].item())
                a = test_model(Loder[rank[j].item()]['test'], j, Basic_model)
                sum+=a
                if j == i:
                    Max_acc.append(a)
                if a > Max_acc[j]:
                    Max_acc[j] = a
            print("avg",sum/(i+1))

    print('=' * 100)
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'GPU  ' + os.environ["CUDA_VISIBLE_DEVICES"])
    print('=' * 100)
    test_loss = 0
    correct = 0
    num = 0
    for batch_idx, (data, target) in enumerate(test_loder):

        data, target = data.cuda(), target.cuda()
        Basic_model.eval()
        pred=Basic_model.forward(data)
        Pred = pred.data.max(1, keepdim=True)[1]
        num += data.size()[0]
        correct += Pred.eq(target.data.view_as(Pred)).cpu().sum()

    test_accuracy = 100. * correct / num  # len(data_loader.dataset)
    print(
        'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'
            .format(
            test_loss, correct, num,
            100. * correct / num, ))
    long=len(Max_acc)
    summ=0
    for i in range(long):
        summ+=Max_acc[i]
    print("total",summ)



