import torch
import torch.nn as nn
import torch.nn.functional as F

from models.wrapper import get_model
from utils import InfIterator, accuracy, share_params, get_optimizer

def run(rank, model, meta_ds, hyper_opt, device, criterion, logger, FLAGS, step):
    # Get a task
    train_loader, test_loader = meta_ds.get_task(
        rank,
        num_classes=FLAGS.num_classes,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.num_workers
    )
    train_iter = InfIterator(train_loader)
    test_iter = InfIterator(test_loader)

    # Get task learner & optimizer
    task_learner = get_model(FLAGS, track_bn_stats=True).to(device)
    fmodel = get_model(FLAGS, track_bn_stats=True).to(device)

    task_learner.load_state_dict(model.state_dict(), strict=True)

    task_hyper_opt = get_optimizer(
        FLAGS.hyper_opt,
        task_learner.get_hyper_params(),
        FLAGS.hyper_lr,
        momentum=FLAGS.hyper_momentum,
        weight_decay=FLAGS.hyper_weight_decay,
        nesterov=FLAGS.hyper_nesterov,
    )
    task_hyper_opt.load_state_dict(hyper_opt.state_dict())

    velocity = [ torch.zeros_like(w) for w in task_learner.get_adapt_params() ]

    if FLAGS.model == "sparsify" and rank == 0:
        count_list = []
        v = []
        for l1norm in task_learner.hyper_layers:
            mask = l1norm.get_mask(training=False)
            v.append(mask)
            count = 0
            for i in range(mask.shape[0]):
                if mask[i].item() < 1e-3:
                    count += 1
            count_list.append(count / mask.shape[0])
        v = torch.cat(v)
        print("min", v.min().item())
        print("max", v.max().item())
        print("mean", v.mean().item())
        print("sparsity", count_list)
        print("reg_term", task_learner.reg_term())
        logger.meter("meta_train", "sparsity", torch.tensor(sum(count_list)/5))

    batchsz = FLAGS.batch_size

    # Train on task
    for k in range(FLAGS.episode_train_steps):

        # perform sgd step with momentum
        x_tr, y_tr, _ = next(train_iter)
        x_tr, y_tr = x_tr.to(device), y_tr.to(device)
        # corrupt here
        if FLAGS.model == "mwn":
            N_corruption = int(batchsz * 0.4)
            yrand = torch.randint(0, FLAGS.num_classes, [N_corruption]).to(device)
            y_tr = torch.cat([yrand, y_tr[N_corruption:]]).to(device)

        task_learner.train()
        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss_tr = criterion_(task_learner(x_tr), y_tr)
            loss_tr = (task_learner.lwm(loss_tr) * loss_tr).mean()
        else:
            loss_tr = criterion(task_learner(x_tr), y_tr)

        # update part
        grad_tr = torch.autograd.grad(
                loss_tr, task_learner.get_adapt_params(), create_graph=True)
        fmodel.load_state_dict(task_learner.state_dict(), strict=True)
        for v, g, w in zip(velocity, grad_tr, fmodel.get_adapt_params()):
            v.data.mul_(FLAGS.adapt_momentum).add_(g.data)
            w.data.add_(v.data, alpha=-FLAGS.adapt_inner_lr)

        if k < 50:
            for w, fw in zip(task_learner.get_adapt_params(), fmodel.get_adapt_params()):
                w.data.copy_(fw.data)
            continue

        # compute first-order meta-grad
        x_te, y_te, _ = next(test_iter)
        x_te, y_te = x_te.to(device), y_te.to(device)
        fmodel.eval()
        loss_te = criterion(fmodel(x_te), y_te)
        if FLAGS.model == "sparsify":
            loss_te += fmodel.reg_term()
        if FLAGS.model == "mwn":
            meta_grad = [torch.zeros_like(w) for w in fmodel.get_hyper_params()]
        else:
            meta_grad = torch.autograd.grad(
                    loss_te, fmodel.get_hyper_params(), retain_graph=True)

        # compute second-order meta-grad
        alpha = torch.autograd.grad(loss_te, fmodel.get_adapt_params())
        second_order = torch.autograd.grad(
                grad_tr, task_learner.get_hyper_params(), grad_outputs=alpha)
        for mg, so in zip(meta_grad, second_order):
            mg.add_(so, alpha=-FLAGS.adapt_inner_lr)

        # average the meta-grad over multiple tasks
        share_params(meta_grad)

        # perform the meta-grad step
        task_hyper_opt.zero_grad()
        for w, g in zip(task_learner.get_hyper_params(), meta_grad):
            hyper_grad = torch.clamp(g, -3, 3)
            if FLAGS.lr_decay:
                hyper_grad *= 1.-((step-1) / FLAGS.train_steps)
            if w.grad is None:
                w.grad = hyper_grad
            else:
                w.grad.copy_(hyper_grad)
        task_hyper_opt.step()

        # update task_learner to fmodel
        for w, fw in zip(task_learner.get_adapt_params(), fmodel.get_adapt_params()):
            w.data.copy_(fw.data)

    # update "model" for the next episode
    # hyper_params
    for w, fw in zip(model.get_hyper_params(), task_learner.get_hyper_params()):
        w.data.copy_(fw.data)

    # hyper opt
    hyper_opt.load_state_dict(task_hyper_opt.state_dict())

    # adapt_params
    reptile_grad = [torch.zeros_like(w) for w in model.get_adapt_params()]
    for rg, w, fw in zip(reptile_grad,
            model.get_adapt_params(), task_learner.get_adapt_params()):
        rg.data.copy_(w.data - fw.data)
    share_params(reptile_grad)
    for w, rg in zip(model.get_adapt_params(), reptile_grad):
        w.data.add_(rg.data, alpha=-FLAGS.adapt_outer_lr)
