from __future__ import print_function
from xmlrpc.client import Boolean

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pdb
import os, shutil
import argparse
import time

from torch.utils.tensorboard import SummaryWriter
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
from aug import *
import pdb
from pacs_rtdataset import *
from pacs_dataset import *


import dg_maml_vi_model
import sys
import numpy as np
from torch.nn import init
from sklearn.model_selection import train_test_split
bird = False
import psutil 
cpu_workers = psutil.cpu_count()

from timm.loss import JsdCrossEntropy
from math import remainder

import learn2learn as l2l

from learn2learn.data.transforms import (NWays,
                                         KShots,
                                         LoadData,
                                         RemapLabels,
                                         ConsecutiveLabels)

import math 


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.0001, type=float, help='learning rate')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='learning rate')

parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')

parser.add_argument('--log_dir', default='log1', help='Log dir [default: log]')
parser.add_argument('--dataset', default='PACS', help='datasets')
parser.add_argument('--batch_size', type=int, default=512, help='Batch Size during training [default: 32]')


parser.add_argument('--shuffle', type=int, default=0, help='Batch Size during training [default: 32]')
parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')

parser.add_argument('--net', default='res18', help='res18 or res50')


parser.add_argument('--autodecay', action='store_true')



parser.add_argument('--test_domain', default='art_painting', help='GPU to use [default: GPU 0]')
parser.add_argument('--train_domain', default='', help='GPU to use [default: GPU 0]')
parser.add_argument('--ite_train', default=True, type=bool, help='learning rate')
parser.add_argument('--max_ite', default=10000, type=int, help='max_ite')
parser.add_argument('--test_ite', default=50, type=int, help='learning rate')
parser.add_argument('--bias', default=1, type=int, help='whether sample')
parser.add_argument('--test_batch', default=100, type=int, help='learning rate')
parser.add_argument('--data_aug', default=1, type=int, help='whether sample')
parser.add_argument('--difflr', default=1, type=int, help='whether sample')








parser.add_argument('--reslr', default=0.5, type=float, help='backbone learning rate')

parser.add_argument('--agg_model', default='concat', help='concat or bayes or rank1')
parser.add_argument('--agg_method', default='mean', help='ensemble or mean or ronly')



parser.add_argument('--ctx_num', default=10, type=int, help='learning rate')
parser.add_argument('--hierar', default=2, type=int, help='hierarchical model')



parser.add_argument('--model_saving_dir', default= './new_models/sampler/tt', type = str, help=' place to save the best model obtained during training')

parser.add_argument('--resume_from_checkpoint', type = str, default =  "/home/sambekar/code_meta/das5_files/dulcet-dawn-69/checkpoint/best_model.pth", help=' resume from checkpoint')
parser.add_argument('--lr_adam_maml', default=0.0001, type=float, help='learning rate Maml')
parser.add_argument('--lr_adam', default=0.0001, type=float, help='learning rate Adam Optimizer')

parser.add_argument('--entropy_threshold_perc', default=0.7, type=float, help='entropy_threshold_perc')
parser.add_argument('--num_iterations', default=250, type=float, help='num_iterations')

parser.add_argument('--variational_refinement', default=True, type=bool, help='variational_refinement')

parser.add_argument('--update_pseudo_label_times', default=1, type=int, help='update_pseudo_label_times')


args = parser.parse_args()

BATCH_SIZE = args.batch_size
OPTIMIZER = args.optimizer

backbone = args.net

max_ite = args.max_ite
test_ite = args.test_ite
test_batch = args.test_batch
iteration_training = args.ite_train

test_domain = args.test_domain
train_domain = args.train_domain


ctx_num = args.ctx_num






difflr = args.difflr
res_lr = args.reslr
hierar = args.hierar
agg_model = args.agg_model



with_bias = args.bias
with_bias = bool(with_bias)
difflr = bool(difflr)











data_aug = args.data_aug
data_aug = bool(data_aug)
model_saving_dir = args.model_saving_dir
resume_from_checkpoint = args.resume_from_checkpoint




if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)


if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

if not os.path.exists(os.path.join(LOG_DIR, 'validation')):
    os.makedirs(os.path.join(LOG_DIR, 'validation'))

if not os.path.exists(os.path.join(LOG_DIR, 'test')):
    os.makedirs(os.path.join(LOG_DIR, 'test'))


if not os.path.exists(os.path.join(LOG_DIR, 'logs')):
    os.makedirs(os.path.join(LOG_DIR, 'logs'))
text_file = os.path.join(LOG_DIR, 'log_train.txt')
text_file2 = os.path.join(LOG_DIR, 'log_std_output.txt')


import sys

class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout
        self.log = open(text_file2,"a")
   
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  

    def flush(self):
        
        
        
        pass  

sys.stdout = Logger()

LOG_FOUT = open(text_file, 'w')

print(args)
LOG_FOUT.write(str(args)+'\n')


def log_string(out_str, print_out=True):
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()
    if print_out:
        print(out_str)


log_string('Saving models to ', MODEL_DIR)

log_string('==> Writing text file and stdout pushing file output to ')
log_string(text_file)
log_string(text_file2)





tr_writer = SummaryWriter(LOG_DIR)
val_writer = SummaryWriter(os.path.join(LOG_DIR, 'validation'))
te_writer = SummaryWriter(os.path.join(LOG_DIR, 'test'))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)


