import os
from networks import LeNet5Feats, ResNetFeats18, classifier
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import time
import numpy as np
import matplotlib.pyplot as plt
import pickle


def parse_args():
    parser = argparse.ArgumentParser(description='Bilevel Training')
    # Basic model parameters.
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['MNIST', 'cifar10'])
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--epochs', type=int, default=200, help='outer update')
    parser.add_argument('--iterations', type=int, default=1, help='T, number of inner iterations')
    parser.add_argument('--batch_size', type=int, default=128)     # 128 for cifar-10
    parser.add_argument('--alpha', type=float, default=0.5, help='alpha')
    parser.add_argument('--gamma', type=float, default=0.5, help='gamma')
    parser.add_argument('--zeta', type=float, default=0.5, help='zeta')
    parser.add_argument('--lamb', type=float, default=0.5, help='lamb')

    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--save_folder', type=str, default='', help='path to save result')
    parser.add_argument('--model_name', type=str, default='', help='Experiment name')

    args = parser.parse_args()

    if not args.save_folder:
        args.save_folder = './save_results'

    args.model_name = 'F2SA_{}_bs_{}_alpha_{}_gamma_{}_zeta_{}_lamb_{}_xupd_{}_yupd_{}'.format(args.dataset,
                                                                                            args.batch_size,
                                                                                            args.alpha, args.gamma,
                                                                                            args.zeta, args.lamb,
                                                                                            args.epochs,
                                                                                            args.iterations)
    args.save_folder = os.path.join(args.save_folder, args.model_name)

    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    return args


def get_data_loaders_network(args):
    if args.dataset == 'MNIST':
        data_test = MNIST(args.data,
                          train=False,
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ]))
        data_train = MNIST(args.data,
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize((32, 32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=0)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=0)

        hyper_net = LeNet5Feats()
        c_net = classifier(n_features=84, n_classes=10)

    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        data_test = CIFAR10(args.data,
                            transform=transform_train,
                            download=True)
        data_train = CIFAR10(args.data,
                             train=False,
                             transform=transform_test,
                             download=True)

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=4)

        hyper_net = ResNetFeats18()
        c_net = classifier(n_features=512, n_classes=10)

    return train_loader, test_loader, hyper_net, c_net



def outer_loss(images, labels, params, hparams, more=False):
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)

    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)

    if more:
        return loss, acc
    else:
        return loss


def update_tensor_grads(params, grads):
    for l, g in zip(params, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g


def inner_solver(images, labels, imagesf, labelsf, hparams, params, zparams, steps, alpha, gamma, lamb):
    for i in range(steps):
        # zparam
        feats = fhnet(images, params=hparams)
        zoutputs = fcnet(feats, params=zparams)
        zloss = criterion(zoutputs, labels)
        zgrads = torch.autograd.grad(zloss, zparams, create_graph=False)

        with torch.no_grad():
            for zp, zg in zip(zparams, zgrads):
                zp.data.sub_(gamma * zg)

        # yparam
        # hgykt
        featsf = fhnet(imagesf, params=hparams)
        outputs = fcnet(featsf, params=params)
        loss = criterion(outputs, labelsf)
        grads = torch.autograd.grad(loss, params, create_graph=False)
        # hfykt
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss = criterion(outputs, labels)
        grads_outer_x = torch.autograd.grad(loss, params)

        with torch.no_grad():
            for p, g, gx in zip(params, grads, grads_outer_x):
                p.data.sub_(alpha * (g + lamb * gx))

    return zparams, params


def update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha, lamb):
    grad_f = cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha, lamb)

    with torch.no_grad():
        for h, g in zip(hparams, grad_f):
            h.data.sub_(g)

    return hparams


