import numpy as np
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn

from tqdm import trange

def run(args, num_episodes, encoder, decoder, meta_ds, device):
    transform = args.transform
    encoder.eval(), decoder.eval()
    accuracies = []

    with torch.no_grad():
        for _ in trange(num_episodes):
            # data
            batch = meta_ds.get_task(args.way, args.support, args.query)
            x_tr, y_tr, x_te, y_te = (b.to(device) for b in batch)
            x_tr, x_te = transform(x_tr, test=True), transform(x_te, test=True)

            # get embedding and prototype        
            x_tr, x_te = encoder(x_tr), encoder(x_te)
            x_tr, y_tr, x_te, y_te = (b.detach().cpu().numpy() for b in (x_tr, y_tr, x_te, y_te))
                
            # learn clf
            clf = LogisticRegression(verbose=0)
            clf.fit(x_tr, y_tr)
            
            # prediction            
            y_te_pred = clf.predict(x_te)
            accuracy = np.mean(np.equal(y_te_pred, y_te).astype(float))*100
            accuracies.append(accuracy)
    
    mean = np.mean(accuracies)
    ci = 1.96*np.std(accuracies)/float(np.sqrt(num_episodes))
    return mean, ci      

