# 00 ->  MIG-GPU-0b63aaba-6ele-a51f-97e5-c8bb77699c04/1/0
# 01 -> MIG-GPU-0b63aaba-6ele-a51f-97e5-c8bb77699c04/2/0
# 10 -> MIG-GPU-1c0365e0-78b1-1672-907a-68efbe86c467/1/0
# 11 -> MIG-GPU-1c0365e0-78b1-1672-907a-68efbe86c467/2/0
# 20 -> MIG-GPU-d0cce0e7-c51a-ba4c-2728-7b6d73beacc8/1/0
# 21 -> MIG-GPU-d0cce0e7-c51a-ba4c-2728-7b6d73beacc8/2/0
# 40 -> MIG-GPU-612f1f55-f57a-e899-0561-2ded7ff24ee7/1/0
# 41 -> MIG-GPU-612f1f55-f57a-e899-0561-2ded7ff24ee7/2/0

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "MIG-GPU-0b63aaba-6ele-a51f-97e5-c8bb77699c04/1/0"

import argparse
import os
import logging

from collections import defaultdict

import rich
import numpy as np
import torch

import oodd.datasets

import oodd.utils
import torch.utils.data
from tqdm import tqdm
import oodd.losses
import torchvision
from torchvision import transforms as tv_transforms
from model import ConvVAE
from train_vae import analyze_z
from ConvVAE.lstm_mdn import MixtureDensityRNN as qz_rnn
from torchvision import datasets
from ConvVAE.plug_calculate_roc_etc import *

from oodd.utils.argparsing import json_file_or_json_unique_keys
from oodd.utils.argparsing import str2bool, json_file_or_json
from oodd.utils import str2bool, get_device, log_sum_exp, set_seed, plot_gallery
import matplotlib.pyplot as plt

LOGGER = logging.getLogger()

# train_datasets = '{ "CIFAR10Dequantized": {"dynamic": true, "split": "train"}}'
# val_datasets = '{"CIFAR10Dequantized": {"dynamic": false, "split": "validation"}, "SVHNDequantized": {"dynamic": false, "split": "validation"}}'

# train_datasets = '{ "CIFAR10Quantized28x28": {"dynamic": true, "split": "train"}}'
# val_datasets = '{"CIFAR10Quantized28x28": {"dynamic": false, "split": "validation"}, "SVHNQuantized28x28": {"dynamic": false, "split": "validation"}}'


train_datasets = '{ "FashionMNISTQuantized": {"dynamic": true, "split": "train"}}'
val_datasets = '{"FashionMNISTQuantized": {"dynamic": false, "split": "validation"},\
    "MNISTQuantized": {"dynamic": false, "split": "validation"}}'

# train_datasets = '{ "FashionMNISTQuantized28x28": {"dynamic": true, "split": "train"}}'
# val_datasets = '{"FashionMNISTQuantized28x28": {"dynamic": false, "split": "validation"},\
#     "MNISTQuantized28x28": {"dynamic": false, "split": "validation"}}'

parser = argparse.ArgumentParser(description="VAE MNIST Example")
parser.add_argument("--model", default="VAE", help="model type (VAE | LVAE | BIVA)")
parser.add_argument("--augment", default="False", help="model type ( False | dc | mask | oversm)")
parser.add_argument("--warmup_epochs", type=int, default=0, help="epochs to warm up the KL term.")
parser.add_argument("--kl_weight", type=float, default=1.0, help="fixes_kl_weigth, if not use warmup.")
parser.add_argument("--mask_ratio", type=float, default=0.0, help=" ")

parser.add_argument("--train_datasets", type=json_file_or_json_unique_keys, default=train_datasets)
parser.add_argument("--val_datasets", type=json_file_or_json_unique_keys, default=val_datasets)
parser.add_argument("--test_datasets", type=json_file_or_json_unique_keys, default=[])
#
# parser.add_argument("--config_deterministic", type=json_file_or_json, default=config_deterministic, help="")
# parser.add_argument("--config_stochastic", type=json_file_or_json, default=config_stochastic, help="")

