import os
from networks import LeNet5Feats, ResNetFeats18, classifier
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import time
import numpy as np
import matplotlib.pyplot as plt
import pickle


def parse_args():
    parser = argparse.ArgumentParser(description='Bilevel Training')
    # Basic model parameters.
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['MNIST', 'cifar10'])
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--epochs', type=int, default=512, help='outer update')
    parser.add_argument('--iterations', type=int, default=1, help='y-update iterations, must be 1 in LazyHessian')
    parser.add_argument('--N', type=int, default=4, help='number of inner iterations for x- and y-update')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--alpha', type=float, default=0.01, help='learning rate of x-update')
    parser.add_argument('--beta', type=float, default=0.5, help='learning rate of y-update')
    parser.add_argument('--gamma', type=float, default=0.1, help='learning rate of z-update')

    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--save_folder', type=str, default='', help='path to save result')
    parser.add_argument('--model_name', type=str, default='', help='Experiment name')

    args = parser.parse_args()

    if not args.save_folder:
        args.save_folder = './save_results'

    args.model_name = 'LazyHessian_{}_bs_{}_xlr_{}_ylr_{}_zlr_{}_xupd_{}_yupd_{}_zupd_1'.format(args.dataset,
                       args.batch_size, args.alpha, args.beta, args.gamma, args.epochs, args.N)
    args.save_folder = os.path.join(args.save_folder, args.model_name)

    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    return args


def get_data_loaders_network(args):
    if args.dataset == 'MNIST':
        data_test = MNIST(args.data,
                          train=False,
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ]))
        data_train = MNIST(args.data,
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize((32, 32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=0)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=0)

        hyper_net = LeNet5Feats()
        c_net = classifier(n_features=84, n_classes=10)

    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        data_test = CIFAR10(args.data,
                            transform=transform_train,
                            download=True)
        data_train = CIFAR10(args.data,
                             train=False,
                             transform=transform_test,
                             download=True)

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=4)

        hyper_net = ResNetFeats18()
        c_net = classifier(n_features=512, n_classes=10)

    return train_loader, test_loader, hyper_net, c_net



def outer_loss(images, labels, params, hparams, more=False):
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)

    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)

    if more:
        return loss, acc
    else:
        return loss


def update_tensor_grads(params, grads):
    for l, g in zip(params, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g


def inner_solver(images, labels, hparams, params, steps, beta):
    for i in range(steps):
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss = criterion(outputs, labels)
        grads = torch.autograd.grad(loss, params)

        for j in range(len(params)):
            params[j] = params[j] - beta * grads[j]
    return params


def update_v(imagesg, labelsg, imagesf, labelsf, hparams, params, v, gamma):
    grad_R = cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v)
    for i in range(len(params)):
        v[i] = v[i] - gamma * grad_R[i]
    return v


def cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v):
    grad_R = []
    featsg = fhnet(imagesg, params=hparams)

    hvp = Hessian_vector_product(featsg, labelsg, params, v)

    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_y = torch.autograd.grad(loss, params)
    for i in range(len(params)):
        grad_R.append(hvp[i] + grads_outer_y[i])
    return grad_R


def Hessian_vector_product(feats, labels, params, v):
    output = fcnet(feats, params=params)
    loss = criterion(output, labels)
    grads_inner_y = torch.autograd.grad(loss, params, create_graph=True)
    HVP = torch.autograd.grad(grads_inner_y, params, grad_outputs=v, retain_graph=True)

    return HVP


def update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, v, alpha):
    grad_f = cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, v)
    for i in range(len(hparams)):
        hparams[i] = hparams[i] - alpha * grad_f[i]

    return hparams


def cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, v):
    grad_f = []

    # gradient f wrt x
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_x = torch.autograd.grad(loss, hparams)

    # cal Jacobian_product
    featsg = fhnet(imagesg, params=hparams)
    jp = Jacobian_product(featsg, labelsg, hparams, params, v)
    for i in range(len(hparams)):
        grad_f.append(grads_outer_x[i] + jp[i])

    return grad_f


def Jacobian_product(feats, labels, hparams, params, v):

    output = fcnet(feats, params=params)
    loss = criterion(output, labels)
    grads_inner_y = torch.autograd.grad(loss, params, create_graph=True)
    JP = torch.autograd.grad(grads_inner_y, hparams, grad_outputs=v, retain_graph=True)

    return JP


