import torch
import numpy as np
import torch.nn as nn

from models.wrapper import get_model
from utils import InfIterator, accuracy, share_params, get_optimizer

def run(rank, model, meta_ds, hyper_opt, device, criterion, logger, FLAGS, step):

    # Get a task
    train_loader, test_loader = meta_ds.get_task(
        rank,
        num_classes=FLAGS.num_classes,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.num_workers
    )
    train_iter = InfIterator(train_loader)
    test_iter = InfIterator(test_loader)

    # fmodel
    fmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    gmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    hmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    fmodel.load_state_dict(model.state_dict(), strict=True)

    task_hyper_opt = get_optimizer(
        FLAGS.hyper_opt,
        fmodel.get_hyper_params(),
        FLAGS.hyper_lr,
        momentum=FLAGS.hyper_momentum,
        weight_decay=FLAGS.hyper_weight_decay,
        nesterov=FLAGS.hyper_nesterov,
    )
    task_hyper_opt.load_state_dict(hyper_opt.state_dict())

    velocity = [ torch.zeros_like(w) for w in fmodel.get_adapt_params() ]

    if FLAGS.model == "sparsify" and rank == 0:
        count_list = []
        v = []
        for l1norm in fmodel.hyper_layers:
            mask = l1norm.get_mask(training=False)
            v.append(mask)
            count = 0
            for i in range(mask.shape[0]):
                if mask[i].item() < 1e-3:
                    count += 1
            count_list.append(count / mask.shape[0])
        v = torch.cat(v)
        print("min", v.min().item())
        print("max", v.max().item())
        print("mean", v.mean().item())
        print("sparsity", count_list)
        print("reg_term", fmodel.reg_term())
        logger.meter("meta_train", "sparsity", torch.tensor(sum(count_list)/5))

    # handling anchors
    K = FLAGS.episode_train_steps
    anchor = 0
    batchsz = FLAGS.batch_size

    idxs = []
    y_tr_corrupted = []
    # full grad step with fmodel
    for k in range(K):
        if k == anchor:
            gmodel.load_state_dict(fmodel.state_dict())

        x_tr, y_tr, idx = next(train_iter)
        x_tr, y_tr = x_tr.to(device), y_tr.to(device)
        idxs.append(idx)
        # corrupt here
        if FLAGS.model == "mwn":
            N_corruption = int(batchsz * 0.4)
            yrand = torch.randint(0, FLAGS.num_classes, [N_corruption]).to(device)
            y_tr = torch.cat([yrand, y_tr[N_corruption:]]).to(device)
            y_tr_corrupted.append(y_tr)

        fmodel.train()
        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss_tr = criterion_(fmodel(x_tr), y_tr)
            loss_tr = (fmodel.lwm(loss_tr) * loss_tr).mean()
        else:
            loss_tr = criterion(fmodel(x_tr), y_tr)

        grad_tr = torch.autograd.grad(loss_tr, fmodel.get_adapt_params())
        for v, g, w in zip(velocity, grad_tr, fmodel.get_adapt_params()):
            v.data.mul_(FLAGS.adapt_momentum).add_(g.data)
            w.data.add_(v.data, alpha=-FLAGS.adapt_inner_lr)

    hmodel.load_state_dict(fmodel.state_dict(), strict=True)

    # compute the direct grad
    alpha = [torch.zeros_like(w) for w in hmodel.get_adapt_params()]
    mgrad = [torch.zeros_like(w) for w in hmodel.get_hyper_params()]

    hmodel.eval()
    num_test_iter = K
    for _ in range(num_test_iter):
        x_te, y_te, _ = next(test_iter)
        x_te, y_te = x_te.to(device), y_te.to(device)

        loss_te = criterion(hmodel(x_te), y_te)
        if FLAGS.model == "sparsify":
            loss_te += hmodel.reg_term()

        alpha_curr = torch.autograd.grad(
                loss_te, hmodel.get_adapt_params(), retain_graph=True)
        alpha = [alp.add_(alpc, alpha=1./float(num_test_iter)) for alp, alpc \
                in zip(alpha, alpha_curr)]

        if FLAGS.model != "mwn":
            mgrad_curr = torch.autograd.grad(loss_te, hmodel.get_hyper_params())
            mgrad = [mgr.add_(mgrc, alpha=1./float(num_test_iter)) for mgr, mgrc \
                    in zip(mgrad, mgrad_curr)]

    # reverse mode diff through the approximated trajectory
    hmodel.train()
    for k in range(K-1, anchor-1, -1):
        for hap, gap in zip(hmodel.get_adapt_params(), gmodel.get_adapt_params()):
            hap.data.add_(gap.data - hap.data, alpha=1./float(k-anchor+1))

        x_tr, y_tr, _ = meta_ds.get_data_by_index(idxs[k], train=True)
        x_tr = x_tr.to(device)
        if FLAGS.model == "mwn":
            y_tr = y_tr_corrupted[k]
        else:
            y_tr = y_tr.type(torch.LongTensor).view(-1).to(device)

        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss_tr = criterion_(hmodel(x_tr), y_tr)
            loss_tr = (hmodel.lwm(loss_tr) * loss_tr).mean()
        else:
            loss_tr = criterion(hmodel(x_tr), y_tr)

        grad_tr = torch.autograd.grad(
                loss_tr, hmodel.get_adapt_params(), create_graph=True)

        alpha_B = torch.autograd.grad(
                grad_tr, hmodel.get_hyper_params(), grad_outputs=alpha,
                retain_graph=True)
        for mgr, alpB in zip(mgrad, alpha_B):
            mgr.add_(alpB, alpha=-FLAGS.adapt_inner_lr)

        alpha_H = torch.autograd.grad(
                grad_tr, hmodel.get_adapt_params(), grad_outputs=alpha,
                retain_graph=True)
        for alp, alpH in zip(alpha, alpha_H):
            alp.add_(alpH, alpha=-FLAGS.adapt_inner_lr)

    # share and update the meta gradient
    share_params(mgrad)
    task_hyper_opt.zero_grad()
    for w, g in zip(fmodel.get_hyper_params(), mgrad):
        hyper_grad = torch.clamp(g, -3, 3)
        if FLAGS.lr_decay:
            hyper_grad *= 1.-((step-1) / FLAGS.train_steps)
        if w.grad is None:
            w.grad = hyper_grad
        else:
            w.grad.copy_(hyper_grad)
    task_hyper_opt.step()

    # update the followings for the next eposide
    # 1. hyper_params
    for w, fw in zip(model.get_hyper_params(), fmodel.get_hyper_params()):
        w.data.copy_(fw.data)

    # 2. hyper optimizer
    hyper_opt.load_state_dict(task_hyper_opt.state_dict())

    # 3. adapt_params
    reptile_grad = [torch.zeros_like(w) for w in model.get_adapt_params()]
    for rg, w, fw in zip(reptile_grad, model.get_adapt_params(), fmodel.get_adapt_params()):
        rg.data.copy_(w.data - fw.data)
    share_params(reptile_grad)
    for w, rg in zip(model.get_adapt_params(), reptile_grad):
        w.data.add_(rg.data, alpha=-FLAGS.adapt_outer_lr)
