import os
import sys
import time
from collections import namedtuple

import numpy as np
import torch
import torch.nn as nn
from torch import distributions as D
from tqdm import tqdm

from dataloader import DLoader
import utils_runs
import utils
import wandb
from models import VAE_MNIST, VAE_CIFAR
from models_large import VAE_L
import attack.trainer as atk_trainer


AttackArgs = namedtuple(
    "attack_args",
    [
        "type",
        "N_trg",
        "N_ref",
        "max_iter",
        "loss_type",
        "p",
        "eps_norm",
        "lr",
        "save_dir",
        "wandb",
    ],
)


def get_attack_arguments(args, save_dir):
    
    attack_args = AttackArgs(
        # attack type: either "unsupervised" or "supervised"
        type=args.attack_type,
        # number of target samples
        N_trg=args.attack_N_trg,
        # number of reference samples
        N_ref=args.attack_N_ref,
        # max iterations for optimization ("budget")
        max_iter=args.attack_max_iter,
        # attack loss type, i.e. "skl", "kl_forward", "kl_reverse", "means", "clf" (Note: "clf" is not implemented yet)
        loss_type=args.attack_loss_type,
        # norm to measure "epsilon-image" in
        p=args.attack_p,
        # maximum norm of "epsilon-image", measured in "p"-norm
        eps_norm=args.attack_eps_norm,
        # learning rate for optimization
        lr=args.attack_lrate,
        # directory to save results
        save_dir=save_dir,
        # whether to use wandb
        wandb=True,
    )

    return attack_args

def safe_repeat(x, n):
    return x.repeat(n, *[1 for _ in range(len(x.size()) - 1)])


def test(model, test_dataloader):
    model.eval()

    nelbo_avg = 0
    ll_avg = 0
    kl_avg = 0
    num_data = 0
    with torch.no_grad():
        for image_batch, _ in tqdm(test_dataloader, desc="Test"):
            image_batch = image_batch.to(device)
            nElbo, ll, kl = model(image_batch)
            nelbo_avg += image_batch.size(0) * nElbo.item()
            ll_avg += image_batch.size(0) * ll.item()
            kl_avg += image_batch.size(0) * kl.item()
            num_data += image_batch.size(0)
        nelbo_avg /= num_data
        ll_avg /= num_data
        kl_avg /= num_data
        print(
            "Test average negative ELBO: %f, LL: %f, KL: %f"
            % (nelbo_avg, ll_avg, kl_avg)
        )

    return nelbo_avg, ll_avg, kl_avg


def eval_elbo(model, data_batch, debug=False):
    model.eval()

    with torch.no_grad():
        data_batch = data_batch.to(device)
        nElbo, ll, kl = model(data_batch)
        if debug:
            print(f"ELBO: {-nElbo.item()}; LL: {ll.item()}; KL: {kl.item()}")

    return -nElbo.item(), ll.item(), kl.item()


def eval_amortisation_gap(
    model, data_var, k=100, check_every=100, sentinel_thres=10, debug=False
):
    """
    data_var should be (cuda) variable.

    Optimise mean and variance to get q*(z|x)
    k: number of samples to estimate the expectation
    """

    B = data_var.size()[0]
    z_dims = model.z_dims

    data_var = safe_repeat(data_var, k)
    # Initialise q*(z|x) parameters with q(z|x)
    with torch.no_grad():
        _, z_mu, z_std = model.q_z(data_var)
    qz_mu = nn.Parameter(z_mu, requires_grad=True)
    qz_logvar = nn.Parameter(z_std.pow(2).log(), requires_grad=True)

    optimizer = torch.optim.Adam([qz_mu, qz_logvar], lr=1e-3)
    best_avg, sentinel, prev_seq = 999999, 0, []

    # perform local opt
    time_ = time.time()
    for epoch in range(1, 999999):
        qz = D.normal.Normal(qz_mu, qz_logvar.mul(0.5).exp_())
        qz = D.independent.Independent(qz, 1)
        pz = D.normal.Normal(torch.zeros_like(qz_mu), torch.ones_like(qz_mu))
        pz = D.independent.Independent(pz, 1)

        # For: KL[q(z|x) || p(z)]
        kl = D.kl.kl_divergence(qz, pz)

        # For likelihood : <log p(x|y)>_q :
        z = (
            qz.rsample()
        )  # rsample() is reparameterised sample, i.e. allows backprop through
        z = z.view(z.size(0), -1)
        fz = model.p_x(z)
        logpx = model.loglikelihood_x_y(data_var, fz)

        elbo = logpx - kl

        optimizer.zero_grad()
        loss = -elbo.mean()
        loss_np = loss.data.cpu().numpy()
        loss.backward()
        optimizer.step()

        prev_seq.append(loss_np)
        if epoch % check_every == 0:
            last_avg = np.mean(prev_seq)
            if debug:  # debugging helper
                sys.stderr.write(
                    "Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n"
                    % (epoch, time.time() - time_, -last_avg, -best_avg)
                )
            if last_avg < best_avg:
                sentinel, best_avg = 0, last_avg
            else:
                sentinel += 1
            if sentinel > sentinel_thres:
                break
            prev_seq = []
            time_ = time.time()

    # evaluation
    with torch.no_grad():
        qz = D.normal.Normal(qz_mu, qz_logvar.mul(0.5).exp_())
        qz = D.independent.Independent(qz, 1)
        pz = D.normal.Normal(torch.zeros_like(qz_mu), torch.ones_like(qz_mu))
        pz = D.independent.Independent(pz, 1)

        # For: KL[q(z|x) || p(z)]
        kl = D.kl.kl_divergence(qz, pz)

        # For likelihood : <log p(x|y)>_q :
        z = qz.sample()
        z = z.view(z.size(0), -1)
        fz = model.p_x(z)
        logpx = model.loglikelihood_x_y(data_var, fz)

        elbo = logpx - kl

        vae_elbo = elbo.mean()
    # iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1)))

    return vae_elbo.item()


