import os
from networks import LeNet5Feats, ResNetFeats18, classifier
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import argparse
import higher
import time
import numpy as np
import matplotlib.pyplot as plt
import pickle


def parse_args():
    parser = argparse.ArgumentParser(description='Bilevel Training')
    # Basic model parameters.
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['MNIST', 'cifar10'])
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--epochs', type=int, default=200, help='outer update')
    parser.add_argument('--iterations', type=int, default=1, help='T, number of inner iterations')
    parser.add_argument('--batch_size', type=int, default=128)     # 128 for cifar-10
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha')
    parser.add_argument('--gamma', type=float, default=0.1, help='gamma')
    parser.add_argument('--zeta', type=float, default=0.1, help='zeta')
    parser.add_argument('--eta', type=float, default=0.5, help='eta')
    parser.add_argument('--lamb', type=float, default=0.1, help='lamb')

    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--save_folder', type=str, default='', help='path to save result')
    parser.add_argument('--model_name', type=str, default='', help='Experiment name')

    args = parser.parse_args()

    if not args.save_folder:
        args.save_folder = './save_results'

    args.model_name = 'F3SA_{}_bs_{}_alpha_{}_gamma_{}_zeta_{}_eta_{}_lamb_{}_xupd_{}_yupd_{}'.format(args.dataset,
                                                                                            args.batch_size,
                                                                                            args.alpha, args.gamma,
                                                                                            args.zeta, args.eta,
                                                                                            args.lamb, args.epochs,
                                                                                            args.iterations)
    args.save_folder = os.path.join(args.save_folder, args.model_name)

    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)

    return args


def get_data_loaders_network(args):
    if args.dataset == 'MNIST':
        data_test = MNIST(args.data,
                          train=False,
                          download=True,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))
                          ]))
        data_train = MNIST(args.data,
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Resize((32, 32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=0)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=0)

        hyper_net = LeNet5Feats()
        c_net = classifier(n_features=84, n_classes=10)

    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        data_test = CIFAR10(args.data,
                            transform=transform_train,
                            download=True)
        data_train = CIFAR10(args.data,
                             train=False,
                             transform=transform_test,
                             download=True)

        train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=True, num_workers=4)

        hyper_net = ResNetFeats18()
        c_net = classifier(n_features=512, n_classes=10)

    return train_loader, test_loader, hyper_net, c_net



def outer_loss(images, labels, params, hparams, more=False):
    feats = fhnet(images, params=hparams)
    outputs = fcnet(feats, params=params)
    loss = criterion(outputs, labels)

    preds = outputs.data.max(1)[1]
    correct = preds.eq(labels.data.view_as(preds)).sum()
    acc = float(correct) / labels.size(0)

    if more:
        return loss, acc
    else:
        return loss


def update_tensor_grads(params, grads):
    for l, g in zip(params, grads):
        if l.grad is None:
            l.grad = torch.zeros_like(l)
        if g is not None:
            l.grad += g


def y_update(imagesg, labelsg, imagesf, labelsf, hparams, params, hparams_old, params_old, alpha, eta, lamb):
    grad_fy = []
    grad_gy = []
    # hfy
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads = torch.autograd.grad(loss, params, create_graph=False)

    featsf = fhnet(imagesf, params=hparams_old)
    outputs = fcnet(featsf, params=params_old)
    loss = criterion(outputs, labelsf)
    grads_old = torch.autograd.grad(loss, params_old, create_graph=False)

    # hgy
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=params)
    loss = criterion(outputs, labelsg)
    gradsg = torch.autograd.grad(loss, params, create_graph=False)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=params_old)
    loss = criterion(outputs, labelsg)
    gradsg_old = torch.autograd.grad(loss, params_old, create_graph=False)

    for j in range(len(params)):
        true_grad = grads[j] + (1 - eta) * (h_old[1][j] - grads_old[j]) + lamb * (gradsg[j] + (1 - eta) * (h_old[2][j] - gradsg_old[j]))

        with torch.no_grad():
            params_old[j].data.copy_(params[j].data)
            params[j].data.sub_(alpha * true_grad)

        grad_fy.append(grads[j] + (1 - eta) * (h_old[1][j] - grads_old[j]))
        grad_gy.append(gradsg[j] + (1 - eta) * (h_old[2][j] - gradsg_old[j]))
    return grad_fy, grad_gy