def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias is not None:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-2)
            if m.bias is not None:
                init.constant(m.bias, 0)

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

best_acc = 0  
best_valid_acc = 0 
start_epoch = 0  


decay_inter = [250, 450]


print('==> Preparing data..')

if args.dataset == 'PACS':
    NUM_CLASS = 7
    num_domain = 4
    batchs_per_epoch = 0
    
    ctx_test = ctx_num
    domains = ['art_painting', 'photo', 'cartoon', 'sketch']
    assert test_domain in domains
    domains.remove(test_domain)
    if train_domain:
        domains = train_domain.split(',')
    log_string('data augmentation is ' + str(data_aug))
    if data_aug:
        
        transform_train = transforms.Compose([
            
            transforms.RandomResizedCrop(224, scale=(0.8, 1.2), ratio=(0.75, 1.33), interpolation=2),
            transforms.RandomHorizontalFlip(),
            ImageJitter(jitter_param),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    log_string('train_domain: ' + str(domains))
    log_string('test: ' + str(test_domain))
    
    all_dataset = PACS(test_domain)
    rt_context = rtPACS(test_domain, ctx_num)
else:
    raise NotImplementedError

def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

args.num_classes = NUM_CLASS
args.num_domains = num_domain
args.bird = bird



print('--> --> LOG_DIR <-- <--', LOG_DIR)



def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True

def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True









net = dg_maml_vi_model.ResNet18_vi()


print('==> Building model..')
print(net)


net.apply(inplace_relu)

net = net.to(device)

pc = get_parameter_number(net)
log_string('Total: %.4fM, Trainable: %.4fM' %(pc['Total']/float(1e6), pc['Trainable']/float(1e6)))






net.train()


if device == 'cuda':
    
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()

WEIGHT_DECAY = args.weight_decay



if OPTIMIZER == 'momentum':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=WEIGHT_DECAY, momentum=0.9)
elif OPTIMIZER == 'nesterov':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=WEIGHT_DECAY, momentum=0.9, nesterov=True)
elif OPTIMIZER=='adam' and difflr and agg_model=='concat' and hierar==0:
    
    
    
    
    
    
    
    
    
    
    
    optimizer = torch.optim.Adam([{'params': net.features.parameters(), 'lr':args.lr * res_lr},   
                              {'params': net.fc.parameters(), 'lr':args.lr}],)
    
elif OPTIMIZER=='adam' and difflr and agg_model=='concat' and hierar==2:
    
    
    
    
    
    
    
    
    
    
    
    
    optimizer = torch.optim.Adam([{'params': net.features.parameters(), 'lr':args.lr * res_lr},   
                              {'params': net.fc.parameters(), 'lr':args.lr}],)
    
    
elif OPTIMIZER=='adam' and not difflr:
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=WEIGHT_DECAY)
elif OPTIMIZER == 'rmsp':
    optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=WEIGHT_DECAY)
else:
    raise NotImplementedError




if args.resume:
    
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7')
    net.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

def find_neighbors(pseudo_labels, targets):
    
    pseudo_labels = pseudo_labels.cpu().numpy()
    targets = targets.cpu().numpy()
    neighbors = []
    
    
    for i in range(pseudo_labels.shape[0]):
        neighbors = np.argmin(np.abs(targets-pseudo_labels[i]))
        pseudo_labels[i] = targets[neighbors]
    return pseudo_labels


def wasserstein_distance_torch (x, y):
    
    x = x.float()
    y = y.float()
    
    x = x.cpu().numpy()
    y = y.cpu().numpy()
    

    return np.mean(x) * np.mean(y)

