import click
import os
import random 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, RandomSampler
from torchvision import models
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid
import torchvision.transforms as transforms

import time
import numpy as np
import copy

from loss_functions import SupConLoss, MdarLoss, BTLoss, InfoNCE
from network import mnist_net, res_net, alex_net, cifar_net, vit_net, reg_net, generator
from network.modules import get_resnet, get_generator, freeze, unfreeze, freeze_, unfreeze_
from tools.miro_utils import *
from tools.farmer import *
import data_loader
from main_base import evaluate
from main_test import evaluate_digit, evaluate_image, evaluate_pacs, evaluate_officehome, evaluate_vlcs

from tools.randaugment import RandAugment
from torchvision import transforms, datasets

import matplotlib.pyplot as plt

####Multi-GPU
from torch.utils.data.distributed import DistributedSampler

HOME = os.environ['HOME']

@click.command()
@click.option('--gpu', type=str, default='0', help='Choose GPU')
@click.option('--data', type=str, default='mnist', help='Dataset name')
@click.option('--ntr', type=int, default=None, help='Select the first ntr samples of the training set')
@click.option('--gen', type=str, default='cnn', help='cnn/hr')
@click.option('--gen_mode', type=str, default=None, help='Generator Mode')
@click.option('--n_tgt', type=int, default=10, help='Number of Targets')
@click.option('--tgt_epochs', type=int, default=10, help='How many epochs were trained on each target domain')
@click.option('--nbatch', type=int, default=None, help='How many batches are included in each epoch')
@click.option('--batchsize', type=int, default=256)
@click.option('--lr', type=float, default=1e-3, help='Learning Rate: Default 1e-4 in Our Experiment')
@click.option('--lr_scheduler', type=str, default='none', help='Whether to choose a learning rate decay strategy')
@click.option('--svroot', type=str, default='./saved')
@click.option('--ckpt', type=str, default='./saved/best.pkl')
@click.option('--w_oracle', type=float, default=1.0, help='Oracle loss Weight')
@click.option('--lmda', type=float, default=0.051, help='Lambda for Adversarial BT')
@click.option('--interpolation', type=str, default='pixel', help='Interpolate between the source domain and the generated domain to get a new domain, two ways：img/pixel')
@click.option('--loss_fn', type=str, default='bt', help= 'Loss Functions (supcon/mdar')
@click.option('--backbone', type=str, default= 'custom', help= 'Backbone Model (custom/resnet18,resnet50,wideresnet')
@click.option('--pretrained', type=str, default= 'False', help= 'Pretrained Backbone - ResNet18/50, Custom MNISTnet does not matter')
@click.option('--projection_dim', type=int, default=128, help= "Projection Dimension of the representation vector for Resnet; Default: 128")
@click.option('--oracle', type=str, default='False', help= "Oracle Model for large pretrained models")
@click.option('--oracle_type', type=str, default='ft', help= "(regnet/regnet_large) Oracle Model Type")
@click.option('--optimizer', type=str, default='adam', help= "adam/sgd")
@click.option('--seed', type=int, default=0, help= "random seed")
@click.option('--autoaug', type=str, default=None, help='AA FastAA RA')