parser.add_argument("--epochs", type=int, default=200, help="number of epochs to train")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate")
parser.add_argument("--train_samples", type=int, default=1, help="samples from approximate posterior")
parser.add_argument("--test_samples", type=int, default=1, help="samples from approximate posterior")
parser.add_argument("--train_importance_weighted", type=str2bool, default=False, const=True, nargs="?", help="use iw bound")
parser.add_argument("--test_importance_weighted", type=str2bool, default=False, const=True, nargs="?", help="use iw bound")

parser.add_argument("--free_nats_epochs", type=int, default=0, help="epochs to warm up the KL term.")
parser.add_argument("--free_nats", type=float, default=0, help="nats considered free in the KL term")
parser.add_argument("--n_eval_samples", type=int, default=32, help="samples from prior for quality inspection")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed")
parser.add_argument("--test_every", type=int, default=1, help="epochs between evaluations")
parser.add_argument("--save_dir", type=str, default="./data_models", help="directory for saving models")
parser.add_argument("--use_wandb", type=str2bool, default=False, help="use wandb tracking")
parser.add_argument("--name", type=str, default=True, help="wandb tracking name")
parser = oodd.datasets.DataModule.get_argparser(parents=[parser])

args, unknown_args = parser.parse_known_args()
rich.print(vars(args))

os.makedirs(args.save_dir, exist_ok=True)
device = get_device(0)
LOGGER.info("Device %s", device)

# FILE_NAME_SETTINGS_SPEC = f"k{args.n_latents_skip}-iw_elbo{args.iw_samples_elbo}-iw_lK{args.iw_samples_Lk}"


def get_save_path(name):
    name = name.replace(" ", "-")
    return f"{args.save_dir}/{name}"


def get_decode_from_p(n_latents, k=0, semantic_k=True):
    """
    k semantic out
    0 True     [False, False, False]
    1 True     [True, False, False]
    2 True     [True, True, False]
    0 False    [True, True, True]
    1 False    [False, True, True]
    2 False    [False, False, True]
    """
    if semantic_k:
        return [True] * k + [False] * (n_latents - k)

    return [False] * (k + 1) + [True] * (n_latents - k - 1)


def get_lengths(dataloaders):
    return [len(loader) for name, loader in dataloaders.items()]


def print_stats(llr, l, lk):
    llr_mean, llr_var, llr_std = np.mean(llr), np.var(llr), np.std(llr)
    l_mean, l_var, l_std = np.mean(l), np.var(l), np.std(l)
    lk_mean, lk_var, lk_std = np.mean(lk), np.var(lk), np.std(lk)
    s = f"  {l_mean:8.3f},   {l_var:8.3f},   {l_std:8.3f}\n"
    s += f"{llr_mean:8.3f}, {llr_var:8.3f}, {llr_std:8.3f}\n"
    s += f" {lk_mean:8.3f},  {lk_var:8.3f},  {lk_std:8.3f}"
    print(s)


# # # Define checkpoints and load model
# checkpoint = oodd.models.Checkpoint(path=args.model_dir)
# checkpoint.load(device=device)
# datamodule = checkpoint.datamodule
# # model = checkpoint.model
# # model.eval()
# # criterion = oodd.losses.ELBO()
# # rich.print(datamodule)
args.batch_size = 1
datamodule = oodd.datasets.DataModule(
    batch_size=args.batch_size,
    test_batch_size=args.batch_size,
    data_workers=args.data_workers,
    train_datasets=args.train_datasets,
    val_datasets=args.val_datasets,
    test_datasets=args.test_datasets,
)

# Add additional datasets to evaluation
TRAIN_DATASET_KEY = list(datamodule.train_datasets.keys())[0]
LOGGER.info("Train dataset %s", TRAIN_DATASET_KEY)

MAIN_DATASET_NAME = datamodule.primary_val_name.strip("Binarized").strip("Quantized").strip("Dequantized").strip("Quantized28x28")
LOGGER.info("Main dataset %s", MAIN_DATASET_NAME)

