import argparse
import datetime
import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import rich
import sklearn.metrics
import torch
import torch.utils.data

from tqdm import tqdm

import oodd
import oodd.models
import oodd.datasets
import oodd.variational
import oodd.losses

from oodd.utils import str2bool, get_device, log_sum_exp, set_seed, plot_gallery
from oodd.evaluators import Evaluator
from oodd.utils.argparsing import json_file_or_json_unique_keys
from oodd.utils.argparsing import str2bool, json_file_or_json

from sklearn.mixture import GaussianMixture as GMM
import copy
import time
from torch.distributions.normal import Normal
from sklearn.manifold import TSNE
import random

LOGGER = logging.getLogger(name=__file__)


try:
    import wandb
    wandb_available = True
except ImportError:
    LOGGER.warning("Running without remote tracking!")
    wandb_available = False

train_datasets = '{ "CIFAR10Dequantized": {"dynamic": true, "split": "train"}}'
val_datasets = '{"CIFAR10Dequantized": {"dynamic": false, "split": "validation"}, "SVHNDequantized": {"dynamic": false, "split": "validation"}}'
likelihood  = 'DiscretizedLogisticMixLikelihoodConv2d'
pretrained_model = './models/3-layer-CIFAR10-VAE-DC-16-32-728-BETA-0.1-BEST'
layer_index = 2
# train_datasets = '{ "FashionMNISTDequantized": {"dynamic": true, "split": "train"}}'
# val_datasets = '{"FashionMNISTDequantized": {"dynamic": false, "split": "validation"},\
#     "MNISTDequantized": {"dynamic": false, "split": "validation"}}'
# likelihood  = 'DiscretizedLogisticLikelihoodConv2d'
# pretrained_model = './models/5-layer-fashionmnist-vae-dc-epoch25'

config_deterministic = '[' \
                       '[{"block": "ResBlockConv2d", "out_channels": 128, "kernel_size": 5, "stride": 2, "weightnorm": true, "gated": false},' \
                       '{"block": "ResBlockConv2d", "out_channels": 128, "kernel_size": 5, "stride": 2, "weightnorm": true, "gated": false},        ' \
                       '{"block": "ResBlockConv2d", "out_channels": 64, "kernel_size": 5, "stride": 1, "weightnorm": true, "gated": false},        ' \
                       '{"block": "ResBlockConv2d", "out_channels": 32, "kernel_size": 5, "stride": 1, "weightnorm": true, "gated": false},        ' \
                       '{"block": "ResBlockConv2d", "out_channels": 16, "kernel_size": 5, "stride": 1, "weightnorm": true, "gated": false}' \
                       ']' \
                       ']'
                       # '[{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false},' \
                       # '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false},        ' \
                       # '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 2, "weightnorm": true, "gated": false}' \
                       # '], ' \
                       # '[{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false}, ' \
                       # '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false},        ' \
                       # '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 2, "weightnorm": true, "gated": false} ' \
                       # '], ' \
                       #  '[{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false}, ' \
                       #  '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false},        ' \
                       #  '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false} ' \
                       #  '], ' \
                       #  '[{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false}, ' \
                       #  '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false},        ' \
                       #  '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false} ' \
                       #  '] ' \
                       #  ']'
                       # '[{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false}, ' \
                       # '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false},        ' \
                       # '{"block": "ResBlockConv2d", "out_channels": 256, "kernel_size": 3, "stride": 1, "weightnorm": true, "gated": false} ] ' \


config_stochastic = '[' \
    '{"block": "GaussianConv2d", "latent_features": 4, "weightnorm": true}'\
']'\
    # '{"block": "GaussianConv2d", "latent_features": 64, "weightnorm": true},' \
    # '{"block": "GaussianDense", "latent_features": 32, "weightnorm": true},' \
    # '{"block": "GaussianDense", "latent_features": 32, "weightnorm": true},' \
    # '{"block": "GaussianDense", "latent_features": 32, "weightnorm": true}' \
    # ']'
    # '{"block": "GaussianDense", "latent_features": 8, "weightnorm": true}' \
    # '{"block": "GaussianDense", "latent_features": 16, "weightnorm": true}' \
# 128 64 32
n_latents = 1
sacale = False
device = get_device(0)

parser = argparse.ArgumentParser(description="VAE MNIST Example")
parser.add_argument("--model", default="VAE", help="model type (VAE | LVAE | BIVA)")
# parser.add_argument("--model_dir", type=str, default="./models/VAEdc_CIFAR10Dequantized-2022-03-13-00-44-02.083833", help="model")
parser.add_argument("--model_dir", type=str, default=pretrained_model, help="model")
parser.add_argument("--augment", default="False", help="model type ( False | dc | dc-policy)")
parser.add_argument("--warmup_epochs", type=int, default=0, help="epochs to warm up the KL term.")
parser.add_argument("--kl_weight", type=float, default=1e-5, help="fixes_kl_weigth, if not use warmup.")
parser.add_argument("--gamma", type=float, default=0.9, help="fixes_discount, if use policy gradient.")