def _kl_div_2d(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    
    batch, chans, height, width = p.shape
    unsummed_kl = F.kl_div(
        q.reshape(batch * chans, height * width).log(), p.reshape(batch * chans, height * width), reduction='none'
    )
    kl_values = unsummed_kl.sum(-1).view(batch, chans)
    return kl_values

def js_div2(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    
    m = 0.5 * (p + q)
    return 0.5 * _kl_div_2d(p, m) + 0.5 * _kl_div_2d(q, m)

def js_div3(net_1_logits, net_2_logits):
    net_1_logits = net_1_logits.to(device).float()
    net_2_logits = net_2_logits.to(device).float()
    net_1_probs = F.softmax(net_1_logits, dim=0)
    net_2_probs = F.softmax(net_2_logits, dim=0)
    
    total_m = 0.5 * (net_1_probs + net_1_probs)
    
    loss = 0.0
    loss += F.kl_div(F.log_softmax(net_1_logits, dim=0), total_m, reduction="batchmean") 
    loss += F.kl_div(F.log_softmax(net_2_logits, dim=0), total_m, reduction="batchmean") 
    return (0.5 * loss)



columns=[ "domain", "Epoch", "id", "image", "guess", "truth"]

def num_to_str_converter(num):
    if num == 0:
        return "Dog"
    elif num == 1:
        return "Elephant"
    elif num == 2:
        return "Giraffe"
    elif num == 3:
        return "Guitar"
    elif num == 4:
        return "Horse"
    elif num == 5:
        return "House"
    elif num == 6:
        return "Person"
    else:
        return "Unknown"
for digit in range(0,7):
    digit = num_to_str_converter(digit)
    
    columns.append((digit))
    

NUM_IMAGES_PER_BATCH = 20  



def num_to_str_converter_batch(num_batch):
    str_batch = []
    for num in num_batch:
        str_batch.append(num_to_str_converter(num))
    return str_batch
    

def log_test_predictions(epoch, domain, images, labels, outputs, predicted, test_table, log_counter):
    
    
    
    
    outputs = outputs.float()
    scores = F.softmax(outputs, dim=1)
    log_scores = scores.cpu().numpy()
    log_images = images.cpu().numpy()
    
    
    log_images = np.moveaxis(log_images, 1, -1)
    log_labels = labels.cpu().numpy()
    log_preds = predicted.cpu().numpy()
    
    log_domain = np.repeat(domain, len(log_images))
    log_epoch = np.repeat(epoch, len(log_images))
    
    
    
    
    _id = 0
    for d, e, i, l, p, s in zip(log_domain, log_epoch,  log_images, log_labels, log_preds, log_scores):
    
    
        img_id = str(_id) + "_" + str(log_counter)
        l = num_to_str_converter(l)
        p = num_to_str_converter(p)
        
        
        
        
        _id += 1
        if _id == NUM_IMAGES_PER_BATCH:
            break



js_div = JsdCrossEntropy()
def train(epoch):
    
    
    
    
    

    
    
    
    
    
    
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    t0 = time.time()
    kl_loss_tot = 0
    w_loss_tot = 0
    js_div_tot= 0
    correct_source = 0
    total_source = 0
    adapt_loss_tot = 0 
    if epoch<3:
        domain_id = epoch
        loss_rate = 1e-8
    else:
        domain_id = np.random.randint(len(domains))
        loss_rate = 1
    print(domain_id)
    all_dataset.reset('train', domain_id, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(all_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cpu_workers, drop_last=False, worker_init_fn= worker_init_fn )
    kl_loss_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
    rt_context.reset('train', transform=transform_train)
    context_loader = torch.utils.data.DataLoader(rt_context, batch_size=(num_domain-1)*NUM_CLASS*ctx_num, shuffle=False, num_workers=cpu_workers, drop_last=False, worker_init_fn=worker_init_fn)
    
    
    f1 = open((os.path.join(LOG_DIR, 'log_labels.txt')), 'a')

    for batch_idx, (inputs, targets, img_name1 ) in enumerate(context_loader):
        context_img, context_label = inputs.to(device), targets.to(device)
        


    for batch_idx, (inputs, targets, img_name2 ) in enumerate(trainloader):
        
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        
        
        outputs_ul, outputs  = net(inputs, context_img )
        
        loss = criterion(outputs_ul, context_label)
        pseudo_label = torch.argmax(outputs, dim=1) 
        adapt_loss = criterion(outputs, targets)
        
        
        

        
        
        
        
        
        

        
        
        
        pseudo_label = find_neighbors(pseudo_label, targets)
        
        
        
        
        img_name1 = list(img_name1)
        
        
        
        
        
        pseudo_label = torch.from_numpy(pseudo_label).long().to(device)
        pseudo_label = pseudo_label.to(device)







        
        
        
        
        
        kl_loss = kl_loss_criterion(pseudo_label, targets)
        w_loss = wasserstein_distance_torch(pseudo_label, targets)
        
        
        
        js_div_loss = 0
        

        
        
        
        
        
        
        
        if int(remainder(epoch,pseudo_label_update_epoch)) == 0:
            
            log_string('Pseudo_Label updating ...')
            loss_total = loss + w_loss   + adapt_loss
            loss_total.backward()

        
        else:
            loss_total = loss 
            loss_total.backward()


        
        optimizer.step()
        

        train_loss += loss.item()
        kl_loss_tot += kl_loss.item()
        w_loss_tot += w_loss.item()
        
        adapt_loss_tot += adapt_loss.item()
        
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        
        _, predicted_source = outputs_ul.max(1)
        total_source  += context_label.size(0)
        correct_source += predicted_source.eq(context_label).sum().item()
        
        
        
        
        if iteration_training and batch_idx>=batchs_per_epoch:
            break
        
    
    
    log_string('Batch: %d, Train_Source_Loss: %f, Train_Source_Acc: %f (%d|%d) Acc: %f (%d/%d), Adapt_loss: %f, KL_Loss: %f, W_Loss: %f, JS_Div_loss: %f,' % (batch_idx, train_loss/(batch_idx+1), 100.*correct_source/total_source, correct_source, total_source, 100.*correct/total, correct, total, adapt_loss_tot/(batch_idx+1), kl_loss_tot/(batch_idx+1), w_loss_tot/(batch_idx+1), js_div_tot/(batch_idx+1)))

    
    
    tr_writer.add_scalar('train/loss', train_loss/(batch_idx+1), epoch)
    tr_writer.add_scalar('train/acc', 100.*correct/total, epoch)
    tr_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], epoch)
    tr_writer.add_scalar('train/weight_decay', WEIGHT_DECAY, epoch)
    tr_writer.add_scalar('train/time', time.time()-t0, epoch)
    tr_writer.add_scalar('train/batch_size', BATCH_SIZE, epoch)
    tr_writer.add_scalar('train/epoch', epoch, epoch)
    tr_writer.add_scalar('train/total_batch', len(trainloader), epoch)
    
    tr_writer.add_scalar('train/total_data', len(all_dataset), epoch)
    
    tr_writer.add_scalar('train/total_train_data', len(trainloader), epoch)



    print('time elapsed: %f' % (time.time()-t0))

NUM_BATCHES_TO_LOG = 2
def test(epoch):
    global best_acc
    
    log_counter = 0 
    
    net.eval()
    all_dataset.reset('test', 0, transform=transform_test)
    testloader = torch.utils.data.DataLoader(all_dataset, batch_size=test_batch, shuffle=False, num_workers=cpu_workers, worker_init_fn=worker_init_fn)
    rt_context.reset('test', transform=transform_test)
    context_loader = torch.utils.data.DataLoader(rt_context, batch_size=(num_domain-1)*NUM_CLASS*ctx_test, shuffle=False, num_workers=cpu_workers, drop_last=False, worker_init_fn=worker_init_fn)
    for batch_idx, (inputs, targets,  img_name1 ) in enumerate(context_loader):
        context_img, context_label = inputs.to(device), targets.to(device)
    test_loss = 0
    correct = 0
    total = 0
    t0 = time.time()
    batch_count = 0 
    with torch.no_grad():

        for batch_idx, (inputs, targets,  img_name1 ) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if batch_count < 2:
                log_test_predictions(epoch, test_domain, inputs, targets, outputs, predicted, test_table, log_counter)
                log_counter += 1
                batch_count +=1


        
        log_string('\t Test Loss %f, Acc: %f' % (test_loss/(batch_idx+1), 100.*correct/total))
        

        
        te_writer.add_scalar('te/loss',  test_loss/batch_idx+1, epoch)
        te_writer.add_scalar('te/acc', 100.*correct/total, epoch)

    acc = 100.*correct/total
    if acc > best_valid_acc:
        print('Saving best model... %f' % acc)
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        
        checkpoint_dir = os.path.join(MODEL_DIR, 'checkpoint')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        torch.save(state, os.path.join(checkpoint_dir, 'best_model.pth'))
        
        
        

def validation(epoch):
    global best_valid_acc
    
    net.eval()
    val_loss = 0
    correct = 0
    total = 0
    t0 = time.time()
    rt_context.reset('val', transform=transform_test)
    context_loader = torch.utils.data.DataLoader(rt_context, batch_size=(num_domain-1)*NUM_CLASS*ctx_test, shuffle=False, num_workers=cpu_workers, drop_last=False, worker_init_fn=worker_init_fn)
    for batch_idx, (inputs, targets,  img_name1 ) in enumerate(context_loader):
        context_img, context_label = inputs.to(device), targets.to(device)

    with torch.no_grad():
        for i in range(4):
            all_dataset.reset('val', i, transform=transform_test)
            valloader = torch.utils.data.DataLoader(all_dataset, batch_size=test_batch, shuffle=False, num_workers=cpu_workers, worker_init_fn=worker_init_fn)

            

            for batch_idx, (inputs, targets,  img_name1) in enumerate(valloader):
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                _, outputs = net(inputs, context_img)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        log_string('Val Loss: %f, Acc: %f' % (val_loss/(batch_idx+1), 100.*correct/total))
        
        val_writer.add_scalar('val/loss', val_loss/(batch_idx+1), epoch)
        val_writer.add_scalar('val/acc', 100.*correct/total, epoch)
        

    
    acc = 100.*correct/total
    if acc > best_valid_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        
        checkpoint_dir = os.path.join(LOG_DIR, 'checkpoint')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        
        
        
        best_valid_acc = acc
    return 0


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt_old(data,labels, learner, loss, adaptation_steps, shots, ways, device):
    
    

    
    
    
    
    
    
    

    adaptation_data = data 
    adaptation_labels = labels
    adaptation_data, adaptation_labels = adaptation_data.to(device), adaptation_labels.to(device)
    evaluation_data, evaluation_labels = adaptation_data, adaptation_labels

    
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy



def fast_adapt(data,labels, learner, loss, adaptation_steps, shots, ways, device):
    
    

    
    
    
    
    
    
    

    
    
    data = data.cpu().numpy()
    labels = labels.cpu().numpy()

    try:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5, stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)
        


        
        for step in range(adaptation_steps):
            adaptation_error = loss(learner(adaptation_data), adaptation_labels)
            learner.adapt(adaptation_error)
            


        
        predictions = learner(evaluation_data)
        evaluation_error = loss(predictions, evaluation_labels)
        evaluation_accuracy = accuracy(predictions, evaluation_labels)
        
        
        
        return evaluation_error, evaluation_accuracy
    except:
        log_string('Error in fast_adapt, train test split')
        val = 0
        return torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device)

