#Common imports
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle

#Sklearn
import sklearn
from sklearn.manifold import TSNE

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

#robustdg
from utils.helper import *
from utils.match_function import *

import random
import torchvision
import datetime

dom_test=[0,90] #rotmnistspur
#dom_test=[0,90]
#[-1]
#["0"]   

# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='fashion_mnist', #'fashion_mnist'
                    help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match', #defalt=erm_match
                    help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18', 
                    help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=[0], #need to be modified
                    help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=dom_test, #to be modified
                    help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10, 
                    help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1, #used to be 1
                    help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224, 
                    help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224, 
                    help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1, 
                    help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')#fc layer for classification
parser.add_argument('--match_layer', type=str, default='logit_match', 
                    help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2', 
                    help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250, 
                    help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0, 
                    help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1, 
                    help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd', 
                    help='Optimizer Choice: sgd; adam') 
parser.add_argument('--weight_decay', type=float, default=5e-4,
                   help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=0.01, 
                    help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=16, 
                    help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=50, ###30
                    help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1, 
                    help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0, 
                    help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0, 
                    help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1, 
                    help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0, 
                    help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05, 
                    help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0, 
                    help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0, 
                    help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5, 
                    help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=1, ###defalt 3
                    help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1, 
                    help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18', 
                    help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match', 
                    help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1, 
                    help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01, 
                    help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5, 
                    help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0, 
                    help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0, 
                    help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=2, 
                    help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0, 
                    help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )


#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0, 
                    help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0, 
                    help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1, 
                    help='0: Infinite Epsilon; 1: Finite Epsilion')


#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=1)
parser.add_argument('--gaussian', type=int, default=1)

#fish
parser.add_argument('--meta_lr', type=float, default=0.01)
parser.add_argument('--meta_steps', type=int, default=5)


#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2, 
                    help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)


#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18', 
                    help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0, 
                    help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')


#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1, 
                    help='Multiple random matches')


# Evaluation specific
parser.add_argument('--test_metric', type=str, default='match_score', 
                    help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test', 
                    help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10, 
                    help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0, 
                    help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val', 
                    help='Dataset Train/Val/Test for the match score evaluation metric')

args = parser.parse_args()



class BaseDataLoader(data_utils.Dataset):
    def __init__(self, args, list_domains, root, transform=None, data_case='train'):
        self.args= args
        self.list_domains = list_domains
        
        self.root = 'data/datasets' + root
        self.transform = transform
        self.data_case = data_case
        
        
        self.base_domain_size= 0
        self.list_size=[]
        self.data= [] 
        self.labels= [] 
        self.domains= [] 
        self.indices= [] 
        self.objects= []

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        d = self.domains[index]
        idx = self.indices[index]
        objs= self.objects[index]
            
        if self.transform is not None:
            x = self.transform(x)
        return x, y, d, idx, objs

    def get_size(self):
        return self.labels.shape[0]
     

class MnistRotated(BaseDataLoader):
    def __init__(self, args, list_domains, mnist_subset, root, transform=None, data_case='train',  download=True):
        
        super().__init__(args, list_domains, root, transform, data_case) 
        self.mnist_subset = mnist_subset
        self.download = download
        
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()#####################

    def _get_data(self):
        
        # Choose subsets that should be included into the training
        list_img = []
        list_labels = []
        list_idx= []
        list_size= []
        data_dir= self.root + self.args.dataset_name + '_' + self.args.mnist_case + '/'           
        
        
        dom=list(range(15,76,5))
        color_list=['red', 'blue', 'green', 'brown', 'pink','yellow']
        

        for domain in self.list_domains:

            if self.args.dataset_name=='rot_mnist_spur':
                if domain<0:
                    degree=0
                    color='white'
                elif domain>77:
                    degree=90
                    color='white'

                else:
                    degree_idx=int(domain/6)
                    col_idx=domain%6
                    degree=dom[degree_idx]
                    color=color_list[col_idx]
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(degree)+'_color_'+color###########需要加入颜色项
            
            else:
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)
         
            
            mnist_imgs= torch.load( load_dir +  '_org_data.pt')
            mnist_labels= torch.load( load_dir +  '_label.pt')
            mnist_idx= list(range(len(mnist_imgs)))
            
            #print('Source Domain ', domain)
            list_img.append(mnist_imgs)
            list_labels.append(mnist_labels)
            list_idx.append(mnist_idx)
            list_size.append(mnist_imgs.shape[0])

       
        # Stack
        data_imgs = torch.cat(list_img)
        data_labels = torch.cat(list_labels) ####把domains的数据合并
        data_indices = np.array(list_idx)
        data_indices= np.hstack(data_indices)###合并，hstack水平合并
        self.training_list_size= list_size
        
        #Rotated MNIST the objects are same the data indices
        data_objects= copy.deepcopy(data_indices)
        
        # Create domain labels####################################################################得到对应的domain label
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size        
        
        # Shuffle everything one more time
        # inds = np.arange(data_labels.size()[0])
        # np.random.shuffle(inds)
        # data_imgs = data_imgs[inds]
        # data_labels = data_labels[inds]
        # data_domains = data_domains[inds].long()
        # data_indices = data_indices[inds]
        # data_objects = data_objects[inds]

        # Convert to onehot
        y = torch.eye(10)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains.long()]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)        
        
        #print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains, data_indices, data_objects##########################


