import os
from networks import LeNet5Feats, ResNetFeats18, classifier
# import resnet
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import hypergrad as hg
# from utils import save_checkpoint
import time
import matplotlib.pyplot as plt
import pickle


def outer_loss(images, labels, params, hparams, more=False):
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)

    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)

    if more:
        return loss, acc
    else:
        return loss


def update_tensor_grads(params, grads):
    for l, g in zip(params, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g


def inner_solver(images, labels, imagesf, labelsf, hparams, params, zparams, steps, alpha, gamma, lamb):
    for i in range(steps):
        # zparam
        feats = fhnet(images, params=hparams)
        zoutputs = fcnet(feats, params=zparams)
        zloss = criterion(zoutputs, labels)
        zgrads = torch.autograd.grad(zloss, zparams)
        for j in range(len(params)):
            zparams[j] = zparams[j] - gamma * zgrads[j]

        # yparam
        # hgykt
        featsf = fhnet(imagesf, params=hparams)
        outputs = fcnet(featsf, params=params)
        loss = criterion(outputs, labelsf)
        grads = torch.autograd.grad(loss, params)
        # hfykt
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss = criterion(outputs, labels)
        grads_outer_x = torch.autograd.grad(loss, params)

        for j in range(len(params)):
            params[j] = params[j] - alpha * (grads[j] + lamb * grads_outer_x[j])
    return zparams, params


# def update_omega(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta, lamb):
#     grad_R = cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta)
#     for i in range(len(params)):
#         v[i] = v[i] - lamb * grad_R[i]
#     return v


# def cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v, delta):
#     grad_R = []
#     featsg = fhnet(imagesg, params=hparams)
#     htilde = H(featsg, labelsg, params, delta, v)
#
#     featsf = fhnet(imagesf, params=hparams)
#     outputs = fcnet(featsf, params=params)
#     loss = criterion(outputs, labelsf)
#     grads_outer_y = torch.autograd.grad(loss, params)
#     for i in range(len(params)):
#         grad_R.append(htilde[i] - grads_outer_y[i])
#     return grad_R


# def H(feats, labels, params, delta, v):
#     H = []
#     # plus
#     params_plus = []
#     params_minus = []
#     for i in range(len(params)):
#         params_plus.append(params[i] + delta * v[i])
#         params_minus.append(params[i] - delta * v[i])
#     output_plus = fcnet(feats, params=params_plus)
#     loss_plus = criterion(output_plus, labels)
#     d_inner_d_w_plus = torch.autograd.grad(loss_plus, params_plus)
#     # minus
#     output_minus = fcnet(feats, params=params_minus)
#     loss_minus = criterion(output_minus, labels)
#     d_inner_d_w_minus = torch.autograd.grad(loss_minus, params_minus)
#     for i in range(len(params)):
#         H.append((d_inner_d_w_plus[i] - d_inner_d_w_minus[i]) / (2 * delta))
#     return H


def update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha):
    grad_f = cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha)
    for i in range(len(hparams)):
        hparams[i] = hparams[i] - grad_f[i]

    return hparams


def cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha):
    grad_f = []

    # hfxk
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_x = torch.autograd.grad(loss, hparams)

    # hgxyk
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=params)
    loss = criterion(outputs, labelsg)
    hgxy = torch.autograd.grad(loss, hparams)

    # hgxzk
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=zparams)
    loss = criterion(outputs, labelsg)
    hgxz = torch.autograd.grad(loss, hparams)

    for i in range(len(hparams)):
        true_grad = zeta * alpha * (grads_outer_x[i] + lamb * (hgxy[i] - hgxz[i]))
        grad_f.append(true_grad)

    return grad_f


# def J(feats, labels, hparams, params, delta, v):
#     J = []
#     # plus
#     params_plus = []
#     params_minus = []
#     for i in range(len(params)):
#         params_plus.append(params[i] + delta * v[i])
#         params_minus.append(params[i] - delta * v[i])
#     output_plus = fcnet(feats, params=params_plus)
#     loss_plus = criterion(output_plus, labels)
#     d_inner_d_x_plus = torch.autograd.grad(loss_plus, hparams, retain_graph=True)
#     # minus
#     output_minus = fcnet(feats, params=params_minus)
#     loss_minus = criterion(output_minus, labels)
#     d_inner_d_x_minus = torch.autograd.grad(loss_minus, hparams)
#     for i in range(len(hparams)):
#         J.append((d_inner_d_x_plus[i] - d_inner_d_x_minus[i]) / (2 * delta))
#     return J


