import time

import numpy as np
import torch
from tqdm import tqdm

import utils_runs
import utils_wandb
import utils
import wandb
from augment import get_augmentation_pipeline
from dataloader import DLoader
from models import VAE_MNIST, VAE_CIFAR
from models_large import VAE_L


def train(
    model, dataloader, num_epochs, learning_rate, out_dir, run_name, chkpt_epochs, augmentation_pipeline
):
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

    train_nelbo_avg = []
    train_ll_avg = []
    train_kl_avg = []

    test_nelbo_avg = []
    train_naug_nelbo_avg = []
    test_ll_avg = []
    test_kl_avg = []

    print("Training ...")
    for epoch in range(num_epochs):
        train_dataloader, test_dataloader = dataloader.load_data()
        noAug_train_dataloader = dataloader.non_augmented_train_dataloader

        train_nelbo_avg.append(0)
        train_ll_avg.append(0)
        train_kl_avg.append(0)
        num_data = 0

        model.train()
        for data_batch in tqdm(train_dataloader, desc="Train One Epoch"):

            if type(data_batch) == list:
                image_batch, _ = data_batch
            else:
                image_batch = data_batch

            # apply augmentation
            if augmentation_pipeline is not None:
                image_batch = augmentation_pipeline(image_batch)[0]
                if "BinaryMNIST" in dataloader.dataset_name:
                    image_batch = image_batch.round()

            # reset gradients
            optimizer.zero_grad()

            # move data to device
            image_batch = image_batch.to(model.device)
            # forward pass
            loss, ll, kl = model(image_batch)

            # backpropagation
            loss.backward()

            # one step of the optimizer (using the gradients from backpropagation)
            optimizer.step()

            # update running averages
            train_nelbo_avg[-1] += image_batch.size(0) * loss.item()
            train_ll_avg[-1] += image_batch.size(0) * ll.item()
            train_kl_avg[-1] += image_batch.size(0) * kl.item()
            num_data += image_batch.size(0)

        train_nelbo_avg[-1] /= num_data
        train_ll_avg[-1] /= num_data
        train_kl_avg[-1] /= num_data

        naug_nelbo_avg, _, _ = test(model, noAug_train_dataloader)
        train_naug_nelbo_avg.append(naug_nelbo_avg)
        print(
            "Epoch [%d / %d] average negative ELBO: %f, LL: %f, KL: %f, noAug_nELBO: %f"
            % (
                epoch + 1,
                num_epochs,
                train_nelbo_avg[-1],
                train_ll_avg[-1],
                train_kl_avg[-1],
                train_naug_nelbo_avg[-1]
            )
        )

        nelbo_avg, ll_avg, kl_avg = test(model, test_dataloader)
        test_nelbo_avg.append(nelbo_avg)
        test_ll_avg.append(ll_avg)
        test_kl_avg.append(kl_avg)
        wandb.log(
            {
                "Train/ELBO": -train_nelbo_avg[-1],
                "Train/NonAugELBO": -train_naug_nelbo_avg[-1],
                "Train/LL": train_ll_avg[-1],
                "Train/KL": train_kl_avg[-1],
                "Test/ELBO": -test_nelbo_avg[-1],
                "Test/LL": test_ll_avg[-1],
                "Test/KL": test_kl_avg[-1],
            }
        )

        if (epoch + 1) in chkpt_epochs:
            utils_runs.save_model(out_dir, run_name, model, epoch + 1)

            # Wandb Plots
            utils_wandb.wandb_plot_recon_and_sample(model, test_dataloader, model.device)

    return (
        train_nelbo_avg,
        train_naug_nelbo_avg,
        train_ll_avg,
        train_kl_avg,
        test_nelbo_avg,
        test_ll_avg,
        test_kl_avg,
    )


