import os
from networks import LeNet5Feats, ResNetFeats18, classifier
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import time
import numpy as np
import matplotlib.pyplot as plt
import pickle


def parse_args():
    parser = argparse.ArgumentParser(description='Bilevel Training')
    # Basic model parameters.
    parser.add_argument('--algorithm', type=str, default='AmIGO', choices=['AmIGO', 'SOBA', 'MASOBA'])
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['MNIST', 'cifar10'])
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--epochs', type=int, default=200, help='x update')
    parser.add_argument('--iterations', type=int, default=2, help='y update')
    parser.add_argument('--z_ite', type=int, default=2, help='z update')
    parser.add_argument('--batch_size', type=int, default=128)     #  128 for cifar-10
    parser.add_argument('--alpha', type=float, default=0.01, help='alpha, outer learning rate')
    parser.add_argument('--beta', type=float, default=0.5, help='beta, inner learning rate')
    parser.add_argument('--gamma', type=float, default=0.1, help='gamma, z update stepsize')
    parser.add_argument('--mu', type=float, default=0.5, help='momentum coefficient')

    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--save_folder', type=str, default='', help='path to save result')
    parser.add_argument('--model_name', type=str, default='', help='Experiment name')

    args = parser.parse_args()

    if args.algorithm == 'SOBA' or args.algorithm == 'MASOBA':
        args.iterations = 1
        args.z_ite = 1


    if not args.save_folder:
        args.save_folder = './save_results'

    if args.algorithm == 'MASOBA':
        args.model_name = '{}_{}_bs_{}_xlr_{}_ylr_{}_zlr_{}_xupd_{}_yupd_{}_zupd_{}_mu_{}'.format(args.algorithm, args.dataset,
                                                                                                args.batch_size,
                                                                                                args.alpha, args.beta,
                                                                                                args.gamma,
                                                                                                args.epochs,
                                                                                                args.iterations, args.z_ite, args.mu)
    else:
        args.model_name = '{}_{}_bs_{}_xlr_{}_ylr_{}_zlr_{}_xupd_{}_yupd_{}_zupd_{}'.format(args.algorithm,
                                                                                            args.dataset,
                                                                                            args.batch_size,
                                                                                            args.alpha, args.beta,
                                                                                            args.gamma,
                                                                                            args.epochs,
                                                                                            args.iterations, args.z_ite)

    args.save_folder = os.path.join(args.save_folder, args.model_name)

    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    return args


def get_data_loaders_network(args):
    if args.dataset == 'MNIST':
        data_test = MNIST(args.data,
                          train=False,
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ]))
        data_train = MNIST(args.data,
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize((32, 32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=0)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=0)

        hyper_net = LeNet5Feats()
        c_net = classifier(n_features=84, n_classes=10)

    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        data_test = CIFAR10(args.data,
                            transform=transform_train,
                            download=True)
        data_train = CIFAR10(args.data,
                             train=False,
                             transform=transform_test,
                             download=True)

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=4)

        hyper_net = ResNetFeats18()
        c_net = classifier(n_features=512, n_classes=10)

    return train_loader, test_loader, hyper_net, c_net



def outer_loss(images, labels, params, hparams, more=False):
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)

    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)

    if more:
        return loss, acc
    else:
        return loss


def update_tensor_grads(params, grads):
    for l, g in zip(params, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g


def inner_solver(images, labels, hparams, params, steps, beta):
    for i in range(steps):
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss = criterion(outputs, labels)
        grads = torch.autograd.grad(loss, params, create_graph=False)

        with torch.no_grad():
            for p, g in zip(params, grads):
                p.data.sub_(beta * g)

    return params


def update_v(imagesg, labelsg, imagesf, labelsf, hparams, params, v, gamma, z_ite):
    for _ in range(z_ite):
        grad_R = cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v)

        with torch.no_grad():
            for i in range(len(params)):
                v[i].data.sub_(gamma * grad_R[i])

    return v


def cal_grad_R(imagesg, labelsg, imagesf, labelsf, hparams, params, v):
    grad_R = []
    featsg = fhnet(imagesg, params=hparams)

    hvp = Hessian_vector_product(featsg, labelsg, params, v)

    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_y = torch.autograd.grad(loss, params, create_graph=False)
    for i in range(len(params)):
        grad_R.append(hvp[i] + grads_outer_y[i])
    return grad_R


def Hessian_vector_product(feats, labels, params, v):
    output = fcnet(feats, params=params)
    loss = criterion(output, labels)
    grads_inner_y = torch.autograd.grad(loss, params, create_graph=True)
    HVP = torch.autograd.grad(grads_inner_y, params, grad_outputs=v, retain_graph=False)

    return [h.detach() for h in HVP]


def update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, v, alpha):
    grad_f = cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, v)

    with torch.no_grad():
        for i in range(len(hparams)):
            hparams[i].data.sub_(alpha * grad_f[i])

    return hparams, grad_f


