import os
from networks import LeNet5Feats, ResNetFeats18, classifier
# import resnet
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import hypergrad as hg
# from utils import save_checkpoint
import time
import matplotlib.pyplot as plt
import pickle


def outer_loss(images, labels, params, hparams, more=False):
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)

    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)

    if more:
        return loss, acc
    else:
        return loss


def update_tensor_grads(params, grads):
    for l, g in zip(params, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g


def inner_solver(images, labels, hparams, params, steps, beta):
    for i in range(steps):
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss = criterion(outputs, labels)
        grads = torch.autograd.grad(loss, params)

        for j in range(len(params)):
            params[j] = params[j] - beta * grads[j]
    return params


def update_omega(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta, lamb):
    grad_R = cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta)
    for i in range(len(params)):
        v[i] = v[i] - lamb * grad_R[i]
    return v


def cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta):
    grad_R = []
    featsg = fhnet(imagesg, params=hparams)
    htilde = H(featsg, labelsg, params, delta, v)

    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_y = torch.autograd.grad(loss, params)
    for i in range(len(params)):
        grad_R.append(htilde[i] - grads_outer_y[i])
    return grad_R


def H(feats, labels, params, delta, v):
    H = []
    # plus
    params_plus = []
    params_minus = []
    for i in range(len(params)):
        params_plus.append(params[i] + delta * v[i])
        params_minus.append(params[i] - delta * v[i])
    output_plus = fcnet(feats, params=params_plus)
    loss_plus = criterion(output_plus, labels)
    d_inner_d_w_plus = torch.autograd.grad(loss_plus, params_plus)
    # minus
    output_minus = fcnet(feats, params=params_minus)
    loss_minus = criterion(output_minus, labels)
    d_inner_d_w_minus = torch.autograd.grad(loss_minus, params_minus)
    for i in range(len(params)):
        H.append((d_inner_d_w_plus[i] - d_inner_d_w_minus[i]) / (2 * delta))
    return H


def update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta, alpha):
    grad_f = cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, delta, v)
    for i in range(len(hparams)):
        hparams[i] = hparams[i] - alpha * grad_f[i]

    return hparams


def cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, delta, v):
    grad_f = []

    # gradient f wrt x
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_x = torch.autograd.grad(loss, hparams)

    # cal J
    featsg = fhnet(imagesg, params=hparams)
    Jacobian = J(featsg, labelsg, hparams, params, delta, v)
    for i in range(len(hparams)):
        grad_f.append(grads_outer_x[i] - Jacobian[i])

    return grad_f


def J(feats, labels, hparams, params, delta, v):
    J = []
    # plus
    params_plus = []
    params_minus = []
    for i in range(len(params)):
        params_plus.append(params[i] + delta * v[i])
        params_minus.append(params[i] - delta * v[i])
    output_plus = fcnet(feats, params=params_plus)
    loss_plus = criterion(output_plus, labels)
    d_inner_d_x_plus = torch.autograd.grad(loss_plus, hparams, retain_graph=True)
    # minus
    output_minus = fcnet(feats, params=params_minus)
    loss_minus = criterion(output_minus, labels)
    d_inner_d_x_minus = torch.autograd.grad(loss_minus, hparams)
    for i in range(len(hparams)):
        J.append((d_inner_d_x_plus[i] - d_inner_d_x_minus[i]) / (2 * delta))
    return J


