import torch
import numpy as np
import torch.nn as nn
import math

from models.wrapper import get_model
from utils import InfIterator, accuracy, share_params, get_optimizer
from algorithms.utils import compute_sub, compute_cosdist, compute_eucdist, compute_len, \
        compute_abs_mean

def pre_run(rank, model, meta_ds, hyper_opt, device, criterion, logger, FLAGS):

    # Get a task
    train_loader, test_loader = meta_ds.get_task(
        rank,
        num_classes=FLAGS.num_classes,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.num_workers
    )
    train_iter = InfIterator(train_loader)
    test_iter = InfIterator(test_loader)

    # fmodel, gmodel, hmodel
    fmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    gmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    hmodel = get_model(FLAGS, track_bn_stats=True).to(device)

    fmodel.load_state_dict(model.state_dict(), strict=True)
    gmodel.load_state_dict(model.state_dict(), strict=True)

    velocity = [ torch.zeros_like(w) for w in fmodel.get_adapt_params() ]
    K_last, K_first = 150, 50
    gamma = FLAGS.gamma
    batchsz = FLAGS.batch_size

    # adapt K steps with fmodel
    idxs = []
    fmodel.train()
    for k in range(K_last):
        x_tr, y_tr, idx = next(train_iter)
        x_tr, y_tr = x_tr.to(device), y_tr.to(device)
        idxs.append(idx)
        loss_tr = criterion(fmodel(x_tr), y_tr)
        grad_tr = torch.autograd.grad(loss_tr, fmodel.get_adapt_params())
        for v, g, w in zip(velocity, grad_tr, fmodel.get_adapt_params()):
            v.data.mul_(FLAGS.adapt_momentum).add_(g.data)
            w.data.add_(v.data, alpha=-FLAGS.adapt_inner_lr)

    # hmodel goes backward through the learning trajectory
    hmodel.load_state_dict(fmodel.state_dict(), strict=True)

    # compute the direct grad
    alpha = [torch.zeros_like(w) for w in hmodel.get_adapt_params()]
    mgrad = [torch.zeros_like(w) for w in hmodel.get_hyper_params()]
    hmodel.eval()
    for _ in range(K_last - K_first):
        x_te, y_te, _ = next(test_iter)
        x_te, y_te = x_te.to(device), y_te.to(device)
        loss_te = criterion(hmodel(x_te), y_te)
        alpha_curr = torch.autograd.grad(loss_te, hmodel.get_adapt_params())
        alpha = [alp.add_(alpc, alpha=1/(K_last - K_first)) for alp, alpc in zip(alpha, alpha_curr)]

    alpha_ori = [w.clone().detach() for w in alpha]

    # initialize (avgx, avgy) and avgmodel
    avgx = torch.zeros_like(x_te)
    avgy = torch.zeros_like(y_te)
    avgmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    avgmodel.load_state_dict(hmodel.state_dict(), strict=True)
    for aap in avgmodel.get_adapt_params():
        aap.data.copy_(torch.zeros_like(aap))

    inputs, outputs = [], []
    # reverse mode diff through the approximated trajectory
    hmodel.train()
    for k in range(K_last-1, K_first-1, -1):
        for hap, gap in zip(hmodel.get_adapt_params(), gmodel.get_adapt_params()):
            hap.data.add_(gap.data - hap.data, alpha=1./float(k+1))

        x_tr, y_tr, _ = meta_ds.get_data_by_index(idxs[k], train=True)
        x_tr, y_tr = x_tr.to(device), y_tr.type(torch.LongTensor).view(-1).to(device)
        loss_tr = criterion(hmodel(x_tr), y_tr)
        grad_tr = torch.autograd.grad(
                loss_tr, hmodel.get_adapt_params(), create_graph=True)

        # compute Jacobian-vector products (aB, aA)
        alpha_B = torch.autograd.grad(
                grad_tr, hmodel.get_hyper_params(), grad_outputs=alpha,
                retain_graph=True)
        for mgr, alpB in zip(mgrad, alpha_B):
            mgr.add_(alpB, alpha=-FLAGS.adapt_inner_lr)

        alpha_H = torch.autograd.grad(
                grad_tr, hmodel.get_adapt_params(), grad_outputs=alpha,
                retain_graph=True)
        for alp, alpH in zip(alpha, alpha_H):
            alp.add_(alpH, alpha=-FLAGS.adapt_inner_lr)

        # update avgmodel
        i = K_last-k # 1,2,3,.,,,,.K/2
        mu = (1.-math.pow(gamma,i-1)) / (1.-math.pow(gamma,i))
        for aap, hap in zip(avgmodel.get_adapt_params(), hmodel.get_adapt_params()):
            aap.data.mul_(mu).add_(hap, alpha=1.-mu)

        # update avgdata
        N_mu = int(batchsz * mu)
        ridx1 = np.random.permutation(batchsz)
        ridx2 = np.random.permutation(batchsz)
        avgx = torch.cat([avgx[ridx1][:N_mu], x_tr[ridx2][:batchsz-N_mu]], 0).to(device)
        avgy = torch.cat([avgy[ridx1][:N_mu], y_tr[ridx2][:batchsz-N_mu]], 0).to(device)

        loss_avg = criterion(avgmodel(avgx), avgy)
        grad_avg = torch.autograd.grad(
                loss_avg, avgmodel.get_adapt_params(), create_graph=True)

        # compute output
        second_approx = torch.autograd.grad(
                grad_avg, avgmodel.get_hyper_params(), grad_outputs=alpha_ori)
        for sa in second_approx:
            sa.mul_(-FLAGS.adapt_inner_lr)

        aTa = 0.
        aTt = 0.
        for a, mg in zip(second_approx, mgrad):
            a = a.clone().detach()
            t = mg.clone().detach() # second ground truth
            aTa = aTa + (a*a).sum()
            aTt = aTt + (a*t).sum()
        outputs.append((aTt/(aTa + 1e-20)).cpu().detach())

        inputs.append( (1.-math.pow(gamma, i)) / (1.-gamma) )

    return { "X": np.array(inputs), "y": np.array(outputs) }