def eval_one_chkpt_elbos(model, test_dataloader, mc_num, chkpt_path, device, debug=False):
    model.load_state_dict(torch.load(chkpt_path, map_location=device))
    model.eval()

    elbo_avg = 0
    elbo_amort_avg = 0
    num_data = 0
    for image_batch, _ in tqdm(test_dataloader, desc="Eval Amort ELBO batches"):
        image_batch = image_batch.to(device)
        # 1. Get ELBO
        elbo, _, _ = eval_elbo(model, image_batch, debug=debug)

        # 2. Get Amortisation Gap
        elbo_amort = eval_amortisation_gap(model, image_batch.to(device), k=mc_num, debug=debug)

        elbo_avg += image_batch.size(0) * elbo
        elbo_amort_avg += image_batch.size(0) * elbo_amort
        num_data += image_batch.size(0)
    elbo_avg /= num_data
    elbo_amort_avg /= num_data


    if debug:
        print(f"ELBO: {elbo_avg}; ELBO Amort: {elbo_amort_avg}")

    return elbo_avg, elbo_amort_avg


def eval_list_chkpts_elbos(model, dataloader, mc_num, chkpt_list, eval_chkpt_epochs, device, seed, debug=False):
    elbo_list = []
    elbo_amort_list = []
    i = 0
    for chkpt_path in tqdm(chkpt_list, desc="Eval List chkpts"):
        utils.set_seed(seed)
        _, test_dataloader = dataloader.load_data()
        elbo, elbo_amort = eval_one_chkpt_elbos(
            model, test_dataloader, mc_num, chkpt_path, device, debug=debug
        )
        elbo_list.append(elbo)
        elbo_amort_list.append(elbo_amort)

        wandb.log({
            "Eval/Epoch": eval_chkpt_epochs[i],
            "Eval/ELBO_q": elbo, 
            "Eval/ELBO_q_star": elbo_amort
        })
        i += 1

    return elbo_list, elbo_amort_list


def eval_one_chkpt_atk(model, test_dataloader, args, chkpt_path, device, debug=False):
    model.load_state_dict(torch.load(chkpt_path, map_location=device))
    model.eval()

    logs, total_logs = atk_trainer.train(model, test_dataloader, args, device)

    if debug:
        print(logs)
        print(total_logs)

    return logs, total_logs