def update_x_ma(imagesg, labelsg, imagesf, labelsf, hparams, params, v, alpha, grad_f_old, mu):
    grad_f = cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, v)

    with torch.no_grad():
        for i in range(len(hparams)):
            grad_mix = mu * grad_f[i] + (1 - mu) * grad_f_old[i]
            hparams[i].data.sub_(alpha * grad_mix)
            grad_f[i] = grad_mix  # return the mixed gradient

    return hparams, grad_f


def cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, v):
    grad_f = []

    # gradient f wrt x
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_x = torch.autograd.grad(loss, hparams, create_graph=False)

    # cal Jacobian_product
    featsg = fhnet(imagesg, params=hparams)
    jp = Jacobian_product(featsg, labelsg, hparams, params, v)
    for i in range(len(hparams)):
        grad_f.append(grads_outer_x[i] + jp[i])

    return grad_f


def Jacobian_product(feats, labels, hparams, params, v):

    output = fcnet(feats, params=params)
    loss = criterion(output, labels)
    grads_inner_y = torch.autograd.grad(loss, params, create_graph=True)
    JP = torch.autograd.grad(grads_inner_y, hparams, grad_outputs=v, retain_graph=False)

    return [j.detach() for j in JP]


if __name__ == "__main__":

    args = parse_args()
    print(args)

    torch.manual_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    data_train_loader, data_test_loader, hypernet, cnet = get_data_loaders_network(args)

    hypernet = hypernet.to(device)
    cnet = cnet.to(device)


    data_train_iter = iter(data_train_loader)
    data_test_iter = iter(data_test_loader)

    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).to(device)
    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]

    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).to(device)
    params = list(cnet.parameters())
    params = [param.requires_grad_(True) for param in params]

    v = []

    for i in range(len(params)):
        v.append(torch.rand_like(params[i]))
        v[i].requires_grad = True


    init_params = []
    for param in params:
        init_params.append(torch.zeros_like(param))

    criterion = torch.nn.CrossEntropyLoss().to(device)
    outer_opt = torch.optim.Adam(hparams, lr=0.01)
    inner_opt = torch.optim.Adam(params)

    warm_start = True
    running_time, test_accs, test_losses, train_accs, train_losses = [], [], [], [], []
    loss_time_results = np.zeros((args.epochs, 5))

    start_time = time.time()

    for epoch in range(args.epochs):
        try:
            imagesf, labelsf = next(data_test_iter)
        except StopIteration:
            data_test_iter = iter(data_test_loader)
            imagesf, labelsf = next(data_test_iter)

        imagesf, labelsf = imagesf.to(device), labelsf.to(device)

        try:
            imagesg, labelsg = next(data_train_iter)
        except StopIteration:
            data_train_iter = iter(data_train_loader)
            imagesg, labelsg = next(data_train_iter)

        imagesg, labelsg = imagesg.to(device), labelsg.to(device)


        # update y
        inner_opt.zero_grad()
        params = inner_solver(imagesg, labelsg, hparams, params, args.iterations, args.beta)

        # update x
        outer_opt.zero_grad()


        if args.algorithm == 'SOBA' or args.algorithm == 'AmIGO':
            hparams, grad_f_old = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, v, args.alpha)
        elif args.algorithm == 'MASOBA':
            if epoch == 0:
                hparams, grad_f_old = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, v, args.alpha)
            else:
                hparams, grad_f_old = update_x_ma(imagesg, labelsg, imagesf, labelsf, hparams, params, v, args.alpha, grad_f_old, args.mu)



        # update v
        v = update_v(imagesg, labelsg, imagesf, labelsf, hparams, params, v, args.gamma, args.z_ite)


        # Evaluate
        train_loss, train_acc = outer_loss(imagesg, labelsg, params, hparams, more=True)
        test_loss, test_acc = outer_loss(imagesf, labelsf, params, hparams, more=True)

        end_time = time.time()

        total_time = end_time - start_time
        running_time.append(total_time)
        train_accs.append(train_acc * 100)
        train_losses.append(train_loss.item())
        test_accs.append(test_acc * 100)
        test_losses.append(test_loss.item())

        loss_time_results[epoch, 0] = train_loss.item()
        loss_time_results[epoch, 1] = train_acc
        loss_time_results[epoch, 2] = test_loss.item()
        loss_time_results[epoch, 3] = test_acc
        loss_time_results[epoch, 4] = total_time


        if epoch % 5 == 0:
            print('Epoch: %d/%d, Train Loss: %f, Test Loss: %f, Train Accuracy: %f, Test Accuracy: %f, Running Time: %f' % (
                epoch, args.epochs, train_loss.item(), test_loss.item(), train_acc, test_acc, total_time))

    # print('Ended in {:.2e} seconds\n'.format(total_time))

    filename = str(args.seed) + '.pt'
    save_path = os.path.join(args.save_folder, filename)

    state_dict = {'runtime': running_time,
                  'train accuracy': train_accs,
                  'train loss': train_losses,
                  'test accuracy': test_accs,
                  'test loss': test_losses}
    torch.save(state_dict, save_path)
    # save_checkpoint(state_dict, save_path)

    file_name_2 = str(args.seed) + '.npy'
    file_addr = os.path.join(args.save_folder, file_name_2)
    with open(file_addr, 'wb') as f:
        np.save(f, loss_time_results)