if __name__ == "__main__":
    m = 1
    torch.manual_seed(m)
    parser = argparse.ArgumentParser(description='Bilevel Training')

    # Basic model parameters.
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'cifar10'])
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--output_dir', type=str, default='MNISTsave')
    args = parser.parse_args()

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    if args.dataset == 'MNIST':
        data_test = MNIST(args.data,
                          train=False,
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ]))
        data_train = MNIST(args.data,
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize((32, 32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))

        data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=0)
        data_test_loader = DataLoader(data_test, batch_size=256, shuffle=True, num_workers=0)

        hypernet = LeNet5Feats().cuda()
        cnet = classifier(n_features=84, n_classes=10).cuda()

    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        data_test = CIFAR10(args.data,
                            transform=transform_train,
                            download=True)
        data_train = CIFAR10(args.data,
                             train=False,
                             transform=transform_test,
                             download=True)

        data_train_loader = DataLoader(data_train, batch_size=128, shuffle=True, num_workers=4)
        data_test_loader = DataLoader(data_test, batch_size=128, shuffle=True, num_workers=4)

        hypernet = ResNetFeats18().cuda()
        cnet = classifier(n_features=512, n_classes=10).cuda()
        lr = 0.05

    numtest = len(data_test_loader)
    print('num of outer batches = ', numtest)
    numtrain = len(data_train_loader)
    print('num of inner batches = ', numtrain)

    data_train_iter = iter(data_train_loader)
    data_test_iter = iter(data_test_loader)

    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).cuda()
    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]

    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).cuda()
    params = list(cnet.parameters())
    params = [param.requires_grad_(True) for param in params]

    v = []

    for i in range(len(params)):
        v.append(torch.rand_like(params[i]))
        v[i].requires_grad = True

    numhparams = sum([torch.numel(hparam) for hparam in hparams])
    numparams = sum([torch.numel(param) for param in params])

    print('size of outer variable: ', numhparams)
    print('size of inner variable: ', numparams)

    init_params = []
    for param in params:
        init_params.append(torch.zeros_like(param))

    criterion = torch.nn.CrossEntropyLoss().cuda()
    outer_opt = torch.optim.Adam(hparams, lr=0.01)
    inner_opt = torch.optim.Adam(params)
    alpha = 0.02
    delta = 0.1
    lamb = 0.8
    beta = 0.5
    T = 1
    steps = 1000
    warm_start = True
    total_time = 0
    running_time, outer_accs, outer_losses = [], [], []
    hessian = 0

    print('Bilevel training with TriBO-HF on ' + args.dataset)
    print('Number of inner iterations T=', T)

    for step in range(1, steps + 1):
        try:
            imagesf, labelsf = next(data_test_iter)
        except StopIteration:
            data_test_iter = iter(data_test_loader)
            imagesf, labelsf = next(data_test_iter)

        imagesf, labelsf = imagesf.cuda(), labelsf.cuda()

        try:
            imagesg, labelsg = next(data_train_iter)
        except StopIteration:
            data_train_iter = iter(data_train_loader)
            imagesg, labelsg = next(data_train_iter)

        imagesg, labelsg = imagesg.cuda(), labelsg.cuda()

        t0 = time.time()
        # update y
        inner_opt.zero_grad()
        params = inner_solver(imagesg, labelsg, hparams, params, T, beta)

        # update omega
        if (step % m == 1):
            omega = update_omega(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta, lamb)
        # update x
        # v_old, v = v, omega
        # if torch.norm(params[0]) <= rv:
        #     v = omega
        # else:
        #     v = params[0] * rv / torch.norm(params[0])
        outer_opt.zero_grad()
        hparams = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, omega, delta, alpha)

        # test
        oloss, oacc = outer_loss(imagesf, labelsf, params, hparams, more=True)

        t1 = time.time() - t0
        total_time += t1
        running_time.append(total_time)
        outer_accs.append(oacc * 100)
        outer_losses.append(oloss.item())

        if step % 10 == 0:
            print('Step: %d/%d, lr: %f, Outer Batch Loss: %f, Accuracy on outer batch: %f' % (
                step, steps, alpha, oloss.item(), oacc))

    print('Ended in {:.2e} seconds\n'.format(total_time))

    filename = 'TriBO-HF' + args.dataset + '_T' + str(T) + "alpha" + str(alpha) + "lambda" + str(lamb) \
               + "beta" + str(beta) + '_m' + str(m) + '.pt'
    save_path = os.path.join(args.output_dir, filename)

    state_dict = {'runtime': running_time,
                  'accuracy': outer_accs,
                  'loss': outer_losses}
    torch.save(state_dict, save_path)
    # save_checkpoint(state_dict, save_path)