def z_update(imagesg, labelsg, hparams, zparams, hparams_old, zparams_old, h_old, eta, gamma):
    grad_z = []
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=zparams)
    loss = criterion(outputs, labelsg)
    hgz = torch.autograd.grad(loss, zparams, create_graph=False)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=zparams_old)
    loss = criterion(outputs, labelsg)
    hgz_old = torch.autograd.grad(loss, zparams_old, create_graph=False)

    for i in range(len(params)):
        true_grad = hgz[i] + (1 - eta) * (h_old[0][i] - hgz_old[i])

        with torch.no_grad():
            zparams_old[i].data.copy_(zparams[i].data)
            zparams[i].data.sub_(gamma * true_grad)

        grad_z.append(true_grad)
    return grad_z


def update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams,
             hparams_old, params_old, zparams_old, eta, lamb, alpha):
    grad_fx = []
    grad_gxy = []
    grad_gxz = []
    # hfx
    featsf = fhnet(imagesf, params=hparams)
    outputs = fcnet(featsf, params=params)
    loss = criterion(outputs, labelsf)
    grads = torch.autograd.grad(loss, hparams, create_graph=False)

    featsf = fhnet(imagesf, params=hparams_old)
    outputs = fcnet(featsf, params=params_old)
    loss = criterion(outputs, labelsf)
    grads_old = torch.autograd.grad(loss, hparams_old, create_graph=False)

    # hgxy
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=params)
    loss = criterion(outputs, labelsg)
    gradsg = torch.autograd.grad(loss, hparams, create_graph=False)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=params_old)
    loss = criterion(outputs, labelsg)
    gradsg_old = torch.autograd.grad(loss, hparams_old, create_graph=False)

    # hgxz
    featsg = fhnet(imagesg, params=hparams)
    outputs = fcnet(featsg, params=zparams)
    loss = criterion(outputs, labelsg)
    hgz = torch.autograd.grad(loss, hparams, create_graph=False)

    featsg = fhnet(imagesg, params=hparams_old)
    outputs = fcnet(featsg, params=zparams_old)
    loss = criterion(outputs, labelsg)
    hgz_old = torch.autograd.grad(loss, hparams_old, create_graph=False)
    for j in range(len(hparams)):
        hfx_temp = grads[j] + (1 - eta) * (h_old[3][j] - grads_old[j])
        hgxy_temp = gradsg[j] + (1 - eta) * (h_old[4][j] - gradsg_old[j])
        hgxz_temp = hgz[j] + (1 - eta) * (h_old[5][j] - hgz_old[j])

        with torch.no_grad():
            hparams_old[j].data.copy_(hparams[j].data)
            hparams[j].data.sub_(alpha * (hfx_temp + lamb * (hgxy_temp - hgxz_temp)))

        grad_fx.append(hfx_temp)
        grad_gxy.append(hgxy_temp)
        grad_gxz.append(hgxz_temp)
    return grad_fx, grad_gxy, grad_gxz