parser.add_argument("--train_datasets", type=json_file_or_json_unique_keys, default=train_datasets)
parser.add_argument("--val_datasets", type=json_file_or_json_unique_keys, default=val_datasets)
parser.add_argument("--test_datasets", type=json_file_or_json_unique_keys, default=[])

parser.add_argument("--config_deterministic", type=json_file_or_json, default=config_deterministic, help="")
parser.add_argument("--config_stochastic", type=json_file_or_json, default=config_stochastic, help="")

parser.add_argument("--epochs", type=int, default=1000, help="number of epochs to train")
parser.add_argument("--learning_rate", type=float, default=3e-4, help="learning rate")
parser.add_argument("--train_samples", type=int, default=1, help="samples from approximate posterior")
parser.add_argument("--test_samples", type=int, default=1, help="samples from approximate posterior")
parser.add_argument("--train_importance_weighted", type=str2bool, default=False, const=True, nargs="?", help="use iw bound")
parser.add_argument("--test_importance_weighted", type=str2bool, default=False, const=True, nargs="?", help="use iw bound")

parser.add_argument("--free_nats_epochs", type=int, default=0, help="epochs to warm up the KL term.")
parser.add_argument("--free_nats", type=float, default=0, help="nats considered free in the KL term")
parser.add_argument("--n_eval_samples", type=int, default=32, help="samples from prior for quality inspection")
parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed")
parser.add_argument("--test_every", type=int, default=1, help="epochs between evaluations")
parser.add_argument("--save_dir", type=str, default="./models", help="directory for saving models")
parser.add_argument("--use_wandb", type=str2bool, default=True, help="use wandb tracking")
parser.add_argument("--name", type=str, default=True, help="wandb tracking name")
parser = oodd.datasets.DataModule.get_argparser(parents=[parser])

args, unknown_args = parser.parse_known_args()

args.start_time = str(datetime.datetime.now()).replace(" ", "-").replace(":", "-")
args.train_sample_reduction = log_sum_exp if args.train_importance_weighted else torch.mean
args.test_sample_reduction = log_sum_exp if args.test_importance_weighted else torch.mean
args.use_wandb = wandb_available and args.use_wandb

set_seed(args.seed)