if __name__ == "__main__":
    m = 0
    torch.manual_seed(m)
    parser = argparse.ArgumentParser(description='Bilevel Training')

    # Basic model parameters.
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'cifar10'])
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--output_dir', type=str, default='MNISTsave')
    args = parser.parse_args()

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    if args.dataset == 'MNIST':
        data_test = MNIST(args.data,
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ]))
        data_train = MNIST(args.data,
                           train=False,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize((32, 32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))

        data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=0)
        data_test_loader = DataLoader(data_test, batch_size=256, shuffle=True, num_workers=0)

        hypernet = LeNet5Feats().cuda()
        cnet = classifier(n_features=84, n_classes=10).cuda()

    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        data_test = CIFAR10(args.data,
                            transform=transform_train,
                            download=True)
        data_train = CIFAR10(args.data,
                             train=False,
                             transform=transform_test,
                             download=True)

        data_train_loader = DataLoader(data_train, batch_size=128, shuffle=True, num_workers=4)
        data_test_loader = DataLoader(data_test, batch_size=128, shuffle=True, num_workers=4)

        hypernet = ResNetFeats18().cuda()
        cnet = classifier(n_features=512, n_classes=10).cuda()
        lr = 0.05

    numtest = len(data_test_loader)
    print('num of outer batches = ', numtest)
    numtrain = len(data_train_loader)
    print('num of inner batches = ', numtrain)

    data_train_iter = iter(data_train_loader)
    data_test_iter = iter(data_test_loader)

    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).cuda()
    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]

    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).cuda()
    params = list(cnet.parameters())
    params = [param.requires_grad_(True) for param in params]
    zparams = [param.requires_grad_(True) for param in params]

    v = []

    for i in range(len(params)):
        v.append(torch.rand_like(params[i]))
        v[i].requires_grad = True

    numhparams = sum([torch.numel(hparam) for hparam in hparams])
    numparams = sum([torch.numel(param) for param in params])

    print('size of outer variable: ', numhparams)
    print('size of inner variable: ', numparams)

    init_params = []
    for param in params:
        init_params.append(torch.zeros_like(param))

    criterion = torch.nn.CrossEntropyLoss().cuda()
    outer_opt = torch.optim.Adam(hparams, lr=0.01)
    inner_opt = torch.optim.Adam(params)
    alpha = 0.003
    gamma = 0.005
    zeta = 0.3
    # delta = 0.1
    lamb = 0.08
    # beta = 0.5
    T = 1
    steps = 1000
    warm_start = True
    total_time = 0
    running_time, outer_accs, outer_losses = [], [], []
    hessian = 0
    step = 0
    print('Bilevel training with TriBO-HF on ' + args.dataset)
    print('Number of inner iterations T=', T)

    # for step in range(1, steps + 1):
    while(total_time <= 60):
        step += 1
        try:
            imagesf, labelsf = next(data_test_iter)
        except StopIteration:
            data_test_iter = iter(data_test_loader)
            imagesf, labelsf = next(data_test_iter)

        imagesf, labelsf = imagesf.cuda(), labelsf.cuda()

        try:
            imagesg, labelsg = next(data_train_iter)
        except StopIteration:
            data_train_iter = iter(data_train_loader)
            imagesg, labelsg = next(data_train_iter)

        imagesg, labelsg = imagesg.cuda(), labelsg.cuda()

        t0 = time.time()
        # inner
        inner_opt.zero_grad()
        zparams, params = inner_solver(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, T, alpha, gamma, lamb)

        # outer_opt.zero_grad()
        hparams = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha)
        lamb += 0.0001
        # test
        oloss, oacc = outer_loss(imagesf, labelsf, params, hparams, more=True)

        t1 = time.time() - t0
        total_time += t1
        running_time.append(total_time)
        outer_accs.append(oacc * 100)
        outer_losses.append(oloss.item())

        if step % 10 == 0:
            print('Step: %d/%d, lr: %f, Outer Batch Loss: %f, Accuracy on outer batch: %f' % (
                step, steps, alpha, oloss.item(), oacc))

    print('Ended in {:.2e} seconds\n'.format(total_time))

    filename = 'F2SA' + args.dataset + "alpha" + str(alpha) + "lambda" + str(lamb) \
               + "beta" + str(zeta) + '_m' + str(m) + '.pt'
    save_path = os.path.join(args.output_dir, filename)

    state_dict = {'runtime': running_time,
                  'accuracy': outer_accs,
                  'loss': outer_losses}
    torch.save(state_dict, save_path)
    # save_checkpoint(state_dict, save_path)
