
import os
import yaml
import time
import hydra
import wandb
import random
import pprint
import argparse
import numpy as np
from tqdm import tqdm
from functools import partial
from omegaconf import DictConfig, OmegaConf
from collections import Counter
import pandas as pd
import pickle

import torch
from torch import nn
from torch import optim
from torch.autograd import grad as torch_grad
import torch.nn.functional as F

from torchvision import transforms
import learn2learn as l2l
from learn2learn.data.transforms import (
    NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels)
from utils import (
    CifarCNN, kernel_mats, kernel_mats_batch, active_dpp, active_dpp_precomp,
    active_prob_cover, active_typiclust, active_coreset, active_gmm, active_gmm_train,
    active_gmm_precomp, active_nic, active_margin, active_ent, active_random, active_vopt_total,
    accuracy, update_config, split_dict, DotDict, revert_to_dict, save_tsne_dataset_features, save_tsne_batch_features,
    pairwise_distances_logits, compute_entropy)
from config import Config
from mixture import GaussianMixture, MixRegEM
from kfda import KFDA


def adapt_maml(batch, learner, loss, adaptation_steps, train_shots, test_shots,
               ways, device):
    data, labels, _ = batch
    #data, labels = batch
    data, labels = data.to(device), labels.to(device)

    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices[selection + offset] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    if loss == 'l2':
        loss = nn.MSELoss(reduction='mean')
        eye = torch.eye(ways).to(device)
        adaptation_labels = eye[adaptation_labels]
        evaluation_labels = eye[evaluation_labels]
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy



def adapt_anil(batch, learner, features, loss, adaptation_steps,
               train_shots, test_shots, ways, device):
    data, labels, _ = batch
    data, labels = data.to(device), labels.to(device)
    data = features(data)

    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices[selection + offset] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    if loss == 'l2':
        loss = nn.MSELoss(reduction='mean')
        eye = torch.eye(ways).to(device)
        adaptation_labels = eye[adaptation_labels]
        evaluation_labels = eye[evaluation_labels]
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy


def adapt_ntk(batch, learner, loss, adaptation_steps, train_shots, test_shots,
               ways, device, time, kernel_batch_size):
    data, labels, _ = batch
    data, labels = data.to(device), labels.to(device)

    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices[selection + offset] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    K_testvtrain, K_trainvtrain = kernel_mats(
        learner, adaptation_data, evaluation_data, device,
        kernels='both', batch_size=kernel_batch_size)
    K_trainvtrain_inv = torch.inverse(
        K_trainvtrain.double() + 0.001 * torch.eye(
            train_shots * ways, dtype=torch.double).to(device)).float()
    temp_mat = torch.mm(K_testvtrain, K_trainvtrain_inv)
    #if time >= 1e-10:
    #    temp_mat = torch.mm(
    #        K_testvtrain, K_trainvtrain_inv - torch.mm(K_trainvtrain_inv, expm_one(-1*time*K_trainvtrain))).to(device)
    #else:
    #    temp_mat = torch.mm(K_testvtrain, K_trainvtrain_inv)
    N_train = adaptation_data.size(0)
    Y_train = np.ones((N_train, ways)) * (-1.0 / ways)
    for i in range(N_train):
        Y_train[i][adaptation_labels[i]] = (1.0 - 1.0 / ways)
    Y_train = torch.tensor(Y_train).to(device).float()

    N_test = evaluation_data.size(0)
    Y_test = np.ones((N_test, ways)) * (-1.0 / ways)
    for i in range(N_test):
        Y_test[i][evaluation_labels[i]] = (1.0 - 1.0 / ways)
    Y_test = torch.tensor(Y_test).to(device).float()

    mean_vec = torch.mm(temp_mat, Y_train)
    if loss =='l2':
        mseloss = nn.MSELoss(reduction='mean')
        valid_error = mseloss(mean_vec, Y_test)
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')
        valid_error = loss(mean_vec, evaluation_labels)

    # Evaluate the adapted model
    pred_label = torch.argmax(mean_vec, dim=1)
    valid_accuracy = torch.sum(pred_label==evaluation_labels).float() / N_test
    return valid_error, valid_accuracy

def adapt_antk(batch, learner, features, loss, adaptation_steps, train_shots, test_shots,
               ways, device, time, kernel_batch_size):
    data, labels, _ = batch
    data, labels = data.to(device), labels.to(device)
    data = features(data)

    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices[selection + offset] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    K_testvtrain = torch.matmul(evaluation_data, adaptation_data.T).double()
    K_trainvtrain = torch.matmul(adaptation_data, adaptation_data.T).double()
    K_trainvtrain_inv = torch.inverse(K_trainvtrain + 1e-8 * torch.eye(train_shots*ways).to(device))
    #K_trainvtrain_inv = torch.inverse(K_trainvtrain)
    temp_mat = torch.mm(K_testvtrain, K_trainvtrain_inv).float()

    #K_testvtest = torch.matmul(evaluation_data, evaluation_data.T) # nt x nt
    #cov = K_testvtest - torch.matmul(temp_mat, K_testvtrain.float().T) # nt x nt

    N_train = adaptation_data.size(0)
    Y_train = np.ones((N_train, ways)) * (-1.0 / ways)
    for i in range(N_train):
        Y_train[i][adaptation_labels[i]] = (1.0 - 1.0 / ways)
    Y_train = torch.tensor(Y_train).to(device).float()

    N_test = evaluation_data.size(0)
    Y_test = np.ones((N_test, ways)) * (-1.0 / ways)
    for i in range(N_test):
        Y_test[i][evaluation_labels[i]] = (1.0 - 1.0 / ways)
    Y_test = torch.tensor(Y_test).to(device).float()

    mean_vec = torch.mm(temp_mat, Y_train)
    if loss =='l2':
        mseloss = nn.MSELoss(reduction='mean')
        valid_error = mseloss(mean_vec, Y_test)
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')
        valid_error = loss(mean_vec, evaluation_labels)

    # Evaluate the adapted model
    pred_label = torch.argmax(mean_vec, dim=1)
    valid_accuracy = torch.sum(pred_label==evaluation_labels).float() / N_test
    #valid_error = valid_error #+ torch.sum(torch.abs(cov - torch.eye(5).to(device))) / 10000.
    return valid_error, valid_accuracy

def adapt_kfda(batch, learner, features, loss, adaptation_steps, train_shots, test_shots,
               ways, device, time, kernel_batch_size):
    data, labels, _ = batch
    data, labels = data.to(device), labels.to(device)
    data = features(data)

    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices[selection + offset] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    ## Adapt the model
    #K_testvtrain = torch.matmul(evaluation_data, adaptation_data.T).double()
    #K_trainvtrain = torch.matmul(adaptation_data, adaptation_data.T).double()
    #K_trainvtrain_inv = torch.inverse(K_trainvtrain + 1e-8 * torch.eye(train_shots*ways).to(device))
    ##K_trainvtrain_inv = torch.inverse(K_trainvtrain)
    #temp_mat = torch.mm(K_testvtrain, K_trainvtrain_inv).float()

    ##K_testvtest = torch.matmul(evaluation_data, evaluation_data.T) # nt x nt
    ##cov = K_testvtest - torch.matmul(temp_mat, K_testvtrain.float().T) # nt x nt

    #N_train = adaptation_data.size(0)
    #Y_train = np.ones((N_train, ways)) * (-1.0 / ways)
    #for i in range(N_train):
    #    Y_train[i][adaptation_labels[i]] = (1.0 - 1.0 / ways)
    #Y_train = torch.tensor(Y_train).to(device).float()

    #N_test = evaluation_data.size(0)
    #Y_test = np.ones((N_test, ways)) * (-1.0 / ways)
    #for i in range(N_test):
    #    Y_test[i][evaluation_labels[i]] = (1.0 - 1.0 / ways)
    #Y_test = torch.tensor(Y_test).to(device).float()

    kfda = KFDA(kernel='rbf', n_components=ways-1, n_classes=ways, device=device)
    kfda.fit(adaptation_data, adaptation_labels)
    probs, pred_labels = kfda.predict(evaluation_data)

    N_test = evaluation_data.size(0)
    eye = torch.eye(ways).double().to(device)
    Y_test = eye[evaluation_labels]

    #mean_vec = torch.mm(temp_mat, Y_train)
    if loss =='l2':
        mseloss = nn.MSELoss(reduction='mean')
        valid_error = mseloss(probs, Y_test)
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')
        valid_error = loss(probs, evaluation_labels)

    # Evaluate the adapted model
    #pred_label = torch.argmax(probs, dim=1)
    valid_accuracy = torch.sum(pred_labels==evaluation_labels).float() / N_test
    #valid_error = valid_error #+ torch.sum(torch.abs(cov - torch.eye(5).to(device))) / 10000.
    return valid_error, valid_accuracy