def eval_list_chkpts_atk(model, dataloader, atk_args, chkpt_list, eval_chkpt_epochs, device, seed, debug=False):
    logs_list = []
    total_logs_list = []
    i = 0
    for chkpt_path in tqdm(chkpt_list, desc="Eval List chkpts"):
        utils.set_seed(seed)
        _, test_dataloader = dataloader.load_data()
        logs, total_logs = eval_one_chkpt_atk(
            model, test_dataloader, atk_args, chkpt_path, device, debug=debug
        )
        logs_list.append(logs)
        total_logs_list.append(total_logs)

        wandb.log({
            "Eval/Epoch": eval_chkpt_epochs[i],
            "Eval/Av_ref_sim": total_logs[0]["Av_ref_sim"], 
            "Eval/Av_ref_rec_sim": total_logs[0]["Av_ref_rec_sim"],
            "Eval/Av_eps_norm": total_logs[0]["Av_eps_norm"],
            "Eval/Av_s_kl": total_logs[0]["Av_s_kl"],
            "Eval/Av_z_dist": total_logs[0]["Av_z_dist"]
        })
        i += 1


    return logs_list, total_logs_list


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    # ----------------------------------------------
    # convert command line arguments to dictionary
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0, help="random seed (default: 0)")

    parser.add_argument(
        "--batch_size", type=int, default=64, help="Batch size for training."
    )
    parser.add_argument(
        "--eval_chkpt_num", type=int, default=10, help="Num of checkpoints to be evaluated."
    )
    parser.add_argument(
        "--eval_one_chkpt", type=int, default=0, help="Evaluate one checkpoint of the given epoch."
    )
    parser.add_argument("--run_name", help="Specified the name of the run.")
    parser.add_argument(
        "--run_batch_name",
        help="Specified the name of the batch for runs if doing a batch grid search etc.",
    )

    # eval amortisation gap
    parser.add_argument(
        "--mc_num", type=int, default=10, help="Num of MC samples for estimating likelihood."
    )

    # eval attack
    parser.add_argument(
        "--eval_attack", action="store_true", help="evaluate the model robustness"
    )
    parser.add_argument(
        "--attack_type", 
        choices=["unsupervised", "supervised"],
        default="unsupervised",
        help="type of attack to perform"
    )
    parser.add_argument(
        "--attack_loss_type", 
        choices=["skl", "kl_forward", "kl_reverse", "means", "clf"],
        default="skl",
        help="type of loss utilized in attack"
    )
    parser.add_argument(
        "--attack_N_trg",
        required=False, 
        type=int, 
        help="number of training samples"
    )
    parser.add_argument(
        "--attack_N_ref",
        type=int,
        default=50,
        help="number of reference samples"
    )
    parser.add_argument(
        "--attack_p", 
        type=str,
        default="inf",
        help="p-norm to utilize in attack [0, 1, 2, 'inf']"
    )
    parser.add_argument(
        "--attack_eps_norm", 
        type=float, 
        default=0.2,
        help="epsilon norm for optimization objective e.g., 0.1, 0.2, 0.3"
    )
    parser.add_argument(
        "--attack_lrate",
        type=float,
        default=1.,
        help="learning rate for construction of the attack",
    )
    parser.add_argument(
        "--attack_max_iter",
        type=int,
        default=50,
        help="maximum number of iterations to look for an adversarial example",
    )
    # ------------

    parser.add_argument(
        "--wandb_entity",
        type=str,
        help="The entity for wandb.",
    )
    # ----------------------------------------------

    args = parser.parse_args()

    config = utils_runs.load_train_config(args.run_batch_name, args.run_name)

    # Initialise wandb
    while (
        True
    ):  # A workaround for the `wandb.errors.UsageError: Error communicating with wandb process`
        try:
            wandb.init(
                project="DMaaPx",
                entity=args.wandb_entity,
                group=args.run_batch_name,
                settings=wandb.Settings(start_method="fork"),
            )
            break
        except:
            print("Retrying: wandb.init")
            time.sleep(5)
    wandb.config.update(args)

    config_string = (
        f"-EvalSettings"
        f"-BatchSize_{args.batch_size}"
        f"-EvalChkptNum_{args.eval_chkpt_num}"
        f"-EvalOneChkpt_{args.eval_one_chkpt}"
        f"-Seed_{args.seed}"
    )

    if args.eval_attack:
        config_string += (
            f"-AtkType_{args.attack_type}"
            f"-AtkLossType_{args.attack_loss_type}"
            f"-AtkNRef_{args.attack_N_ref}"
            f"-AtkP_{args.attack_p}"
            f"-AtkEpsNorm_{args.attack_eps_norm}"
            f"-AtkLRate_{args.attack_lrate}"
            f"-AtkMaxIter_{args.attack_max_iter}"
        )
        if args.attack_type == "supervised":
            config_string += f"-AtkNTrg_{args.attack_N_trg}"
    else:
        config_string += f"-MCNum_{args.mc_num}" 

    if args.eval_attack:
        wandb.run.name = "Eval-Atk-" + args.run_name + config_string
        wandb.config.update(
            {"run_name": "Eval-Atk-" + args.run_name + config_string}, allow_val_change=True
        )
    else:
        wandb.run.name = "Eval-Gaps-" + args.run_name + config_string
        wandb.config.update(
            {"run_name": "Eval-Gaps-" + args.run_name + config_string}, allow_val_change=True
        )

    out_dir = "./runs/" + args.run_batch_name
    run_dir = os.path.join("./runs", args.run_batch_name, args.run_name)

    if args.eval_attack:
        save_dir = os.path.join(run_dir, "attack", config_string[1:])
        attack_args = get_attack_arguments(args, save_dir)

    chkpt_list = utils_runs.get_chkpt_list_from_run_dir(run_dir)
    if args.eval_one_chkpt == 0:
        assert len(chkpt_list)-1 >= args.eval_chkpt_num
        eval_chkpt_list_idx = np.linspace(
            0, len(chkpt_list) - 1, args.eval_chkpt_num + 1, dtype=int
        )[1:]
        eval_chkpt_list = [chkpt_list[i] for i in eval_chkpt_list_idx]
        eval_chkpt_epochs = [int(x.split("_")[-1][:-3]) for x in eval_chkpt_list]
        print("Eval Chkpt Epochs: ", eval_chkpt_epochs)
    else:
        chkpt_epochs = [int(x.split("_")[-1][:-3]) for x in chkpt_list]
        chkpt_idx = chkpt_epochs.index(args.eval_one_chkpt)
        eval_chkpt_list = [chkpt_list[chkpt_idx]]
        eval_chkpt_epochs = [chkpt_epochs[chkpt_idx]]
        print("Eval Chkpt Epochs: ", eval_chkpt_epochs)

    # Set random seeds
    utils.set_seed(args.seed)

    # dataset_name, batch_size, seed, train_transform=None,
    # path=None, augment=0.0, subset_portion=1.0
    dataloader = DLoader(
        config["dataset"], args.batch_size, args.seed, path=config["data_path"],
    )
    

    if config["dataset"] in ["BinaryMNIST", "Diffusion-BinaryMNIST", "FashionMNIST", "Diffusion-FashionMNIST"]:
        if "fc" in config and config["fc"]:
            used_fc = True
        else:
            used_fc = False
        if "Binary" in config["dataset"]:
            ll_family = "Bernoulli"
            used_grayscale = False
        else:
            # ll_family = "MoL"
            # used_grayscale = True
            ll_family = "GaussianFixedSigma"
            used_grayscale = False
        model = VAE_MNIST(ll_family, device, grayscale=used_grayscale, fc=used_fc)
    elif config["dataset"] in ["CIFAR10", "Diffusion-CIFAR10", "SVHN", "Diffusion-SVHN"]:
        if "resnet" in config and config["resnet"]:
            if "resnet_channels" in config and config["resnet_channels"]:
                model = VAE_L(device, channels=config["resnet_channels"], z_channels=config["resnet_z_channels"])
            else:
                model = VAE_L(device)
        else:
            if "n_c" in config and config["n_c"]:
                model = VAE_CIFAR("MoL", device, c=config["n_c"], z_dims=config["z_dims"])
            else:
                model = VAE_CIFAR("MoL", device)
    model.to(device)

    if args.eval_attack:
        logs_list, total_logs_list = eval_list_chkpts_atk(
            model, dataloader, attack_args, eval_chkpt_list, eval_chkpt_epochs, device, args.seed, debug=False
        )
        # Save results
        results = {
            "epochs": eval_chkpt_epochs,
            "logs_list": logs_list,
            "total_logs_list": total_logs_list,
        }
        if args.eval_one_chkpt == 0:
            save_name = "eval_atk" + config_string
        else:
            save_name = "eval_atk_one_chkpt" + config_string + f"-Epoch_{args.eval_one_chkpt}"
    else:
        elbo_list, elbo_amor_list = eval_list_chkpts_elbos(
            model, dataloader, args.mc_num, eval_chkpt_list, eval_chkpt_epochs, device, args.seed, debug=False
        )
        # Save results
        results = {
            "epochs": eval_chkpt_epochs,
            "elbo_list": elbo_list,
            "elbo_amor_list": elbo_amor_list,
        }
        if args.eval_one_chkpt == 0:
            save_name = "eval_elbo_and_gaps" + config_string
        else:
            save_name = "eval_elbo_and_gaps_one_chkpt" + config_string + f"-Epoch_{args.eval_one_chkpt}"

    

    utils_runs.save_train_results_as_json(
        out_dir, results, args.run_name, save_name
    )
