import numpy as np
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn

from tqdm import trange

def run(args, num_episodes, encoder, decoder, meta_ds, device):
    transform = args.transform
    encoder.eval(), decoder.eval()
    accuracies = []

    with torch.no_grad():
        for _ in trange(num_episodes):
            # data
            batch = meta_ds.get_task(args.way, args.support, args.query)
            x_tr, y_tr, x_te, y_te = (b.to(device) for b in batch)
            x_tr, x_te = transform(x_tr, test=True), transform(x_te, test=True)

            # get embedding and prototype        
            x_tr, x_te = encoder(x_tr), encoder(x_te)
            index = torch.sort(y_tr, descending=False)[1]
            prototypes = x_tr[index].view(args.way, args.support, x_tr.shape[-1])
            prototypes = decoder(prototypes)
            x_tr, y_tr, x_te, y_te, prototypes = (b.detach().cpu().numpy() for b in (x_tr, y_tr, x_te, y_te, prototypes))

            # learn clf
            clf = LogisticRegression(verbose=0, warm_start=True, fit_intercept=False)#, max_iter=10)
            clf.coef_ = prototypes
            clf.fit(x_tr, y_tr)

            # prediction            
            y_te_pred = clf.predict(x_te)
            accuracy = np.mean(np.equal(y_te_pred, y_te).astype(float))*100
            accuracies.append(accuracy)
    
    mean = np.mean(accuracies)
    ci = 1.96*np.std(accuracies)/float(np.sqrt(num_episodes))
    return mean, ci      

