import numpy as np

import torch
import torch.nn as nn

from utils import compute_logits
from tqdm import trange

def run(args, num_episodes, encoder, decoder, meta_ds, device):
    transform = args.transform
    encoder.eval(), decoder.eval()
    accuracies = []

    with torch.no_grad():
        for _ in trange(num_episodes):
            # data
            batch = meta_ds.get_task(args.way, args.support, args.query)
            x_tr, y_tr, x_te, y_te = (b.to(device) for b in batch)
            x_tr, x_te = transform(x_tr, test=True), transform(x_te, test=True)

            # get embedding and prototype        
            x_tr, x_te = encoder(x_tr), encoder(x_te)
            index = torch.sort(y_tr, descending=False)[1]
            prototypes = x_tr[index].view(args.way, args.support, x_tr.shape[-1])
            prototypes = decoder(prototypes)

            # prediction
            y_te_pred = compute_logits(x_te, prototypes, alpha=1.0).argmax(dim=-1)
            accuracy = torch.mean(torch.eq(y_te_pred, y_te).float())
            accuracy = accuracy.detach().cpu().numpy()*100
            accuracies.append(accuracy)
    
    mean = np.mean(accuracies)
    ci = 1.96*np.std(accuracies)/float(np.sqrt(num_episodes))
    return mean, ci      

