import os
from networks import LeNet5Feats, classifier
import torch
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
import argparse
import higher
import time
from torchvision.datasets import FashionMNIST
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = '1'
torch.random.manual_seed(123)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
parser = argparse.ArgumentParser(description='Bilevel Training')
parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'FashionMNIST'])
parser.add_argument('--data', type=str, default='./data')
args = parser.parse_args()

# Load dataset
def load_data(dataset_name):
    if dataset_name == 'MNIST':
        full_train_data = MNIST(args.data, train=True, download=True, transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))
    else:
        full_train_data = FashionMNIST(args.data, train=True, download=True, transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ]))

    train_data, val_data = random_split(full_train_data, [50000, 10000])

    test_data = MNIST(args.data, train=False, download=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])) if dataset_name == 'MNIST' else FashionMNIST(args.data, train=False, download=True,
                                                     transform=transforms.Compose([
                                                         transforms.Resize((32, 32)),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.2860,), (0.3530,))
                                                     ]))

    data_train_loader = DataLoader(train_data, batch_size=256, shuffle=True, num_workers=4)
    data_val_loader = DataLoader(val_data, batch_size=256, shuffle=False, num_workers=4)
    data_test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=4)

    return data_train_loader, data_val_loader, data_test_loader


# Function to evaluate the model performance



