import torch
import torch.nn as nn
import time
import os
import sys
import torch.optim as optim
import argparse
import json
import copy
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms

sys.path.append("../../../src")
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from metrics import *
from models import ModelConstructor
from utils import *

from datasets import read_data_sets

from sparseml.pytorch.optim import (
    ScheduledModifierManager,
)
from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity

from FishLeg.src.optim.FishLeg import *

LOSSES = {
    "fixedgaussianlikelihood": FixedGaussianLikelihood,
    "crossentropy": nn.CrossEntropyLoss,
    "bernoullilikelihood": BernoulliLikelihood,
    "softmaxlikelihood": SoftMaxLikelihood,
}

ACCURACY = {"mse": MSE_accuracy, "top1": top1_accuracy, "top5": top5_accuracy}


OPTIMIZERS = {"adam": optim.Adam, "fishleg": FishLeg}

if __name__ == "__main__":
    cli_parser = argparse.ArgumentParser(
        description="SparseFL PyTorch Experimental Interface."
    )
    cli_parser.add_argument(
        "--config",
        "-c",
        type=str,
        help="File containing configuration for experiment.",
        default="../config/test.json",
    )
    args = cli_parser.parse_args()

    # Load in the config file.
    if args.config:
        try:
            with open(args.config) as json_file:
                config = json.load(json_file)
        except IOError as e:
            print("Input/Output error when loading in file:\n")
            raise
    else:
        raise ValueError("No config file provided.")

    seed = 13
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # Parse the arguments.
    parser = Parser(config)

    # Unpack the output arguments of the parser.
    exp_args, data_args, model_args, pretrain_args, compression_args = parser.parse()

    exp_info = ExperimentInfo(
        target=data_args["name"], exp_type=compression_args["type"], **exp_args
    )

    device = torch.device(
        "cuda" if torch.cuda.is_available() and exp_args["device"] else "cpu"
    )
    data_device = (
        torch.device("cuda") if data_args["data_on_cuda"] else torch.device("cpu")
    )

    model_builder = ModelConstructor(**model_args)
    pretrain_model, opt_state = model_builder.build()
    pretrain_model.to(device)
    print("Model Built")

    if model_args["architecture"]["name"] == "autoencoder":
        # Dataset
        dataset = read_data_sets(
            data_args["name"],
            exp_args["data_dir"],
            if_autoencoder=True,
            data_device=data_device,
        )
        train_dataset = dataset.train
        test_dataset = dataset.test

    elif model_args["architecture"]["name"] == "resnet":
        pretrain_model.eval()
        preprocess = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        train_dataset = ImageFolder(
            root=os.path.join(exp_args["data_dir"], "train"), transform=preprocess
        )
        test_dataset = ImageFolder(
            root=os.path.join(exp_args["data_dir"], "val"), transform=preprocess
        )

    print("Dataset Ready")

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=data_args["batch_size"], shuffle=True
    )

    initial_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True
    )

    aux_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=data_args["batch_size"],
        shuffle=True,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=data_args["batch_size"], shuffle=True
    )

    # save config
    out_file = open(os.path.join(exp_info.save_path, "config.json"), "w")
    json.dump(config, out_file, indent=6)
    out_file.close()

    ## PRETRAINING
    if (
        pretrain_args
        and not opt_state
        and not (model_args["architecture"]["name"] == "resnet")
    ):
        pretrain_scores, pretrain_opt_state = pretrain(
            pretrain_model,
            train_loader,
            aux_loader,
            test_loader,
            **pretrain_args,
            device=device,
            model_save_path=exp_info.save_path,
            opt_state=opt_state,
        )

        pretrain_model_state_dict = pretrain_model.state_dict()
        pretrain_opt_state_dict = pretrain_opt_state

        exp_info.log_pretraining(pretrain_scores)
    else:
        pretrain_model_state_dict = pretrain_model.state_dict()
        pretrain_opt_state_dict = opt_state

    # create ScheduledModifierManager and Optimizer wrapper
    recipe_path = compression_args.pop("recipe_path")
    manager_kwargs = {}

    model = copy.deepcopy(pretrain_model)

    loss_func = compression_args["loss_func"]
    compression_loss = LOSSES[loss_func["name"]](**loss_func["args"], device=device)

    accuracy_funcs = [ACCURACY[acc] for acc in compression_args["accuracy"]]

    # define for OBS/M-FAC pruner
    if compression_args["grad_sampler"]:

        def data_loader_builder(device=device, **kwargs):
            while True:
                for input, target in train_loader:
                    input, target = input.to(device, torch.float32), target.to(device)
                    yield [input], {}, target

        manager_kwargs["grad_sampler"] = {
            "data_loader_builder": data_loader_builder,
            "loss_function": compression_loss,
        }

    if compression_args["method"] == "FLS":

        def nll(model, data):
            data_x, data_y = data
            pred_y = model.forward(data_x)
            return compression_loss.nll(pred_y, data_y)

        def draw(model, data):
            data_x, data_y = data
            pred_y = model.forward(data_x)
            return (data_x, compression_loss.draw(pred_y))

        opt = FishLeg(
            model,
            draw,
            nll,
            aux_loader,
            likelihood=compression_loss,
            **compression_args["optimizer"]["args"],
            device=device,
            num_steps=compression_args["pretrain"]["iterations"] * 2,
            warmup_data=train_loader,
            warmup_loss=nll,
        )

        if compression_args["load_path"]:
            pretrain_model_path = compression_args["load_path"]
            checkpoint = torch.load(pretrain_model_path)
            model.load_state_dict(checkpoint["model_state_dict"])
            print("model loaded from ", pretrain_model_path)

        else:
            aux_losses = opt.pretrain_fish(
                train_loader,
                nll,
                testloader=test_loader,
                **compression_args["pretrain"],
            )

            save_path = os.path.join(exp_info.save_path, f"flmodel_checkpoint_best.pth")
            with open(save_path, mode="wb") as file_:
                torch.save(
                    {
                        "model_state_dict": model.state_dict(),
                    },
                    file_,
                )

    elif compression_args["optimizer"]:
        opt = OPTIMIZERS[compression_args["optimizer"]["name"].lower()](
            model.parameters(), **compression_args["optimizer"]["args"]
        )

    else:
        pass

    model.eval()

    if compression_args["type"] == "1shot":
        print("\n----------------------")
        print("One-shot Pruning")
        print("----------------------")
        model.train()
        scores = {}

        check_path = os.path.join(exp_info.save_path, "1shot-checkpoints")
        os.makedirs(check_path)
        os.chmod(check_path, 0o7777)

        # Get initial scores
        scores[0] = validation_scores(
            model,
            test_loader,
            compression_loss,
            accuracy_funcs,
            device,
        )
        scores[0]["time_taken"] = 0

        print("\n----------------------------")
        print("Dense Model Results")
        val_accs = ""
        for m, acc in enumerate(scores[0]["accuracy"]):
            val_accs += f"Val Accuracy {m}: {acc:.3f}, "
        print(
            f"Val Loss: {scores[0]['loss']:.3f}, {val_accs}Time taken: {scores[0]['time_taken']:.2f}s\n"
        )

        for sparsity in compression_args["sparsities"]:
            model.train()
            print("\n----------------------------")
            print(f"Sparsity {100*sparsity:.1f}%")
            model_sparse = copy.deepcopy(model)
            # create sparseml manager
            manager = ScheduledModifierManager.from_yaml(recipe_path)
            # update manager
            manager.modifiers[0].init_sparsity = sparsity
            manager.modifiers[0].final_sparsity = sparsity
            # record time
            st = time.time()
            # apply recipe
            manager.apply(model_sparse, **manager_kwargs, finalize=True)

            for name, layer in get_prunable_layers(model_sparse):
                print(
                    "{}.weight: {:.4f}".format(
                        name, tensor_sparsity(layer.weight).item()
                    )
                )
            # Record time
            et = time.time()

            scores[sparsity] = validation_scores(
                model_sparse,
                test_loader,
                compression_loss,
                accuracy_funcs,
                device,
            )

            scores[sparsity]["time_taken"] = et - st

            val_accs = ""
            for m, acc in enumerate(scores[sparsity]["accuracy"]):
                val_accs += f"Val Accuracy {m}: {acc:.3f}, "

            print(
                f"Val Loss: {scores[sparsity]['loss']:.3f}, {val_accs}Time taken: {scores[sparsity]['time_taken']:.2f}s\n"
            )

            with open(
                os.path.join(check_path, f"sparse-{str(int(sparsity*100))}_model.pth"),
                mode="wb",
            ) as file_:
                torch.save(
                    {
                        "model_state_dict": model_sparse.state_dict(),
                    },
                    file_,
                )
        print("\n----------------------------")

        exp_info.log_one_shot(scores)

    elif compression_args["type"] == "1shot+FT":
        if compression_args["ft_optimizer"]:
            opt = OPTIMIZERS[compression_args["ft_optimizer"]["name"].lower()](
                model.parameters(), **compression_args["ft_optimizer"]["args"]
            )

        print("\n----------------------------")
        print("One-shot Pruning + Finetuning")
        print("----------------------------")
        train_scores = {}
        val_scores = {}

        check_path = os.path.join(exp_info.save_path, "1shotFT-checkpoints")
        os.makedirs(check_path)
        os.chmod(check_path, 0o7777)

        train_scores[0] = run_one_epoch(
            0,
            model,
            train_loader,
            compression_loss,
            accuracy_funcs,
            device,
        )

        val_scores[0] = run_one_epoch(
            0,
            model,
            test_loader,
            compression_loss,
            accuracy_funcs,
            device,
        )

        manager = ScheduledModifierManager.from_yaml(recipe_path)

        opt = manager.modify(
            model, opt, steps_per_epoch=len(train_loader), **manager_kwargs
        )

        train_scores[0]["sparsity"] = 0.0
        val_scores[0]["sparsity"] = 0.0
        train_scores[0]["time_taken"] = 0

        train_accs = ""
        for m, acc in enumerate(train_scores[0]["accuracy"]):
            train_accs += f"Train Accuracy {m}: {acc:.3f}, "

        val_accs = ""
        for m, acc in enumerate(val_scores[0]["accuracy"]):
            val_accs += f"Val Accuracy {m}: {acc:.3f}, "

        print("--------------------")
        print("Epoch 0/{} - Sparsity: {:.2f}%".format(manager.max_epochs, 0.0))
        print(
            f"Train Loss: {train_scores[0]['loss']:.3f}, Test Loss: {val_scores[0]['loss']:.3f}, {train_accs}{val_accs}"
        )

        # Run model pruning
        start_epoch = manager.min_epochs
        for epoch in range(start_epoch, manager.max_epochs):
            # run training loop
            epoch_name = "{}/{}".format(epoch + 1, manager.max_epochs)

            st = time.time()

            train_scores[epoch + 1] = run_one_epoch(
                epoch + 1,
                model,
                train_loader,
                compression_loss,
                accuracy_funcs,
                device,
                train=True,
                optimizer=opt,
            )

            et = time.time() - st
            train_scores[epoch + 1]["time_taken"] = 0

            val_scores[epoch + 1] = run_one_epoch(
                epoch + 1, model, test_loader, compression_loss, accuracy_funcs, device
            )

            sparsity = get_current_sparsity(manager, epoch)

            zeros, total = 0, 0
            for name, layer in get_prunable_layers(model):
                zeros += (layer.weight == 0).sum()
                total += (layer.weight).numel()

            real_sparsity = zeros.float() / float(total)

            train_scores[epoch + 1]["sparsity"] = real_sparsity.item()
            val_scores[epoch + 1]["sparsity"] = real_sparsity.item()

            train_accs = ""
            for m, acc in enumerate(train_scores[epoch + 1]["accuracy"]):
                train_accs += f"Train Accuracy {m}: {acc:.3f}, "

            val_accs = ""
            for m, acc in enumerate(val_scores[epoch + 1]["accuracy"]):
                val_accs += f"Val Accuracy {m}: {acc:.3f}, "

            print("--------------------")
            print(
                "Epoch {} - Sparsity: {:.2f}%".format(epoch_name, 100 * real_sparsity)
            )
            print(
                f"Train Loss: {train_scores[epoch + 1]['loss']:.3f}, Test Loss: {val_scores[epoch + 1]['loss']:.3f}, {train_accs}{val_accs}"
            )

        with open(
            os.path.join(check_path, f"sparse-{str(int(sparsity*100))}FT_model.pth"),
            mode="wb",
        ) as file_:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                },
                file_,
            )

        print("--------------------")
        manager.finalize(model)

        exp_info.log_one_shot_FT(train_scores, val_scores)

    else:
        print("\n----------------------------")
        print("Gradual Pruning")
        print("----------------------------")
        train_scores = {}
        val_scores = {}

        check_path = os.path.join(exp_info.save_path, "gradual-checkpoints")
        os.makedirs(check_path)
        os.chmod(check_path, 0o7777)

        train_scores[0] = run_one_epoch(
            0,
            model,
            train_loader,
            compression_loss,
            accuracy_funcs,
            device,
        )

        val_scores[0] = run_one_epoch(
            0,
            model,
            test_loader,
            compression_loss,
            accuracy_funcs,
            device,
        )

        manager = ScheduledModifierManager.from_yaml(recipe_path)

        opt = manager.modify(
            model, opt, steps_per_epoch=len(train_loader), **manager_kwargs
        )

        train_scores[0]["time_taken"] = 0

        train_scores[0]["sparsity"] = 0.0
        val_scores[0]["sparsity"] = 0.0

        train_accs = ""
        for m, acc in enumerate(train_scores[0]["accuracy"]):
            train_accs += f"Train Accuracy {m}: {acc:.3f}, "

        val_accs = ""
        for m, acc in enumerate(val_scores[0]["accuracy"]):
            val_accs += f"Val Accuracy {m}: {acc:.3f}, "

        print("--------------------")
        print("Epoch 0/{} - Sparsity: {:.2f}%".format(manager.max_epochs, 0.0))
        print(
            f"Train Loss: {train_scores[0]['loss']:.3f}, Test Loss: {val_scores[0]['loss']:.3f}, {train_accs}{val_accs}"
        )

        # Run model pruning
        start_epoch = manager.min_epochs
        for epoch in range(start_epoch, manager.max_epochs):
            # run training loop
            epoch_name = "{}/{}".format(epoch + 1, manager.max_epochs)

            st = time.time()

            train_scores[epoch + 1] = run_one_epoch(
                epoch + 1,
                model,
                train_loader,
                compression_loss,
                accuracy_funcs,
                device,
                train=True,
                optimizer=opt,
            )

            et = time.time()
            train_scores[epoch + 1]["time_taken"] = et - st

            val_scores[epoch + 1] = run_one_epoch(
                epoch + 1, model, test_loader, compression_loss, accuracy_funcs, device
            )

            sparsity = get_current_sparsity(manager, epoch)

            zeros, total = 0, 0
            for name, layer in get_prunable_layers(model):
                zeros += (layer.weight == 0).sum()
                total += (layer.weight).numel()

            real_sparsity = zeros.float() / float(total)

            train_scores[epoch + 1]["sparsity"] = real_sparsity.item()
            val_scores[epoch + 1]["sparsity"] = real_sparsity.item()

            train_accs = ""
            for m, acc in enumerate(train_scores[epoch + 1]["accuracy"]):
                train_accs += f"Train Accuracy {m}: {acc:.3f}, "

            val_accs = ""
            for m, acc in enumerate(val_scores[epoch + 1]["accuracy"]):
                val_accs += f"Val Accuracy {m}: {acc:.3f}, "

            print("--------------------")
            print(
                "Epoch {} - Sparsity: {:.2f}%".format(epoch_name, 100 * real_sparsity)
            )
            print(
                f"Train Loss: {train_scores[epoch + 1]['loss']:.3f}, Test Loss: {val_scores[epoch + 1]['loss']:.3f}, {train_accs}{val_accs}"
            )

        with open(
            os.path.join(check_path, f"sparse-{str(int(sparsity*100))}_model.pth"),
            mode="wb",
        ) as file_:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                },
                file_,
            )

        print("--------------------")
        manager.finalize(model)

        exp_info.log_gradual(train_scores, val_scores)

    print("\nHave a great day!\n")