kl_loss_criterion = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)

def fast_adapt_pl(data,labels, learner, loss, adaptation_steps, shots, ways, device):
    
    

    
    
    
    
    
    
    

    
    
    data = data.cpu().numpy()
    labels = labels.cpu().numpy()

    try:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5, stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)
        


        
        for step in range(adaptation_steps):
            
            predictions = learner(adaptation_data)
            pseudo_label = torch.argmax(predictions, dim=1)
            pseudo_label = find_neighbors(pseudo_label, adaptation_labels)
            pseudo_label = torch.from_numpy(pseudo_label).long().to(device)
            pseudo_label = pseudo_label.to(device)
            
            
            

            w_loss = wasserstein_distance_torch(pseudo_label, adaptation_labels)
            



            adaptation_error = loss(learner(adaptation_data), pseudo_label)
            learner.adapt(adaptation_error)
            


        
        predictions = learner(evaluation_data)
        evaluation_error = loss(predictions, evaluation_labels)
        evaluation_accuracy = accuracy(predictions, evaluation_labels)
        
        
        
        return evaluation_error, w_loss,  evaluation_accuracy
    except:
        log_string('Error in fast_adapt, train test split')
        val =0 
        return torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device)


nll_loss = torch.nn.NLLLoss()


nll_loss = torch.nn.NLLLoss()
def alpha_weight_for_pseudo_label_loss(iter):
    
    alpha = 1/(1+np.exp(-iter/1000))
    return alpha