IN_DIST_DATASET = MAIN_DATASET_NAME + " test"
TRAIN_DATASET = MAIN_DATASET_NAME + " train"
LOGGER.info("Main in-distribution dataset %s", IN_DIST_DATASET)
if MAIN_DATASET_NAME in ["FashionMNIST", "MNIST"]:
    extra_val = dict(
        # SVHNQuantizedGrey28x28=dict(split="validation"),
        # CIFAR100QuantizedGrey28x28=dict(split="validation"),
        # STL10QuantizedGrey28x28=dict(split='train'),
        # CIFAR10QuantizedGrey28x28=dict(split="validation"),
        # CelebAQuantizedGrey28x28=dict(split="test"),
        # Flowers102QuantizedGrey28x28=dict(split="test"),
        # Places365QuantizedGrey28x28=dict(split="test"),
        # FakeDataQuantizedGrey28x28=dict(split="test"),
        # LFWPeopleQuantizedGrey28x28=dict(split="train"),
        # SUN397QuantizedGrey28x28=dict(split='test'),
        # RenderedSST2QuantizedGrey28x28=dict(split='test'),
        # constant_1=dict(split='test'),

        # notMNISTQuantized=dict(split='validation'),
        # Omniglot28x28Quantized=dict(split='validation'),
        # Omniglot28x28InvertedQuantized=dict(split='validation'),
        # KMNISTQuantized=dict(split='validation'),  # Effectively quantized
    )
    extra_test = {TRAIN_DATASET_KEY: dict(split="train", dynamic=False)}
elif MAIN_DATASET_NAME in ["CIFAR10", "SVHN"]:
    extra_val = dict(
        # CIFAR10DequantizedGrey=dict(split='test', preprocess='deterministic'),
        # CIFAR100Quantized28x28=dict(split='test'),
        # CelebAQuantized28x28=dict(split='test'),
        # STL10Quantized28x28=dict(split='train'),
        # Flowers102Quantized28x28=dict(split='test'),
        # Places365Quantized28x28=dict(split='test'),
        # constant_3=dict(split='test'),
        # FakeDataQuantized28x28=dict(split='test'),
        # LFWPeopleQuantized28x28=dict(split='train'),
        # SUN397Quantized28x28=dict(split='test'),
        # GTSRBQuantized28x28=dict(split='train'),
        # DTDQuantized28x28=dict(split='train', dynamic=True),
    )
    extra_test = {TRAIN_DATASET_KEY: dict(split="train", dynamic=True)}
else:
    raise ValueError(f"Unknown main dataset name {MAIN_DATASET_NAME}")

datamodule.add_datasets(val_datasets=extra_val, test_datasets=extra_test)
datamodule.data_workers = 4
# datamodule.batch_size = args.batch_size
# datamodule.test_batch_size = args.batch_size

datamodule.batch_size = 1
datamodule.test_batch_size = 1
LOGGER.info("%s", datamodule)


n_test_batches = get_lengths(datamodule.val_datasets) + get_lengths(datamodule.test_datasets)
for name, loader in datamodule.val_datasets.items():
    print(f'dataset:{name}-->len:{len(loader)}')
# dataset:CIFAR10Dequantized-->len:10000
# dataset:SVHNDequantized-->len:531131
# dataset:CelebAQuantized-->len:19962
# dataset:Food101Quantized-->len:25250
# dataset:Flowers102Quantized-->len:6149
# dataset:Places365Quantized-->len:36500
# dataset:FakeDataQuantized-->len:10000
# dataset:LFWPeopleQuantized-->len:9525
# dataset:RenderedSST2Quantized-->len:6920
# dataset:PCAMQuantized-->len:32768
# dataset:EuroSATQuantized-->len:27000
# dataset:GTSRBQuantized-->len:26640
# N_EQUAL_EXAMPLES_CAP:6000
N_EQUAL_EXAMPLES_CAP = int(min(n_test_batches)/1000)*1000
print(f'N_EQUAL_EXAMPLES_CAP:{N_EQUAL_EXAMPLES_CAP}')
assert N_EQUAL_EXAMPLES_CAP % args.batch_size == 0, "Batch size must divide smallest dataset size"


