import os
from networks import LeNet5Feats, ResNetFeats18, classifier
# import resnet
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import hypergrad as hg
# from utils import save_checkpoint
import time
import matplotlib.pyplot as plt
import pickle


def outer_loss(images, labels, params, hparams, more=False):
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)

    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)

    if more:
        return loss, acc
    else:
        return loss


def update_tensor_grads(params, grads):
    for l, g in zip(params, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g


def y_update(imagesg, labelsg, imagesf, labelsf, hparams, params, hparams_old, params_old, alpha, eta, lamb):
    grad_fy = []
    grad_gy = []
    # hfy
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads = torch.autograd.grad(loss, params)

    featsf = fhnet(imagesf, params=hparams_old)
    outputs = fcnet(featsf, params=params_old)
    loss = criterion(outputs, labelsf)
    grads_old = torch.autograd.grad(loss, params_old)

    # hgy
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=params)
    loss = criterion(outputs, labelsg)
    gradsg = torch.autograd.grad(loss, params)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=params_old)
    loss = criterion(outputs, labelsg)
    gradsg_old = torch.autograd.grad(loss, params_old)

    for j in range(len(params)):
        true_grad = grads[j] + (1 - eta) * (h_old[1][j] - grads_old[j]) + lamb * (gradsg[j] + (1 - eta) * (h_old[2][j] - gradsg_old[j]))
        params_old[j] = params[j]
        params[j] = params[j] - alpha * true_grad
        grad_fy.append(grads[j] + (1 - eta) * (h_old[1][j] - grads_old[j]))
        grad_gy.append(gradsg[j] + (1 - eta) * (h_old[2][j] - gradsg_old[j]))
    return grad_fy, grad_gy


def z_update(imagesg, labelsg, hparams, zparams, hparams_old, zparams_old, h_old, eta, gamma):
    grad_z = []
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=zparams)
    loss = criterion(outputs, labelsg)
    hgz = torch.autograd.grad(loss, zparams)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=zparams_old)
    loss = criterion(outputs, labelsg)
    hgz_old = torch.autograd.grad(loss, zparams_old)

    for i in range(len(params)):
        true_grad = hgz[i] + (1 - eta) * (h_old[0][i] - hgz_old[i])
        zparams_old[i] = zparams[i]
        zparams[i] = zparams[i] - gamma * true_grad
        grad_z.append(true_grad)
    return grad_z


def update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams,
             hparams_old, params_old, zparams_old, eta, lamb, alpha):
    grad_fx = []
    grad_gxy = []
    grad_gxz = []
    # hfx
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads = torch.autograd.grad(loss, hparams)

    featsf = fhnet(imagesf, params=hparams_old)
    outputs = fcnet(featsf, params=params_old)
    loss = criterion(outputs, labelsf)
    grads_old = torch.autograd.grad(loss, hparams_old)

    # hgxy
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=params)
    loss = criterion(outputs, labelsg)
    gradsg = torch.autograd.grad(loss, hparams)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=params_old)
    loss = criterion(outputs, labelsg)
    gradsg_old = torch.autograd.grad(loss, hparams_old)

    # hgxz
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=zparams)
    loss = criterion(outputs, labelsg)
    hgz = torch.autograd.grad(loss, hparams)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=zparams_old)
    loss = criterion(outputs, labelsg)
    hgz_old = torch.autograd.grad(loss, hparams_old)
    for j in range(len(hparams)):
        hfx_temp = grads[j] + (1 - eta) * (h_old[3][j] - grads_old[j])
        hgxy_temp = gradsg[j] + (1 - eta) * (h_old[4][j] - gradsg_old[j])
        hgxz_temp = hgz[j] + (1 - eta) * (h_old[5][j] - hgz_old[j])
        hparams_old[j] = hparams[j]
        # zeta = 1
        hparams[j] = hparams[j] - alpha * (hfx_temp + lamb * (hgxy_temp - hgxz_temp))
        grad_fx.append(hfx_temp)
        grad_gxy.append(hgxy_temp)
        grad_gxz.append(hgxz_temp)
    return grad_fx, grad_gxy, grad_gxz


def cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha):
    grad_f = []

    # hfxk
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_x = torch.autograd.grad(loss, hparams)

    # hgxyk
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=params)
    loss = criterion(outputs, labelsg)
    hgxy = torch.autograd.grad(loss, hparams)

    # hgxzk
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=zparams)
    loss = criterion(outputs, labelsg)
    hgxz = torch.autograd.grad(loss, hparams)

    for i in range(len(hparams)):
        true_grad = zeta * alpha * (grads_outer_x[i] + lamb * (hgxy[i] - hgxz[i]))
        grad_f.append(true_grad)

    return grad_f