# Define the experiment
def run_experiment(experiment_id, data_train_loader, data_val_loader, data_test_loader):
    hypernet = LeNet5Feats().cuda()
    cnet = classifier(n_features=84, n_classes=10).cuda()
    data_train_iter = iter(data_train_loader)
    data_val_iter = iter(data_val_loader)
    criterion = torch.nn.CrossEntropyLoss().cuda()


    fhnet = higher.monkeypatch(hypernet, copy_initial_weights=True).cuda()
    fcnet = higher.monkeypatch(cnet, copy_initial_weights=True).cuda()

    hparams = list(hypernet.parameters())
    hparams = [hparam.requires_grad_(True) for hparam in hparams]
    params_y = [p.detach().clone().requires_grad_(True) for p in cnet.parameters()]
    params_z = [p.detach().clone().requires_grad_(True) for p in cnet.parameters()]

    def outer_loss(params, hparams, images=None, labels=None, more=False):
        nonlocal data_val_iter
        if images is None or labels is None:
            try:
                images, labels = next(data_val_iter)
            except StopIteration:
                data_val_iter = iter(data_val_loader)
                images, labels = next(data_val_iter)
        images, labels = images.cuda(), labels.cuda()
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss = criterion(outputs, labels)

        preds = outputs.data.max(1)[1]
        correct = preds.eq(labels.data.view_as(preds)).sum()
        acc = float(correct) / labels.size(0)

        return (loss, acc) if more else loss

    def inner_loss(params, hparams, images=None, labels=None, more=False):
        nonlocal data_train_iter
        if images is None or labels is None:
            try:
                images, labels = next(data_train_iter)
            except StopIteration:
                data_train_iter = iter(data_train_loader)
                images, labels = next(data_train_iter)
        images, labels = images.cuda(), labels.cuda()
        feats = fhnet(images, params=hparams)
        outputs = fcnet(feats, params=params)
        loss = criterion(outputs, labels)

        preds = outputs.data.max(1)[1]
        correct = preds.eq(labels.data.view_as(preds)).sum()
        acc = float(correct) / labels.size(0)

        return (loss, acc) if more else loss

    def model_dot(param_1, param_2):
        return sum(torch.sum(p1 * p2) for p1, p2 in zip(param_1, param_2))

    def model_norm_sq(param):
        return sum(torch.sum(p ** 2) for p in param)

    def aggeration_loss(params_y, params_z, hparams, rho, sigma, inner_batch, outer_batch, more=False):
        inner_imgs, inner_labels = inner_batch
        outer_imgs, outer_labels = outer_batch

        loss = outer_loss(params_y, hparams, outer_imgs, outer_labels) \
               - rho * (inner_loss(params_y, hparams, inner_imgs, inner_labels) - inner_loss(params_z, hparams,
                                                                                             inner_imgs, inner_labels)) \
               - sigma * model_dot(params_z, params_y) \
               + sigma / 2 * model_norm_sq(params_z)
        if more:
            train_loss, train_acc = inner_loss(params_y, hparams, inner_imgs, inner_labels, more=True)
            val_loss, val_acc = outer_loss(params_y, hparams, outer_imgs, outer_labels, more=True)
            return train_loss, train_acc, val_loss, val_acc
        else:
            return loss

    def evaluate(params, hparams, data_test_loader):
        fhnet.eval()
        fcnet.eval()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for images, labels in data_test_loader:
                images, labels = images.cuda(), labels.cuda()
                feats = fhnet(images, params=hparams)
                outputs = fcnet(feats, params=params)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * images.size(0)
                preds = outputs.argmax(dim=1)
                total_correct += preds.eq(labels).sum().item()
                total_samples += labels.size(0)
        avg_loss = total_loss / total_samples
        accuracy = total_correct / total_samples
        return avg_loss, accuracy






    total_time = 0
    steps = 20000
    running_time, test_accs, test_losses = [], [], []

    loss, acc = evaluate(params_y, hparams, data_test_loader)
    running_time.append(total_time)
    test_accs.append(acc)
    test_losses.append(loss)
    print(f'Experiment {experiment_id} - initial: loss={loss}, acc={acc}, time={total_time}')

    p, q = 0.01, 0.01
    for step in range(1, steps + 1):
        alpha = 0.01 * (step + 1) ** (-8 * p - 8 * q)
        beta = 0.01 * (step + 1) ** (-2 * p - q)
        rho = 10 * (step + 1) ** p
        sigma = 0.1 * (step + 1) ** (-q)

        try:
            inner_batch = next(data_train_iter)
        except StopIteration:
            data_train_iter = iter(data_train_loader)
            inner_batch = next(data_train_iter)

        try:
            outer_batch = next(data_val_iter)
        except StopIteration:
            data_val_iter = iter(data_val_loader)
            outer_batch = next(data_val_iter)

        inner_batch = (inner_batch[0].cuda(), inner_batch[1].cuda())
        outer_batch = (outer_batch[0].cuda(), outer_batch[1].cuda())

        time_start = time.time()

        loss = aggeration_loss(params_y, params_z, hparams, rho, sigma, inner_batch, outer_batch)
        grads_y = torch.autograd.grad(loss, params_y, retain_graph=True)
        grads_z = torch.autograd.grad(loss, params_z, retain_graph=True)

        with torch.no_grad():
            for param, grad in zip(params_y, grads_y):
                param.data += beta * grad
            for param, grad in zip(params_z, grads_z):
                param.data -= beta * grad

        loss = aggeration_loss(params_y, params_z, hparams, rho, sigma, inner_batch, outer_batch)
        grads_hparams = torch.autograd.grad(loss, hparams, retain_graph=False)
        with torch.no_grad():
            for param, grad in zip(hparams, grads_hparams):
                param.data -= alpha * grad

        time_end = time.time()
        total_time += time_end - time_start

        if step % 500 == 0 or total_time > 300:
            loss, acc = evaluate(params_y, hparams, data_test_loader)
            running_time.append(total_time)
            test_accs.append(acc)
            test_losses.append(loss)
            print(f' step={step}, loss={loss}, acc={acc}, time={total_time}')
            if total_time > 300:
                break

    return running_time, test_accs, test_losses


# Run 5 experiments and collect statistics
final_results = {
    "mean_loss": [],
    "std_loss": [],
    "mean_acc": [],
    "std_acc": [],
    "mean_time": [],
    "std_time": []
}

for exp_id in range(1, 11):
    print(f"Running experiment {exp_id}...")

    data_train_loader, data_val_loader, data_test_loader = load_data(args.dataset)
    running_time, test_accs, test_losses = run_experiment(exp_id, data_train_loader, data_val_loader, data_test_loader)

    final_results["mean_loss"].append(test_losses)
    final_results["std_loss"].append(test_losses)
    final_results["mean_acc"].append(test_accs)
    final_results["std_acc"].append(test_accs)
    final_results["mean_time"].append(running_time)
    final_results["std_time"].append(running_time)

np.save(f"result/{args.dataset}_averaged_results_SiPBA.npy",final_results)

print("Experiment finished!")