def fast_adapt_pl_entropy(data,labels, learner, loss, adaptation_steps, shots, ways, device, iter):
    
    

    
    
    
    
    
    
    

    
    
    data = data.cpu().numpy()
    labels = labels.cpu().numpy()

    try:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5, stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)
        


        
        for step in range(adaptation_steps):
            
            
            
            
            
            
            
            
            

            
            



            
            
            adaptation_error = loss(learner(adaptation_data), adaptation_labels)
            w_loss = wasserstein_distance_torch(adaptation_labels, adaptation_labels)
            learner.adapt(adaptation_error)
            


        
        predictions = learner(evaluation_data)
        evaluation_error = loss(predictions, evaluation_labels)
        evaluation_accuracy = accuracy(predictions, evaluation_labels)
        
        
        _, pseudo_label = torch.max(predictions, 1)
        
        
        
        
        alpha = alpha_weight_for_pseudo_label_loss(iter)
        alpha = alpha+ 0.5
        pseudo_label_error = alpha*(loss(predictions, pseudo_label))
        
        
        pseudo_label_accuracy = 0.5
        
        actual_loss = loss(predictions, evaluation_labels)
        

        
        
        
        return pseudo_label_error, pseudo_label_accuracy, evaluation_error, w_loss,  evaluation_accuracy, actual_loss
    except:
        log_string('Error in fast_adapt, train test split')
        val =0 
        return torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device)


ul_loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
def filter_predictions_based_on_entropy(predictions, entropy_threshold=0.5):
    
    entropy_threshold = 0.2
    
    entropy  = torch.nn.functional.softmax(predictions, dim=1)
    
    
    print('entropy size', entropy.shape)
    entropy_threshold_indices = entropy < entropy_threshold
    
    
    print('True False', entropy_threshold_indices)
    
    
    entropy_threshold_indices = entropy_threshold_indices.sum(dim=1) > 4
    print('entropy size', entropy.shape)
    entropy_threshold_indices = torch.where(entropy_threshold_indices)
    print('torch where', entropy_threshold_indices)

    
    

    filtered_predictions = predictions[entropy_threshold_indices]
    return filtered_predictions, entropy_threshold_indices

def filter_data_and_labels_one_occurence(data,labels):
    l1 = []
    for i in range(len(labels)):
        if np.sum(labels == labels[i]) == 1:
            l1.append(i)
    
    data = np.delete(data, l1, axis=0)
    labels = np.delete(labels, l1, axis=0)
    return data, labels


def filter_data_and_labels_two_occurence(data,labels):
    
    unique_labels = np.unique(labels)
    indices_to_delete = []
    for label in unique_labels:
        indices = np.where(labels == label)[0]
        if len(indices) > 2:
            indices_to_delete.extend(indices[2:])

    
    
    labels = np.delete(labels, indices_to_delete)
    
    data = np.delete(data, indices_to_delete, axis=0)
    return data, labels


def fast_adapt_pl_entropy_new(data,labels_actual, learner, loss, adaptation_steps, shots, ways, device, iter):
    
    

    
    
    
    
    
    
    

    
    
    
    
    
    labels = learner(data)
    data = data.cpu().numpy()
    data_actual = data.copy()
    
    labels = torch.argmax(labels, dim=1)
    labels = labels.cpu().numpy()
    
    
    
    
    
    data, labels = filter_data_and_labels_one_occurence(data, labels)
    
    
    
    data, labels = filter_data_and_labels_two_occurence(data, labels)
    
    
    
    



    if 1==1:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5,  stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)

        
        labels_actual = labels_actual.cpu().numpy()
        
        actual_data_train, actual_data_test, actual_labels_train, actual_labels_test = train_test_split(data_actual, labels_actual, test_size=0.5, stratify=labels_actual)
        actual_data_train, actual_labels_train = torch.from_numpy(actual_data_train).to(device), torch.from_numpy(actual_labels_train).to(device)
        actual_data_test, actual_labels_test = torch.from_numpy(actual_data_test).to(device), torch.from_numpy(actual_labels_test).to(device)



        
        for step in range(adaptation_steps):
            
            
            
            
            
            
            
            
            

            
            



            
            
            adaptation_error = loss(learner(adaptation_data), adaptation_labels)
            w_loss = wasserstein_distance_torch(adaptation_labels, adaptation_labels)
            learner.adapt(adaptation_error)
            


        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        predictions = learner(evaluation_data)
        _, pseudo_label = torch.max(predictions, 1)
        alpha = alpha_weight_for_pseudo_label_loss(iter)
        alpha = alpha
        pseudo_label_error = alpha*(loss(predictions, pseudo_label))

        
        
        actual_accuracy = accuracy(learner(actual_data_test), actual_labels_test)
        beta = (1-alpha)
        if beta < 0:
            beta = 0.1
        actual_loss = beta * (loss(learner(actual_data_test), actual_labels_test))



        
        
        
        
        return pseudo_label_error, actual_accuracy, actual_loss, alpha, beta
    
    
    
    
    

