import torch
import torch.nn as nn

from models.wrapper import get_model
from utils import InfIterator, accuracy

def run(rank, model, meta_ds, device, criterion, logger, FLAGS, meta_level):
    # Get a task
    train_loader, test_loader = meta_ds.get_task(
        rank,
        num_classes=FLAGS.num_classes,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.num_workers
    )
    train_iter = InfIterator(train_loader)

    # Get task learner & optimizer
    task_learner = get_model(FLAGS, track_bn_stats=True).to(device)
    task_learner.load_state_dict(model.state_dict(), strict=True)

    velocity = [ torch.zeros_like(w) for w in task_learner.get_adapt_params() ]

    batchsz = FLAGS.batch_size

    # Train on task
    task_learner.train()
    for _ in range(FLAGS.episode_train_steps):
        x, y, _ = next(train_iter)
        x, y = x.to(device), y.to(device)
        if FLAGS.model == "mwn":
            N_corruption = int(batchsz * 0.4)
            yrand = torch.randint(0, FLAGS.num_classes, [N_corruption]).to(device)
            y = torch.cat([yrand, y[N_corruption:]]).to(device)

        # compute inner-grad
        if FLAGS.model == "sparsify":
            y_pred = task_learner(x, hardmask=True)
        elif FLAGS.model == "nas":
            y_pred = task_learner(x, fix_architecture=True)
        else:
            y_pred = task_learner(x)

        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss = criterion_(task_learner(x), y)
            loss = (task_learner.lwm(loss) * loss).mean()
        else:
            loss = criterion(task_learner(x), y)

        grad = torch.autograd.grad(loss, task_learner.get_adapt_params())

        # perform sgd step with momentum
        for v, g, w in zip(velocity, grad, task_learner.get_adapt_params()):
            v.data.mul_(FLAGS.adapt_momentum).add_(g.data)
            w.data.add_(v.data, alpha=-FLAGS.adapt_inner_lr)

    logger.meter(meta_level, "loss_tr", loss)
    logger.meter(meta_level, "accuracy_tr", accuracy(y_pred, y))

    # Evaluate
    task_learner.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y, _ in test_loader:
            x, y = x.to(device), y.to(device)

            if FLAGS.model == "sparsify":
                y_pred = task_learner(x, hardmask=True)
            elif FLAGS.model == "nas":
                y_pred = task_learner(x, fix_architecture=True)
            else:
                y_pred = task_learner(x)
            loss = criterion(y_pred, y)
            if FLAGS.model == "sparsify":
                loss += task_learner.reg_term()
            logger.meter(meta_level, "loss_te", loss)

            pred = torch.max(y_pred, dim=1)
            correct += pred[1].eq(y).sum()
            total += y.size(0)
    logger.meter(meta_level, "accuracy_te", 1.0 * correct / total)