def run(rank, model, meta_ds, hyper_opt, device, criterion, logger, FLAGS, coeff, step):

    # Get a task
    train_loader, test_loader = meta_ds.get_task(
        rank,
        num_classes=FLAGS.num_classes,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.num_workers
    )
    train_iter = InfIterator(train_loader)
    test_iter = InfIterator(test_loader)

    # fmodel
    fmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    fmodel.load_state_dict(model.state_dict(), strict=True)

    task_hyper_opt = get_optimizer(
        FLAGS.hyper_opt,
        fmodel.get_hyper_params(),
        FLAGS.hyper_lr,
        momentum=FLAGS.hyper_momentum,
        weight_decay=FLAGS.hyper_weight_decay,
        nesterov=FLAGS.hyper_nesterov,
    )
    task_hyper_opt.load_state_dict(hyper_opt.state_dict())

    velocity = [ torch.zeros_like(w) for w in fmodel.get_adapt_params() ]

    if FLAGS.model == "sparsify" and rank == 0:
        count_list = []
        v = []
        for l1norm in fmodel.hyper_layers:
            mask = l1norm.get_mask(training=False)
            v.append(mask)
            count = 0
            for i in range(mask.shape[0]):
                if mask[i].item() < 1e-3:
                    count += 1
            count_list.append(count / mask.shape[0])
        v = torch.cat(v)
        print("min", v.min().item())
        print("max", v.max().item())
        print("mean", v.mean().item())
        print("sparsity", count_list)
        print("reg_term", fmodel.reg_term())
        logger.meter("meta_train", "sparsity", torch.tensor(sum(count_list)/5))

    K = FLAGS.episode_train_steps
    gamma = FLAGS.gamma
    batchsz = FLAGS.batch_size

    # initialize avgmodel
    avgmodel = get_model(FLAGS, track_bn_stats=True).to(device)
    avgmodel.load_state_dict(fmodel.state_dict(), strict=True)
    if FLAGS.model == "nas":
        avgmodel.inv_temp = fmodel.inv_temp
    for aap in avgmodel.get_adapt_params():
        aap.data.copy_(torch.zeros_like(aap))

    # full grad step with fmodel
    for k in range(K):

        # compute avgmodel
        # weights
        mu = (gamma - math.pow(gamma,k+1)) / (1.-math.pow(gamma,k+1))
        for aap, fap in zip(avgmodel.get_adapt_params(), fmodel.get_adapt_params()):
            aap.data.mul_(mu).add_(fap, alpha=1.-mu)
        # hyperparams
        for aap, fap in zip(avgmodel.get_hyper_params(), fmodel.get_hyper_params()):
            aap.data.copy_(fap.data)

        # fmodel takes a grad step
        x_tr, y_tr, _ = next(train_iter)
        x_tr, y_tr = x_tr.to(device), y_tr.to(device)
        if FLAGS.model == "mwn":
            N_corruption = int(batchsz * 0.4)
            yrand = torch.randint(0, FLAGS.num_classes, [N_corruption]).to(device)
            y_tr = torch.cat([yrand, y_tr[N_corruption:]]).to(device)

        fmodel.train()
        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss_tr = criterion_(fmodel(x_tr), y_tr)
            loss_tr = (fmodel.lwm(loss_tr) * loss_tr).mean()
        else:
            loss_tr = criterion(fmodel(x_tr), y_tr)
        # update part
        grad_tr = torch.autograd.grad(loss_tr, fmodel.get_adapt_params())
        for v, g, w in zip(velocity, grad_tr, fmodel.get_adapt_params()):
            v.data.mul_(FLAGS.adapt_momentum).add_(g.data)
            w.data.add_(v.data, alpha=-FLAGS.adapt_inner_lr)

        # compute avgdata
        if k == 0:
            avgx = torch.zeros_like(x_tr)
            avgy = torch.zeros_like(y_tr)
        N_mu = int(batchsz * mu)
        ridx1 = np.random.permutation(batchsz)
        ridx2 = np.random.permutation(batchsz)
        avgx = torch.cat([avgx[ridx1][:N_mu], x_tr[ridx2][:batchsz-N_mu]], 0).to(device)
        avgy = torch.cat([avgy[ridx1][:N_mu], y_tr[ridx2][:batchsz-N_mu]], 0).to(device)

        if k < 50:
            continue

        # compute the direct grad
        x_te, y_te, _ = next(test_iter)
        x_te, y_te = x_te.to(device), y_te.to(device)
        fmodel.eval()
        loss_te = criterion(fmodel(x_te), y_te)
        if FLAGS.model == "sparsify":
            loss_te += fmodel.reg_term()
        if FLAGS.model == "mwn":
            meta_grad = [torch.zeros_like(w) for w in fmodel.get_hyper_params()]
        else:
            meta_grad = torch.autograd.grad(loss_te, fmodel.get_hyper_params(),
                    retain_graph=True)

        # reverse mode diff through the approximated trajectory
        grad_te = torch.autograd.grad(loss_te, fmodel.get_adapt_params())
        avgmodel.train()
        if FLAGS.model == "mwn":
            criterion_ = nn.CrossEntropyLoss(reduction='none').to(device)
            loss_avg = criterion_(avgmodel(avgx), avgy)
            loss_avg = (avgmodel.lwm(loss_avg) * loss_avg).mean()
        else:
            loss_avg = criterion(avgmodel(avgx), avgy)
        grad_avg = torch.autograd.grad(
                loss_avg, avgmodel.get_adapt_params(), create_graph=True)

        # distilled aB
        second_order = torch.autograd.grad(
                grad_avg, avgmodel.get_hyper_params(), grad_outputs=grad_te)
        for so in second_order:
            so.mul_(-FLAGS.adapt_inner_lr)

        # normalize and apply pi
        if FLAGS.model != "mwn":
            beta = coeff[0]
            scaler = beta * (1.-math.pow(gamma, k+1)) / (1.-gamma)
            for so in second_order:
                so.mul_(scaler)

        # g <- gFO + gSO
        for mg, so in zip(meta_grad, second_order):
            mg.add_(so)

        # share and update the meta gradient
        share_params(meta_grad)
        task_hyper_opt.zero_grad()
        for w, g in zip(fmodel.get_hyper_params(), meta_grad):
            hyper_grad = torch.clamp(g, -3, 3)
            if FLAGS.lr_decay:
                hyper_grad *= 1.-((step-1) / FLAGS.train_steps)
            if w.grad is None:
                w.grad = hyper_grad
            else:
                w.grad.copy_(hyper_grad)
        task_hyper_opt.step()

    # update the followings for the next eposide
    # 1. hyper_params
    for w, fw in zip(model.get_hyper_params(), fmodel.get_hyper_params()):
        w.data.copy_(fw.data)

    # 2. hyper optimizer
    hyper_opt.load_state_dict(task_hyper_opt.state_dict())

    # 3. adapt_params
    reptile_grad = [torch.zeros_like(w) for w in model.get_adapt_params()]
    for rg, w, fw in zip(reptile_grad, model.get_adapt_params(), fmodel.get_adapt_params()):
        rg.data.copy_(w.data - fw.data)
    share_params(reptile_grad)
    for w, rg in zip(model.get_adapt_params(), reptile_grad):
        w.data.add_(rg.data, alpha=-FLAGS.adapt_outer_lr)