def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    temprature = 1
    x = x/ temprature
    x = -(x.softmax(1) * x.log_softmax(1)).sum(1)
    return x


def entropy_predictions(predictions, entropy_threshold_perc=0.9):
    
    entropys = softmax_entropy(predictions)
    
    
    
    
    
    
    
    
    
    
    
    
    sorted_ascending = torch.sort(entropys, descending=False)[1]
    sorted_ascending = sorted_ascending.cpu().numpy()
    
    entropy_threshold_perc = 1- entropy_threshold_perc
    top_k = int(entropy_threshold_perc*len(sorted_ascending))
    sorted_ascending = sorted_ascending[:top_k]
    
    
    

    
    
     
    
    
    
    
    


    return sorted_ascending


def filter_data_entropy_ids(data, labels, entropy_ids):
    
    labels = np.delete(labels, entropy_ids)
    
    data = np.delete(data, entropy_ids, axis=0)
    return data, labels

def fast_adapt_test_time(data, learner, loss, adaptation_steps, shots, ways, device, iter, entropy_threshold_perc):
    
    

    
    
    
    
    
    
    

    
    
    labels = learner(data)
    
    entropy_ids = entropy_predictions(labels, entropy_threshold_perc)
    data = data.cpu().numpy()
    
    labels = torch.argmax(labels, dim=1)
    labels = labels.cpu().numpy()
    
    
    
    
    data, labels = filter_data_entropy_ids(data, labels, entropy_ids)
    
    
    data, labels = filter_data_and_labels_one_occurence(data, labels)
    ent_num_samples = len(labels)
    
    
    
    
    
    
    
    
    






    if 1==1:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5,  stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)
        


        
        adaptation_steps = 1
        for step in range(adaptation_steps):
            
            
            
            predictions = learner(adaptation_data)
            _, pseudo_label = torch.max(predictions, 1)
            
            
            
            
            
            

            
            
            
            
            
            
            
            
            
            
            
            
            
            
            
            
            

            
            
            
            
            
            



            adaptation_error = loss(predictions, pseudo_label)

            learner.adapt(adaptation_error)
            


        
        predictions = learner(evaluation_data)
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        _, pseudo_label = torch.max(predictions, 1)
        alpha = alpha_weight_for_pseudo_label_loss(iter)
        alpha = alpha + 0.5
        pseudo_label_error = alpha*(loss(predictions, pseudo_label))
        
        evaluation_accuracy = 0.5
        







        
        
        
        
        
        
        return pseudo_label_error, evaluation_accuracy, ent_num_samples
    
    
    
    



def fast_adapt_test_time_vi(data, learner, loss, adaptation_steps, shots, ways, device, iter, entropy_threshold_perc, update_pseudo_label_times):
    
    

    
    
    
    
    
    
    

    
    
    labels, _ = learner(data)
    
    entropy_ids = entropy_predictions(labels, entropy_threshold_perc)
    data = data.cpu().numpy()
    
    labels = torch.argmax(labels, dim=1)
    labels = labels.cpu().numpy()
    
    
    
    
    data, labels = filter_data_entropy_ids(data, labels, entropy_ids)
    
    
    data, labels = filter_data_and_labels_one_occurence(data, labels)
    ent_num_samples = len(labels)
    
    
    
    
    
    
    
    
    






    if 1==1:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5,  stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)
        


        
        adaptation_steps = 1
        for step in range(adaptation_steps):
            
            
            
            


            

            
            
            for i in range(update_pseudo_label_times):

                _, features = learner(adaptation_data)
                qDz = []
                for cate in range(7):
                    if cate in adaptation_labels.unique():
                        qDz.append(features[adaptation_labels==cate].mean(0,keepdim=True))
                    else:
                        qDz.append(features.mean(0,keepdim=True))
                qDz = torch.cat(qDz,0)
                qw_mu , qw_sigma = learner.module.classifier(qDz)

                y = torch.mm(features, qw_mu.permute(1,0).contiguous().view(512, 7)) 

                
                
                
                y=  y.view(len(adaptation_labels), 7)
                
                y = y.unsqueeze(1)
                
                y = y.repeat(1,5,1)
                
                preds, _ = learner(adaptation_data)
                
                preds = preds.unsqueeze(2)
                
                preds = preds.repeat(1,1,5)
                
                
                _, updated_pseudo_label = torch.max(y, 1)


                
                
                unifo = torch.rand(y.size())
                unifo = unifo.to(device)
                samples = torch.argmax(y - torch.log(-torch.log(unifo)), dim=2)
                
                samples = samples.to(device)
                
                
                adaptation_error = loss(preds, samples)
                learner.adapt(adaptation_error)
                


                
                




                

                
                
                
                
                







        
        predictions, _ = learner(evaluation_data)
        
        
        
        
        
        
        
        
        
        _, features = learner(evaluation_data)
        qDz = []
        for cate in range(7):
            if cate in evaluation_labels.unique():
                qDz.append(features[evaluation_labels==cate].mean(0,keepdim=True))
            else:
                qDz.append(features.mean(0,keepdim=True))
        qDz = torch.cat(qDz,0)
        qw_mu , qw_sigma = learner.module.classifier(qDz)

        y = torch.mm(features, qw_mu.permute(1,0).contiguous().view(512, 7)) 
        y=  y.view(len(evaluation_labels), 7)
        
        _, updated_pseudo_label = torch.max(y, 1)

        alpha = alpha_weight_for_pseudo_label_loss(iter)
        alpha = alpha + 0.5

        pseudo_label_error = alpha*(loss(predictions, updated_pseudo_label))

        evaluation_accuracy = 0.5







        
        return pseudo_label_error, evaluation_accuracy, ent_num_samples