# N_EQUAL_EXAMPLES_CAP = min([args.n_eval_examples, N_EQUAL_EXAMPLES_CAP])
N_EQUAL_EXAMPLES_CAP = min([5000, N_EQUAL_EXAMPLES_CAP])
LOGGER.info("%s = %s", "N_EQUAL_EXAMPLES_CAP", N_EQUAL_EXAMPLES_CAP)

# decode_from_p = get_decode_from_p(model.n_latents, k=args.n_latents_skip)

dataloaders = {(k + " test", v) for k, v in datamodule.val_loaders.items()}
dataloaders |= {(k + " train", v) for k, v in datamodule.test_loaders.items()}

scores = defaultdict(list)
elbos = defaultdict(list)
elbos_k = defaultdict(list)

# ==================================== options ====================================== #
plot_entropy = False
compute_scores = False
ent_ablation = True
# =================================================================================== #

# ============================= load pretrained model =============================== #
if compute_scores:
    ## Arguments
    parser_cs = argparse.ArgumentParser(description='VAE MNIST Example')
    parser_cs.add_argument('--analyze_mode', type=bool, default=True, help='whether analyze the model.')
    parser_cs.add_argument('--add_noise', type=bool, default=False, help='')
    parser_cs.add_argument('--use_entropy', type=bool, default=False, help='verify the entropy')
    parser_cs.add_argument('--kl_q_q', type=bool, default=True, help='verify the q_id(z)')
    parser_cs.add_argument('--qq_sigma', type=float, default=0.1, help='')
    parser_cs.add_argument('--neigh', type=bool, default=False, help='verify the q_id(z)')
    parser_cs.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser_cs.add_argument('--id', type=str, default='CIFAR',
                        help='FMNIST, MNIST, CIFAR')
    parser_cs.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser_cs.add_argument('--z_dim', type=int, default=200, metavar='N',
                        help='dimension of the latent variable')
    parser_cs.add_argument('--beta', type=float, default=1.0,
                        help='beta * KL')
    parser_cs.add_argument('--no-cuda', action='store_true', default=False,
                        help='enables CUDA training')
    parser_cs.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser_cs.add_argument('--model', type=str, default='discrete_logistic', metavar='N',
                        help='which model to use: bce_vae, mse_vae,  gaussian_vae, or sigma_vae or optimal_sigma_vae, discrete_logistic')
    parser_cs.add_argument('--log_dir', type=str, default='./results', metavar='N')
    args_cs = parser_cs.parse_args()

    if args_cs.analyze_mode:
        args_cs.batch_size = 1
    ## Cuda
    args_cs.cuda = not args_cs.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args_cs.cuda else "cpu")

    ## Dataset
    TRANSFORM_non_BINARIZE = torchvision.transforms.Compose(
        [
            tv_transforms.Resize((28, 28)),
            torchvision.transforms.ToTensor(),
        ]
    )
    TRANSFORM_ = TRANSFORM_non_BINARIZE

    if args_cs.id == 'FMNIST':
        train_dataset = datasets.FashionMNIST('./data', train=True, download=True,
                                              transform=TRANSFORM_)  # transforms.ToTensor()
        test_dataset = datasets.FashionMNIST('./data', train=False,
                                             transform=TRANSFORM_)
        ood_test_dataset = datasets.MNIST('./data', train=False, download=True,
                                          transform=TRANSFORM_)
    elif args_cs.id == 'MNIST':
        train_dataset = datasets.MNIST('./data', train=True, download=True,
                                       transform=TRANSFORM_)  # transforms.ToTensor()
        test_dataset = datasets.MNIST('./data', train=False,
                                      transform=TRANSFORM_)
        ood_test_dataset = datasets.FashionMNIST('./data', train=False, download=True,
                                                 transform=TRANSFORM_)
    elif args_cs.id == 'CIFAR':
        # transform_ = tv_transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])
        train_dataset = datasets.CIFAR10('../../data', train=True, transform=TRANSFORM_, download=True)
        test_dataset = datasets.CIFAR10('../../data', train=False, transform=TRANSFORM_, download=True)
        ood_test_dataset = datasets.SVHN('../../data', split='test', transform=TRANSFORM_, download=True)
    # --- data loading --- #

    kwargs = {'num_workers': 10, 'pin_memory': True} if args_cs.cuda else {}
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args_cs.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args_cs.batch_size, shuffle=True, **kwargs)

    ## Build Model
    if args_cs.id in ['FMNIST', 'MNIST']:
        model = ConvVAE(device, 1, args_cs).to(device)

    if args_cs.id in ['CIFAR', 'SVHN']:
        model = ConvVAE(device, 3, args_cs).to(device)