if __name__ == "__main__":

    args = parse_args()
    print(args)

    torch.manual_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    data_train_loader, data_test_loader, hypernet, cnet = get_data_loaders_network(args)

    hypernet = hypernet.to(device)
    cnet = cnet.to(device)

    # numtest = len(data_test_loader)
    # print('num of outer batches = ', numtest)
    # numtrain = len(data_train_loader)
    # print('num of inner batches = ', numtrain)

    data_train_iter = iter(data_train_loader)
    data_test_iter = iter(data_test_loader)

    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).to(device)
    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]

    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).to(device)
    params = list(cnet.parameters())
    params = [param.requires_grad_(True) for param in params]

    v = []

    for i in range(len(params)):
        v.append(torch.rand_like(params[i]))
        v[i].requires_grad = True


    # numhparams = sum([torch.numel(hparam) for hparam in hparams])
    # numparams = sum([torch.numel(param) for param in params])
    # print('size of outer variable: ', numhparams)
    # print('size of inner variable: ', numparams)

    init_params = []
    for param in params:
        init_params.append(torch.zeros_like(param))

    criterion = torch.nn.CrossEntropyLoss().to(device)
    outer_opt = torch.optim.Adam(hparams, lr=0.01)
    inner_opt = torch.optim.Adam(params)

    warm_start = True
    running_time, test_accs, test_losses, train_accs, train_losses = [], [], [], [], []
    loss_time_results = np.zeros((args.epochs * args.N, 5))

    start_time = time.time()

    z_epoch = int(args.epochs / args.N)

    for epoch in range(z_epoch):

        for n in range(args.N):
            try:
                imagesf, labelsf = next(data_test_iter)
            except StopIteration:
                data_test_iter = iter(data_test_loader)
                imagesf, labelsf = next(data_test_iter)

            imagesf, labelsf = imagesf.to(device), labelsf.to(device)

            try:
                imagesg, labelsg = next(data_train_iter)
            except StopIteration:
                data_train_iter = iter(data_train_loader)
                imagesg, labelsg = next(data_train_iter)

            imagesg, labelsg = imagesg.to(device), labelsg.to(device)

            # update y
            inner_opt.zero_grad()
            params = inner_solver(imagesg, labelsg, hparams, params, args.iterations, args.beta)

            # update x
            outer_opt.zero_grad()
            hparams = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, v, args.alpha)

            # Evaluate
            train_loss, train_acc = outer_loss(imagesg, labelsg, params, hparams, more=True)
            test_loss, test_acc = outer_loss(imagesf, labelsf, params, hparams, more=True)

            end_time = time.time()

            total_time = end_time - start_time
            running_time.append(total_time)
            train_accs.append(train_acc * 100)
            train_losses.append(train_loss.item())
            test_accs.append(test_acc * 100)
            test_losses.append(test_loss.item())

            loss_time_results[epoch*args.N+n, 0] = train_loss.item()
            loss_time_results[epoch*args.N+n, 1] = train_acc
            loss_time_results[epoch*args.N+n, 2] = test_loss.item()
            loss_time_results[epoch*args.N+n, 3] = test_acc
            loss_time_results[epoch*args.N+n, 4] = total_time

            if (epoch*args.N+n) % 5 == 0:
                print(
                    'Epoch: %d/%d, Train Loss: %f, Test Loss: %f, Train Accuracy: %f, Test Accuracy: %f, Running Time: %f' % (
                        epoch*args.N+n, args.epochs, train_loss.item(), test_loss.item(), train_acc, test_acc, total_time))


        # update v
        v = update_v(imagesg, labelsg, imagesf, labelsf, hparams, params, v, args.gamma)



    # print('Ended in {:.2e} seconds\n'.format(total_time))

    filename = str(args.seed) + '.pt'
    save_path = os.path.join(args.save_folder, filename)

    state_dict = {'runtime': running_time,
                  'train accuracy': train_accs,
                  'train loss': train_losses,
                  'test accuracy': test_accs,
                  'test loss': test_losses}
    torch.save(state_dict, save_path)
    # save_checkpoint(state_dict, save_path)

    file_name_2 = str(args.seed) + '.npy'
    file_addr = os.path.join(args.save_folder, file_name_2)
    with open(file_addr, 'wb') as f:
        np.save(f, loss_time_results)