def fast_adapt_test_time_vi_normal_refinement(data, learner, loss, adaptation_steps, shots, ways, device, iter, entropy_threshold_perc):
    
    

    
    
    
    
    
    
    

    
    
    labels, _ = learner(data)
    
    entropy_ids = entropy_predictions(labels, entropy_threshold_perc)
    data = data.cpu().numpy()
    
    labels = torch.argmax(labels, dim=1)
    labels = labels.cpu().numpy()
    
    
    
    
    data, labels = filter_data_entropy_ids(data, labels, entropy_ids)
    
    
    data, labels = filter_data_and_labels_one_occurence(data, labels)
    ent_num_samples = len(labels)
    
    
    
    
    
    
    
    
    






    if 1==1:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5,  stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)
        


        
        adaptation_steps = 1
        for step in range(adaptation_steps):
            
            
            
            


            

            
            
            
            
            
            
            
            
            
            
            

            
            
            
            

            preds, _ = learner(adaptation_data)
            _, pseudo_label = torch.max(preds, 1)
            adaptation_error = loss(preds, pseudo_label)
            learner.adapt(adaptation_error)






        
        predictions, _ = learner(evaluation_data)
        
        
        _, pseudo_label = torch.max(predictions, 1)
        alpha = alpha_weight_for_pseudo_label_loss(iter)
        alpha = alpha + 0.5
        pseudo_label_error = alpha*(loss(predictions, pseudo_label))
        
        evaluation_accuracy = 0.5
        
        
        
        
        
        
        
        
        
        

        
        
        
        

        
        

        

        







        
        return pseudo_label_error, evaluation_accuracy, ent_num_samples
    
    
    
    


def eval_test_data(data,labels, learner, loss, adaptation_steps, shots, ways, device):
    
    

    
    
    
    
    
    
    

    
    
    data = data.cpu().numpy()
    labels = labels.cpu().numpy()

    try:
        adaptation_data, evaluation_data, adaptation_labels, evaluation_labels = train_test_split(data, labels, test_size=0.5,stratify=labels)
        adaptation_data, adaptation_labels = torch.from_numpy(adaptation_data).to(device), torch.from_numpy(adaptation_labels).to(device)
        evaluation_data, evaluation_labels = torch.from_numpy(evaluation_data).to(device), torch.from_numpy(evaluation_labels).to(device)
        


        
        for step in range(adaptation_steps):
            adaptation_error = loss(learner(adaptation_data), adaptation_labels)
            
            


        
        predictions = learner(evaluation_data)
        evaluation_error = loss(predictions, evaluation_labels)
        evaluation_accuracy = accuracy(predictions, evaluation_labels)
        
        
        
        return evaluation_error, evaluation_accuracy
    except:
        log_string('Error in fast_adapt, train test split')
        val = 0
        return torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device), torch.tensor(val).to(device)