def adapt_lwantk(batch, learner, features, loss, adaptation_steps, train_shots, test_shots,
                 ways, device, time, kernel_batch_size, dataset):
    data, labels, indices = batch
    data, labels = data.to(device), labels.to(device)
    data = features(data)

    adaptation_indices_bool = np.zeros(len(indices), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices_bool[selection + offset] = True
    evaluation_indices_bool = ~adaptation_indices_bool
    adaptation_data, adaptation_labels, adaptation_indices =\
        data[adaptation_indices_bool], labels[adaptation_indices_bool], indices[adaptation_indices_bool]
    evaluation_data, evaluation_labels, evaluation_indices =\
        data[evaluation_indices_bool], labels[evaluation_indices_bool], indices[evaluation_indices_bool]
    evaluation_indices = evaluation_indices.detach().cpu().numpy()

    images_per_class = Counter(adaptation_labels.detach().cpu().numpy().tolist())

    # find all candidate points
    (cand_indices, cand_labels) = dataset.load_candidates(
        adaptation_indices, evaluation_indices, adaptation_labels) # cand_labels: task labels (0~ways)
    sorted_indices = np.argsort(cand_indices)
    cand_indices, cand_labels = cand_indices[sorted_indices], cand_labels[sorted_indices]


    N_train = adaptation_data.size(0)
    Y_train = np.ones((N_train, ways)) * (-1.0 / ways)
    for i in range(N_train):
        Y_train[i][adaptation_labels[i]] = (1.0 - 1.0 / ways)
    Y_train = torch.tensor(Y_train).to(device).float()

    N_test = evaluation_data.size(0)
    Y_test = np.ones((N_test, ways)) * (-1.0 / ways)
    for i in range(N_test):
        Y_test[i][evaluation_labels[i]] = (1.0 - 1.0 / ways)
    Y_test = torch.tensor(Y_test).to(device).float()

    # =================================================
    #diff = adaptation_data.unsqueeze(1) - adaptation_data.unsqueeze(0)
    #support_weights = torch.exp(-torch.norm(diff, p=2, dim=-1))
    #support_weights = support_weights / torch.sum(support_weights, dim=0, keepdim=True)
    #support_weights = torch.diag_embed(support_weights) # E x S x S

    #batched_support = adaptation_data.unsqueeze(0).repeat(support_weights.shape[0], 1, 1) # E x S x D
    #K_trainvtrain = torch.matmul(
    #    torch.matmul(batched_support.transpose(1, 2), support_weights), batched_support).double()
    #K_trainvtrain += 1e-8 * torch.eye(batched_support.shape[-1]).unsqueeze(0).to(device)

    #LU, pivots = torch.linalg.lu_factor(K_trainvtrain)
    #XT_W_Y = torch.matmul(
    #    torch.matmul(batched_support.transpose(1, 2), support_weights), Y_train).double() # E x S x D @ E x S x S @ S x C
    #betas = torch.linalg.lu_solve(LU, pivots, XT_W_Y).float() # E x D x C

    #query_diff = evaluation_data.unsqueeze(1) - adaptation_data.unsqueeze(0)
    #query_weights = torch.exp(-torch.norm(query_diff, p=2, dim=-1)).unsqueeze(-1) # Q x E x 1
    #query_weights = query_weights / torch.sum(query_weights, dim=1, keepdim=True) # Q x E x 1

    #batched_query = evaluation_data.unsqueeze(1).repeat(1, adaptation_data.shape[0], 1) # Q x E x D
    #preds = torch.einsum('qed,edc->qec', batched_query, betas) # Q x E x C
    #mean_vec = torch.sum(query_weights * preds, dim=1)
    # =================================================

    #phi_total = features(dataset[cand_indices][0].to(device)) # phi: (N, d)
    #covs = []
    #selected_indices, selected_labels = [], []
    #for cls in images_per_class.keys():
    #    num_query = images_per_class[cls]
    #    cls_cand_indices = cand_indices[cand_labels==cls]
    #    phi = phi_total[cand_labels==cls] # phi: (N, d)
    #    #phi = features(cand_data) # phi: (N, d)

    #    gmm = GaussianMixture(
    #        n_components=num_query, n_features=phi.shape[1], covariance_type="diag").to(device)
    #    gmm.fit(phi)
    #    gmm_idx = gmm.get_nearest_samples(phi).detach().cpu().numpy()

    #    selected_idx = cls_cand_indices[gmm_idx]
    #    selected_indices.append(selected_idx)
    #    selected_labels += [cls] * num_query
    #    covs.append(gmm.var)

    #adaptation_data = dataset[np.concatenate(selected_indices)][0]
    #adaptation_data = features(adaptation_data.to(device))
    #evaluation_data = features(evaluation_data.to(device))

    # =================================================
    #mix_regressor = MixRegEM(*adaptation_data.shape, device=device)
    #mix_regressor.fit(adaptation_data, Y_train)

    #import IPython; IPython.embed()
    # =================================================
    # Adapt the model
    K_testvtrain = torch.matmul(evaluation_data, adaptation_data.T)
    K_trainvtrain = (torch.matmul(
        adaptation_data, adaptation_data.T) + 0.0 * torch.eye(train_shots*ways).to(device)).double() # S x S
    try:
        LU, pivots = torch.linalg.lu_factor(K_trainvtrain)
    except:
        import IPython; IPython.embed()
    K_trainvtrain_inv_Y = torch.linalg.lu_solve(LU, pivots, Y_train.double()).float() # S x C

    mean_vec = torch.mm(K_testvtrain, K_trainvtrain_inv_Y) # Q x C

    query_weights = F.softmax(
        pairwise_distances_logits(evaluation_data, adaptation_data), dim=-1).unsqueeze(-1).detach() # Q x E x 1
    mean_vec = torch.sum(mean_vec.unsqueeze(1) * query_weights, dim=1)
    # =================================================
    #query_weights = F.softmax(
    #    pairwise_distances_logits(evaluation_data, adaptation_data), dim=-1).unsqueeze(-1) # Q x E x 1
    #argmax_indices = torch.argmax(query_weights, dim=1).squeeze(-1).detach()

    #K_testvtrain = torch.matmul(evaluation_data, adaptation_data.T)
    #K_trainvtrain = torch.matmul(
    #    adaptation_data, adaptation_data.T) + 0.0 * torch.eye(train_shots*ways).to(device) # S x S

    #K_testvtrain = K_testvtrain[torch.arange(5), argmax_indices]
    #K_trainvtrain = K_trainvtrain[argmax_indices, argmax_indices]
    #Y_train = Y_train[argmax_indices]

    #mean_vec = ((K_testvtrain / K_trainvtrain).unsqueeze(-1) * Y_train)
    # =================================================

    #K_trainvtrain_inv = torch.inverse(K_trainvtrain + 1e-8 * torch.eye(train_shots*ways).to(device))
    #K_trainvtrain_inv = torch.inverse(K_trainvtrain)
    #temp_mat = torch.mm(K_testvtrain, K_trainvtrain_inv).float()

    #N_train = adaptation_data.size(0)
    #Y_train = np.ones((N_train, ways)) * (-1.0 / ways)
    #for i in range(N_train):
    #    Y_train[i][adaptation_labels[i]] = (1.0 - 1.0 / ways)
    #Y_train = torch.tensor(Y_train).to(device).float()

    #N_test = evaluation_data.size(0)
    #Y_test = np.ones((N_test, ways)) * (-1.0 / ways)
    #for i in range(N_test):
    #    Y_test[i][evaluation_labels[i]] = (1.0 - 1.0 / ways)
    #Y_test = torch.tensor(Y_test).to(device).float()

    #mean_vec = torch.mm(temp_mat, Y_train) # eval (nt) x ways (k)

    # compute weights using the distance between adapt and eval assuming cov is identity
    #unnorm_weights = torch.exp(-1.0 * torch.linalg.norm(
    #    adaptation_data.unsqueeze(0) - evaluation_data.unsqueeze(1), dim=-1)) # eval (nt) x adapt (E)
    #covs = torch.cat(covs, dim=1)
    #prec = torch.rsqrt(covs).detach()
    #mu, x = adaptation_data.unsqueeze(0).detach(), evaluation_data.unsqueeze(1).detach()
    #unnorm_weights = torch.sum(
    #    (mu * mu + x * x - 2 * x * mu) * (prec ** 2), dim=2, keepdim=True)
    #unnorm_weights = torch.exp(-0.5 * unnorm_weights) + 1e-8
    #weights = unnorm_weights / torch.sum(unnorm_weights, dim=0, keepdim=True)
    #mean_vec = torch.sum(mean_vec.unsqueeze(1) * weights, dim=1)

    if loss =='l2':
        mseloss = nn.MSELoss(reduction='mean')
        valid_error = mseloss(mean_vec, Y_test)
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')
        valid_error = loss(mean_vec, evaluation_labels)

    # Evaluate the adapted model
    pred_label = torch.argmax(mean_vec, dim=1)
    valid_accuracy = torch.sum(pred_label==evaluation_labels).float() / N_test
    return valid_error, valid_accuracy


def adapt_metaopt(batch, maml, features, loss, adaptation_steps, train_shots, test_shots,
                  ways, device, time, kernel_batch_size, dataset):
    features.train()
    data, labels, indices = batch
    data, labels = data.to(device), labels.to(device)

    # Sort data samples by labels
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)
    embeddings = features(data)

    # Compute support and query embeddings
    adaptation_indices_bool = np.zeros(len(indices), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices_bool[selection + offset] = True
    evaluation_indices_bool = ~adaptation_indices_bool
    adaptation_data, adaptation_labels, adaptation_indices =\
        embeddings[adaptation_indices_bool], labels[adaptation_indices_bool], indices[adaptation_indices_bool]
    evaluation_data, evaluation_labels, evaluation_indices =\
        embeddings[evaluation_indices_bool], labels[evaluation_indices_bool], indices[evaluation_indices_bool]

    maml.fit_(adaptation_data, adaptation_labels, ways=ways)
    logits = maml(evaluation_data)

    if loss == 'l2':
        loss = nn.MSELoss(reduction='mean')
        eye = torch.eye(ways).to(device)
        evaluation_labels = eye[evaluation_labels]
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')

    # Evaluate the adapted model
    evaluation_error = loss(logits, evaluation_labels)
    evaluation_accuracy = accuracy(logits, evaluation_labels)
    return evaluation_error, evaluation_accuracy

def adapt_proto(batch, maml, features, loss, adaptation_steps, train_shots, test_shots,
                ways, device, time, kernel_batch_size, dataset):
    data, labels, indices = batch
    data, labels = data.to(device), labels.to(device)

    # Sort data samples by labels
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)
    embeddings = features(data)

    # Compute support and query embeddings
    adaptation_indices_bool = np.zeros(len(indices), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices_bool[selection + offset] = True
    evaluation_indices_bool = ~adaptation_indices_bool
    adaptation_data, adaptation_labels, adaptation_indices =\
        embeddings[adaptation_indices_bool], labels[adaptation_indices_bool], indices[adaptation_indices_bool]
    evaluation_data, evaluation_labels, evaluation_indices =\
        embeddings[evaluation_indices_bool], labels[evaluation_indices_bool], indices[evaluation_indices_bool]

    support = adaptation_data.reshape(ways, train_shots, -1).mean(dim=1)
    query = evaluation_data

    logits = pairwise_distances_logits(query, support)
    if loss == 'l2':
        loss = nn.MSELoss(reduction='mean')
        eye = torch.eye(ways).to(device)
        evaluation_labels = eye[evaluation_labels]
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')

    evaluation_error = loss(logits, evaluation_labels)
    evaluation_accuracy = accuracy(logits, evaluation_labels)
    return evaluation_error, evaluation_accuracy

def adapt_proto_semi(batch, maml, features, loss, adaptation_steps, train_shots, test_shots,
                ways, device, time, kernel_batch_size, dataset):
    features.train()
    data, labels, indices = batch
    data, labels = data.to(device), labels.to(device)

    # Sort data samples by labels
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)
    embeddings = features(data)

    # Compute support and query embeddings
    adaptation_indices_bool = np.zeros(len(indices), dtype=bool)
    selection = np.arange(ways) * (train_shots + test_shots)
    for offset in range(train_shots):
        adaptation_indices_bool[selection + offset] = True
    evaluation_indices_bool = ~adaptation_indices_bool
    adaptation_data, adaptation_labels, adaptation_indices =\
        embeddings[adaptation_indices_bool], labels[adaptation_indices_bool], indices[adaptation_indices_bool]
    evaluation_data, evaluation_labels, evaluation_indices =\
        embeddings[evaluation_indices_bool], labels[evaluation_indices_bool], indices[evaluation_indices_bool]

    support = embeddings[adaptation_indices_bool]
    support = support.reshape(ways, train_shots, -1).mean(dim=1)
    query = embeddings[evaluation_indices_bool]

    evaluation_labels = labels[evaluation_indices_bool].long()
    logits = pairwise_distances_logits(query, support)
    if loss == 'l2':
        loss = nn.MSELoss(reduction='mean')
        eye = torch.eye(ways).to(device)
        evaluation_labels = eye[evaluation_labels]
    else:
        loss = nn.CrossEntropyLoss(reduction='mean')

    evaluation_error = loss(logits, evaluation_labels)
    evaluation_accuracy = accuracy(logits, evaluation_labels)
    return evaluation_error, evaluation_accuracy