if __name__ == "__main__":
    m = 1
    torch.manual_seed(m)
    parser = argparse.ArgumentParser(description='Bilevel Training')

    # Basic model parameters.
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'cifar10'])
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--output_dir', type=str, default='MNISTsave')
    args = parser.parse_args()

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    data_test = MNIST(args.data,
                      download=True,
                      transform=transforms.Compose([
                          transforms.Resize((32, 32)),
                          transforms.ToTensor(),
                          transforms.Normalize((0.1307,), (0.3081,))
                      ]))
    data_train = MNIST(args.data,
                       train=False,
                       download=True,
                       transform=transforms.Compose([
                           transforms.Resize((32, 32)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))

    data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=0)
    data_test_loader = DataLoader(data_test, batch_size=256, shuffle=True, num_workers=0)

    hypernet = LeNet5Feats().cuda()
    cnet = classifier(n_features=84, n_classes=10).cuda()

    numtest = len(data_test_loader)
    print('num of outer batches = ', numtest)
    numtrain = len(data_train_loader)
    print('num of inner batches = ', numtrain)

    data_train_iter = iter(data_train_loader)
    data_test_iter = iter(data_test_loader)

    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).cuda()
    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]
    hparams_old = [hparam.requires_grad_(True) for hparam in hparams]

    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).cuda()
    params = list(cnet.parameters())
    params = [param.requires_grad_(True) for param in params]
    params_old = [param.requires_grad_(True) for param in params]
    zparams = [param.requires_grad_(True) for param in params]
    zparams_old = [param.requires_grad_(True) for param in params]

    numhparams = sum([torch.numel(hparam) for hparam in hparams])
    numparams = sum([torch.numel(param) for param in params])

    # initialization
    h = [zparams, params, params, hparams, hparams, hparams]
    h_old = [zparams, params, params, hparams, hparams, hparams]

    print('size of outer variable: ', numhparams)
    print('size of inner variable: ', numparams)

    init_params = []
    for param in params:
        init_params.append(torch.zeros_like(param))

    criterion = torch.nn.CrossEntropyLoss().cuda()
    outer_opt = torch.optim.Adam(hparams, lr=0.01)
    inner_opt = torch.optim.Adam(params)
    # m = 3
    alpha = 0.005
    gamma = 0.008
    zeta = 0.1
    eta = 0.8
    # delta = 0.1
    lamb = 0.2
    # beta = 0.5
    T = 10
    steps = 1000
    warm_start = True
    total_time = 0
    running_time, outer_accs, outer_losses = [], [], []
    hessian = 0
    step = 0

    print('Bilevel training with TriBO-HF on ' + args.dataset)
    print('Number of inner iterations T=', T)

    # for step in range(1, steps + 1):
    while (total_time <= 60):
        step += 1
        try:
            imagesf, labelsf = next(data_test_iter)
        except StopIteration:
            data_test_iter = iter(data_test_loader)
            imagesf, labelsf = next(data_test_iter)

        imagesf, labelsf = imagesf.cuda(), labelsf.cuda()

        try:
            imagesg, labelsg = next(data_train_iter)
        except StopIteration:
            data_train_iter = iter(data_train_loader)
            imagesg, labelsg = next(data_train_iter)

        imagesg, labelsg = imagesg.cuda(), labelsg.cuda()

        t0 = time.time()
        # inner
        inner_opt.zero_grad()
        # zparams, params = inner_solver(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, T, alpha, gamma,
        #                                lamb)

        hgz = z_update(imagesg, labelsg, hparams, zparams, hparams_old, zparams_old, h_old, eta, gamma)
        h_old[0], h[0] = h[0], hgz

        hfy, hgy = y_update(imagesg, labelsg, imagesf, labelsf, hparams, params, hparams_old, params_old, alpha, eta, lamb)
        h_old[1], h_old[2], h[1], h[2] = h[1], h[2], hfy, hgy

        hfx, hgxy, hgxz = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams,
             hparams_old, params_old, zparams_old, eta, lamb, alpha)
        h_old[3], h_old[4], h_old[5], h[3], h[4], h[5] = h[3], h[4], h[5], hfx, hgxy, hgxz
        lamb += 0.0001
        # test
        oloss, oacc = outer_loss(imagesf, labelsf, params, hparams, more=True)

        t1 = time.time() - t0
        total_time += t1
        running_time.append(total_time)
        outer_accs.append(oacc * 100)
        outer_losses.append(oloss.item())

        if step % 10 == 0:
            print('Step: %d/%d, lr: %f, Outer Batch Loss: %f, Accuracy on outer batch: %f' % (
                step, steps, alpha, oloss.item(), oacc))

    print('Ended in {:.2e} seconds\n'.format(total_time))

    filename = 'F3SA' + args.dataset + "alpha" + str(alpha) + "lambda" + str(lamb) \
               + "beta" + str(zeta) + '_m' + str(m) + '.pt'
    save_path = os.path.join(args.output_dir, filename)

    state_dict = {'runtime': running_time,
                  'accuracy': outer_accs,
                  'loss': outer_losses}
    torch.save(state_dict, save_path)
    # save_checkpoint(state_dict, save_path)