def train_meta(args):
    
    

    
    

    
    
    
    
    
    
    
    
    
    python_file_name = os.path.basename(__file__)
    
    current_directory = os.getcwd()
    python_file_name = os.path.join(current_directory, python_file_name)
    
    log_string('Uploaded file: %s' % python_file_name)

    
    
    

    for i in range(4):
        
        
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    t0 = time.time()
    kl_loss_tot = 0
    w_loss_tot = 0
    js_div_tot= 0
    correct_source = 0
    total_source = 0
    adapt_loss_tot = 0 
    
    
    
    
    
    
    
    batch_size = 8

    adaptation_steps = 1
    num_iterations = args.num_iterations
    shots =5 
    ways = 7
    lr_adam_maml = args.lr_adam_maml
    lr_adam = args.lr_adam
    entropy_threshold_perc = args.entropy_threshold_perc
    update_pseudo_label_times = args.update_pseudo_label_times
    
    fast_lr = 0.5
    
    
    
    
    
    
        
    
    
    
    
    
    
    
    
    
    
    
    
    


    
    
    
    
    
    
    
    
    
    

    maml =  l2l.algorithms.MAML(net, lr=lr_adam_maml, first_order=False, allow_nograd=True)
    maml.to(device)
    print(maml)
    
    
    
    
    
    
    
    optimizer = torch.optim.Adam(maml.parameters(), lr=lr_adam, weight_decay=WEIGHT_DECAY, )
    
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    
    scheduler_lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0.01, eps=1e-08)

    loss = torch.nn.CrossEntropyLoss(reduction='mean') 
    checkpoint_dir = os.path.join(MODEL_DIR, 'checkpoint')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    best_valid_acc = 0
    
    if args.resume_from_checkpoint:
        if os.path.exists(args.resume_from_checkpoint):
            log_string('Loading checkpoint from {}'.format(args.resume_from_checkpoint))
            checkpoint = torch.load(args.resume_from_checkpoint, map_location = device)
            maml.load_state_dict(checkpoint['net'])
            
            best_valid_acc = checkpoint['acc']
            start_epoch = checkpoint['epoch']
            log_string('Loaded checkpoint from epoch {}'.format(start_epoch))
            log_string('\n \n ===> Loaded model from checkpoint <==========')
        else:
            log_string('No checkpoint found at {}'.format(args.resume_from_checkpoint))
            exit()

    

    
    
    
    
    
    
    maml.to(device)

    
    
    
    
    es_acc = test_meta_model(maml, -1)


    print('Number of trainable parameters: ', sum(p.numel() for p in maml.parameters() if p.requires_grad))
    all_dataset.reset('ttt', 0, transform=transform_test)
            
    
    
    
    test_dataset = l2l.data.MetaDataset(all_dataset)
    transforms_test = [
        l2l.data.transforms.NWays(test_dataset, ways),
        l2l.data.transforms.KShots(test_dataset, 2*shots),
        l2l.data.transforms.LoadData(test_dataset),
        
    ]
    test_taskset = l2l.data.TaskDataset(test_dataset, transforms_test, num_tasks = 999999)

    log_string('***********************************')

    if args.variational_refinement == True and update_pseudo_label_times!=0:
        log_string('Variational refinement is ON')
    else:
        log_string('Using Normal Pseudo labels')
    log_string('***********************************')

    count_epoch_acc_best = 0



    for iter in range(num_iterations):
        
       
        t0 = time.time()
        
        if (int(math.remainder(iter, 1 ))==0):

            
            
            
            
            
            print('Adapting test time')
            
            
            
            
            
            
            
            
            
                
            
            

            
            
            
            
            
            
            
            
            
            
            
            
            

            print('Testing..')
            

            
            meta_test_error = 0
            meta_test_accuracy = 0
            ent_num_samples_tot = 0 
            for counter, (inputs, targets, img_name2 ) in enumerate(test_taskset):
                learner = maml.clone()
                
                data, labels = inputs.to(device), targets.to(device)
                

                
                if args.variational_refinement == True and update_pseudo_label_times!=0:

                    eval_error_test , eval_accuracy_test, ent_num_samples = fast_adapt_test_time_vi(data, learner, loss, adaptation_steps, shots, ways, device, iter, entropy_threshold_perc, update_pseudo_label_times)
                elif update_pseudo_label_times==0:
                    eval_error_test , eval_accuracy_test, ent_num_samples= fast_adapt_test_time_vi_normal_refinement(data, learner, loss, adaptation_steps, shots, ways, device, iter, entropy_threshold_perc)
                eval_error_test = 0.0001 * eval_error_test
                eval_error_test.backward()
                meta_test_error += eval_error_test.item()
                ent_num_samples_tot += ent_num_samples
                
                
                
                if (counter == (batch_size -1) ):
                    break
            
            
            
            
            log_string('\t Iteration %d: test error %.2f ent_num_samples %.2f' % (iter, meta_test_error/batch_size, ent_num_samples_tot/batch_size))
            
            
            
            test_acc = test_meta_model(maml, iter+1)
            
            
            for p in maml.module.features.parameters():
                p.grad.data.mul_(1.0 / batch_size)
            for p in maml.module.fc.parameters():
                p.grad.data.mul_(1.0 / batch_size)
            optimizer.step()
            optimizer.zero_grad()
            

        
        t1 = time.time()
        
        log_string('Time taken for iteration %d: %.2f minutes' % (iter, (t1-t0)/60))
        print('\n')
        
        
        
        
        
        
        
        acc = test_acc
        if acc > best_valid_acc:
            print('Saving best model... %f' % acc)
            state = {
                'net': maml.state_dict(),
                'acc': acc,
                'epoch': iter,
            }
            
            best_valid_acc = acc
            
            torch.save(state, os.path.join(checkpoint_dir, 'best_model.pth'))

        
        if acc < es_acc:
            count_epoch_acc_best += 1
        if count_epoch_acc_best > 3:
            log_string('*********Early stopping*********')
            log_string('Accuracy is not increasing for continosuly 3 iterations, so stopping the training')
            log_string('*********Early stopping*********')
            break
        
        

        
        


def test_meta_model(maml, iter):
    
    
    
    log_counter = 0 
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    print('Testing on real test set...')

    all_dataset.reset('test', 0, transform=transform_test)
    testloader = torch.utils.data.DataLoader(all_dataset, batch_size=test_batch, shuffle=False, num_workers=cpu_workers, worker_init_fn=worker_init_fn)
    rt_context.reset('test', transform=transform_test)
    
    
    
    test_loss = 0
    correct = 0
    total = 0
    
    batch_count = 0 
    with torch.no_grad():

        for batch_idx, (inputs, targets,  img_name1 ) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _= maml(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()



        
        tl = test_loss/(batch_idx+1)
        log_string('\t Real Test Loss %f, Acc: %f' % ((tl), (100.*correct/total)))
        print(test_loss/(batch_idx+1))
        

        
        
        te_writer.add_scalar('te/loss',  test_loss/batch_idx+1, iter)
        te_writer.add_scalar('te/acc', 100.*correct/total, iter)
        
        
        test_acc = 100.*correct/total
        return test_acc



        







    

    








decay_ite = [0.6*max_ite]

















































train_meta(args)