def test(model, test_dataloader):
    model.eval()

    nelbo_avg = 0
    ll_avg = 0
    kl_avg = 0
    num_data = 0
    with torch.no_grad():
        for data_batch in tqdm(test_dataloader, desc="Test"):
            if type(data_batch) == list:
                image_batch, _ = data_batch
            else:
                image_batch = data_batch

            image_batch = image_batch.to(device)
            nElbo, ll, kl = model(image_batch)
            nelbo_avg += image_batch.size(0) * nElbo.item()
            ll_avg += image_batch.size(0) * ll.item()
            kl_avg += image_batch.size(0) * kl.item()
            num_data += image_batch.size(0)
        nelbo_avg /= num_data
        ll_avg /= num_data
        kl_avg /= num_data
        print(
            "Test average negative ELBO: %f, LL: %f, KL: %f"
            % (nelbo_avg, ll_avg, kl_avg)
        )

    return nelbo_avg, ll_avg, kl_avg


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    # convert command line arguments to dictionary
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0, help="random seed (default: 0)")

    parser.add_argument(
        "--dataset",
        choices=["BinaryMNIST", "MNIST", "CIFAR10", "Diffusion-MNIST", 
                "Diffusion-BinaryMNIST", "Diffusion-CIFAR10", "CIFAR10_Sub", "Diffusion-SVHN",
                "FashionMNIST", "Diffusion-FashionMNIST"],
        help="Choose the dataset for the experment.",
    )
    parser.add_argument(
        "--image_directory_cutoff_index",
        type=int,
        default=None,
        help="Specified the cutoff index for image directories of diffusion data.",
    )
    parser.add_argument(
        "--data_path",
        default=None,
        help="Specified the path for diffusion generated datasets.",
    )
    parser.add_argument('--resnet', action='store_true', help='Use VAE with ResNet for CIFAR10')
    parser.add_argument('--fc', action='store_true', help='Use VAE with Fully Connected Nets for MNIST')
    parser.add_argument(
        "--subset_portion", type=float, default=0, help="(CIFAR10_Sub) If use a subet of training set, need to provide its portion."
    )

    parser.add_argument(
        "--num_epochs", type=int, default=100, help="Number of epochs for training."
    )
    parser.add_argument(
        "--batch_size", type=int, default=64, help="Batch size for training."
    )
    parser.add_argument(
        "--learning_rate", type=float, default=1e-3, help="Learning rate."
    )
    parser.add_argument(
        "--resnet_channels", type=int, default=256, help="Number of channels for ResNet."
    )
    parser.add_argument(
        "--resnet_z_channels", type=int, default=2, help="Number of latent channels for ResNet."
    )
    parser.add_argument(
        "--n_c", type=int, default=32, help="Number of latent channels."
    )
    parser.add_argument(
        "--z_dims", type=int, default=20, help="z_dims."
    )

    parser.add_argument(
        "--num_chkpts", type=int, default=10, help="Number of checkpoint to save."
    )

    parser.add_argument("--run_name", default="", help="Specified the name of the run.")
    parser.add_argument(
        "--run_batch_name",
        default="singles",
        help="Specified the name of the batch for runs if doing a batch grid search etc.",
    )
    parser.add_argument(
        "--augment",
        type=float,
        default=0.0,
        help="Augmentation probability for dataset (EDM: 0.12)",
    )
    parser.add_argument(
        "--aug_quality",
        choices=["optimal", "optimal-MNIST", "bad", "worst", "edm"],
        default="optimal",
        help="Augmentation to apply, one of ['optimal', 'bad']",
    )
    parser.add_argument(
        "--wandb_entity",
        type=str,
        help="The entity for wandb.",
    )

    args = parser.parse_args()

    # Initialise wandb
    while (
        True
    ):  # A workaround for the `wandb.errors.UsageError: Error communicating with wandb process`
        try:
            wandb.init(
                project="DMaaPx",
                entity=args.wandb_entity,
                group=args.run_batch_name,
                settings=wandb.Settings(start_method="fork"),
            )
            break
        except:
            print("Retrying: wandb.init")
            time.sleep(5)
    wandb.config.update(args)

    config_string = (
        f"-Epochs_{args.num_epochs}"
        f"-BatchSize_{args.batch_size}"
        f"-{args.dataset}"
        f"-ResNet_{args.resnet}"
        f"-Seed_{args.seed}"
    )
    if args.dataset in ["CIFAR10", "Diffusion-CIFAR10", "CIFAR10_Sub", "SVHN", "Diffusion-SVHN"]:
        if args.resnet:
            config_string += f"-ResNetChannels_{args.resnet_channels}"
            config_string += f"-ResNetZChannels_{args.resnet_z_channels}"
        else:
            config_string += f"-n_c_{args.n_c}"
            config_string += f"-z_dims_{args.z_dims}"
    if args.dataset == "CIFAR10_Sub":
        config_string += f"-Portion_{args.subset_portion}"
    if args.augment > 0:
        config_string += f"-Augment_{args.augment}_{args.aug_quality}"
    if "MNIST" in args.dataset:
        config_string += f"-FC_{args.fc}"
    if args.image_directory_cutoff_index is not None:
        config_string += f"-ImgDirCutoff_{args.image_directory_cutoff_index}"

    if args.run_name == "":
        run_name = wandb.run.name + config_string
    else:
        run_name = args.run_name + config_string
        wandb.run.name = run_name
    print(run_name)
    wandb.config.update({"run_name": run_name}, allow_val_change=True)

    out_dir = "./runs/" + args.run_batch_name
    # Save training config
    utils_runs.save_train_config(out_dir, run_name, vars(args))

    chkpt_epochs = np.linspace(0, args.num_epochs, args.num_chkpts + 1, dtype=int)
    chkpt_epochs[0] = 1

    # Set random seeds
    utils.set_seed(args.seed)

    dataloader = DLoader(
        args.dataset,
        args.batch_size,
        args.seed,
        path=args.data_path,
        augment=args.augment,
        subset_portion=args.subset_portion,
        image_directory_cutoff_index=args.image_directory_cutoff_index,
    )

    if args.dataset in ["BinaryMNIST", "Diffusion-BinaryMNIST", "FashionMNIST", "Diffusion-FashionMNIST"]:
        if "Binary" in args.dataset:
            vae = VAE_MNIST("Bernoulli", device, grayscale=False, fc=args.fc)
        else:
            # vae = VAE_MNIST("MoL", device, grayscale=True, fc=args.fc)
            vae = VAE_MNIST("GaussianFixedSigma", device,  fc=args.fc)
    elif args.dataset in ["CIFAR10", "Diffusion-CIFAR10", "CIFAR10_Sub", "SVHN", "Diffusion-SVHN"]:
        if args.resnet:
            vae = VAE_L(device, channels=args.resnet_channels, z_channels=args.resnet_z_channels)
        else:
            vae = VAE_CIFAR("MoL", device, c=args.n_c, z_dims=args.z_dims)
    vae = vae.to(device)

    # Get augmentation pipeline (or None, if no augmentation probability is zero)
    augmentation_pipeline = get_augmentation_pipeline(args)

    (
        train_nelbo_avg,
        train_naug_nelbo_avg,
        train_ll_avg,
        train_kl_avg,
        test_nelbo_avg,
        test_ll_avg,
        test_kl_avg,
    ) = train(
        vae,
        dataloader,
        num_epochs=args.num_epochs,
        learning_rate=args.learning_rate,
        out_dir=out_dir,
        run_name=run_name,
        chkpt_epochs=chkpt_epochs,
        augmentation_pipeline=augmentation_pipeline
    )
    train_elbo_avg = [-x for x in train_nelbo_avg]
    train_naug_elbo_avg = [-x for x in train_naug_nelbo_avg]
    test_elbo_avg = [-x for x in test_nelbo_avg]

    # Save results
    results = {
        "train_elbo_avg": train_elbo_avg,
        "train_naug_elbo_avg": train_naug_elbo_avg,
        "train_ll_avg": train_ll_avg,
        "train_kl_avg": train_kl_avg,
        "test_elbo_avg": test_elbo_avg,
        "test_ll_avg": test_ll_avg,
        "test_kl_avg": test_kl_avg,
    }
    utils_runs.save_train_results_as_json(
        out_dir, results, run_name, "train_test_elbos"
    )

    utils_wandb.wandb_plot_elbos(results)