x_axis = np.arange(1, 26)
n = 0
all_ood_logpxs = []
all_ood_names = []

dataloaders_list = list(dataloaders)
# ordering the dataloader
dataloaders = list(dataloaders)
for ii, dd in enumerate(dataloaders):
    if dd[0] == 'FashionMNISTQuantized test':
        index_to_move = ii
dataloaders = dataloaders[index_to_move:index_to_move+1] + dataloaders[:index_to_move] + dataloaders[index_to_move+1:]

for ii, dd in enumerate(dataloaders):
    if dd[0] == 'FashionMNISTQuantized train':
        index_to_move = ii
dataloaders = dataloaders[index_to_move:index_to_move+1] + dataloaders[:index_to_move] + dataloaders[index_to_move+1:]

all_train_decays = []
all_labels = []

with torch.no_grad():
    for dataset, dataloader in dataloaders:
        dataset = dataset.replace("Binarized", "").replace("Quantized", "").replace("Dequantized", "").replace("28x28", "")
        print(f"Evaluating {dataset}")
        # if dataset != 'SVHNGrey test':
        #     continue

        if compute_scores:
            model.load_state_dict(torch.load(
                f'../ConvVAE/vae_logs/results_id_{args_cs.id}/' + f'Zdim_{args_cs.z_dim}_beta_{args_cs.beta}_{args_cs.model}_{args_cs.add_noise}_checkpoint_{args_cs.epochs}' + '.pt'))
            # # avg_ELBO:  cifar10=12978  |  fmnist=
            # if args_cs.id == 'CIFAR':
            #     avg_ELBO = 12978
            #
            #
            if dataset == 'constant_1' or dataset == 'constant_3':
                dataset = 'constant'

            if n == 0:
                train_mus, id_mus, ood_mus, train_logpx, id_logpx, ood_logpx, train_kl, id_kl, ood_kl, avg_ELBO = analyze_z(model, train_loader, test_loader, dataloader, args_cs, dataset=dataset, ent_ablation=ent_ablation)
                n += 1
            else:
                if 'CIFAR10 ' in dataset:
                    continue
                _, _, ood_mus, _, _, ood_logpx, _, _, ood_kl, avg_ELBO = analyze_z(
                    model, train_loader, test_loader, dataloader, args_cs, avg_ELBO, dataset, ent_ablation=ent_ablation)

            # compute scores
            reference_scores = id_logpx
            test_scores = ood_logpx
            if 'CIFAR10 ' not in dataset:
                all_ood_logpxs.append(ood_logpx)
                all_ood_names.append(dataset)


            bigger_is_id = 1
            # compute metrics
            # check bigger should be id or ood?
            y_true = np.array([*[bigger_is_id] * len(reference_scores), *[1 - bigger_is_id] * len(test_scores)])
            y_score = np.concatenate([reference_scores, test_scores])

            (
                (roc_auc, fpr, tpr, thresholds),
                (pr_auc, precision, recall, thresholds),
                fpr80,
            ) = compute_roc_pr_metrics(y_true=y_true, y_score=y_score, reference_class=0)

            results = dict(
                roc=dict(roc_auc=roc_auc, fpr=fpr, tpr=tpr, thresholds=thresholds),
                pr=dict(pr_auc=pr_auc, precision=precision, recall=recall, thresholds=thresholds),
                fpr80=fpr80,
            )

            print(f"TEST {dataset} AUROC={roc_auc:6.4f}, AUPRC={pr_auc:6.4f}, FPR80={fpr80:6.4f}\n")


            # # save the results
            # np.save(
            #     f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_train_muz' + '.npy',
            #     train_mus)
            # np.save(
            #     f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_id_muz' + '.npy',
            #     id_mus)
            # np.save(
            #     f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_ood_muz' + '.npy',
            #     ood_mus)
            #
            # np.save(
            #     f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_train_logpx' + '.npy',
            #     train_logpx)
            # np.save(
            #     f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_id_logpx' + '.npy',
            #     id_logpx)
            # np.save(
            #     f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_ent_{args.use_entropy}_klqq_{args.kl_q_q}_{args.qq_sigma}_ood_logpx' + '.npy',
            #     ood_logpx)

            print(f'id logpx mean: {np.mean(id_logpx)}')
            print(f'ood logpx mean: {np.mean(ood_logpx)}')
            import seaborn as sns
            c1, c2, c3 = sns.color_palette('Set1', 3)
            sns.kdeplot(id_logpx[:, 0]/(28*28*np.log(2)), fill=True, color='cornflowerblue', alpha=0.4, label=f'{args_cs.id} test (ID)')
            sns.kdeplot(ood_logpx[:, 0]/(28*28*np.log(2)), fill=True, color='lightcoral', alpha=0.4, label=f'{dataset} (OOD)')
            plt.legend(loc='upper left')
            plt.xlabel('log(bits/dim)')
            plt.ylabel('Density')
            plt.tight_layout()
            plt.savefig(f'./figs/kde_fm_m_avoid.pdf')
            plt.savefig(f'./figs/kde_fm_m_avoid.png')
            plt.show()
            plt.clf()

            plt.hist(id_logpx, bins=100, facecolor="deepskyblue", alpha=0.5, label=f'{args_cs.id} test (ID)')
            plt.hist(ood_logpx, bins=100, facecolor="orangered", alpha=0.5, label=f'{dataset}(OOD)')
            plt.legend()
            # plt.xlim(-2600, -1200)
            plt.title('log p(x) estimated by ELBO')
            plt.show()
            plt.clf()



        if plot_entropy:
            count = 0
            train_decay = np.zeros(25)
            for b, (data, _) in tqdm(enumerate(dataloader), total=N_EQUAL_EXAMPLES_CAP / args.batch_size):
                # data = x.to(device)

                count += data.shape[0]

                data_np = np.array(data[0, 0])
                # plt.imshow(data_np, cmap='gray')
                # title = 'original'
                # plt.title(title)
                # plt.show()
                # plt.clf()

                min_step = []

                U, sigma, V = np.linalg.svd(data_np)
                for i in range(1, 26, 1):
                    reconstimg = np.matrix(U[:, :i]) * np.diag(sigma[:i]) * np.matrix(V[:i, :])
                    gap = np.mean(np.abs(reconstimg - data_np))
                    train_decay[i - 1] += gap
                    # if gap <= 0.0005:
                    #     min_step.append(i)
                    #     break
                if dataset in ['FashionMNIST train', 'FashionMNIST test']:
                    if count == 500:
                        break
                else:
                    if count == 100:
                        break
            train_decay = train_decay / count
            print(f'dataset: {dataset}  train decay: {train_decay}')
            all_train_decays.append(train_decay)
            all_labels.append(f'{dataset}')
            # plt.plot(x_axis, train_decay, label=f'{dataset}')

    if ent_ablation:
        plt.hist(id_logpx, alpha=0.5, label=f'{args_cs.id}-10 test (ID)')
        # sns.kdeplot(id_logpx, shade=True, alpha=0.5, label=f'{args_cs.id}-10 test (ID)')
        # for dataset_i in range(len(all_ood_logpxs)):
        #     plt.hist(all_ood_logpxs[dataset_i], alpha=0.5, label=f'{all_ood_names[dataset_i]}')
        plt.legend(loc='upper left')
        plt.xlabel(r'${\mathcal{C}}(x)$')
        plt.ylabel('Density')
        plt.xlim(0, 13500)
        plt.tight_layout()
        plt.savefig(f'./figs/estimated_Hx_of_ID_CIFAR.pdf')
        plt.savefig(f'./figs/estimated_Hx_of_ID_CIFAR.png')
        plt.show()
        plt.clf()

        for dataset_i in range(len(all_ood_logpxs)):
            label = all_ood_names[dataset_i].replace(" test", "").replace("_3", "")
            if label == "FakeData":
                label = "random"
            plt.hist(all_ood_logpxs[dataset_i], alpha=0.5, label=label)
        plt.legend(loc='upper left')
        plt.xlabel(r'Estimated $\hat{\mathcal{H}}(x)$')
        plt.ylabel('Density')
        plt.xlim(0, 13500)
        plt.tight_layout()
        plt.savefig(f'./figs/estimated_Hx_of_OOD_CIFAR.pdf')
        plt.savefig(f'./figs/estimated_Hx_of_OOD_CIFAR.png')
        plt.show()
        plt.clf()



    if plot_entropy:
        for ii, ll in enumerate(all_labels):
            if ll in ['constant_1 test']:
                index_to_move = ii
        all_labels = all_labels[:index_to_move] + all_labels[index_to_move + 1:] + all_labels[
                                                                                   index_to_move:index_to_move + 1]
        all_train_decays = all_train_decays[:index_to_move] + all_train_decays[index_to_move + 1:] + all_train_decays[
                                                                                                     index_to_move:index_to_move + 1]

        for ii, ll in enumerate(all_labels):
            if ll in ['FakeDataGrey test']:
                index_to_move = ii
        all_labels = all_labels[:index_to_move] + all_labels[index_to_move + 1:] + all_labels[
                                                                                   index_to_move:index_to_move + 1]
        all_train_decays = all_train_decays[:index_to_move] + all_train_decays[index_to_move + 1:] + all_train_decays[
                                                                                                     index_to_move:index_to_move + 1]

        fig, ax = plt.subplots()
        for ii in range(len(all_labels)):
            if all_labels[ii] in ['CIFAR100Grey test', 'CelebAGrey test']:
                continue
            elif all_labels[ii] == 'constant_1 test':
                label = 'Constant'
                ax.plot(x_axis, all_train_decays[ii], linestyle='dashed', label=label)
            elif all_labels[ii] == 'FakeDataGrey test':
                label = 'Random'
                ax.plot(x_axis, all_train_decays[ii], linestyle='dashed', label=label)
            elif all_labels[ii] in ['FashionMNIST train', 'FashionMNIST test']:
                label = all_labels[ii]
                ax.plot(x_axis, all_train_decays[ii], linewidth=2, label=label)
            elif all_labels[ii] in ['CIFAR10Grey test', 'SVHNGrey test']:
                label = all_labels[ii].replace("Grey", "Gray")
                ax.plot(x_axis, all_train_decays[ii], linestyle='dashed', label=label)
            else:
                label = all_labels[ii]
                ax.plot(x_axis, all_train_decays[ii], linestyle='dashed', label=label)
        plt.yscale('log')
        plt.xlabel(r'Number of singular values ($n$) ')
        plt.ylabel(r'Reconstruction error ($|x_{re} - x|$)')
        # plt.ylim(2e-5, 1e-3)
        legend = ax.legend(loc=3, fontsize=10)
        frame = legend.get_frame()
        frame.set_facecolor('white')
        frame.set_linewidth(1.5)
        frame.set_edgecolor('lightgray')

        ax.grid(color='gray', linestyle='-', linewidth=1, alpha=0.3)
        ax.set_facecolor('white')
        ax.spines['bottom'].set_color('black')
        ax.spines['top'].set_color('black')
        ax.spines['left'].set_color('black')
        ax.spines['right'].set_color('black')

        plt.tight_layout()
        seed = np.random.random()
        plt.savefig(f'./figs/{datamodule.primary_val_name}_FINAL_svd_decay_curve.pdf')
        plt.savefig(f'./figs/{datamodule.primary_val_name}_FINAL_svd_decay_curve.png')
        plt.show()
        plt.clf()

        #     sample_elbos, sample_elbos_k = [], []
        #
        #     # Regular ELBO
        #     for i in tqdm(range(args.iw_samples_elbo), leave=False):
        #         likelihood_data, stage_datas, _ = model(x, decode_from_p=False, use_mode=False)
        #
        #         kl_divergences = [
        #             stage_data.loss.kl_elementwise
        #             for stage_data in stage_datas
        #             if stage_data.loss.kl_elementwise is not None
        #         ]
        #         loss, elbo, likelihood, kl_divergences = criterion(
        #             likelihood_data.likelihood,
        #             kl_divergences,
        #             samples=1,
        #             free_nats=0,
        #             beta=args.beta,   # default : 1
        #             sample_reduction=None,
        #             batch_reduction=None,
        #         )
        #         sample_elbos.append(elbo.detach())
        #         # sample_elbos.append(-loss.detach())
        #
        #     # L>k bound
        #     for i in tqdm(range(args.iw_samples_Lk), leave=False):
        #         likelihood_data_k, stage_datas_k, _ = model(x, decode_from_p=decode_from_p, use_mode=decode_from_p)
        #         kl_divergences_k = [
        #             stage_data.loss.kl_elementwise
        #             for stage_data in stage_datas_k
        #             if stage_data.loss.kl_elementwise is not None
        #         ]
        #         loss_k, elbo_k, likelihood_k, kl_divergences_k = criterion(
        #             likelihood_data_k.likelihood,
        #             kl_divergences_k,
        #             samples=1,
        #             free_nats=0,
        #             beta=1,
        #             sample_reduction=None,
        #             batch_reduction=None,
        #         )
        #         sample_elbos_k.append(elbo_k.detach())
        #
        #     sample_elbos = torch.stack(sample_elbos, axis=0)
        #     sample_elbos_k = torch.stack(sample_elbos_k, axis=0)
        #
        #     sample_elbo = oodd.utils.log_sum_exp(sample_elbos, axis=0)
        #     sample_elbo_k = oodd.utils.log_sum_exp(sample_elbos_k, axis=0)
        #
        #     score = sample_elbo - sample_elbo_k
        #
        #     scores[dataset].extend(score.tolist())
        #     elbos[dataset].extend(sample_elbo.tolist())
        #     elbos_k[dataset].extend(sample_elbo_k.tolist())
        #
        #     if n > N_EQUAL_EXAMPLES_CAP:
        #         LOGGER.warning(f"Skipping remaining iterations due to {N_EQUAL_EXAMPLES_CAP}")
        #         break