if __name__ == "__main__":

    args = parse_args()
    print(args)

    torch.manual_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    data_train_loader, data_test_loader, hypernet, cnet = get_data_loaders_network(args)

    hypernet = hypernet.to(device)
    cnet = cnet.to(device)


    data_train_iter = iter(data_train_loader)
    data_test_iter = iter(data_test_loader)

    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).to(device)
    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]
    hparams_old = [hparam.clone().detach().requires_grad_(True) for hparam in hparams]

    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).to(device)
    params = list(cnet.parameters())
    params = [param.requires_grad_(True) for param in params]
    params_old = [param.clone().detach().requires_grad_(True) for param in params]
    zparams = [param.clone().detach().requires_grad_(True) for param in params]
    zparams_old = [param.clone().detach().requires_grad_(True) for param in params]


    # initialization
    h = [zparams, params, params, hparams, hparams, hparams]
    h_old = [zparams, params, params, hparams, hparams, hparams]

    init_params = []
    for param in params:
        init_params.append(torch.zeros_like(param))

    criterion = torch.nn.CrossEntropyLoss().to(device)
    outer_opt = torch.optim.Adam(hparams, lr=0.01)
    inner_opt = torch.optim.Adam(params)

    warm_start = True
    total_time = 0
    running_time, test_accs, test_losses, train_accs, train_losses = [], [], [], [], []
    loss_time_results = np.zeros((args.epochs, 5))

    start_time = time.time()

    for epoch in range(args.epochs):
        try:
            imagesf, labelsf = next(data_test_iter)
        except StopIteration:
            data_test_iter = iter(data_test_loader)
            imagesf, labelsf = next(data_test_iter)

        imagesf, labelsf = imagesf.to(device), labelsf.to(device)

        try:
            imagesg, labelsg = next(data_train_iter)
        except StopIteration:
            data_train_iter = iter(data_train_loader)
            imagesg, labelsg = next(data_train_iter)

        imagesg, labelsg = imagesg.to(device), labelsg.to(device)

        inner_opt.zero_grad()

        # update z
        hgz = z_update(imagesg, labelsg, hparams, zparams, hparams_old, zparams_old, h_old, args.eta, args.gamma)
        h_old[0], h[0] = h[0], hgz

        # update y
        hfy, hgy = y_update(imagesg, labelsg, imagesf, labelsf, hparams, params, hparams_old, params_old, args.alpha, args.eta,
                            args.lamb)
        h_old[1], h_old[2], h[1], h[2] = h[1], h[2], hfy, hgy

        # update x
        hfx, hgxy, hgxz = update_x(imagesg, labelsg, imagesf, labelsf, hparams, params, zparams,
                                   hparams_old, params_old, zparams_old, args.eta, args.lamb, args.alpha)
        h_old[3], h_old[4], h_old[5], h[3], h[4], h[5] = h[3], h[4], h[5], hfx, hgxy, hgxz
        args.lamb += 0.0001

        # Evaluate
        train_loss, train_acc = outer_loss(imagesg, labelsg, params, hparams, more=True)
        test_loss, test_acc = outer_loss(imagesf, labelsf, params, hparams, more=True)

        end_time = time.time()

        total_time = end_time - start_time
        running_time.append(total_time)
        train_accs.append(train_acc * 100)
        train_losses.append(train_loss.item())
        test_accs.append(test_acc * 100)
        test_losses.append(test_loss.item())

        loss_time_results[epoch, 0] = train_loss.item()
        loss_time_results[epoch, 1] = train_acc
        loss_time_results[epoch, 2] = test_loss.item()
        loss_time_results[epoch, 3] = test_acc
        loss_time_results[epoch, 4] = total_time

        if epoch % 5 == 0:
            print('Epoch: %d/%d, Train Loss: %f, Test Loss: %f, Train Accuracy: %f, Test Accuracy: %f, Running Time: %f' % (
                epoch, args.epochs, train_loss.item(), test_loss.item(), train_acc, test_acc, total_time))

    # print('Ended in {:.2e} seconds\n'.format(total_time))

    filename = str(args.seed) + '.pt'
    save_path = os.path.join(args.save_folder, filename)

    state_dict = {'runtime': running_time,
                  'train accuracy': train_accs,
                  'train loss': train_losses,
                  'test accuracy': test_accs,
                  'test loss': test_losses}
    torch.save(state_dict, save_path)
    # save_checkpoint(state_dict, save_path)

    file_name_2 = str(args.seed) + '.npy'
    file_addr = os.path.join(args.save_folder, file_name_2)
    with open(file_addr, 'wb') as f:
        np.save(f, loss_time_results)