@torch.no_grad()
def test(epoch, dataloader, evaluator, dataset_name="test", max_test_examples=float("inf")):
    LOGGER.info(f"Testing: {dataset_name}")
    model.eval()

    x, _ = next(iter(dataloader))
    x = x.to(device)
    n = min(x.size(0), 8)
    likelihood_data, stage_datas, skip_likelihood = model(x, n_posterior_samples=args.test_samples)
    p_x_mean = likelihood_data.mean[: args.batch_size].view(args.batch_size, *in_shape)  # Reshape zeroth "sample"
    p_x_samples = likelihood_data.samples[: args.batch_size].view(args.batch_size, *in_shape)  # Reshape zeroth "sample"
    comparison = torch.cat([x[:n], p_x_mean[:n], p_x_samples[:n]])
    comparison = comparison.permute(0, 2, 3, 1)  # [B, H, W, C]
    fig, ax = plot_gallery(comparison.cpu().numpy(), ncols=n)
    fig.savefig(os.path.join(args.save_dir, f"reconstructions_{dataset_name}_{epoch:03}"))
    plt.close()

    decode_from_p_combinations = [[True] * n_p + [False] * (model.n_latents - n_p) for n_p in range(model.n_latents)]
    for decode_from_p in tqdm(decode_from_p_combinations, leave=False):
        n_skipped_latents = sum(decode_from_p)

        if max_test_examples != float("inf"):
            iterator = tqdm(
                zip(range(max_test_examples // dataloader.batch_size), dataloader),
                smoothing=0.9,
                total=max_test_examples // dataloader.batch_size,
                leave=False,
            )
        else:
            iterator = tqdm(enumerate(dataloader), smoothing=0.9, total=len(dataloader), leave=False)

        for _, (x, _) in iterator:
            x = x.to(device)

            likelihood_data, stage_datas, skip_recon_likelihood = model(
                x, n_posterior_samples=args.test_samples, decode_from_p=decode_from_p, use_mode=decode_from_p
            )
            kl_divergences = [
                stage_data.loss.kl_elementwise
                for stage_data in stage_datas
                if stage_data.loss.kl_elementwise is not None
            ]
            loss, elbo, likelihood, kl_divergences = criterion(
                likelihood_data.likelihood,
                kl_divergences,
                samples=args.test_samples,
                free_nats=0,
                beta=1,
                sample_reduction=args.test_sample_reduction,
                batch_reduction=None,
            )

            if n_skipped_latents == 0:  # Regular ELBO
                evaluator.update(dataset_name, "elbo", {"log p(x)": elbo})
                evaluator.update(
                    dataset_name, "likelihoods", {"loss": -loss, "log p(x)": elbo, "log p(x|z)": likelihood}
                )
                klds = {
                    f"KL z{i+1}": kl
                    for i, kl in enumerate(
                        [sd.loss.kl_samplewise for sd in stage_datas if sd.loss.kl_samplewise is not None]
                    )
                }
                klds["KL(q(z|x), p(z))"] = kl_divergences
                evaluator.update(dataset_name, "divergences", klds)

            evaluator.update(dataset_name, f"skip-elbo", {f"{n_skipped_latents} log p(x)": elbo})
            evaluator.update(dataset_name, f"skip-elbo-{dataset_name}", {f"{n_skipped_latents} log p(x)": elbo})
            evaluator.update(
                dataset_name,
                f"skip-likelihoods-{dataset_name}",
                {
                    f"{n_skipped_latents} loss": -loss,
                    f"{n_skipped_latents} log p(x)": elbo,
                    f"{n_skipped_latents} log p(x|z)": likelihood,
                },
            )
            klds = {
                f"{n_skipped_latents} KL z{i+1}": kl
                for i, kl in enumerate(
                    [sd.loss.kl_samplewise for sd in stage_datas if sd.loss.kl_samplewise is not None]
                )
            }
            klds[f"{n_skipped_latents} KL(q(z|x), p(z))"] = kl_divergences
            evaluator.update(dataset_name, f"skip-divergences-{dataset_name}", klds)


def get_posterior_mu(model, dataloader, layer_index, max_test_examples=float("inf")):
    z_data = []
    if max_test_examples != float("inf"):
        iterator = tqdm(
            zip(range(max_test_examples // dataloader.batch_size), dataloader),
            smoothing=0.9,
            total=max_test_examples // dataloader.batch_size,
            leave=False,
        )
    else:
        iterator = tqdm(enumerate(dataloader), smoothing=0.9, total=len(dataloader), leave=False)

    for idx, (x, _) in iterator:
        x = x.to(device)
        with torch.no_grad():
            x = x.repeat(args.train_samples, *((1,) * (x.ndim - 1)))  # Posterior samples
            if model.lambda_init is not None:
                x = model.lambda_init(x)
            posteriors = model.infer(x)
            if idx == len(iterator) - 1:
                break
            layer_posterior_mu = posteriors[layer_index].mean.reshape(x.shape[0], -1)
            z_data.append(layer_posterior_mu.cpu().numpy())
        # for sample_i in range(posteriors[-1].z.shape[0]):
        #     # sample_i_z = posteriors[-1].z[sample_i].cpu().numpy()
        #     sample_i_z = posteriors[-1].mean[sample_i].cpu().numpy()
        #     # sample_i_z = sample_i_z / np.sqrt(np.sum(sample_i_z**2))
        #     # sample_i_z =
        #     z_data.append(sample_i_z)
    z_data = np.stack(z_data).reshape(-1, layer_posterior_mu.shape[-1])

    return z_data


def collapse_multiclass_to_binary(y_true, zero_label=None):
    # Force the class index in zero_label to be zero and the others to collapse to 1
    zero_label_indices = y_true == zero_label
    y_true[zero_label_indices] = 0
    y_true[~zero_label_indices] = 1
    return y_true


def compute_roc_auc(y_true=None, y_score=None, zero_label=None):
    """Compares class zero_label to all other classes in y_true"""
    y_true = collapse_multiclass_to_binary(y_true, zero_label)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true, y_score)
    roc_auc = sklearn.metrics.roc_auc_score(y_true, y_score, average="macro")
    return roc_auc, fpr, tpr, thresholds


def compute_pr_auc(y_true=None, y_score=None, zero_label=None):
    """Compares class zero_label to all other classes in y_true"""
    y_true = collapse_multiclass_to_binary(y_true, zero_label)
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true, y_score)
    pr_auc = sklearn.metrics.average_precision_score(y_true, y_score, average="macro")
    return pr_auc, precision, recall, thresholds


def compute_roc_pr_metrics(y_true, y_score, classes, reference_class):
    """Compute the ROC and PR metrics from a primary dataset class to a number of other dataset classes"""
    roc_results = {}
    pr_results = {}
    for class_label in sorted(set(y_true) - set([reference_class])):
        idx = np.logical_or(y_true == reference_class, y_true == class_label)  # Compare primary to the other dataset

        roc_auc, fpr, tpr, thresholds = compute_roc_auc(
            y_true=y_true[idx], y_score=y_score[idx], zero_label=reference_class
        )

        pr_auc, precision, recall, thresholds = compute_pr_auc(
            y_true=y_true[idx], y_score=y_score[idx], zero_label=reference_class
        )

        idx_where_tpr_is_eighty = np.where((tpr - 0.8 >= 0))[0][0]
        fpr80 = fpr[idx_where_tpr_is_eighty]

        ood_target = [source for source, label in classes.items() if label == class_label][0]
        roc_results[ood_target] = dict(roc_auc=roc_auc, fpr=fpr, tpr=tpr, fpr80=fpr80, thresholds=thresholds)
        pr_results[ood_target] = dict(pr_auc=pr_auc, precision=precision, recall=recall, thresholds=thresholds)

    return roc_results, pr_results, fpr80


def subsample_labels_and_scores(y_true, y_score, n_examples):
    """Subsample y_true and y_score to have n_examples while maintaining their relative ordering"""
    assert len(y_true) == len(y_score) >= n_examples, f"Got {len(y_true)}, {len(y_score)}, {n_examples}"
    indices = [np.random.choice(np.where(y_true == i)[0], n_examples, replace=False) for i in set(y_true)]
    y_true = np.concatenate([y_true[idx] for idx in indices])
    y_score = np.concatenate([y_score[idx] for idx in indices])
    return y_true, y_score


if __name__ == "__main__":
    # Data
    datamodule = oodd.datasets.DataModule(
        batch_size=args.batch_size,
        test_batch_size=250,
        data_workers=args.data_workers,
        train_datasets=args.train_datasets,
        val_datasets=args.val_datasets,
        test_datasets=args.test_datasets,
    )
    args.save_dir = os.path.join(args.save_dir, list(datamodule.train_datasets.keys())[0] + "-"+args.model+'-'+args.augment+'-' + args.start_time)
    os.makedirs(args.save_dir, exist_ok=True)

    fh = logging.FileHandler(os.path.join(args.save_dir, "dvae.log"))
    fh.setLevel(logging.INFO)
    LOGGER.addHandler(fh)

    in_shape = datamodule.train_dataset.datasets[0].size[0]
    datamodule.save(args.save_dir)

    # Define checkpoints and load model
    checkpoint = oodd.models.Checkpoint(path=args.model_dir)
    checkpoint.load(device)
    datamodule = checkpoint.datamodule
    model = checkpoint.model
    model.eval()

    # get z_data from the train set

    train_z_data = get_posterior_mu(model, dataloader=datamodule.train_loader, layer_index=layer_index)
    for ddd in datamodule.val_datasets:
        if ddd == datamodule.primary_val_name:
            id_test_z = get_posterior_mu(model, dataloader=datamodule.val_loaders[ddd], layer_index=layer_index, max_test_examples=10000)
        else:
            ood_dataset_name = ddd
            ood_test_z = get_posterior_mu(model, dataloader=datamodule.val_loaders[ddd], layer_index=layer_index, max_test_examples=10000)

    # # 1) probability in prior space
    train_z_p_in_prior = Normal(0., 1.).log_prob(torch.tensor(train_z_data)).sum(dim=-1)
    id_test_z_p_in_prior = Normal(0., 1.).log_prob(torch.tensor(id_test_z)).sum(dim=-1)
    ood_test_z_p_in_prior = Normal(0., 1.).log_prob(torch.tensor(ood_test_z)).sum(dim=-1)
    plt.hist(train_z_p_in_prior, bins=100, density=True, color="deepskyblue", alpha=0.5,
             label=f'{datamodule.primary_val_name.replace("Dequantized", "")} train (ID)')
    plt.hist(id_test_z_p_in_prior, bins=100, density=True, color="orangered", alpha=0.5,
             label=f'{datamodule.primary_val_name.replace("Dequantized", "")} test (ID)')
    plt.hist(ood_test_z_p_in_prior, bins=100, density=True, color="purple", alpha=0.5,
             label=f'{ood_dataset_name.replace("Dequantized", "")} test (OOD)')
    plt.xlabel(r"Posterior $z$'s log-probability in $\mathcal{N}(0,I)$", fontsize=15)
    plt.ylabel('Density', fontsize=15)
    # plt.xlim([-1750, -650])
    # plt.title(f'Trained on {datamodule.primary_val_name}', fontsize=20)
    plt.legend(loc=0)
    plt.tight_layout()
    # plt.figure(figsize=(6, 8))
    plt.savefig(f'./figs/p(mu(q(z)))_{datamodule.primary_val_name}.pdf')
    plt.savefig(f'./figs/p(mu(q(z)))_{datamodule.primary_val_name}.png')
    plt.show()
    plt.clf()

    import seaborn as sns

    sns.kdeplot(train_z_p_in_prior, fill=True, color='lime', alpha=0.4,
                label=f'{datamodule.primary_val_name.replace("Dequantized", "")} train (ID)')
    sns.kdeplot(id_test_z_p_in_prior, fill=True, color="cornflowerblue", alpha=0.6,
                label=f'{datamodule.primary_val_name.replace("Dequantized", "")} test (ID)')
    sns.kdeplot(ood_test_z_p_in_prior, fill=True, color="lightcoral", alpha=0.5,
                label=f'{ood_dataset_name.replace("Dequantized", "")} test (OOD)')

    plt.xlabel(r"Posterior $z$'s log-probability in $\mathcal{N}(0,I)$", fontsize=15)
    plt.ylabel('Density', fontsize=15)
    # plt.xlim([-1750, -650])
    # plt.title(f'Trained on {datamodule.primary_val_name}', fontsize=20)
    plt.legend(loc='upper left')
    plt.tight_layout()
    # plt.figure(figsize=(6, 8))
    plt.savefig(f'./figs/p(mu(q(z)))_{datamodule.primary_val_name}_KDE.pdf')
    plt.savefig(f'./figs/p(mu(q(z)))_{datamodule.primary_val_name}_KDE.png')
    plt.show()
    plt.clf()
    #
    # # train a GMM for the z_data from the train set
    #
    # # use sklearn gmm
    # # gpu_z_data = copy.deepcopy(z_data)
    # # for idx, z_ in enumerate(gpu_z_data[-1]):
    # #     gpu_z_data[-1][idx] = z_.flatten().cpu().numpy()
    # # gmm_input = gpu_z_data[-1]
    #
    # global_z_gmm = GMM(n_components=10, random_state=42).fit(train_z_data)
    # log_like_train_z = global_z_gmm.score_samples(train_z_data)
    # log_like_id_z = global_z_gmm.score_samples(id_test_z)
    # log_like_ood_z = global_z_gmm.score_samples(ood_test_z)
    # plt.hist(log_like_train_z, bins=100, density=True, facecolor="deepskyblue", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} train (in)')
    # plt.hist(log_like_id_z, bins=100, density=True, facecolor="orangered", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} test (in)')
    # plt.hist(log_like_ood_z, bins=100, density=True, facecolor="purple", alpha=0.5,
    #          label=f'{ood_dataset_name} test (ood)')
    # plt.xlabel('Posterior $z_{\mu}$\'s log p(x) in fitted GMM', fontsize=18)
    # plt.ylabel('Density', fontsize=18)
    # # plt.xlim([-5, -0.5])
    # plt.title(f'Trained on {datamodule.primary_val_name}', fontsize=20)
    # plt.legend(loc=0)
    # plt.tight_layout()
    # # plt.figure(figsize=(6, 8))
    # # plt.savefig(f'./figs/Posterior z_mu\'s p(x) in $N$(0,I).pdf')
    # # plt.savefig(f'./figs/Posterior z_mu\'s p(x) in $N$(0,I).png')
    # plt.show()
    # print(f'log_like in GMM: train_z:{log_like_train_z.mean()}, id_test_z:{log_like_id_z.mean()}, ood_z:{log_like_ood_z.mean()}')
    #
    # cluster_train_z = global_z_gmm.predict(train_z_data)
    # cluster_id_z = global_z_gmm.predict(id_test_z)
    # cluster_ood_z = global_z_gmm.predict(ood_test_z)
    # plt.hist(cluster_train_z, bins=10, facecolor="deepskyblue", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} train (in)')
    # plt.hist(cluster_id_z, bins=10, facecolor="orangered", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} test (in)')
    # plt.hist(cluster_ood_z, bins=10, facecolor="purple", alpha=0.5,
    #          label=f'{ood_dataset_name} test (ood)')
    # plt.legend()
    # plt.title('predicted cluster by fitted GMM')
    # plt.show()
    #
    # proba_ood_z_c8 = global_z_gmm._estimate_log_prob(ood_test_z)[0]
    # proba_id_z_c8 = global_z_gmm._estimate_log_prob(id_test_z)[10]
    #
    #
    # def final_score(gmm_, z):
    #     proba_pick = global_z_gmm._estimate_log_prob(global_z_gmm.means_)
    #     proba_each_pick = np.zeros(len(global_z_gmm.means_))
    #     for mean_idx in range(len(global_z_gmm.means_)):
    #         proba_each_pick[mean_idx] = proba_pick[mean_idx, mean_idx]
    #
    #     cluster = gmm_.predict(z)
    #     proba = gmm_._estimate_log_prob(z)
    #     final_score = torch.zeros(len(cluster))
    #     for s_idx in range(len(cluster)):
    #         final_score[s_idx] = proba[s_idx, cluster[s_idx]] / proba_each_pick[cluster[s_idx]]
    #
    #     return final_score
    # fs_train_z = final_score(global_z_gmm, train_z_data)
    # fs_id_z = final_score(global_z_gmm, id_test_z)
    # fs_ood_z = final_score(global_z_gmm, ood_test_z)
    # plt.hist(fs_train_z, bins=100, density=True, facecolor="deepskyblue", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} train (in)')
    # plt.hist(fs_id_z, bins=100, density=True,facecolor="orangered", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} test (in)')
    # plt.hist(fs_ood_z, bins=100, density=True, facecolor="purple", alpha=0.5,
    #          label=f'{ood_dataset_name} test (ood)')
    # plt.legend()
    # plt.title('final score (lower should be id) by fitted GMM')
    # plt.show()
    #
    # # analyze one cluster
    # analyze_c = 8
    # ood_z_ca = []
    # prob = global_z_gmm._estimate_log_prob(ood_test_z)
    # for ci in range(len(cluster_ood_z)):
    #     if cluster_ood_z[ci] == analyze_c:
    #         ood_z_ca.append(prob[ci, analyze_c])
    # id_z_ca = []
    # prob = global_z_gmm._estimate_log_prob(id_test_z)
    # for ci in range(len(cluster_id_z)):
    #     if cluster_id_z[ci] == analyze_c:
    #         id_z_ca.append(prob[ci, analyze_c])
    # train_z_ca = []
    # prob = global_z_gmm._estimate_log_prob(train_z_data)
    # for ci in range(len(cluster_id_z)):
    #     if cluster_train_z[ci] == analyze_c:
    #         train_z_ca.append(prob[ci, analyze_c])
    # plt.hist(np.array(ood_z_ca), bins=50, density=True, facecolor='purple', alpha=0.5,
    #          label=f'{ood_dataset_name} test (ood)')
    # plt.hist(np.array(id_z_ca), bins=50, density=True, facecolor="orangered", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} test (in)')
    # plt.hist(np.array(train_z_ca), bins=50, density=True, facecolor="deepskyblue", alpha=0.5,
    #          label=f'{datamodule.primary_val_name} train (in)')
    # plt.legend()
    # plt.title(f'analysis of log p for the {analyze_c}th cluster')
    # plt.show()

    # ============================================================================= #
    # t-sne test
    # tsne within train_z, id_z, ood_z
    # tsne_inputs = []
    #
    # train_z_1000 = np.array(random.sample(list(train_z_data), 1000))
    # tsne_inputs.append(train_z_1000)
    #
    # id_test_z_1000 = np.array(random.sample(list(id_test_z), 1000))
    # tsne_inputs.append(id_test_z_1000)
    #
    # ood_test_z_1000 = np.array(random.sample(list(ood_test_z), 1000))
    # tsne_inputs.append(ood_test_z_1000)
    #
    # tsne_inputs = np.concatenate((tsne_inputs[0], tsne_inputs[1], tsne_inputs[2]), axis=0)
    #
    # # z_embedded = TSNE(n_components=2, perplexity=50.0, early_exaggeration=12.0, learning_rate=10.0, n_iter=2000, n_iter_without_progress=300, min_grad_norm=1e-07, metric='euclidean', init='random', verbose=0, random_state=None, method='barnes_hut', angle=0.5, n_jobs=None).fit_transform(tsne_input)
    # z_embedded = TSNE(n_components=2).fit_transform(tsne_inputs)
    #
    # lables = [f'{datamodule.primary_val_name} train (in)', f'{datamodule.primary_val_name} test (in)', f'{ood_dataset_name} test (ood)']
    # colors = ["deepskyblue", "orangered", 'purple']
    # for class_i in range(3):
    #     plt.scatter(z_embedded[class_i*1000:(class_i+1)*1000, 0], z_embedded[class_i*1000:(class_i+1)*1000, 1],
    #                 c=colors[class_i], label=lables[class_i], s=10, alpha=0.6)
    #
    # # plt.title(f'stage z_{stage_i}')
    # plt.xticks([])
    # plt.yticks([])
    # # legend_text = f'"t-sne of z" -> red:MNIST(ood) | green:FashionMNIST(in) | blue:Prior(in)'
    # # plt.title(legend_text)
    # # plt.tight_layout()
    # plt.legend()
    # plt.title('t-sne of posterior $z_{\mu}$')
    # # fig_recon_save_path = os.path.join(args.save_dir, f'{dataset}_q_in_q_ood_p.pdf')
    # # print(f'save path: {fig_recon_save_path}')
    # # plt.savefig(fig_recon_save_path, bbox_inches='tight', pad_inches=0.0)
    # plt.show()

    # ================================================================================ #
    # tsne within train_z, ood_test_z, prior
    np.save(f'../ConvVAE/vae_logs/hvae_dc/' + f'Zdim_{layer_index}_beta_1E-1_HVAE_DC_16-32-728_train_muz' + '.npy', train_z_data)
    np.save(f'../ConvVAE/vae_logs/hvae_dc/' + f'Zdim_{layer_index}_728_beta_1E-1_HVAE_DC_16-32-728_id_muz' + '.npy', id_test_z)
    np.save(f'../ConvVAE/vae_logs/hvae_dc/' + f'Zdim_{layer_index}_728_beta_1E-1_HVAE_DC_16-32-728_ood_muz' + '.npy', ood_test_z)
    tsne_inputs = []

    train_z_1000 = np.array(random.sample(list(train_z_data), 1000))
    tsne_inputs.append(train_z_1000)

    # id_test_z_1000 = np.array(random.sample(list(id_test_z), 1000))
    # tsne_inputs.append(id_test_z_1000)

    ood_test_z_1000 = np.array(random.sample(list(ood_test_z), 1000))
    tsne_inputs.append(ood_test_z_1000)

    prior_z_1000 = np.random.randn(train_z_1000.shape[0], train_z_1000.shape[1])
    tsne_inputs.append(prior_z_1000)

    tsne_inputs = np.concatenate((tsne_inputs[0], tsne_inputs[1], tsne_inputs[2]), axis=0)

    # z_embedded = TSNE(n_components=2, perplexity=50.0, early_exaggeration=12.0, learning_rate=10.0, n_iter=2000, n_iter_without_progress=300, min_grad_norm=1e-07, metric='euclidean', init='random', verbose=0, random_state=None, method='barnes_hut', angle=0.5, n_jobs=None).fit_transform(tsne_input)
    z_embedded = TSNE(n_components=2).fit_transform(tsne_inputs)

    lables = [f'{datamodule.primary_val_name} train (in)', f'{ood_dataset_name} test (ood)', 'prior distribution samples']
    colors = ["deepskyblue", 'purple', 'green']
    for class_i in range(3):
        plt.scatter(z_embedded[class_i*1000:(class_i+1)*1000, 0], z_embedded[class_i*1000:(class_i+1)*1000, 1],
                    c=colors[class_i], label=lables[class_i], s=10, alpha=0.6)

    # plt.title(f'stage z_{stage_i}')
    plt.xticks([])
    plt.yticks([])
    # legend_text = f'"t-sne of z" -> red:MNIST(ood) | green:FashionMNIST(in) | blue:Prior(in)'
    # plt.title(legend_text)
    # plt.tight_layout()
    plt.legend()
    plt.title(f'layer {layer_index} t-sne of latent space')
    # fig_recon_save_path = os.path.join(args.save_dir, f'{dataset}_q_in_q_ood_p.pdf')
    # print(f'save path: {fig_recon_save_path}')
    # plt.savefig(fig_recon_save_path, bbox_inches='tight', pad_inches=0.0)
    plt.show()

    # test the learned gmm
    # pred_ws = []
    # pred_log_likelis = []
    # for name, dataloader in datamodule.val_loaders.items():
    #     time.sleep(1)
    #     test_z_data = get_top_z(model, dataloader, max_test_examples=10000)
    #     # test_z_gmm = GMM(n_components=10, random_state=42).fit(test_z_data)
    #     # print(test_z_gmm.weights_)
    #     pred_w = global_z_gmm.predict_proba(test_z_data)
    #     pred_log_likeli = global_z_gmm.score_samples(test_z_data)
    #     #
    #     pred_ws.append(pred_w)
    #     pred_log_likelis.append(pred_log_likeli)
    #     #
    #     print(f'{name} pred_neg_likeli.mean(): {pred_log_likeli.mean()}')

        # pred = global_z_gmm.predict(test_z_data)
        # label = datamodule.val_datasets[f"{name}"].dataset.targets

    pass



    # from gmm_model import GaussianMixture
    # global_z_gmm_gpu = GaussianMixture(n_components=10, n_features=32)
    # global_z_gmm_gpu.fit(torch.tensor(z_data))
    #
    # for name, dataloader in datamodule.val_loaders.items():
    #     time.sleep(1)
    #     test_z_data = get_top_z(model, dataloader, max_test_examples=10000)
    #     # test_z_gmm = GMM(n_components=10, random_state=42).fit(test_z_data)
    #     # print(test_z_gmm.weights_)
    #     pred_w = global_z_gmm_gpu.predict_proba(torch.tensor(test_z_data))
    #     pred_neg_likeli = global_z_gmm_gpu.score_samples(torch.tensor(test_z_data)) * -1.0
    #     #
    #     pred_ws.append(pred_w)
    #     pred_neg_likelis.append(pred_neg_likeli)
    #     #
    #     print(f'pred_neg_likeli.mean(): {pred_neg_likeli.mean()}')

    #
    # p_z_samples = model.prior.sample(torch.Size([args.n_eval_samples])).to(device)
    # sample_latents = [None] * (model.n_latents - 1) + [p_z_samples]
    #
    # # Optimization
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    #
    # criterion = oodd.losses.ELBO()
    #
    # deterministic_warmup = oodd.variational.DeterministicWarmup(n=args.warmup_epochs)
    # free_nats_cooldown = oodd.variational.FreeNatsCooldown(
    #     constant_epochs=args.free_nats_epochs // 2,
    #     cooldown_epochs=args.free_nats_epochs // 2,
    #     start_val=args.free_nats,
    #     end_val=0,
    # )
    #
    # # Logging
    # LOGGER.info("Experiment config:")
    # LOGGER.info(args)
    # rich.print(vars(args))
    # LOGGER.info("%s", deterministic_warmup)
    # LOGGER.info("%s", free_nats_cooldown)
    # LOGGER.info("DataModule:\n%s", datamodule)
    # LOGGER.info("Model:\n%s", model)
    #
    # if args.use_wandb:
    #     wandb.init(project="hvae-oodd", config=args, name=f"{args.model} {datamodule.primary_val_name} {args.name}")
    #     wandb.save("*.pt")
    #     wandb.watch(model, log="all")
    #
    # # Run
    # test_elbos = [-np.inf]
    # test_evaluator = Evaluator(primary_source=datamodule.primary_val_name, primary_metric="log p(x)", logger=LOGGER, use_wandb=args.use_wandb)
    #
    # LOGGER.info("Running training...")
    # for epoch in range(1, args.epochs + 1):
    #     # current_beta = train(epoch)
    #
    #     if epoch % args.test_every == 0:
    #         # Sample
    #         with torch.no_grad():
    #             likelihood_data, stage_datas, skip_likelihood = model.sample_from_prior(
    #                 n_prior_samples=args.n_eval_samples, forced_latent=sample_latents
    #             )
    #             p_x_samples = likelihood_data.samples.view(args.n_eval_samples, *in_shape)
    #             p_x_mean = likelihood_data.mean.view(args.n_eval_samples, *in_shape)
    #             comparison = torch.cat([p_x_samples, p_x_mean])
    #             comparison = comparison.permute(0, 2, 3, 1)  # [B, H, W, C]
    #             fig, ax = plot_gallery(comparison.cpu().numpy(), ncols=args.n_eval_samples // 4)
    #             fig.savefig(os.path.join(args.save_dir, f"samples_{epoch:03}"))
    #             plt.close()
    #
    #         # Test
    #         for name, dataloader in datamodule.val_loaders.items():
    #             test(epoch, dataloader=dataloader, evaluator=test_evaluator, dataset_name=name, max_test_examples=10000)
    #
    #         # Save
    #         # test_elbo = test_evaluator.get_primary_metric().mean().cpu().numpy()
    #         # if np.max(test_elbos) < test_elbo:
    #         #     test_evaluator.save(args.save_dir, idx=epoch)
    #         #     model.save(args.save_dir, idx=f'{epoch}_beta_{current_beta}_seed_{args.seed}')
    #         #     LOGGER.info("Saved model!")
    #         # test_elbos.append(test_elbo)
    #
    #         # Compute LLR
    #         for source in test_evaluator.sources:
    #             for k in range(1, model.n_latents):
    #                 log_p_a = test_evaluator.metrics[source][f"skip-elbo"][f"0 log p(x)"]
    #                 log_p_b = test_evaluator.metrics[source][f"skip-elbo"][f"{k} log p(x)"]
    #                 llr = log_p_a - log_p_b
    #                 test_evaluator.update(source, series="LLR", metrics={f"LLR>{k}": llr})
    #
    #         # Compute AUROC score for L>k and LLR>k metrics
    #         reference_dataset = datamodule.primary_val_name
    #         max_examples = min(
    #             [len(d) for d in datamodule.val_datasets.values()]
    #         )  # Maximum number of examples to use for equal sized sets
    #
    #         # L >k
    #         for n_skipped_latents in range(model.n_latents):
    #             y_true, y_score, classes = test_evaluator.get_classes_and_scores_per_source(
    #                 f"skip-elbo", f"{n_skipped_latents} log p(x)"
    #             )
    #             y_true, y_score = subsample_labels_and_scores(y_true, y_score, max_examples)
    #             roc, pr, fpr80 = compute_roc_pr_metrics(
    #                 y_true, -y_score, classes, classes[reference_dataset]
    #             )  # Negation since higher score means more OOD
    #             for ood_target, value_dict in roc.items():
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"ROC AUC L>k",
    #                     metrics={f"ROC AUC L>{n_skipped_latents} {ood_target}": [value_dict["roc_auc"]]},
    #                 )
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"ROC AUC L>{n_skipped_latents}",
    #                     metrics={f"ROC AUC L>{n_skipped_latents} {ood_target}": [value_dict["roc_auc"]]},
    #                 )
    #             for ood_target, value_dict in pr.items():
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"PRC AUC L>k",
    #                     metrics={f"PRC AUC L>{n_skipped_latents} {ood_target}": [value_dict["pr_auc"]]},
    #                 )
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"PRC AUC L>{n_skipped_latents}",
    #                     metrics={f"PRC AUC L>{n_skipped_latents} {ood_target}": [value_dict["pr_auc"]]},
    #                 )
    #
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"fpr80 L>k",
    #                     metrics={f"fpr80 L>{n_skipped_latents} {ood_target}": [fpr80]},
    #                 )
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"fpr80 L>{n_skipped_latents}",
    #                     metrics={f"fpr80 L>{n_skipped_latents} {ood_target}": [fpr80]},
    #                 )
    #
    #         # LLR >0 >k
    #         for n_skipped_latents in range(1, model.n_latents):
    #             y_true, y_score, classes = test_evaluator.get_classes_and_scores_per_source(
    #                 f"LLR", f"LLR>{n_skipped_latents}"
    #             )
    #             y_true, y_score = subsample_labels_and_scores(y_true, y_score, max_examples)
    #             roc, pr, fpr80 = compute_roc_pr_metrics(y_true, y_score, classes, classes[reference_dataset])
    #             for ood_target, value_dict in roc.items():
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"ROC AUC LLR>k",
    #                     metrics={f"ROC AUC LLR>{n_skipped_latents} {ood_target}": [value_dict["roc_auc"]]},
    #                 )
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"ROC AUC LLR>{n_skipped_latents}",
    #                     metrics={f"ROC AUC LLR>{n_skipped_latents} {ood_target}": [value_dict["roc_auc"]]},
    #                 )
    #
    #             for ood_target, value_dict in pr.items():
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"PRC AUC LLR>k",
    #                     metrics={f"PRC AUC LLR>{n_skipped_latents} {ood_target}": [value_dict["pr_auc"]]},
    #                 )
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"PRC AUC LLR>{n_skipped_latents}",
    #                     metrics={f"PRC AUC LLR>{n_skipped_latents} {ood_target}": [value_dict["pr_auc"]]},
    #                 )
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"fpr80 LLR>k",
    #                     metrics={f"fpr80 LLR>{n_skipped_latents} {ood_target}": [fpr80]},
    #                 )
    #                 test_evaluator.update(
    #                     source=reference_dataset,
    #                     series=f"fpr80 LLR>{n_skipped_latents}",
    #                     metrics={f"fpr80 LLR>{n_skipped_latents} {ood_target}": [fpr80]},
    #                 )
    #
    #         # Report
    #         print(">>>>>>>>>>>"+args.save_dir+f'  kl_weigth:{current_beta}  gamma:{args.gamma}  '+"<<<<<<<<<<<<<")
    #         test_evaluator.report(epoch * len(datamodule.train_loader))
    #         test_evaluator.log(epoch)
    #         test_evaluator.reset()