def experiment(gpu, data, ntr, gen, gen_mode, \
        n_tgt, tgt_epochs, nbatch, batchsize, lr, lr_scheduler, svroot, ckpt, \
        w_oracle,lmda, interpolation, loss_fn, \
        backbone, pretrained, projection_dim, oracle, oracle_type, autoaug, optimizer, seed):

    settings = locals().copy()
    my_seed_everywhere(seed+1)
    print("Setting Seed as {s}".format(s=seed+1))
    print("===================================")
    print(settings)
        
    device= ("cuda" if torch.cuda.is_available() else "cpu")
    print('Device:', device)
    print('Current cuda device:', torch.cuda.current_device())
    print('Count of using GPUs:', torch.cuda.device_count())
    g1root = os.path.join(svroot, 'g1')
    if not os.path.exists(g1root):
        os.makedirs(g1root)
    # Load dataset
    imdim = 3 # Default 3 channels
    if data in ['mnist']:
        trset = data_loader.load_mnist('train', ntr=ntr,autoaug=autoaug)
        teset = data_loader.load_mnist('test')
        imsize = [32, 32]
    elif data in ['cifar10']:
        trset = data_loader.load_cifar10(split='train', autoaug=None) 
        teset = data_loader.load_cifar10(split='test')
        imsize = [32, 32]
    elif data in ['pacs']:
        trset = data_loader.load_pacs(split='train',autoaug=autoaug)
        teset = data_loader.load_pacs(split='test')
        imsize = [224,224] #[32, 32] #[224,224]#[32, 32] 
    elif data in ['pacs_art','pacs_cartoon','pacs_sketch']:
        if data =='pacs_art':
            source= 'art'
        elif data == 'pacs_cartoon':
            source= 'cartoon'
        elif data == 'pacs_sketch':
            source= 'sketch'
        trset = data_loader.load_pacs_cross(split='train', source= source )
        teset = data_loader.load_pacs_cross(split='test', source= source)
        imsize = [224,224]
    elif data in ['officehome']:
        trset = data_loader.load_officehome(split='train')
        teset = data_loader.load_officehome(split='test')
        imsize = [224,224]
    elif data in ['vlcs']:
        trset = data_loader.load_vlcs(split='train')
        teset = data_loader.load_vlcs(split='test')        

    
    print("--Training With {data} data".format(data=data))
    trloader = DataLoader(trset, batch_size=batchsize, num_workers=8, sampler=RandomSampler(trset, True, nbatch*batchsize))  
    teloader = DataLoader(teset, batch_size=batchsize, num_workers=8, shuffle=False, drop_last=False) 
    
    # Load model
    ### Load [Task=Proxy] Model
    
    ###MNIST
    if data in ['mnist', 'mnist_t']:
        if backbone == 'custom':
            src_net = mnist_net.ConvNet(projection_dim).cuda()
            saved_weight = torch.load(ckpt)
            src_net.load_state_dict(saved_weight['cls_net'])
            src_opt = optim.Adam(src_net.parameters(), lr=lr)
        elif backbone in ['resnet18','resnet50','wideresnet']:
            encoder = get_resnet(backbone, pretrained) 
            n_features = encoder.fc.in_features
            output_dim = 10 
            src_net= res_net.ConvNet(encoder, projection_dim, n_features,output_dim).cuda() 
            saved_weight = torch.load(ckpt)
            src_net.load_state_dict(saved_weight['cls_net'])
            src_opt = optim.Adam(src_net.parameters(), lr=lr)
    
    ###CIFAR
    elif data in ['cifar10']:
        if backbone in ['cifar_net']:
            output_dim= 10
            src_net= cifar_net.ConvNet(projection_dim=projection_dim, output_dim=output_dim).cuda()
            saved_weight = torch.load(ckpt)
            src_net.load_state_dict(saved_weight['cls_net'])
            if optimizer == 'adam':
                src_opt = optim.Adam(src_net.parameters(), lr=lr)
            elif optimizer == 'sgd':
                src_opt = optim.SGD(src_net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)     
    #PACS
    elif data in ['pacs','pacs_art','pacs_cartoon','pacs_sketch']:
        if backbone == 'custom':
            raise ValueError('PLEASE USE Resnet-18/50/AlexNet For PACS')
        elif backbone in ['resnet18','resnet50','wideresnet']:
            encoder = get_resnet(backbone, pretrained) 
            n_features = encoder.fc.in_features
            output_dim= 7
            src_net= res_net.ConvNet(encoder, projection_dim, n_features, output_dim).cuda() 
            saved_weight = torch.load(ckpt)
            src_net.load_state_dict(saved_weight['cls_net'])
            src_opt = optim.Adam(src_net.parameters(), lr=lr)
        elif backbone in ['alexnet']:
            output_dim= 7
            src_net = alex_net.ConvNet(projection_dim=projection_dim, output_dim= output_dim).cuda()
            
            saved_weight = torch.load(ckpt)
            src_net.load_state_dict(saved_weight['cls_net'])
            if optimizer == 'adam':
                src_opt = optim.Adam(src_net.parameters(), lr=lr)
            elif optimizer == 'sgd':
                src_opt = optim.SGD(src_net.parameters(), lr=lr, momentum=0.9, nesterov=True)   
    #Office Home
    elif data in ['officehome']:
        if backbone in ['resnet18','resnet50']:
            encoder = get_resnet(backbone, pretrained) 
            n_features = encoder.fc.in_features
            output_dim= 65
            src_net= res_net.ConvNet(encoder, projection_dim, n_features, output_dim).cuda()
            saved_weight = torch.load(ckpt)
            src_net.load_state_dict(saved_weight['cls_net'])
            if optimizer == 'adam':
                src_opt = optim.Adam(src_net.parameters(), lr=lr)
            elif optimizer == 'sgd':
                src_opt = optim.SGD(src_net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
        if lr_scheduler == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(src_net, tgt_epochs)    
        elif lr_scheduler == 'linear':
            scheduler = optim.lr_scheduler.LinearLR(src_net, tgt_epochs) 
    #VLCS
    elif data in ['vlcs']:
        if backbone in ['resnet18','resnet50']:
            encoder = get_resnet(backbone, pretrained) # Pretrained Backbone default as True
            n_features = encoder.fc.in_features
            output_dim= 5
            src_net= res_net.ConvNet(encoder, projection_dim, n_features, output_dim).cuda() 
            saved_weight = torch.load(ckpt)
            src_net.load_state_dict(saved_weight['cls_net'])
            if optimizer == 'adam':
                src_opt = optim.Adam(src_net.parameters(), lr=lr)
            elif optimizer == 'sgd':
                src_opt = optim.SGD(src_net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
        if lr_scheduler == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(src_net, tgt_epochs)    
        elif lr_scheduler == 'linear':
            scheduler = optim.lr_scheduler.LinearLR(src_net, tgt_epochs)

    #Loss Functions
    cls_criterion = nn.CrossEntropyLoss()
    peer_criterion= BTLoss(projection_dim=projection_dim, device= device) #BarlowTwins
    if loss_fn == 'infonce':
        peer_criterion= InfoNCE()
    l2_loss= nn.MSELoss()


    ####DATA PARALLEL
    NGPU = torch.cuda.device_count()
    if NGPU > 1:
        print("--Using Multiples GPUs: ", NGPU, "GPUs")
        print("--Visible Devices: {c}".format(c= os.environ['CUDA_VISIBLE_DEVICES']))
        gpu_list= list(range(NGPU))
        src_net_save= copy.deepcopy(src_net)
        src_net= torch.nn.DataParallel(src_net, device_ids= gpu_list) #DL
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        src_net_save= copy.deepcopy(src_net)
        src_net= src_net.to(device)
        gpu_list= list()
        
    ##########################################    
    #Create Task Model (Same as Proxy Model)
    if (oracle == 'True'):
        if oracle_type in ['self_digits']:
            oracle_net = mnist_net.ConvNet(projection_dim).cuda()
            saved_weight = torch.load('/home/$USERNAME/PEER/saved-model/mnist/base_custom_False_128_run0/best.pkl')
            oracle_net.load_state_dict(saved_weight['cls_net'])
            freeze("all",oracle_net)            
        
        elif oracle_type in ['self_pacs']:
            if backbone in ['alexnet']:
                oracle_net = alex_net.ConvNet(projection_dim=projection_dim, output_dim= output_dim).cuda()
                saved_weight = torch.load('/home/$USERNAME/PEER/saved-model/pacs/base_alexnet_True_1024_run0/best.pkl')
                oracle_net.load_state_dict(saved_weight['cls_net'])
            elif backbone in ['resnet18','resnet50']:
                encoder = get_resnet(backbone, pretrained)
                n_features = encoder.fc.in_features
                output_dim = 7
                if backbone == 'resnet18':
                    saved_weight = torch.load('/home/$USERNAME/PEER/saved-model/pacs/base_resnet18_True_1024_run0/best.pkl')
                elif backbone == 'resnet50':
                    saved_weight = torch.load('/home/$USERNAME/PEER/saved-model/pacs/base_resnet50_True_1024_run0/best.pkl')
                oracle_net= res_net.ConvNet(encoder, projection_dim, n_features, output_dim).cuda() #projection_dim/ n_features
                oracle_net.load_state_dict(saved_weight['cls_net'])
            freeze("all", oracle_net)
        elif oracle_type in ['self_cifar']:
            output_dim= 10
            oracle_net= cifar_net.ConvNet(projection_dim=projection_dim, output_dim=output_dim).cuda()
            saved_weight= torch.load('/home/$USERNAME/PEER/saved-model/cifar10/base_cifar_net_False_1024_run0/best.pkl')
            oracle_net.load_state_dict(saved_weight['cls_net'])
            freeze("all", oracle_net)

        elif oracle_type in ['self_office']:
            encoder = get_resnet(backbone, pretrained) # Pretrained Backbone default as True
            n_features = encoder.fc.in_features
            output_dim= 65
            if backbone == 'resnet18':
                saved_weight = torch.load('/home/$USERNAME/PEER/saved-model/officehome/base_resnet18_True_1024_run0/best.pkl')
            elif backbone == 'resnet50':
                saved_weight = torch.load('/home/$USERNAME/PEER/saved-model/officehome/base_resnet50_True_1024_run0/best.pkl')  
            oracle_net= res_net.ConvNet(encoder, projection_dim, n_features, output_dim).cuda() #projection_dim/ n_features
            oracle_net.load_state_dict(saved_weight['cls_net'])
            freeze("all", oracle_net)

        elif oracle_type in ['self_vlcs']:
            encoder = get_resnet(backbone, pretrained) # Pretrained Backbone default as True
            n_features = encoder.fc.in_features
            output_dim= 5
            if backbone == 'resnet18':
                saved_weight = torch.load('/home/$USERNAME/PEER/saved-model/vlcs/base_resnet18_True_1024_run0/best.pkl')
            oracle_net= res_net.ConvNet(encoder, projection_dim, n_features, output_dim).cuda() #projection_dim/ n_features
            oracle_net.load_state_dict(saved_weight['cls_net'])
            freeze("all", oracle_net)
        
        if NGPU > 1:
            print("--Assigning Multiples GPUs for Oracle: ", NGPU, "GPUs")
            oracle_net= torch.nn.DataParallel(oracle_net, device_ids= gpu_list)
                
        else:
            print("--Assigning Single GPU for Oracle: ", NGPU, "GPU") #sameclass
            oracle_net= oracle_net.to(device)

        oracle_net.oracle= True
        oracle_opt = optim.Adam(oracle_net.parameters(), lr=lr)
        oracle_net_root= copy.deepcopy(oracle_net)
        
    if (oracle != 'True'):
        oracle_opt= None
    ##############################################################################################################################
    trajectory= list()
    trajectory_num= 0
    '''
    TRAIN: Train for a total of n_tgt * tgt_epochs (=k in PEER)
    '''
    # Train
    global_best_acc = 0
    for i_tgt in range(n_tgt):            
        print(f'Target domain {i_tgt}/{n_tgt}')
        trajectory_num += 1
        src_net= src_net_save.to(device) #####worship
        #src_net= copy.deepcopy(src_net_save).to(device)
        src_opt = optim.Adam(src_net.parameters(), lr=lr) #####worship
        if oracle == 'True':
            freeze("encoder",oracle_net)
            oracle_net.pro_head= src_net.pro_head #Shared Projection Head
        #Random Augmentation
        augmentation_strategy = [transforms.ToPILImage(), RandAugment(4,random.randint(1,30)), transforms.ToTensor()] 
        if data in ['pacs','vlcs','officehome']:
            augmentation_strategy = [transforms.ToPILImage(), RandAugment(3,random.randint(1,15)), transforms.ToTensor()] 
        augmentation_module= transforms.Compose(augmentation_strategy)

        #ith target generator train sequence
        if lr_scheduler == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(src_opt, tgt_epochs*len(trloader)) #(https://discuss.pytorch.org/t/cosineannealinglr-step-size-t-max/104687/)
        elif lr_scheduler == 'linear':
            scheduler = optim.lr_scheduler.LinearLR(src_opt,start_factor=0.3333333333333333, total_iters=tgt_epochs)
        elif lr_scheduler =='step':
            scheduler = optim.lr_scheduler.MultiStepLR(src_opt, milestones = [5, 10, 15], gamma = 0.5)
        
        best_acc = 0
        for epoch in range(tgt_epochs):
            t1 = time.time()
            
            loss_list = []
            time_list = []
            
            
            src_net.train() 
            for i, (x, y) in enumerate(trloader):  
                
                x, y = x.to(device), y.to(device)
                
                #augmentation
                
                x_bar = torch.stack([augmentation_module(image) for image in x]).to(device)
                #PASS
                p_bar, z_bar, h_bar, inter_bar = src_net(x_bar, mode='prof')
                p,z,h,inter= src_net(x, mode= 'prof')
                if (oracle == 'True'):  
                    if oracle_type in ['digits','self_digits','self_pacs', 'self_cifar','self_office','self_vlcs']:
                        p_oracle, z_oracle, h_oracle, _ = oracle_net(x, mode= 'prof')
                    else:                        
                        raise ValueError("Please Check Model Type")
                
                
                erm_loss = cls_criterion(p_bar, y) + cls_criterion(p, y) 
                
                con_loss = torch.tensor(0) #Previous augmentation-based sDG methods used con_loss to learn augment-invariant features.
                
                #PEER
                if oracle == 'True': 
                    peer_loss= peer_criterion(z_oracle, z_bar)
   
                #Source Task Model Loss
                loss = erm_loss + con_loss

                #PEER Loss 
                if (oracle == 'True'):
                    loss += (w_oracle* peer_loss)
                elif (oracle == 'False'):
                    peer_loss= torch.tensor(0)
                
                src_opt.zero_grad()
                if oracle_opt:
                    oracle_opt.zero_grad()

                loss.backward()
                src_opt.step()

                # update learning rate
                if lr_scheduler in ['cosine','linear']:
                    scheduler.step()
               
                loss_list.append([erm_loss.item(), con_loss.item(), peer_loss.item()])
            erm_loss,con_loss, peer_loss = np.mean(loss_list, 0)
            
            # Test
            src_net.eval()
            
            # unified teacc
            if data in ['mnist', 'mnist_t', 'mnistvis','cifar10','pacs','officehome','pacs_art','pacs_cartoon','pacs_sketch','vlcs']:
                teacc = evaluate(src_net, teloader) 
            
            #Save Model
            if NGPU > 1:
                torch.save({'cls_net':src_net.module.state_dict()}, os.path.join(svroot, f'{i_tgt}-best.pkl'))
            else:   
                torch.save({'cls_net':src_net.state_dict()}, os.path.join(svroot, f'{i_tgt}-best.pkl'))

            t2 = time.time()

            # Save Log for Tensorboard
            print(f'epoch {epoch}, time {t2-t1:.2f}, erm {erm_loss:.4f} con {con_loss:.4f} oracle {peer_loss:.4f} /// teacc {teacc:2.2f}')

            
        if oracle == 'True':
            #version 1. trajectory
            #new_oracle= copy.deepcopy(oracle_net) #oracle_net_root
            #trajectory.append(copy.deepcopy(src_net))
            #for layer_index, (new_oracle_param, *trajectory_params) in enumerate(zip(new_oracle.parameters(), *(model.parameters() for model in trajectory))):
            #    new_oracle_param.data = torch.stack(trajectory_params).mean(dim=0) #temp_oracle -> temp_oracle_param
            
            #version 2. no trajectory
            new_oracle = copy.deepcopy(oracle_net)  # Initialize once
            for layer_index, (new_oracle_param, src_param) in enumerate(zip(new_oracle.parameters(), src_net.parameters())):
                # Update new_oracle parameters as a running average
                new_oracle_param.data = (new_oracle_param.data * trajectory_num + src_param.data) / (trajectory_num + 1)

            
            
            oracle_net = new_oracle.to(device)
        
        if oracle == 'False':
            trajectory.append(copy.deepcopy(src_net))



        # Test the generalization effect of the i_tgt model - (run1)
        from main_test import evaluate_digit, evaluate_image, evaluate_pacs, evaluate_officehome, evaluate_pacs_extra, evaluate_vlcs
        
        if data == 'mnist':
            pklpath = f'{svroot}/{i_tgt}-best.pkl'
            evaluate_digit(gpu, pklpath, pklpath+'.test', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim) #Pretrained set as False, it will load our model instead.
        elif data == 'cifar10':
            pklpath = f'{svroot}/{i_tgt}-best.pkl'
            evaluate_image(gpu, pklpath, pklpath+'.test', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim, c_level= 5)
        elif data == 'pacs':
            pklpath = f'{svroot}/{i_tgt}-best.pkl'
            evaluate_pacs(gpu, pklpath, pklpath+'.test', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)
        elif data == 'officehome':
            pklpath = f'{svroot}/{i_tgt}-best.pkl'
            evaluate_officehome(gpu, pklpath, pklpath+'.test', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)            
        elif data in ['pacs_art','pacs_cartoon','pacs_sketch']:
            pklpath = f'{svroot}/{i_tgt}-best.pkl'
            evaluate_pacs_extra(gpu, pklpath, pklpath+'.test', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)
        elif data in ['vlcs']:
            pklpath = f'{svroot}/{i_tgt}-best.pkl'
            evaluate_vlcs(gpu, pklpath, pklpath+'.test', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)           
    #writer.close()
    if oracle == 'True':
        print("Testing PEER model")
        if NGPU > 1:
            torch.save({'cls_net':oracle_net.module.state_dict()}, os.path.join(svroot, f'peer-best.pkl'))
        else:   
            torch.save({'cls_net':oracle_net.state_dict()}, os.path.join(svroot, f'peer-best.pkl'))

        if data == 'mnist':
            pklpath = f'{svroot}/peer-best.pkl'
            evaluate_digit(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim) #Pretrained set as False, it will load our model instead.
        elif data == 'cifar10':
            pklpath = f'{svroot}/peer-best.pkl'
            evaluate_image(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim, c_level= 5)
        elif data == 'pacs':
            pklpath = f'{svroot}/peer-best.pkl'
            evaluate_pacs(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)
        elif data == 'officehome':
            pklpath = f'{svroot}/peer-best.pkl'
            evaluate_officehome(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)            
        elif data in ['pacs_art','pacs_cartoon','pacs_sketch']:
            pklpath = f'{svroot}/peer-best.pkl'
            evaluate_pacs_extra(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)
        elif data == 'vlcs':
            pklpath = f'{svroot}/peer-best.pkl'
            evaluate_vlcs(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)    

    elif oracle == 'False':
        print("Testing Naive Averaged Model")
        new_src_net=copy.deepcopy(src_net)
        for layer_index, (new_src_param, *trajectory_params) in enumerate(zip(new_src_net.parameters(), *(model.parameters() for model in trajectory))):
            new_src_param.data = torch.stack(trajectory_params).mean(dim=0) #temp_oracle -> temp_oracle_param
        if NGPU > 1:
            torch.save({'cls_net':new_src_net.module.state_dict()}, os.path.join(svroot, f'naive-best.pkl'))
        else:   
            torch.save({'cls_net':new_src_net.state_dict()}, os.path.join(svroot, f'naive-best.pkl'))

        if data == 'mnist':
            pklpath = f'{svroot}/naive-best.pkl'
            evaluate_digit(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim) #Pretrained set as False, it will load our model instead.
        elif data == 'cifar10':
            pklpath = f'{svroot}/naive-best.pkl'
            evaluate_image(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim, c_level= 5)
        elif data == 'pacs':
            pklpath = f'{svroot}/naive-best.pkl'
            evaluate_pacs(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)
        elif data == 'officehome':
            pklpath = f'{svroot}/naive-best.pkl'
            evaluate_officehome(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)            
        elif data in ['pacs_art','pacs_cartoon','pacs_sketch']:
            pklpath = f'{svroot}/naive-best.pkl'
            evaluate_pacs_extra(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)
        elif data == 'vlcs':
            pklpath = f'{svroot}/naive-best.pkl'
            evaluate_vlcs(gpu, pklpath, pklpath+'.result', backbone= backbone, pretrained= pretrained, projection_dim= projection_dim)    


if __name__=='__main__':
    #my_seed_everywhere(seed)
    experiment()