def get_dataloader_level_1(args, domains, data_case, kwargs):
    
    dataset={}
    batch_size=50
        
         
            
    #print('MNIST Subset: ', mnist_subset)
    data_obj=  MnistRotated(args, domains, 0, '/mnist/', data_case=data_case)###rot_mnist_spur的这个要改
    
    dataset['data_loader']= data_utils.DataLoader(data_obj, batch_size=batch_size, shuffle=True, **kwargs )
    
    dataset['data_obj']= data_obj
    dataset['total_domains']= len(domains)
    dataset['domain_list']= domains
    dataset['base_domain_size']= data_obj.base_domain_size       
    dataset['domain_size_list']= data_obj.training_list_size    
    
    #print(data_case, data_obj.base_domain_size, data_obj.training_list_size)
    return dataset
class Identity(nn.Module):
    def __init__(self,n_inputs):
        super(Identity, self).__init__()
        self.in_features=n_inputs
        
    def forward(self, x):
        return x
class Model_erm(nn.Module):
    def __init__(self,feat,fc):
        super(Model_erm, self).__init__()
        self.featurizer =feat
        self.classifier = fc
        
    def forward(self,input_data):
        feature = self.featurizer(input_data)
        class_output = self.classifier(feature)
        return class_output
class BaseAlgo():
    def __init__(self, args, train_dataset,base_res_dir, cuda):
        
        
        self.args= args
        self.train_dataset= train_dataset['data_loader']
        
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        
        
        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)
          
        self.train_acc=[]
        
    def get_resnet(self,classes, fc_layer, num_ch, pre_trained):    
        
            
        model=  torchvision.models.resnet18(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
            
            
        if fc_layer:
            model.fc = nn.Linear(n_inputs, n_outputs)
        else:
            print('Here')
            model.fc = Identity(n_inputs)
   
            
        if num_ch==1:
            model.conv1 = nn.Conv2d(1, 64, 
                                    kernel_size=(7, 7), 
                                    stride=(2, 2), 
                                    padding=(3, 3), 
                            bias=False)
        return model
    
    
    
    def get_model(self):
        
        resnet_enc=self.get_resnet(self.args.out_classes,0, self.args.img_c, 0)
        feat=resnet_enc
        resnet_fc=self.get_resnet(self.args.out_classes,1, self.args.img_c, 0)
        fc=resnet_fc.fc
        phi=Model_erm(feat,fc)
        phi=phi.to(self.cuda) 
        return phi
    
    
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        
    
        
    def get_opt(self):
        if self.args.opt == 'sgd':
            opt= optim.SGD([
                         {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        elif self.args.opt == 'adam':
            opt= optim.Adam([
                        {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
                ], lr= self.args.lr)
        
        return opt

class Erm(BaseAlgo):
    def __init__(self, args, train_dataset,  base_res_dir,  cuda):
        
        super().__init__(args, train_dataset, base_res_dir, cuda) 
              
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0        
        for epoch in range(self.args.epochs):   
            
            
            penalty_erm=0
            train_acc= 0.0
            train_size=0
    
            for batch_idx, (x_e, y_e ,d_e, idx_e,obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                
                #Forward Pass
                out= self.phi(x_e)
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)

                #Backprorp
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del loss_e
                torch.cuda.empty_cache()
        
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]
                
   
            print('Train Loss Basic : ',  penalty_erm )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
        # Save the model's weights post training
        self.save_model()    
class DANN():
    def __init__(self, args, train_dataset,indecies, val_dataset, test_dataset, base_res_dir, cuda):
        
        self.args= args
        self.train_dataset=train_dataset['data_loader']
        self.val_dataset= val_dataset['data_loader']
        self.test_dataset= test_dataset['data_loader']
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        self.indecies=indecies
        
        self.phi= self.get_model()
        
        #self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)    
        
        self.final_acc=[]
        self.val_acc=[]
        self.train_acc=[]
        
 
        self.conditional = bool(self.args.conditional)
        self.class_balance = False        
        
        self.featurizer = self.phi.featurizer
        self.classifier = self.phi.classifier
        self.discriminator = self.phi.discriminator
        self.class_embeddings = self.phi.class_embeddings

        
        self.grad_penalty= self.args.grad_penalty
        self.lambda_= self.args.penalty_ws
        self.d_steps_per_g_step= self.args.d_steps_per_g_step
        self.initial_lr= 0.01
        
        # Optimizers
        self.disc_opt = torch.optim.SGD(
            (list(self.discriminator.parameters()) + 
                list(self.class_embeddings.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)

        self.gen_opt = torch.optim.SGD(
            (list(self.featurizer.parameters()) + 
                list(self.classifier.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)     
    def get_model(self):
        
        from models import net_for_dann
        phi=net_for_dann.get_model_for_DANN(self.args.model_name, self.args.out_classes,  
                        self.args.img_c, self.args.pre_trained, self.args.os_env)
        
        phi=phi.to(self.cuda) 
        return phi
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        np.save( self.base_res_dir + '/Val_Acc' + '.npy', np.array(self.val_acc) )
        np.save( self.base_res_dir + '/Test_Acc' +  '.npy', np.array(self.final_acc))
    
       

    def get_test_accuracy(self, case):
       
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset= self.val_dataset
        elif case == 'test':
            dataset= self.test_dataset

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                #Forward Pass
                out= self.phi(x_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
                test_size+= y_e.shape[0]
                

        print(' Accuracy: ', case, 100*test_acc/test_size )         
                
        #self.privacy_engine.module.enable_hooks()
        #gra.enable_hooks()        
        return 100*test_acc/test_size       
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0;
        for epoch in range(self.args.epochs):   
                    
            penalty_erm=0
            penalty_dann=0
            train_acc= 0.0
            train_size=0
                    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
                if self.indecies[batch_idx] == 0 :
                    # print('false ',batch_idx)
                    continue
        #         print('Batch Idx: ', batch_idx)
                # print('true ',batch_idx)

                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).to(self.cuda)
        
                all_x = x_e
                all_y = y_e
                all_z = self.featurizer(all_x)
                #print(all_x.shape) [64, 3, 224, 224]
                # print(all_z.shape)# [64, 512]
                if self.conditional:
                    disc_input = all_z + self.class_embeddings(all_y)
                else:
                    disc_input = all_z
                disc_out = self.discriminator(disc_input)
                #print(disc_out.shape) 64,6
                disc_labels = d_e
                # print('d_e') 64,1 [0,1,2,3,4,5]   
                if self.class_balance:
                    y_counts = F.one_hot(all_y).sum(dim=0)
                    weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float()
                    disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
                    disc_loss = (weights * disc_loss).sum()
                else:
                    disc_loss = F.cross_entropy(disc_out, disc_labels)
                    

                #Gen Loss
                all_preds = self.classifier(all_z)
                classifier_loss = F.cross_entropy(all_preds, all_y)
                gen_loss = (classifier_loss +
                            (self.lambda_ * -disc_loss))

                penalty_erm += float(classifier_loss)
                penalty_dann += float(disc_loss)
                
                d_steps_per_g = self.d_steps_per_g_step
                if (epoch % (1+d_steps_per_g) < d_steps_per_g):
                    #print('disc_opt')
                    self.disc_opt.zero_grad()
                    disc_loss.backward()
                    #print(disc_loss)
                    self.disc_opt.step()
                else:
                    #print('gen_opt')
                    self.disc_opt.zero_grad()
                    self.gen_opt.zero_grad()
                    gen_loss.backward()
                    self.gen_opt.step()
                
                del classifier_loss
                del gen_loss 
                del disc_loss
                torch.cuda.empty_cache()
                
                #Forward Pass
                features = self.featurizer(x_e)
                out = self.classifier(features)                
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]                
                        
   
            print('Train Loss Basic : ',  penalty_erm, penalty_dann )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])   
def build_similary_matrix(cov_function, items):
    """
    build the similarity matrix 
    """
    L = np.zeros((len(items), len(items)))
    for i in range(len(items)):
        for j in range(i, len(items)):
            L[i, j] = cov_function(items[i],items[j])             #cov_function(items[i], items[j])
            L[j, i] = L[i, j]
    return L

def get_similar_cov(inx):
    if inx==1:
        return get_l1_similar
    elif inx==2:
        return get_l2_similar
    elif inx==3:
        return get_cos_similar
    else:
        return exp_quadratic()

def get_l2_similar(v1,v2):
    return np.linalg.norm(v1-v2)

def get_l1_similar(v1,v2):
    return np.linalg.norm(v1-v2,ord=1)

def get_cos_similar(v1, v2):
    num = float(np.dot(v1, v2))  # 向量点乘
    denom = np.linalg.norm(v1) * np.linalg.norm(v2)  # 求模长的乘积
    return 0.5 + 0.5 * (num / denom) if denom != 0 else 0

def exp_quadratic(sigma=0.1):
    """
    exponential quadratic covariance function
    """
    def f(p1, p2):
        return np.exp(-0.5 * (((p1 - p2)**2).sum()) / sigma**2)
    return f
    
def sample_k(items, L, k, max_nb_iterations=1000, rng=np.random):
    """
    Sample a list of k items from a DPP defined
    by the similarity matrix L.
    """
    initial = rng.choice(range(len(items)), size=k, replace=False)
    X = [False] * len(items)
    for i in initial:
        X[i] = True
    X = np.array(X)
    for i in range(max_nb_iterations):
        u = rng.choice(np.arange(len(items))[X])
        v = rng.choice(np.arange(len(items))[~X])
        Y = X.copy()
        Y[u] = False
        L_Y = L[Y, :]
        L_Y = L_Y[:, Y]
        L_Y_inv = np.linalg.inv(L_Y)

        c_v = L[v:v+1, :]
        c_v = c_v[:, v:v+1]
        b_v = L[Y, :]
        b_v = b_v[:, v:v+1]
        c_u = L[u:u+1, :]
        c_u = c_u[:, u:u+1]
        b_u = L[Y, :]
        b_u = b_u[:, u:u+1]

        p = min(1, c_v - np.dot(np.dot(b_v.T, L_Y_inv), b_v) /
                (c_u - np.dot(np.dot(b_u.T, L_Y_inv.T), b_u)))
        if rng.uniform() <= p:
            X = Y[:]
            X[v] = True
    return X

runId = datetime.datetime.now().isoformat().replace(':', '_')
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
    kwargs = {'num_workers': 1, 'pin_memory': False} 
else:
    kwargs= {}
final_accuracy_target_val=[]
final_accuracy_source_val=[]
#fashion mnist
domains_l1_rm=[[17, 39, 51, 62, 68],[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70],[22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73],[40, 41, 50, 63, 74], [16, 19, 23, 47, 65], [28, 40, 49, 53, 55], [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73], [37, 49, 61, 68, 74], [22, 32, 33, 72, 75]]
#跑完的记得存回去
domains_l1_fm=[ [15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], [19, 21, 39, 57, 68], [25, 38, 39, 61, 65],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46],[15, 20, 31, 45, 50], [25, 33, 38, 47, 50], [20, 37, 38, 64, 66], [17, 38, 47, 61, 69], [16, 29, 53, 55, 74],[24, 51, 62, 67, 68]]
#

for domains in domains_l1_fm:


    print(domains)
    res_dir= 'results_dann/l2/'

    base_res_dir=(
            res_dir + args.dataset_name + '_150/' + #args.method_name + '/' + 'train_' + 
            str(domains)
        )
    if not os.path.exists(base_res_dir):
        os.makedirs(base_res_dir) 
    # dataset1=get_dataloader_level_1(args, domains, 'train', kwargs)

    # res_dir= 'results/two-level/'+ runId

    # base_res_dir=(
    #         res_dir + args.dataset_name + '/' + args.method_name 
    #         + '/' + 'train_' + str(args.train_domains)
    #     )
    # if not os.path.exists(base_res_dir):
    #     os.makedirs(base_res_dir) 

    # train_method= Erm(args, dataset1,  base_res_dir, cuda)

    # train_method.train()

    # data_loc=[]
    # for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset1['data_loader']):
    # # for data in data_imgs:
    #     with torch.no_grad():
    #         x_e= x_e.to(cuda)
    #         location=train_method.phi.featurizer(x_e)                
    #         location=torch.mean(location,0) 
    #         location=location.cpu().numpy()
    #     data_loc.append(location)
    # print(batch_idx)
    # L=build_similary_matrix(exp_quadratic(0.1),data_loc)#160取100
    # X=sample_k(data_loc,L,110) ##100 43.65  #80 for colored #65 for others

    # #print(X)

    # indecies=[]
    # for i in range(0,len(X)):
    #             if X[i]==True:
    #                 indecies.append(i)
    #dataset2=dataset1[['data_loader']][X]
    dataset1=get_dataloader_level_1(args, domains, 'train', kwargs)
    val_dataset= get_dataloader( args, 0, domains, 'val', 0, kwargs ) 
        # if args.method_name=='dannin':
        #     test_dataset= get_dataloader( args, (run+1)%3, train_domains, 'val', 0, kwargs )  
        # else:
    test_dataset= get_dataloader( args, 0,args.test_domains, 'test', 0, kwargs )  

    base_load_index_dir='/home/skyfall/rdd/rdd/indecies_150_no_shuffle_fm'
    residual_dir='/'+str(domains)+'/indecies.npy'
    #rot
    # base_load_index_dir='/home/skyfall/rdd/results/two-level/fish/rot_mnist/'
    # residual_dir=str(domains)+'indecies.npy'

    indecies=np.load(base_load_index_dir+residual_dir)
    count=0
    for i in indecies:
        count+= i==1
    print(count)

    train_method= DANN(
        
                            args, dataset1,indecies, val_dataset,
                            test_dataset, base_res_dir, 
                            cuda
                            )
    train_method.train()  
    np.save( base_res_dir + '/Val_Acc' + '.npy', np.array(train_method.val_acc) )
    np.save( base_res_dir + '/Test_Acc' +  '.npy', np.array(train_method.final_acc)) 
    if args.method_name != 'matchdg_ctr':
            final_acc= np.max(train_method.final_acc)
            final_accuracy_target_val.append(final_acc)
            print(final_accuracy_target_val)
            
            idx= np.argmax(train_method.val_acc)
            final_acc= train_method.final_acc[idx]
            final_accuracy_source_val.append(final_acc)
            print(final_accuracy_source_val)
            print('final-acc')
            print(train_method.final_acc)
            print('val acc')
            print(train_method.val_acc)



print('\n')
print('Done for the Model..')
print('Final Test Accuracy (Source Validation)', np.mean(final_accuracy_source_val), np.std(final_accuracy_source_val) )
print('Final Test Accuracy (Target Validation)', np.mean(final_accuracy_target_val), np.std(final_accuracy_target_val) )
print('\n')
        
#python level-2-dann.py



#157
# final-acc
# [9.7, 34.325, 34.45, 45.675, 45.925, 58.525, 59.65, 66.5, 66.175, 63.35, 63.075, 75.475, 75.35, 74.075, 74.125, 76.675, 76.475, 75.075, 74.95, 80.55, 80.525, 80.3, 80.45, 81.325, 81.375, 83.25, 83.125, 82.875, 83.0, 83.1, 83.225, 82.425, 82.6, 80.025, 80.225, 80.525, 80.45, 73.225, 73.35, 84.175, 83.975, 83.525, 83.5, 85.1, 85.225, 85.75, 85.7, 83.625, 83.6, 84.75]
# val acc
# [10.75, 54.05, 54.6, 80.25, 80.25, 89.65, 89.75, 90.85, 90.25, 86.55, 86.4, 94.15, 94.35, 94.95, 95.0, 95.95, 96.05, 95.8, 96.2, 95.0, 94.8, 96.45, 96.25, 96.8, 96.9, 96.9, 96.7, 96.75, 96.75, 97.65, 97.45, 97.6, 97.4, 96.2, 96.45, 97.25, 97.3, 93.8, 94.05, 97.35, 97.4, 96.8, 96.8, 97.8, 97.7, 97.45, 97.55, 97.3, 97.3, 97.65]

# final-acc
# [8.5, 27.5, 27.475, 37.95, 38.25, 45.5, 45.425, 48.825, 49.0, 51.15, 51.45, 52.675, 53.05, 55.1, 55.025, 55.85, 55.875, 56.75, 56.65, 56.85, 56.425, 60.675, 61.025, 57.975, 57.975, 60.325, 60.425, 60.65, 60.95, 61.775, 61.9, 62.575, 62.9, 61.0, 61.125, 62.025, 61.75, 61.25, 61.075, 60.175, 60.425, 62.675, 63.1, 63.95, 64.05, 60.625, 61.025, 60.025, 60.075, 62.825]
# val acc
# [8.35, 63.3, 62.9, 88.9, 89.35, 92.95, 92.65, 95.0, 94.8, 95.7, 95.5, 96.0, 96.05, 96.5, 96.45, 96.85, 97.1, 97.45, 97.5, 97.7, 97.8, 97.75, 98.05, 97.8, 98.05, 98.15, 97.9, 97.85, 98.15, 98.2, 98.15, 97.8, 98.0, 98.25, 98.35, 98.15, 98.15, 98.15, 98.25, 97.95, 98.05, 98.35, 98.25, 97.55, 97.65, 97.75, 98.05, 97.85, 97.85, 98.35]

# final-acc
# [9.05, 28.725, 29.125, 37.65, 37.575, 44.4, 44.2, 52.7, 52.875, 55.4, 55.375, 59.675, 59.425, 54.625, 54.725, 60.675, 60.625, 60.375, 60.85, 66.75, 66.575, 63.55, 63.475, 65.7, 65.525, 65.85, 65.45, 62.65, 62.925, 69.175, 69.1, 71.35, 71.325, 67.2, 67.6, 68.4, 68.35, 72.275, 72.175, 65.8, 65.925, 67.525, 67.65, 63.6, 63.8, 69.775, 69.675, 69.875, 70.1, 70.9]
# val acc
# [10.25, 61.7, 61.15, 80.5, 80.4, 81.6, 82.25, 94.15, 93.7, 95.25, 95.25, 95.9, 95.8, 93.8, 93.6, 96.4, 96.65, 97.0, 97.0, 97.25, 96.95, 97.35, 97.25, 97.8, 97.4, 97.45, 97.6, 97.0, 96.95, 97.95, 98.0, 97.7, 97.75, 97.95, 97.85, 97.85, 97.95, 97.9, 98.0, 97.3, 97.3, 97.95, 97.85, 96.4, 96.7, 98.1, 97.9, 97.85, 97.85, 97.85]

# final-acc
# [10.225, 27.9, 27.775, 43.85, 43.9, 53.525, 53.4, 56.025, 55.925, 60.375, 60.2, 65.225, 65.175, 73.0, 73.2, 71.65, 71.65, 78.75, 79.225, 73.175, 72.875, 76.875, 77.05, 77.125, 76.9, 77.55, 77.425, 75.8, 76.175, 75.5, 75.725, 76.05, 75.675, 80.6, 80.275, 78.3, 78.725, 78.725, 78.925, 80.25, 80.125, 79.575, 79.525, 78.925, 78.9, 79.025, 78.95, 81.775, 82.15, 80.575]
# val acc
# [9.9, 48.85, 49.2, 75.05, 74.6, 90.7, 90.8, 88.25, 88.25, 93.9, 93.85, 93.75, 93.75, 96.55, 96.75, 96.7, 96.6, 96.8, 96.85, 96.65, 96.85, 97.25, 97.25, 97.55, 97.6, 97.8, 98.0, 98.25, 97.95, 97.65, 97.65, 97.55, 97.7, 98.3, 98.35, 98.1, 98.25, 96.15, 96.3, 97.8, 97.8, 98.4, 98.25, 98.35, 98.2, 98.2, 98.25, 97.85, 97.7, 98.3]

# final-acc
# [10.125, 31.325, 31.275, 42.75, 42.55, 53.85, 53.95, 58.075, 58.3, 63.625, 63.6, 55.25, 55.65, 59.1, 59.475, 69.65, 69.275, 74.625, 74.725, 72.575, 72.25, 72.65, 72.175, 74.475, 74.6, 75.0, 74.875, 77.275, 77.325, 68.875, 69.225, 75.9, 75.8, 76.075, 75.9, 74.0, 74.25, 78.225, 78.1, 78.35, 78.15, 77.075, 77.3, 79.125, 79.05, 79.025, 79.275, 81.85, 81.95, 76.35]
# val acc
# [8.5, 53.45, 52.9, 71.05, 71.75, 90.55, 90.3, 90.75, 90.5, 94.75, 94.6, 93.0, 92.95, 94.3, 94.2, 94.35, 94.6, 96.55, 96.5, 96.95, 96.75, 96.9, 96.6, 97.15, 97.15, 95.8, 96.25, 97.25, 97.15, 93.4, 93.55, 97.65, 97.6, 97.4, 97.5, 97.5, 97.6, 97.85, 97.85, 97.6, 97.65, 97.75, 97.85, 97.65, 97.8, 97.75, 97.5, 97.1, 97.1, 96.5]
# [85.1, 61.125, 69.775, 79.575, 78.225]

#135
#line
# [10.141666666666667, 24.325000000000003, 24.28333333333333, 41.40833333333333, 41.35, 50.94166666666666, 51.25, 57.28333333333333, 57.21666666666667, 62.06666666666666, 
# 62.18333333333334, 63.975, 63.949999999999996, 63.300000000000004, 63.11666666666667, 67.69166666666668, 67.46666666666667, 67.50833333333333, 67.63333333333333, 70.29166666666667, 70.175, 71.33333333333333, 71.34166666666667, 72.94166666666666, 73.0, 71.47500000000001, 71.80833333333334, 76.025, 75.91666666666667, 73.825, 73.89166666666667, 74.425, 74.23333333333333, 74.38333333333334, 74.34166666666665, 74.44166666666666, 74.44166666666666, 72.55833333333332, 72.425, 72.925, 72.85000000000001, 74.8, 
# 74.91666666666667, 75.68333333333332, 75.83333333333333, 75.05, 75.96666666666667, 76.65833333333333, 76.64166666666667, 76.8]

# final-acc
# [8.425, 31.725, 31.725, 42.85, 42.825, 50.425, 50.175, 65.025, 65.3, 69.35, 69.15, 66.95, 67.125, 77.15, 77.25, 80.15, 80.15, 75.45, 75.35, 73.75, 73.575, 78.25, 78.275, 79.7, 79.5, 80.65, 81.25, 83.575, 83.425, 82.65, 82.5, 78.675, 78.525, 85.45, 85.4, 83.35, 83.35, 81.175, 81.275, 83.425, 83.6, 85.1, 84.925, 85.05, 85.1, 79.925, 79.85, 83.975, 83.975, 84.6]
# val acc
# [8.9, 46.3, 47.25, 72.6, 72.65, 85.1, 84.7, 88.45, 88.5, 92.95, 92.95, 93.7, 93.35, 94.9, 95.0, 94.75, 94.8, 95.95, 95.9, 95.95, 96.35, 96.25, 96.1, 95.7, 95.8, 96.15, 96.5, 96.4, 96.35, 96.9, 96.65, 95.2, 95.25, 96.95, 97.1, 96.7, 96.7, 97.35, 97.3, 97.4, 97.35, 97.0, 97.1, 97.3, 97.25, 97.25, 97.25, 97.55, 97.45, 97.3]

# final-acc
# [9.9, 28.3, 28.225, 33.3, 33.25, 44.375, 44.425, 49.4, 49.575, 53.0, 52.9, 55.375, 55.4, 56.725, 56.575, 56.6, 56.375, 59.925, 60.1, 59.775, 59.475, 52.075, 51.9, 58.875, 58.775, 62.0, 61.925, 63.975, 64.25, 61.225, 61.275, 62.925, 63.1, 63.525, 63.675, 66.15, 66.175, 63.175, 63.2, 63.15, 63.25, 62.05, 62.225, 63.65, 63.7, 60.55, 60.725, 62.55, 62.7, 63.075]
# val acc
# [10.1, 57.1, 57.25, 81.0, 80.95, 91.65, 91.6, 94.8, 94.85, 95.55, 95.6, 95.85, 95.9, 96.0, 95.9, 96.5, 96.95, 97.35, 97.2, 96.75, 96.7, 95.95, 95.8, 97.8, 97.6, 98.0, 97.95, 97.8, 97.6, 97.9, 97.75, 97.7, 97.75, 97.95, 97.9, 98.2, 98.2, 97.75, 97.95, 98.05, 98.15, 98.1, 98.25, 98.15, 98.1, 97.8, 97.8, 98.25, 98.3, 98.2]

# final-acc
# [8.7, 24.5, 24.5, 37.65, 37.6, 39.275, 39.475, 49.4, 48.95, 51.25, 51.325, 51.55, 51.575, 56.3, 56.35, 56.475, 56.425, 56.6, 56.85, 59.425, 59.275, 57.925, 57.65, 59.15, 59.375, 59.0, 59.1, 61.5, 61.3, 63.025, 63.275, 66.85, 66.65, 61.6, 61.925, 64.2, 64.3, 59.975, 59.55, 65.025, 65.475, 65.225, 65.175, 64.825, 64.875, 65.1, 65.075, 65.975, 65.725, 63.55]
# val acc
# [9.55, 45.65, 46.0, 79.55, 79.8, 86.1, 86.15, 93.85, 93.9, 95.25, 94.9, 94.95, 95.15, 95.95, 95.9, 96.25, 96.05, 96.15, 96.25, 96.45, 96.45, 96.65, 96.45, 96.95, 96.95, 94.95, 95.0, 97.1, 97.1, 97.0, 96.95, 97.5, 97.55, 95.95, 96.1, 97.4, 97.35, 97.25, 97.4, 97.6, 97.65, 97.4, 97.45, 97.45, 97.3, 97.65, 97.65, 97.7, 97.7, 97.65]

# final-acc
# [14.325, 27.975, 27.875, 42.7, 42.575, 51.5, 51.95, 59.725, 59.35, 69.275, 69.25, 64.65, 64.775, 69.625, 69.575, 74.525, 74.7, 72.1, 71.725, 77.425, 77.3, 76.825, 76.925, 78.225, 77.825, 70.975, 70.975, 81.5, 81.425, 77.55, 77.55, 78.8, 79.0, 81.4, 81.875, 81.45, 81.4, 81.775, 81.875, 80.05, 80.2, 79.625, 79.225, 81.875, 82.3, 83.1, 82.875, 78.25, 78.425, 78.175]
# val acc
# [14.15, 47.15, 47.2, 73.75, 74.1, 85.35, 85.5, 92.4, 92.4, 94.35, 94.6, 94.35, 94.6, 95.45, 95.55, 96.45, 96.55, 96.45, 96.55, 96.55, 96.65, 96.6, 96.6, 97.0, 97.0, 96.3, 96.1, 97.6, 97.65, 96.9, 96.9, 97.1, 97.1, 97.65, 97.5, 97.95, 98.2, 98.1, 98.0, 97.55, 97.9, 97.55, 97.6, 98.2, 98.15, 98.2, 98.1, 97.65, 97.8, 97.25]

# final-acc
# [10.225, 18.3, 18.45, 34.35, 34.475, 53.25, 53.35, 58.25, 58.25, 61.85, 61.675, 63.85, 63.625, 65.575, 65.35, 62.775, 62.675, 69.95, 69.6, 68.1, 68.275, 74.375, 73.925, 64.225, 63.8, 73.8, 73.675, 73.775, 73.225, 68.875, 69.2, 73.85, 73.625, 73.35, 73.725, 71.925, 71.925, 73.45, 73.35, 74.45, 74.275, 74.875, 74.85, 71.9, 71.8, 75.1, 74.725, 76.325, 75.925, 74.375]
# val acc
# [8.65, 23.45, 23.65, 65.2, 65.45, 87.6, 87.55, 92.25, 92.35, 91.95, 92.05, 94.75, 94.55, 95.1, 95.25, 94.3, 94.5, 96.5, 96.65, 96.95, 96.85, 96.9, 96.8, 94.45, 94.45, 97.0, 96.9, 97.05, 97.2, 96.35, 96.15, 97.2, 97.45, 97.7, 97.7, 97.25, 97.15, 97.55, 97.65, 97.8, 97.9, 97.95, 97.65, 97.5, 97.45, 97.7, 97.7, 97.7, 97.65, 97.75]

#110
# final-acc
# [10.475, 29.95, 29.825, 48.075, 48.2, 54.425, 54.55, 63.9, 63.8, 63.375, 63.05, 69.25, 69.05, 69.025, 69.175, 68.95, 68.975, 73.65, 73.325, 79.075, 78.825, 77.325, 77.125, 75.775, 75.8, 75.475, 75.85, 77.175, 76.875, 80.4, 80.125, 80.15, 80.25, 79.875, 80.1, 78.7, 78.825, 81.275, 81.6, 82.675, 82.825, 81.675, 81.275, 81.4, 81.35, 81.05, 81.075, 82.575, 82.75, 83.475]
# val acc
# [11.9, 45.45, 45.4, 76.15, 76.4, 86.4, 86.45, 90.9, 90.9, 93.05, 93.4, 94.55, 94.55, 94.95, 95.0, 95.25, 95.15, 96.0, 95.85, 96.1, 96.15, 96.4, 96.35, 96.5, 96.45, 96.65, 96.4, 96.75, 96.75, 96.35, 96.55, 97.05, 96.9, 97.15, 97.05, 96.8, 96.95, 97.1, 97.2, 97.0, 97.1, 97.4, 97.3, 97.4, 97.55, 97.45, 97.45, 97.55, 97.55, 97.5]

# final-acc
# [9.675, 27.15, 27.075, 35.125, 35.525, 42.925, 42.825, 45.95, 45.65, 49.775, 49.725, 55.275, 55.2, 53.1, 52.925, 48.125, 48.075, 57.325, 57.525, 56.025, 56.1, 61.55, 61.45, 57.325, 57.7, 55.025, 55.1, 58.95, 58.65, 59.925, 59.95, 60.925, 61.075, 61.625, 61.675, 61.6, 62.0, 59.275, 59.1, 62.45, 62.475, 63.075, 62.775, 63.25, 63.325, 62.7, 63.15, 62.625, 62.475, 62.45]
# val acc
# [10.25, 54.45, 54.1, 83.45, 83.3, 90.7, 90.8, 92.7, 93.0, 95.35, 95.45, 95.45, 95.4, 96.6, 96.4, 95.15, 94.95, 97.15, 97.4, 97.6, 97.15, 97.45, 97.5, 97.5, 97.6, 97.4, 97.65, 97.9, 98.05, 97.5, 97.55, 98.15, 98.35, 98.25, 98.25, 98.2, 98.05, 97.9, 97.95, 98.45, 98.4, 98.45, 98.4, 98.6, 98.55, 98.4, 98.45, 98.45, 98.4, 98.4]

# final-acc
# [8.65, 26.35, 26.125, 35.475, 35.35, 45.65, 46.125, 47.875, 48.0, 49.95, 50.2, 53.925, 53.725, 55.65, 55.45, 58.75, 58.55, 56.55, 56.475, 59.375, 59.35, 62.525, 62.475, 62.6, 62.55, 59.075, 59.125, 68.8, 68.7, 62.8, 62.675, 66.625, 66.5, 65.825, 65.825, 63.4, 63.4, 66.15, 66.225, 64.75, 64.575, 67.475, 67.7, 64.575, 64.575, 66.95, 67.15, 67.65, 67.65, 68.5]
# val acc
# [9.75, 49.45, 49.0, 82.55, 82.0, 90.75, 90.8, 93.0, 93.1, 94.2, 94.15, 95.3, 95.5, 95.4, 95.45, 96.45, 96.65, 96.5, 96.55, 95.7, 95.75, 97.0, 96.8, 96.4, 96.4, 95.3, 95.35, 97.05, 96.9, 97.5, 97.45, 97.55, 97.6, 97.4, 97.45, 97.35, 97.15, 94.55, 94.4, 97.6, 97.45, 97.7, 97.4, 96.65, 96.6, 97.75, 97.75, 97.55, 97.5, 97.7]

# final-acc
# [10.225, 24.925, 25.025, 47.1, 47.175, 57.325, 57.325, 63.85, 63.7, 71.225, 71.475, 69.625, 69.725, 66.525, 66.6, 76.925, 76.825, 77.925, 78.25, 78.025, 77.7, 78.325, 78.4, 80.125, 80.15, 79.675, 80.15, 82.5, 82.8, 82.8, 82.725, 80.45, 80.35, 81.6, 81.625, 81.825, 82.0, 82.05, 81.925, 83.775, 83.7, 79.575, 79.775, 84.625, 84.875, 83.175, 83.0, 80.725, 80.65, 83.6]
# val acc
# [9.9, 42.85, 42.6, 74.5, 74.6, 88.6, 88.75, 93.65, 93.4, 93.6, 93.65, 94.8, 94.85, 95.1, 95.25, 95.75, 95.85, 96.5, 96.65, 96.45, 96.25, 97.6, 97.55, 97.7, 97.7, 97.1, 97.1, 97.55, 97.55, 97.55, 97.65, 97.65, 97.6, 98.0, 98.1, 96.15, 95.95, 98.05, 98.05, 98.3, 98.4, 98.25, 98.1, 98.35, 98.2, 97.95, 98.25, 97.9, 98.0, 98.05]

# final-acc
# [11.55, 21.7, 21.7, 41.65, 41.525, 49.85, 50.3, 60.125, 59.95, 65.025, 64.875, 68.375, 68.4, 67.725, 67.3, 67.4, 67.025, 68.05, 68.175, 73.475, 73.475, 73.15, 73.15, 76.1, 76.3, 75.675, 76.15, 76.775, 76.25, 75.875, 76.275, 76.2, 75.85, 75.725, 75.575, 78.1, 77.925, 69.475, 69.125, 70.25, 70.275, 77.35, 77.275, 77.85, 78.05, 75.025, 74.75, 78.6, 78.625, 78.3]
# val acc
# [9.75, 40.9, 40.95, 76.25, 76.25, 88.45, 88.35, 91.75, 91.7, 94.0, 94.0, 94.05, 94.15, 95.6, 95.2, 94.8, 95.05, 95.15, 95.3, 96.15, 96.05, 97.15, 97.2, 96.9, 97.0, 96.8, 97.0, 97.45, 97.4, 97.55, 97.65, 96.75, 96.8, 97.95, 97.85, 97.85, 97.9, 93.8, 93.75, 95.6, 95.9, 97.95, 98.1, 97.9, 98.05, 96.9, 96.9, 98.05, 97.95, 97.75]