import torch

from models.wrapper import get_model
from utils import InfIterator, accuracy, share_params, get_optimizer

def run(rank, model, meta_ds, hyper_opt, device, criterion, logger, FLAGS, step):
    # Get a task
    train_loader, test_loader = meta_ds.get_task(
        rank,
        num_classes=FLAGS.num_classes,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.num_workers
    )
    train_iter = InfIterator(train_loader)
    test_iter = InfIterator(test_loader)

    # Get task learner & optimizer
    task_learner = get_model(FLAGS, track_bn_stats=True).to(device)
    task_learner.load_state_dict(model.state_dict(), strict=True)

    # hyper optimizer
    task_hyper_opt = get_optimizer(
        FLAGS.hyper_opt,
        task_learner.get_hyper_params(),
        FLAGS.hyper_lr,
        momentum=FLAGS.hyper_momentum,
        weight_decay=FLAGS.hyper_weight_decay,
        nesterov=FLAGS.hyper_nesterov,
    )
    task_hyper_opt.load_state_dict(hyper_opt.state_dict())

    velocity = [ torch.zeros_like(w) for w in task_learner.get_adapt_params() ]

    if FLAGS.model == "sparsify" and rank == 0:
        count_list = []
        v = []
        for l1norm in task_learner.hyper_layers:
            mask = l1norm.get_mask(training=False)
            v.append(mask)
            count = 0
            for i in range(mask.shape[0]):
                if mask[i].item() < 1e-3:
                    count += 1
            count_list.append(count / mask.shape[0])
        v = torch.cat(v)
        print("min", v.min().item())
        print("max", v.max().item())
        print("mean", v.mean().item())
        print("sparsity", count_list)
        print("reg_term", task_learner.reg_term())
        logger.meter("meta_train", "sparsity", torch.tensor(sum(count_list)/5))

    # Train on task
    for k in range(FLAGS.episode_train_steps):

        # compute inner-grad
        x_tr, y_tr, _ = next(train_iter)
        x_tr, y_tr = x_tr.to(device), y_tr.to(device)
        task_learner.train()
        loss_tr = criterion(task_learner(x_tr), y_tr)

        # update part
        grad_tr = torch.autograd.grad(loss_tr, task_learner.get_adapt_params())
        for v, g, w in zip(velocity, grad_tr, task_learner.get_adapt_params()):
            v.data.mul_(FLAGS.adapt_momentum).add_(g.data)
            w.data.add_(v.data, alpha=-FLAGS.adapt_inner_lr)

        if k < int(FLAGS.episode_train_steps / 2):
            continue

        # compute first-order meta-grad
        x_te, y_te, _ = next(test_iter)
        x_te, y_te = x_te.to(device), y_te.to(device)
        task_learner.eval()
        loss_te = criterion(task_learner(x_te), y_te)
        if FLAGS.model == "sparsify":
            loss_te += task_learner.reg_term()
        meta_grad = torch.autograd.grad(loss_te, task_learner.get_hyper_params())

        # average the meta-grad over multiple tasks
        share_params(meta_grad)

        # perform the meta-grad step
        task_hyper_opt.zero_grad()
        for w, g in zip(task_learner.get_hyper_params(), meta_grad):
            hyper_grad = torch.clamp(g, -3, 3)
            if FLAGS.lr_decay:
                hyper_grad *= 1.-((step-1) / FLAGS.train_steps)
            if w.grad is None:
                w.grad = hyper_grad
            else:
                w.grad.copy_(hyper_grad)
        task_hyper_opt.step()

    # update "model" for the next episode
    # hyper_params
    for w, fw in zip(model.get_hyper_params(), task_learner.get_hyper_params()):
        w.data.copy_(fw.data)

    # hyper opt
    hyper_opt.load_state_dict(task_hyper_opt.state_dict())

    # adapt_params
    reptile_grad = [torch.zeros_like(w) for w in model.get_adapt_params()]
    for rg, w, fw in zip(reptile_grad,
            model.get_adapt_params(), task_learner.get_adapt_params()):
        rg.data.copy_(w.data - fw.data)
    share_params(reptile_grad)
    for w, rg in zip(model.get_adapt_params(), reptile_grad):
        w.data.add_(rg.data, alpha=-FLAGS.adapt_outer_lr)
