import torch
import numpy as np
import torch.nn as nn

from models.wrapper import get_model
from utils import InfIterator, accuracy, share_params, get_optimizer

def run(rank, model, meta_ds, hyper_opt, device, criterion, logger, FLAGS, step):

    # Get a task
    train_loader, test_loader = meta_ds.get_task(
        rank,
        num_classes=FLAGS.num_classes,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.num_workers
    )
    train_iter = InfIterator(train_loader)
    test_iter = InfIterator(test_loader)

    # fmodel, gmodel, hmodel
    fmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    fmodel.load_state_dict(model.state_dict(), strict=True)

    task_hyper_opt = get_optimizer(
        FLAGS.hyper_opt,
        fmodel.get_hyper_params(),
        FLAGS.hyper_lr,
        momentum=FLAGS.hyper_momentum,
        weight_decay=FLAGS.hyper_weight_decay,
        nesterov=FLAGS.hyper_nesterov,
    )
    task_hyper_opt.load_state_dict(hyper_opt.state_dict())

    velocity = [torch.zeros_like(w) for w in fmodel.get_adapt_params()]

    if FLAGS.model == "sparsify" and rank == 0:
        count_list = []
        v = []
        for l1norm in fmodel.hyper_layers:
            mask = l1norm.get_mask(training=False)
            v.append(mask)
            count = 0
            for i in range(mask.shape[0]):
                if mask[i].item() < 1e-3:
                    count += 1
            count_list.append(count / mask.shape[0])
        v = torch.cat(v)
        print("min", v.min().item())
        print("max", v.max().item())
        print("mean", v.mean().item())
        print("sparsity", count_list)
        print("reg_term", fmodel.reg_term())
        logger.meter("meta_train", "sparsity", torch.tensor(sum(count_list)/5))

    batchsz = FLAGS.batch_size
    K = FLAGS.episode_train_steps
    N = FLAGS.num_neumann_steps
    assert((K/2) % N == 0)
    # 50 % [1,2,5,10] == 0

    # full grad step with fmodel
    for k in range(K):
        # fmodel takes a grad step
        x_tr, y_tr, _ = next(train_iter)
        x_tr, y_tr = x_tr.to(device), y_tr.to(device)
        if FLAGS.model == "mwn":
            N_corruption = int(batchsz * 0.4)
            yrand = torch.randint(0, FLAGS.num_classes, [N_corruption]).to(device)
            y_tr = torch.cat([yrand, y_tr[N_corruption:]]).to(device)

        fmodel.train()
        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss_tr = criterion_(fmodel(x_tr), y_tr)
            loss_tr = (fmodel.lwm(loss_tr) * loss_tr).mean()
        else:
            loss_tr = criterion(fmodel(x_tr), y_tr)

        grad_tr = torch.autograd.grad(loss_tr, fmodel.get_adapt_params())
        for v, g, w in zip(velocity, grad_tr, fmodel.get_adapt_params()):
            v.data.mul_(FLAGS.adapt_momentum).add_(g.data)
            w.data.add_(v.data, alpha=-FLAGS.adapt_inner_lr)

        if k < K - FLAGS.neumann_factor * int((K/2)/N):
            continue

        # compute the direct grad
        fmodel.eval()
        direct_grad = [torch.zeros_like(w) for w in fmodel.get_hyper_params()]  # mgrad in drmad
        d_loss_te_dw = [torch.zeros_like(w) for w in fmodel.get_adapt_params()]  # alpha in drmad

        for _ in range(N):
            x_te, y_te, _ = next(test_iter)
            x_te, y_te = x_te.to(device), y_te.to(device)
            loss_te = criterion(fmodel(x_te), y_te)
            if FLAGS.model == "sparsify":
                loss_te += fmodel.reg_term()

            d_loss_te_dw_curr = torch.autograd.grad(loss_te, fmodel.get_adapt_params(),
                    retain_graph=True)
            for dw, dwc in zip(d_loss_te_dw, d_loss_te_dw_curr):
                dw.data.add_(dwc, alpha=1.0 / float(N))

            if FLAGS.model != "mwn":
                direct_grad_curr = torch.autograd.grad(loss_te, fmodel.get_hyper_params())
                for dg, dgc in zip(direct_grad, direct_grad_curr):
                    dg.data.add_(dgc, alpha=1.0 / float(N))

        preconditioner = torch.cat([p.clone().detach().view(-1) for p in d_loss_te_dw])

        # compute indirect_grad
        fmodel.train()
        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss_tr = criterion_(fmodel(x_tr), y_tr)
            loss_tr = (fmodel.lwm(loss_tr) * loss_tr).mean()
        else:
            loss_tr = criterion(fmodel(x_tr), y_tr)

        d_loss_tr_dw = torch.autograd.grad(loss_tr, fmodel.get_adapt_params(), create_graph=True)
        d_loss_tr_dw = torch.cat([p.view(-1) for p in d_loss_tr_dw])

        accummulated_term = preconditioner.clone().detach()
        for _ in range(N):
            hess = torch.autograd.grad(
                d_loss_tr_dw, fmodel.get_adapt_params(), retain_graph=True, grad_outputs=accummulated_term
            )
            hess = torch.cat([p.contiguous().view(-1) for p in hess])
            accummulated_term -= FLAGS.adapt_inner_lr * hess
            preconditioner += accummulated_term

        preconditioner *= FLAGS.adapt_inner_lr
        indirect_grad = torch.autograd.grad(d_loss_tr_dw, fmodel.get_hyper_params(),
                grad_outputs=preconditioner.view(-1))

        mgrad = [dg - idg for dg, idg in zip(direct_grad, indirect_grad)]

        # share and update the meta gradient
        share_params(mgrad)
        task_hyper_opt.zero_grad()
        for w, g in zip(fmodel.get_hyper_params(), mgrad):
            hyper_grad = torch.clamp(g, -3, 3)
            if FLAGS.lr_decay:
                hyper_grad *= 1.-((step-1) / FLAGS.train_steps)
            if w.grad is None:
                w.grad = hyper_grad
            else:
                w.grad.copy_(hyper_grad)
        task_hyper_opt.step()

    # update the followings for the next eposide
    # 1. hyper_params
    for w, fw in zip(model.get_hyper_params(), fmodel.get_hyper_params()):
        w.data.copy_(fw.data)

    # 2. hyper optimizer
    hyper_opt.load_state_dict(task_hyper_opt.state_dict())

    # 3. adapt_params
    reptile_grad = [torch.zeros_like(w) for w in model.get_adapt_params()]
    for rg, w, fw in zip(reptile_grad, model.get_adapt_params(), fmodel.get_adapt_params()):
        rg.data.copy_(w.data - fw.data)
    share_params(reptile_grad)
    for w, rg in zip(model.get_adapt_params(), reptile_grad):
        w.data.add_(rg.data, alpha=-FLAGS.adapt_outer_lr)