def main(config, ith_sweep=1, num_sweeps=1):
    run = wandb.init(project='active_meta_test_fixed',
                     config=config, entity="won-bae", mode="offline", reinit=True)

    # define configs
    checkpoint_path = config.checkpoint_path
    checkpoint_dir = config.checkpoint_dir

    save_al_comp = config.save_al_comp
    save_prob_cover = config.save_prob_cover
    save_entropy = config.save_entropy
    save_tsne = config.save_tsne

    training = config.model.training
    test_dataset = config.task.test_dataset
    active_config = config.active
    seed = config.task.seed

    if checkpoint_dir is not None:
        config_path = os.path.join(checkpoint_dir, 'config.yaml')
        with open(config_path) as f:
            config_yaml = yaml.safe_load(f)
        config = Config(dict(config_yaml))

    task_config, model_config, train_config = (
        config.task, config.model, config.train)

    # define task configs
    dataset, ways, train_shots, test_shots, tag = (
        task_config.dataset, task_config.ways, task_config.train_shots,
        task_config.test_shots, task_config.tag)
    if test_dataset is None:
        test_dataset = dataset

    # define model configs
    mode, backbone, channel_size = (
        model_config.mode, model_config.backbone, model_config.channel_size)

    # define train configs
    num_iterations, meta_lr, meta_batch_size, loss = (
        train_config.num_iterations, train_config.meta_lr,
        train_config.meta_batch_size, train_config.loss)
    test_meta_batch_size = 100 # 600
    val_meta_batch_size = meta_batch_size # 100

    # maml configs
    fast_lr, adaptation_steps = (
        train_config.fast_lr, train_config.adaptation_steps)
    test_adaptation_steps = adaptation_steps

    # ntk configs
    kernel_batch_size, ode_time = (
        train_config.kernel_batch_size, train_config.ode_time)

    # active configs
    train_active_strategy, per_class, active_strategy, threshold = (
        active_config.train_strategy, active_config.per_class, active_config.strategy, active_config.threshold)
    train_per_class = True
    #if dataset == 'miniimagenet':
    #    threshold = 6.4
    ##elif dataset == 'tieredimagenet':
    #elif dataset == 'fc100':
    #    threshold = 0.14
    print(f'threshold for prob cover: {threshold}')


    # creat an experiment name
    keywords = list()
    keywords.append(f'data_{dataset}')
    keywords.append(f'ways_{ways}')
    keywords.append(f'train_{train_shots}')
    keywords.append(f'test_{test_shots}')
    keywords.append(f'mode_{mode}')
    keywords.append(f'backbone_{backbone}')
    keywords.append(f'iter_{num_iterations}')
    keywords.append(f'mlr_{meta_lr}')
    keywords.append(f'mbatch_{meta_batch_size}')
    keywords.append(f'loss_{loss}')
    if mode not in ['ntk', 'antk', 'lwantk', 'kfda']:
        keywords.append(f'flr_{fast_lr}')
        keywords.append(f'asteps_{adaptation_steps}')
    keywords.append(f'active_{active_strategy}')
    keywords.append(f'per_class_{per_class}')
    keywords.append(f'atrain_{train_active_strategy}')

    exp_name = '_'.join(keywords)
    exp_name = f'{exp_name}_{tag}' if tag else exp_name
    wandb.run.name = exp_name

    if training:
        print(f'{ith_sweep}/{num_sweeps} sweeps...', flush=True)
        pp = pprint.PrettyPrinter(indent=2)
        pp.pprint(config)


    # set seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # set a device
    device = torch.device('cpu')
    if torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')
        print(f'{torch.cuda.get_device_name()}')

    # create datasets
    def _create_dataset(dataset):
        if dataset == 'miniimagenet':
            train_dataset = l2l.vision.datasets.MiniImagenet(
                root='/home/whbae/active_meta/data', mode='train', download=True)
            valid_dataset = l2l.vision.datasets.MiniImagenet(
                root='/home/whbae/active_meta/data', mode='validation', download=True)
            test_dataset = l2l.vision.datasets.MiniImagenet(
                root='/home/whbae/active_meta/data', mode='test', download=True)
        elif dataset == 'tieredimagenet':
            transform = transforms.Compose([transforms.ToTensor()])
            train_dataset = l2l.vision.datasets.TieredImagenet(
                root='/home/whbae/active_meta/data', mode='train', transform=transform, download=True)
            valid_dataset = l2l.vision.datasets.TieredImagenet(
                root='/home/whbae/active_meta/data', mode='validation', transform=transform, download=True)
            test_dataset = l2l.vision.datasets.TieredImagenet(
                root='/home/whbae/active_meta/data', mode='test', transform=transform, download=True)
        elif dataset == 'cifar_fs':
            transform = transforms.Compose([transforms.ToTensor()])
            train_dataset = l2l.vision.datasets.CIFARFS(
                root='/home/whbae/active_meta/data', mode='train', transform=transform, download=True)
            valid_dataset = l2l.vision.datasets.CIFARFS(
                root='/home/whbae/active_meta/data', mode='validation', transform=transform, download=True)
            test_dataset = l2l.vision.datasets.CIFARFS(
                root='/home/whbae/active_meta/data', mode='test', transform=transform, download=True)
        elif dataset == 'fc100':
            transform = transforms.Compose([transforms.ToTensor()])
            train_dataset = l2l.vision.datasets.FC100(
                root='/home/whbae/active_meta/data', mode='train', transform=transform, download=True)
            valid_dataset = l2l.vision.datasets.FC100(
                root='/home/whbae/active_meta/data', mode='validation', transform=transform, download=True)
            test_dataset = l2l.vision.datasets.FC100(
                root='/home/whbae/active_meta/data', mode='test', transform=transform, download=True)
        elif dataset == 'cub':
            image_size, crop_size = 128, 84
            transform = transforms.Compose([transforms.ToTensor()])
            train_transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.RandomCrop(crop_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor()
            ])
            train_dataset = l2l.vision.datasets.CUBirds200(
                root='/home/whbae/active_meta/data', mode='train', transform=train_transform, download=True)
            valid_dataset = l2l.vision.datasets.CUBirds200(
                root='/home/whbae/active_meta/data', mode='validation', transform=test_transform, download=True)
            test_dataset = l2l.vision.datasets.CUBirds200(
                root='/home/whbae/active_meta/data', mode='test', transform=test_transform, download=True)
        elif dataset == 'fgvca':
            #transform = transforms.Compose([transforms.ToTensor()])
            image_size, crop_size = 256, 224
            train_transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.RandomCrop(crop_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])
            test_transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor()
            ])
            train_dataset = l2l.vision.datasets.FGVCAircraft(
                root='/home/whbae/active_meta/data', mode='train', transform=train_transform, download=True)
            valid_dataset = l2l.vision.datasets.FGVCAircraft(
                root='/home/whbae/active_meta/data', mode='validation', transform=test_transform, download=True)
            test_dataset = l2l.vision.datasets.FGVCAircraft(
                root='/home/whbae/active_meta/data', mode='test',  transform=test_transform, download=True)
        else:
            raise NotImplementedError(f'dataset {dataset} is not implemented')

        train_dataset = l2l.data.MetaDataset(train_dataset)
        valid_dataset = l2l.data.MetaDataset(valid_dataset)
        test_dataset = l2l.data.MetaDataset(test_dataset)
        return (train_dataset, valid_dataset, test_dataset)


    (train_dataset, valid_dataset, _) = _create_dataset(dataset)
    (_, _, test_dataset) = _create_dataset(test_dataset)

    if mode in ['proto' 'proto_semi']:
        train_ways = 30 if train_shots == 1 else 20
    else:
        train_ways = ways

    train_transforms = [
        NWays(train_dataset, train_ways),
        #KShots(train_dataset, 2 * train_shots),
        KShots(train_dataset, (train_shots + test_shots)) if mode in ['proto' 'proto_semi'] else KShots(train_dataset, 2 * train_shots),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=-1)

    valid_transforms = [
        NWays(valid_dataset, ways),
        KShots(valid_dataset, train_shots + test_shots),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=-1)

    test_transforms = [
        NWays(test_dataset, ways),
        KShots(test_dataset, train_shots + test_shots),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=-1)

    # create a model
    def _create_model(mode, ways, backbone, dataset, channel_size, meta_lr, fast_lr, device):
        if backbone == 'cnn4':
            channel_size = 32
            if dataset in ['cifar_fs', 'fc100']:
                embedding_size = channel_size * 4
            else:
                embedding_size = channel_size * 25
            features = l2l.vision.models.CNN4Backbone(hidden_size=channel_size) # 32
        elif backbone == 'resnet12':
            channel_size = 640
            embedding_size = channel_size
            features = l2l.vision.models.ResNet12Backbone() # 640
        elif backbone == 'wrn28':
            channel_size = 640
            embedding_size = channel_size
            features = l2l.vision.models.WRN28Backbone() # 640
        else:
            raise NotImplementedError(f'dataset {backbone} is not implemented')
        features.to(device)

        output_dim = 1 if mode in ['ntk', 'antk', 'lwantk', 'kfda'] else ways
        if mode in ['anil']:
            head = torch.nn.Linear(embedding_size, output_dim)
            head.to(device)
            maml = l2l.algorithms.MAML(head, lr=fast_lr, first_order=False) # maml for only head part
            parameters = list(features.parameters()) + list(maml.parameters())
        elif mode in ['antk', 'lwantk', 'kfda']:
            head = torch.nn.Linear(embedding_size, output_dim)
            head.to(device)
            maml = l2l.algorithms.MAML(head, lr=fast_lr, first_order=False) # dummy: it won't be trained
            parameters = features.parameters()
        elif mode in ['metaopt']:
            head = l2l.nn.SVClassifier(C_reg=0.1, max_iters=15, normalize=False)
            head.to(device)
            maml = l2l.algorithms.MAML(head, lr=fast_lr)
            parameters = features.parameters()
        elif mode in ['proto', 'proto_semi']:
            head = torch.nn.Linear(embedding_size, output_dim)
            head.to(device)
            maml = l2l.algorithms.MAML(head, lr=fast_lr, first_order=False) # dummy: it won't be trained
            parameters = features.parameters()
        else:
            if dataset in ['cifar_fs', 'fc100']:
                model = CifarCNN(output_dim, hidden_size=channel_size)
            else:
                model = l2l.vision.models.MiniImagenetCNN(output_dim, hidden_size=channel_size)
            model.to(device)
            features = model.features
            maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
            parameters = maml.parameters()

            if dataset == 'miniimagenet' and test_dataset == 'cub':
                assert checkpoint_path is not None or checkpoint_dir is not None
                if checkpoint_path is None and checkpoint_dir is not None:
                    test_result_path = os.path.join(checkpoint_dir, 'best_test.yaml')
                    with open(test_result_path, 'r') as f:
                        test_result = yaml.safe_load(f)
                        test_iter = test_result['iteration']
                    checkpoint_path = os.path.join(checkpoint_dir, f'iter_{test_iter}.pth')

                maml.load_state_dict(torch.load(checkpoint_path))
                print('loading model from {}'.format(checkpoint_path))
                parameters = model.features.parameters()

        opt = optim.Adam(parameters, lr=meta_lr)
        return (maml, features, opt, parameters)

    maml, features, opt, parameters = _create_model(
        mode, ways, backbone, dataset, channel_size, meta_lr, fast_lr, device)
    if mode in ['proto', 'proto_semi']:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            opt, step_size=2000, gamma=0.5)
    else:
        lr_scheduler = None

    valid_best = 0.
    iteration_best = 0
    valid_avg_best = 0
    iteration_avg_best = 1
    best_test_acc = 0.0
    best_test_confi = 0.
    meta_valid_accuracies = []
    dpp_query_func = None
    gmm_query_func = None

    folder = f'checkpoints/{exp_name}'
    print('folder', folder)
    if not os.path.exists(folder):
        os.makedirs(folder)

    if checkpoint_path is not None:
        PATH = os.path.join(folder, checkpoint_path)

    # save a config
    with open(os.path.join(folder, "config.yaml"), "w") as f:
        yaml.dump(revert_to_dict(config), f)

    # train
    if training:
        for iteration in tqdm(range(1, num_iterations+1, 1)):
            opt.zero_grad()
            meta_train_error = 0.0
            meta_train_accuracy = 0.0
            meta_valid_error = 0.0
            meta_valid_accuracy = 0.0

            #if train_active_strategy == 'dpp_precomp' and iteration % 1 == 0:
            #    del dpp_query_func
            #    with torch.no_grad():
            #        dpp_query_func = active_dpp_precomp(
            #            maml, train_tasks.dataset, features=features, device=device)
            if train_active_strategy == 'gmm_precomp' and iteration % 1 == 0:
                del gmm_query_func
                with torch.no_grad():
                    gmm_query_func = active_gmm_precomp(
                        maml, train_tasks.dataset, features=features, device=device)

            for task in range(meta_batch_size):
                batch = train_tasks.sample()

                if train_active_strategy == 'random':
                    batch = active_random(
                        maml, train_tasks.dataset, batch, ways, features=features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=train_per_class)
                elif train_active_strategy == 'dpp':
                    batch = active_dpp(
                        maml, train_tasks.dataset, batch, ways, features=features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=train_per_class)
                #elif train_active_strategy == 'dpp_precomp':
                #    batch = dpp_query_func(
                #        train_tasks.dataset, batch, ways,
                #        train_shots=train_shots, test_shots=train_shots,
                #        per_class=train_per_class)
                elif train_active_strategy == 'gmm_precomp':
                    batch = gmm_query_func(
                        train_tasks.dataset, batch, ways,
                        train_shots=train_shots, test_shots=train_shots,
                        per_class=train_per_class)
                #elif train_active_strategy == 'gmm':
                #    batch = active_gmm(
                #        maml, train_tasks.dataset, batch, ways, features=features,
                #        train_shots=train_shots, test_shots=train_shots, device=device,
                #        per_class=train_per_class)
                elif train_active_strategy == 'prob_cover':
                    batch = active_prob_cover(
                        maml, train_tasks.dataset, batch, ways, features=features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=train_per_class, p=threshold, mode='train', dataset_name=dataset)
                elif train_active_strategy == 'typiclust':
                    batch = active_typiclust(
                        maml, train_tasks.dataset, batch, ways, features=features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=train_per_class)
                elif train_active_strategy == 'coreset':
                    batch = active_coreset(
                        maml, train_tasks.dataset, batch, ways, features=features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=train_per_class)
                elif train_active_strategy == 'margin':
                    try:
                        classifier = maml.classifier
                    except:
                        classifier = maml
                    batch = active_margin(
                        classifier, train_tasks.dataset, batch, ways, features=features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=train_per_class)
                elif train_active_strategy == 'entropy':
                    try:
                        classifier = maml.classifier
                    except:
                        classifier = maml
                    batch = active_ent(
                        classifier, train_tasks.dataset, batch, ways, features=features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=train_per_class)
                else:
                    raise NotImplementedError(f'Mode: {mode} has not been implemented')


                if mode == 'maml':
                    learner = maml.clone()
                    evaluation_error, evaluation_accuracy = adapt_maml(
                        batch, learner, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device)
                    evaluation_error.backward()
                elif mode == 'anil':
                    learner = maml.clone()
                    evaluation_error, evaluation_accuracy = adapt_anil(
                        batch, learner, features, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device)
                    evaluation_error.backward()
                elif mode == 'ntk':
                    evaluation_error, evaluation_accuracy = adapt_ntk(
                        batch, maml, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device, time=0.,
                        kernel_batch_size=kernel_batch_size)
                elif mode == 'antk':
                    evaluation_error, evaluation_accuracy = adapt_antk(
                        batch, maml, features, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device, time=0.,
                        kernel_batch_size=kernel_batch_size)
                    evaluation_error.backward()
                elif mode == 'lwantk':
                    evaluation_error, evaluation_accuracy = adapt_lwantk(
                        batch, maml, features, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device, time=0.,
                        kernel_batch_size=kernel_batch_size, dataset=train_tasks.dataset)
                    evaluation_error.backward()
                elif mode == 'kfda':
                    evaluation_error, evaluation_accuracy = adapt_kfda(
                        batch, maml, features, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device, time=0.,
                        kernel_batch_size=kernel_batch_size)
                    evaluation_error.backward()
                elif mode == 'metaopt':
                    evaluation_error, evaluation_accuracy = adapt_metaopt(
                        batch, maml, features, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device, time=0.,
                        kernel_batch_size=kernel_batch_size, dataset=train_tasks.dataset)
                    evaluation_error.backward()
                elif mode == 'proto':
                    features.train()
                    evaluation_error, evaluation_accuracy = adapt_proto(
                        batch, maml, features, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=train_ways, device=device, time=0.,
                        kernel_batch_size=kernel_batch_size, dataset=train_tasks.dataset)
                    opt.zero_grad()
                    evaluation_error.backward()
                    opt.step()
                elif mode == 'proto_semi':
                    evaluation_error, evaluation_accuracy = adapt_proto_semi(
                        batch, maml, features, loss, adaptation_steps,
                        train_shots=train_shots, test_shots=train_shots,
                        ways=ways, device=device, time=0.,
                        kernel_batch_size=kernel_batch_size, dataset=train_tasks.dataset)
                    opt.zero_grad()
                    evaluation_error.backward()
                    opt.step()
                else:
                    raise NotImplementedError(f'Mode: {mode} has not been implemented')

                meta_train_error += evaluation_error
                meta_train_accuracy += evaluation_accuracy.item()

            # Average the accumulated gradients and optimize
            if mode in ['proto', 'proto_semi']:
                lr_scheduler.step()
            else:
                if mode in ['ntk']:
                    meta_train_error.backward()
                else:
                    for p in parameters:
                        p.grad.data.mul_(1.0 / meta_batch_size)
                opt.step()

            # Print some metrics
            if iteration % 100 == 0:
                print('\n', flush=True)
                print('Iteration', iteration, flush=True)
                print('Meta Train Error', meta_train_error / meta_batch_size, flush=True)
                print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size, flush=True)

                run.log({'iteration': iteration,
                         'meta_train_err': meta_train_error / meta_batch_size,
                         'meta_train_acc': meta_train_accuracy / meta_batch_size})

            if iteration % 100 == 0:
                for task in range(val_meta_batch_size):
                    # Compute meta-validation loss
                    learner = maml.clone()
                    batch = valid_tasks.sample()

                    if active_strategy == 'random':
                        batch = active_random(
                            maml, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'dpp':
                        batch = active_dpp(
                            maml, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'prob_cover':
                        batch = active_prob_cover(
                            maml, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class, p=threshold, mode='val', dataset_name=dataset)
                    elif active_strategy == 'typiclust':
                        batch = active_typiclust(
                            maml, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'coreset':
                        batch = active_coreset(
                            maml, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'gmm':
                        batch = active_gmm(
                            maml, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'nic':
                        batch = active_nic(
                            maml, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'margin':
                        try:
                            classifier = maml.classifier
                        except:
                            classifier = maml
                        batch = active_margin(
                            classifier, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'entropy':
                        try:
                            classifier = maml.classifier
                        except:
                            classifier = maml
                        batch = active_ent(
                            classifier, valid_tasks.dataset, batch, ways, features=features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    else:
                        raise NotImplementedError(f'Mode: {mode} has not been implemented')

                    if mode == 'maml':
                        evaluation_error, evaluation_accuracy = adapt_maml(
                            batch, learner, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device)
                    elif mode == 'anil':
                        evaluation_error, evaluation_accuracy = adapt_anil(
                            batch, learner, features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device)
                    elif mode == 'ntk':
                        evaluation_error, evaluation_accuracy = adapt_ntk(
                            batch, learner, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size)
                    elif mode == 'antk':
                        evaluation_error, evaluation_accuracy = adapt_antk(
                            batch, learner, features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size)
                    elif mode == 'lwantk':
                        evaluation_error, evaluation_accuracy = adapt_lwantk(
                            batch, learner, features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=valid_tasks.dataset)
                    elif mode == 'kfda':
                        evaluation_error, evaluation_accuracy = adapt_kfda(
                            batch, learner, features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size)
                    elif mode == 'metaopt':
                        evaluation_error, evaluation_accuracy = adapt_metaopt(
                            batch, learner, features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=valid_tasks.dataset)
                    elif mode == 'proto':
                        features.eval()
                        evaluation_error, evaluation_accuracy = adapt_proto(
                            batch, learner, features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=valid_tasks.dataset)
                    elif mode == 'proto_semi':
                        evaluation_error, evaluation_accuracy = adapt_proto_semi(
                            batch, learner, features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=valid_tasks.dataset)
                    else:
                        raise NotImplementedError(f'Mode: {mode} has not been implemented')

                    meta_valid_error += evaluation_error.item()
                    meta_valid_accuracy += evaluation_accuracy.item()

                print('Meta Valid Error', meta_valid_error / val_meta_batch_size)
                print('Meta Valid Accuracy', meta_valid_accuracy / val_meta_batch_size)
                meta_valid_accuracies.append(meta_valid_accuracy / val_meta_batch_size)
                valid_avg = np.mean(meta_valid_accuracies)

                #if iteration > 1000 and valid_avg > valid_avg_best:
                #    valid_avg_best = valid_avg
                #    iteration_avg_best = iteration

                if meta_valid_accuracies[-1] > valid_best:
                    valid_best = meta_valid_accuracies[-1]
                    iteration_avg_best = iteration

                    PATH = os.path.join(folder, 'iter_{}.pth'.format(str(iteration_avg_best)))
                    print('save path', PATH)
                    torch.save(maml.state_dict(), PATH)
                    if mode in ['anil', 'antk', 'lwantk', 'kfda', 'proto', 'proto_semi', 'metaopt']:
                        FEATURE_PATH = os.path.join(
                            folder, 'features_iter_{}.pth'.format(str(iteration_avg_best)))
                        torch.save(features.state_dict(), FEATURE_PATH)

                    with open(os.path.join(folder, 'best_valid.yaml'), 'w') as f:
                        f.write(yaml.dump(
                            {'iteration': iteration_avg_best, 'best_avg_valid_acc': valid_avg_best}))

                print('Best Valid Accuracy', valid_avg_best)
                print('Average Valid Accuracy', valid_avg)
                run.log({'iteration': iteration,
                         'meta_valid_err': meta_valid_error / val_meta_batch_size,
                         'meta_valid_acc': meta_valid_accuracy / val_meta_batch_size,
                         'best_valid_acc': valid_avg_best})

            if iteration > 1000 and (iteration % 10000 == 0 or ((mode in ['proto', 'proto_semi'] or dataset in ['cifar_fs', 'fc100', 'tieredimagenet']) and iteration % 1000 == 0)) :
                PATH = os.path.join(folder, 'iter_{}.pth'.format(str(iteration_avg_best)))
                test_maml, test_features, _, _ = _create_model(
                    mode, ways, backbone, dataset, channel_size, meta_lr, fast_lr, device)
                test_maml.load_state_dict(torch.load(PATH))
                print('loading model from {}'.format(PATH))

                if mode in ['anil', 'antk', 'lwantk', 'kfda', 'proto', 'proto_semi']:
                    FEATURE_PATH = os.path.join(
                        folder, 'features_iter_{}.pth'.format(str(iteration_avg_best)))
                    test_features.load_state_dict(torch.load(FEATURE_PATH))
                    test_features.to(device)
                    print('loading model from {}'.format(FEATURE_PATH))

                meta_test_error = 0.
                meta_test_accuracies = []

                for task in range(test_meta_batch_size):
                    # Compute meta-testing loss
                    learner = test_maml.clone()
                    batch = test_tasks.sample()

                    if active_strategy == 'random':
                        batch = active_random(
                            test_maml, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'dpp':
                        batch = active_dpp(
                            test_maml, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'prob_cover':
                        batch = active_prob_cover(
                            test_maml, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class, p=threshold, mode='test', dataset_name=dataset)
                    elif active_strategy == 'typiclust':
                        batch = active_typiclust(
                            test_maml, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'coreset':
                        batch = active_coreset(
                            test_maml, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'nic':
                        batch = active_nic(
                            test_maml, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'gmm':
                        batch = active_gmm(
                            test_maml, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'margin':
                        try:
                            classifier = maml.classifier
                        except:
                            classifier = maml
                        batch = active_margin(
                            classifier, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    elif active_strategy == 'entropy':
                        try:
                            classifier = maml.classifier
                        except:
                            classifier = maml
                        batch = active_ent(
                            classifier, test_tasks.dataset, batch, ways, features=test_features,
                            train_shots=train_shots, test_shots=test_shots, device=device,
                            per_class=per_class)
                    else:
                        raise NotImplementedError(f'Mode: {mode} has not been implemented')

                    if mode == 'maml':
                        evaluation_error, evaluation_accuracy = adapt_maml(
                            batch, learner, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device)
                    elif mode == 'anil':
                        evaluation_error, evaluation_accuracy = adapt_anil(
                            batch, learner, test_features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device)
                    elif mode == 'ntk':
                        evaluation_error, evaluation_accuracy = adapt_ntk(
                            batch, learner, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size)
                    elif mode == 'antk':
                        evaluation_error, evaluation_accuracy = adapt_antk(
                            batch, learner, test_features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size)
                    elif mode == 'lwantk':
                        evaluation_error, evaluation_accuracy = adapt_lwantk(
                            batch, learner, test_features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
                    elif mode == 'kfda':
                        evaluation_error, evaluation_accuracy = adapt_kfda(
                            batch, learner, test_features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size)
                    elif mode == 'metaopt':
                        evaluation_error, evaluation_accuracy = adapt_metaopt(
                            batch, learner, test_features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
                    elif mode == 'proto':
                        evaluation_error, evaluation_accuracy = adapt_proto(
                            batch, learner, test_features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
                    elif mode == 'proto_semi':
                        evaluation_error, evaluation_accuracy = adapt_proto_semi(
                            batch, learner, test_features, loss, test_adaptation_steps,
                            train_shots=train_shots, test_shots=test_shots,
                            ways=ways, device=device, time=0.,
                            kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
                    else:
                        raise NotImplementedError(f'Mode: {mode} has not been implemented')

                    meta_test_error += evaluation_error.item()
                    meta_test_accuracies.append(evaluation_accuracy.item())

                meta_test_accuracy = np.mean(np.asarray(meta_test_accuracies))

                if meta_test_accuracy > best_test_acc:
                    best_test_acc = meta_test_accuracy
                    best_test_confi = 1.96 * np.std(meta_test_accuracies, 0) / np.sqrt(test_meta_batch_size * test_shots)

                    TEST_PATH = os.path.join(folder, 'best_test.pth')
                    print('save path', TEST_PATH)
                    torch.save(learner.state_dict(), TEST_PATH)

                    with open(os.path.join(folder, 'best_test.yaml'), 'w') as f:
                        f.write(yaml.dump(
                            {'iteration': iteration_avg_best,
                             'best_test_acc': best_test_acc.item(),
                             'best_test_confi': best_test_confi.item()}))

                print('Meta Test Error', meta_test_error / test_meta_batch_size)
                print('Meta Test Accuracy', meta_test_accuracy,
                      'Confidence', 1.96 * np.std(meta_test_accuracies, 0) / np.sqrt(test_meta_batch_size*test_shots))
                print('Best Meta Test Accuracy', best_test_acc, 'Best Meta Test Confidence:', best_test_confi)

                run.log({'iteration': iteration,
                           'meta_test_err': meta_test_error / test_meta_batch_size,
                           'meta_test_acc': meta_test_accuracy.item(),
                           'best_test_acc': best_test_acc.item(),
                           'best_test_confi': best_test_confi.item()})

            #if active_strategy == 'random' and iteration % 100000 == 0:
            #    train_num_classes = 10
            #    train_tsne_save_path = os.path.join(
            #        folder, f'train_tsne_df_iter{iteration}.csv')
            #    save_tsne_features(
            #        train_tasks.dataset, features, train_num_classes, device, train_tsne_save_path, step_size=1)
            #    val_num_classes = 10
            #    val_tsne_save_path = os.path.join(
            #        folder, f'val_tsne_df_iter{iteration}.csv')
            #    save_tsne_features(
            #        valid_tasks.dataset, features, val_num_classes, device, val_tsne_save_path, step_size=1)

    else:
        assert checkpoint_path is not None or checkpoint_dir is not None
        if checkpoint_path is None and checkpoint_dir is not None:
            test_result_path = os.path.join(checkpoint_dir, 'best_test.yaml')
            with open(test_result_path, 'r') as f:
                test_result = yaml.safe_load(f)
                test_iter = test_result['iteration']
            checkpoint_path = os.path.join(checkpoint_dir, f'iter_{test_iter}.pth')

        test_maml, test_features, _, _ = _create_model(
            mode, ways, backbone, dataset, channel_size, meta_lr, fast_lr, device)
        test_maml.load_state_dict(torch.load(checkpoint_path))
        print('loading model from {}'.format(checkpoint_path))

        if mode in ['anil', 'antk', 'lwantk', 'kfda', 'proto', 'proto_semi']:
            FEATURE_PATH = checkpoint_path.replace('/iter', '/features_iter')
            test_features.load_state_dict(torch.load(FEATURE_PATH))
            test_features.to(device)

        meta_test_error = 0.
        meta_test_accuracies = []
        entropies = []
        distinct_class_hist = np.zeros(ways)
        for task in range(test_meta_batch_size):
            # Compute meta-testing loss
            learner = test_maml.clone()
            batch = test_tasks.sample()
            orig_batch = (batch[0].clone(), batch[1].clone(), batch[2].clone())

            if active_strategy == 'random':
                batch = active_random(
                    test_maml, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class)
            elif active_strategy == 'dpp':
                batch = active_dpp(
                    test_maml, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class)
            elif active_strategy == 'prob_cover':
                batch = active_prob_cover(
                    test_maml, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class, p=threshold)
            elif active_strategy == 'typiclust':
                batch = active_typiclust(
                    test_maml, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class)
            elif active_strategy == 'coreset':
                batch = active_coreset(
                    test_maml, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class)
            elif active_strategy == 'gmm':
                batch = active_gmm(
                    test_maml, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class)
            elif active_strategy == 'margin':
                try:
                    classifier = test_maml.classifier
                except:
                    classifier = test_maml
                batch = active_margin(
                    classifier, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class)
            elif active_strategy == 'entropy':
                try:
                    classifier = test_maml.classifier
                except:
                    classifier = test_maml
                batch = active_ent(
                    classifier, test_tasks.dataset, batch, ways, features=test_features,
                    train_shots=train_shots, test_shots=test_shots, device=device,
                    per_class=per_class)
            else:
                raise NotImplementedError(f'Mode: {mode} has not been implemented')

            if mode == 'maml':
                evaluation_error, evaluation_accuracy = adapt_maml(
                    batch, learner, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device)
            elif mode == 'anil':
                evaluation_error, evaluation_accuracy = adapt_anil(
                    batch, learner, test_features, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device)
            elif mode == 'ntk':
                evaluation_error, evaluation_accuracy = adapt_ntk(
                    batch, learner, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device, time=0.,
                    kernel_batch_size=kernel_batch_size)
            elif mode == 'antk':
                evaluation_error, evaluation_accuracy = adapt_antk(
                    batch, learner, test_features, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device, time=0.,
                    kernel_batch_size=kernel_batch_size)
            elif mode == 'lwantk':
                evaluation_error, evaluation_accuracy = adapt_lwantk(
                    batch, learner, test_features, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device, time=0.,
                    kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
            elif mode == 'kfda':
                evaluation_error, evaluation_accuracy = adapt_kfda(
                    batch, learner, test_features, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device, time=0.,
                    kernel_batch_size=kernel_batch_size)
            elif mode == 'metaopt':
                evaluation_error, evaluation_accuracy = adapt_metaopt(
                    batch, learner, test_features, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device, time=0.,
                    kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
            elif mode == 'proto':
                evaluation_error, evaluation_accuracy = adapt_proto(
                    batch, learner, test_features, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device, time=0.,
                    kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
            elif mode == 'proto_semi':
                evaluation_error, evaluation_accuracy = adapt_proto_semi(
                    batch, learner, test_features, loss, test_adaptation_steps,
                    train_shots=train_shots, test_shots=test_shots,
                    ways=ways, device=device, time=0.,
                    kernel_batch_size=kernel_batch_size, dataset=test_tasks.dataset)
            else:
                raise NotImplementedError(f'Mode: {mode} has not been implemented')

            meta_test_error += evaluation_error.item()
            meta_test_accuracies.append(evaluation_accuracy.item())

            entropy, num_distinct_classes = compute_entropy(batch, ways, train_shots, test_shots)
            entropies.append(entropy)
            distinct_class_hist[num_distinct_classes-1] += 1

        meta_test_accuracy = np.mean(np.asarray(meta_test_accuracies))
        meta_test_confidence = 1.96 * np.std(meta_test_accuracies, 0) / np.sqrt(test_meta_batch_size*test_shots)
        print('Meta Test Error', meta_test_error / test_meta_batch_size)
        print('Meta Test Accuracy', meta_test_accuracy,
              'Confidence', meta_test_confidence)

        run.log({'meta_test_err': meta_test_error / test_meta_batch_size,
                 'meta_test_acc': meta_test_accuracy.item()})

        if save_al_comp:
            new_row = pd.DataFrame({
                'active_strategy': [active_strategy],
                'per_class': [per_class],
                'radius': [threshold],
                'meta_test_accuracy': [meta_test_accuracy],
                'meta_test_confidence': [meta_test_confidence]})

            df_path = os.path.join(checkpoint_dir, f'al_comp.csv')
            if os.path.exists(df_path):
                df = pd.read_csv(df_path)
                df = pd.concat([df, new_row], ignore_index=True)
            else:
                df = new_row
            df.to_csv(df_path, index=False)


        # save prob cover performance for different radius
        if save_prob_cover:
            new_row = pd.DataFrame({
                'active_strategy': [active_strategy],
                'per_class': [per_class],
                'radius': [threshold],
                'meta_test_accuracy': [meta_test_accuracy],
                'meta_test_confidence': [meta_test_confidence]})

            df_path = os.path.join(checkpoint_dir, f'best_test_only.csv')
            if os.path.exists(df_path):
                df = pd.read_csv(df_path)
                df = pd.concat([df, new_row], ignore_index=True)
            else:
                df = new_row
            df.to_csv(df_path, index=False)

        # report avgerage empirical entropy of queries
        if save_entropy:
            entropies = torch.tensor(entropies)
            mean_entropy = torch.mean(entropies)
            print(f'Active learning: {active_strategy}, Entropy: {mean_entropy}')
            new_row = pd.DataFrame({
                'active_strategy': [active_strategy],
                'per_class': [per_class],
                'mean_entropy': [mean_entropy.item()]})

            df_path = os.path.join(checkpoint_dir, f'entropy_seed{seed}.csv')
            if os.path.exists(df_path):
                df = pd.read_csv(df_path)
                df = pd.concat([df, new_row], ignore_index=True)
            else:
                df = new_row
            df.to_csv(df_path, index=False)

            distinct_class_hist_path = os.path.join(
                checkpoint_dir, f'distinct_class_hist_al_{active_strategy}_seed{seed}.pkl')
            with open(distinct_class_hist_path, 'wb') as f:
                pickle.dump(distinct_class_hist, f)


        if save_tsne:
            data_type = 'test'
            tasks = train_tasks if data_type == 'train' else test_tasks
            for i in range(seed):
                orig_batch = tasks.sample()

            al_batch_dict = {}
            #active_strategies = ['dpp', 'coreset', 'typiclust', 'prob_cover', 'gmm']
            active_strategies = ['gmm']
            #active_strategies = ['coreset', 'gmm']
            for active_strategy in active_strategies:
                batch = (orig_batch[0].clone(), orig_batch[1].clone(), orig_batch[2].clone())
                if active_strategy == 'random':
                    al_batch = active_random(
                        test_maml, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class)
                elif active_strategy == 'dpp':
                    al_batch = active_dpp(
                        test_maml, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class)
                elif active_strategy == 'prob_cover':
                    al_batch = active_prob_cover(
                        test_maml, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class, p=threshold)
                elif active_strategy == 'typiclust':
                    al_batch = active_typiclust(
                        test_maml, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class)
                elif active_strategy == 'coreset':
                    al_batch = active_coreset(
                        test_maml, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class)
                elif active_strategy == 'gmm':
                    al_batch = active_gmm(
                        test_maml, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class)
                elif active_strategy == 'margin':
                    try:
                        classifier = test_maml.classifier
                    except:
                        classifier = test_maml
                    al_batch = active_margin(
                        classifier, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class)
                elif active_strategy == 'entropy':
                    try:
                        classifier = test_maml.classifier
                    except:
                        classifier = test_maml
                    al_batch = active_ent(
                        classifier, tasks.dataset, batch, ways, features=test_features,
                        train_shots=train_shots, test_shots=train_shots, device=device,
                        per_class=per_class)
                else:
                    raise NotImplementedError(f'Mode: {mode} has not been implemented')

                al_batch_dict[active_strategy] = (al_batch[0].clone(), al_batch[1].clone(), al_batch[2].clone())

            tsne_save_path = os.path.join(checkpoint_dir, f'{data_type}_tsne_df.pkl')
            save_tsne_batch_features(
                tasks.dataset, orig_batch, al_batch_dict, test_features, ways, train_shots, train_shots,
                device, tsne_save_path, cand_ratio=1.0)


            #test_tsne_save_path = os.path.join(checkpoint_dir, f'test_tsne_df_{active_strategy}2.pkl')
            #save_tsne_batch_features(
            #    test_tasks.dataset, orig_batch, batch, features, ways, train_shots, test_shots,
            #    device, test_tsne_save_path, cand_ratio=1.0)

            #orig_batch=batch=train_tasks.sample()
            #learner = maml.clone()
            #evaluation_error, evaluation_accuracy = adapt_anil(
            #    batch, learner, features, loss, adaptation_steps,
            #    train_shots=train_shots, test_shots=train_shots,
            #    ways=ways, device=device)
            #evaluation_error.backward()
            #import IPython; IPython.embed()
            #opt.step()

            #train_num_classes = 10
            #train_tsne_save_path = os.path.join(checkpoint_dir, f'train_tsne_df.csv')
            #save_tsne_dataset_features(
            #    train_tasks.dataset, features, train_num_classes, device, train_tsne_save_path, step_size=1)

            #val_num_classes = 10
            #val_tsne_save_path = os.path.join(
            #    folder, f'val_tsne_df_iter{iteration}.csv')
            #save_tsne_features(
            #    valid_tasks.dataset, features, val_num_classes, device, val_tsne_save_path, step_size=1)

    # finish a run
    run.finish()
    return (maml, features, train_dataset)


@hydra.main(version_base=None, config_path="configs", config_name="template")
def run_experiments(cfg : DictConfig) -> None:
    cfg = OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True)
    config = Config(dict(cfg)) # wrap dict with mmcv config

    sweep_name = config.get('sweep_name', None)
    if sweep_name:
        sweep_path = os.path.join(f'configs/sweep/{sweep_name}.yaml')
        with open(sweep_path, 'r') as f:
            sweep_config = yaml.safe_load(f)
        list_sweep_conig = split_dict(sweep_config)
        for i, sweep_config in enumerate(list_sweep_conig):
            updated_config = Config(dict(config))
            updated_config.merge_from_dict(sweep_config)
            updated_config = DotDict(dict(updated_config)) # # unwrap mmcv config to dotdict
            main(updated_config, ith_sweep=i+1, num_sweeps=len(list_sweep_conig))
    else:
        config = DotDict(dict(config)) # # unwrap mmcv config to dotdict
        main(config, ith_sweep=1, num_sweeps=1)

if __name__ == '__main__':
    run_experiments()