# # print likelihoods
# for dataset in sorted(scores.keys()):
#     print("===============", dataset, "===============")
#     print_stats(scores[dataset], elbos[dataset], elbos_k[dataset])
#
# train_elbo_dataset = datamodule.primary_val_name  #.replace("Dequantized", "")
# for dataset_name in datamodule.val_datasets:
#     if dataset_name == train_elbo_dataset:
#         id_dataset_name = dataset_name.replace("Dequantized", "")
#     else:
#         ood_dataset_name = dataset_name.replace("Dequantized", "")
# train_dataset_name = train_elbo_dataset.replace("Dequantized", "")
# import matplotlib.pyplot as plt
# plt.hist(elbos[f'{train_dataset_name} train'], bins=100, density=True, facecolor="deepskyblue", alpha=0.5,
#              label=f'{train_dataset_name} train (in)')
# plt.hist(elbos[f'{id_dataset_name} test'], bins=100, density=True, facecolor="orangered", alpha=0.5,
#              label=f'{id_dataset_name} test (in)')
# plt.hist(elbos[f'{ood_dataset_name} test'], bins=100, density=True, facecolor="purple", alpha=0.5,
#              label=f'{ood_dataset_name} test (ood)')
# plt.xlabel('ELBO', fontsize=18)
# plt.ylabel('Density', fontsize=18)
# # plt.xlim([-5, -0.5])
# plt.title(f'Trained on {datamodule.primary_val_name}', fontsize=20)
# plt.legend(loc=0)
# plt.tight_layout()
# plt.show()
#
#
# # save scores
# torch.save(scores, get_save_path(f"values-scores-{IN_DIST_DATASET}-{FILE_NAME_SETTINGS_SPEC}.pt"))
# torch.save(elbos_k, get_save_path(f"values-elbos_k-{IN_DIST_DATASET}-{FILE_NAME_SETTINGS_SPEC}.pt"))
# print(f'N_EQUAL_EXAMPLES_CAP:{N_EQUAL_EXAMPLES_CAP}')