def cal_grad_f(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, zeta, alpha, lamb):
    grad_f = []

    # hfxk
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads_outer_x = torch.autograd.grad(loss, hparams, create_graph=False)

    # hgxyk
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=params)
    loss = criterion(outputs, labelsg)
    hgxy = torch.autograd.grad(loss, hparams, create_graph=False)

    # hgxzk
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=zparams)
    loss = criterion(outputs, labelsg)
    hgxz = torch.autograd.grad(loss, hparams, create_graph=False)

    for i in range(len(hparams)):
        true_grad = zeta * alpha * (grads_outer_x[i] + lamb * (hgxy[i] - hgxz[i]))
        grad_f.append(true_grad)

    return grad_f


if __name__ == "__main__":

    args = parse_args()
    print(args)

    torch.manual_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    data_train_loader, data_test_loader, hypernet, cnet = get_data_loaders_network(args)

    hypernet = hypernet.to(device)
    cnet = cnet.to(device)


    data_train_iter = iter(data_train_loader)
    data_test_iter = iter(data_test_loader)

    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).to(device)
    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]

    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).to(device)
    params = list(cnet.parameters())
    params = [param.requires_grad_(True) for param in params]
    zparams = [param.clone().detach().requires_grad_(True) for param in params]


    init_params = []
    for param in params:
        init_params.append(torch.zeros_like(param))

    criterion = torch.nn.CrossEntropyLoss().to(device)
    outer_opt = torch.optim.Adam(hparams, lr=0.01)
    inner_opt = torch.optim.Adam(params)

    warm_start = True
    total_time = 0
    running_time, test_accs, test_losses, train_accs, train_losses = [], [], [], [], []
    loss_time_results = np.zeros((args.epochs, 5))

    start_time = time.time()

    for epoch in range(args.epochs):
        try:
            imagesf, labelsf = next(data_test_iter)
        except StopIteration:
            data_test_iter = iter(data_test_loader)
            imagesf, labelsf = next(data_test_iter)

        imagesf, labelsf = imagesf.to(device), labelsf.to(device)

        try:
            imagesg, labelsg = next(data_train_iter)
        except StopIteration:
            data_train_iter = iter(data_train_loader)
            imagesg, labelsg = next(data_train_iter)

        imagesg, labelsg = imagesg.to(device), labelsg.to(device)

        # update y and z
        inner_opt.zero_grad()
        zparams, params = inner_solver(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, args.iterations, args.alpha, args.gamma,
                                       args.lamb)

        # update x
        hparams = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams, args.zeta, args.alpha, args.lamb)
        args.lamb += 0.0001

        # Evaluate
        train_loss, train_acc = outer_loss(imagesg, labelsg, params, hparams, more=True)
        test_loss, test_acc = outer_loss(imagesf, labelsf, params, hparams, more=True)

        end_time = time.time()

        total_time = end_time - start_time
        running_time.append(total_time)
        train_accs.append(train_acc * 100)
        train_losses.append(train_loss.item())
        test_accs.append(test_acc * 100)
        test_losses.append(test_loss.item())

        loss_time_results[epoch, 0] = train_loss.item()
        loss_time_results[epoch, 1] = train_acc
        loss_time_results[epoch, 2] = test_loss.item()
        loss_time_results[epoch, 3] = test_acc
        loss_time_results[epoch, 4] = total_time

        if epoch % 5 == 0:
            print('Epoch: %d/%d, Train Loss: %f, Test Loss: %f, Train Accuracy: %f, Test Accuracy: %f, Running Time: %f' % (
                epoch, args.epochs, train_loss.item(), test_loss.item(), train_acc, test_acc, total_time))

    # print('Ended in {:.2e} seconds\n'.format(total_time))

    filename = str(args.seed) + '.pt'
    save_path = os.path.join(args.save_folder, filename)

    state_dict = {'runtime': running_time,
                  'train accuracy': train_accs,
                  'train loss': train_losses,
                  'test accuracy': test_accs,
                  'test loss': test_losses}
    torch.save(state_dict, save_path)
    # save_checkpoint(state_dict, save_path)

    file_name_2 = str(args.seed) + '.npy'
    file_addr = os.path.join(args.save_folder, file_name_2)
    with open(file_addr, 'wb') as f:
        np.save(f, loss_time_results)
