"""Script to evaluate the OODD scores (LLR and L>k) for a saved HVAE"""

import argparse
import os
import logging

from collections import defaultdict
from typing import *

from tqdm import tqdm

import rich
import numpy as np
import torch

import oodd
import oodd.datasets
import oodd.evaluators
import oodd.models
import oodd.losses
import oodd.utils
from oodd.utils import str2bool, get_device, log_sum_exp, set_seed, plot_gallery

import matplotlib.pyplot as plt
import cv2

LOGGER = logging.getLogger()


parser = argparse.ArgumentParser()

# =======> FASHIONMNIST
cifar_model_dir = "5-layer-fashionmnist-vae-false"
parser.add_argument("--beta", type=float, default=0.001)
importance_samples = 1
# OURS: vae-dc-fashinmnist   5-layer-fashionmnist-vaedc-268617
# HVAE: VAE-FASIONMINST-EPOCH600-046043   vae-binaryfashionmnist-epoch1000  5-layer-fashionmnist-vae-false
# LVAE: LVAE-FASHIONMNIST-EPOCH400-773564  lvae-binaryfashionmnist-epoch1000
# BIVA: BIVA-FashionMNIST-EPOCH349-987397   biva-binaryFashionMNIST-epoch600


# ======> CIFAR10
# cifar_model_dir = "5-layer-vaedc-cifar10"
# importance_samples = 1
# parser.add_argument("--beta", type=float, default=1.0)
# OURS: "vae-dc-warm-cifar10-0.978-epo51-454529"    "vae-dc-0.1-cifar10-0.96-epo32-317843"   VAEdc_CIFAR10Dequantized-2022-03-13-00-44-02.083833
#       5-layer-vaedc-cifar10
# VAE : VAE-CIFAR10Dequantized-EPOCH300-075664   vae-cifar-show
# LVAE: LVAE-CIFAR10Dequantized-EPOCH100-150420   LVAE-CIFAR10-EPOCH257-150420  LVAE-cifar-show
# BIVA: BIVA-CIFAR10Dequantized-0.86-epoch1-946984    BIVA-choose-cifar10    BIVA-CIFAR10-GOOD  biva-cifar10-densityshow


show_n_figs = 1
iter_n_show = 100
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
# random.seed(seed)
# parser.add_argument("--model_dir", type=str, default="./models/VAEdc_CIFAR10Dequantized-2022-03-13-00-44-02.083833", help="model")
# parser.add_argument("--model_dir", type=str, default="./models/vae-dc-warm-cifar10-0.978-epo51-454529", help="model")  # im202: 0.7316
parser.add_argument("--model_dir", type=str, default="./models/"+cifar_model_dir, help="model")  # best: im1 -> celeba 0.723; im101:celeba 0.7265

parser.add_argument("--n_eval_samples", type=int, default=100, help="samples from prior for quality inspection")


parser.add_argument("--save_dir", type=str, default="./figs_patial_2/"+cifar_model_dir, help="directory to store scores in")
parser.add_argument("--iw_samples_elbo", type=int, default=importance_samples, help="importances samples for regular ELBO")

parser.add_argument("--iw_samples_Lk", type=int, default=importance_samples, help="importances samples for L>k bound")
parser.add_argument("--n_eval_examples", type=int, default=float("inf"), help="cap on the number of examples to use")
parser.add_argument("--n_latents_skip", type=int, default=2, help="the value of k in the paper")
parser.add_argument("--batch_size", type=int, default=show_n_figs, help="batch size for evaluation")
parser.add_argument("--device", type=str, default="cuda:2", help="device to evaluate on")


args = parser.parse_args()
rich.print(vars(args))

os.makedirs(args.save_dir, exist_ok=True)
device = oodd.utils.get_device() if args.device == "auto" else torch.device(args.device)
LOGGER.info("Device %s", device)

FILE_NAME_SETTINGS_SPEC = f"k{args.n_latents_skip}-iw_elbo{args.iw_samples_elbo}-iw_lK{args.iw_samples_Lk}"


def get_save_path(name):
    name = name.replace(" ", "-")
    return f"{args.save_dir}/{name}"


def get_decode_from_p(n_latents, k=0, semantic_k=True):
    """
    k semantic out
    0 True     [False, False, False]
    1 True     [True, False, False]
    2 True     [True, True, False]
    0 False    [True, True, True]
    1 False    [False, True, True]
    2 False    [False, False, True]
    """
    if semantic_k:
        return [True] * k + [False] * (n_latents - k)

    return [False] * (k + 1) + [True] * (n_latents - k - 1)


def get_lengths(dataloaders):
    return [len(loader) for name, loader in dataloaders.items()]


def print_stats(llr, l, lk):
    llr_mean, llr_var, llr_std = np.mean(llr), np.var(llr), np.std(llr)
    l_mean, l_var, l_std = np.mean(l), np.var(l), np.std(l)
    lk_mean, lk_var, lk_std = np.mean(lk), np.var(lk), np.std(lk)
    s = f"  {l_mean:8.3f},   {l_var:8.3f},   {l_std:8.3f}\n"
    s += f"{llr_mean:8.3f}, {llr_var:8.3f}, {llr_std:8.3f}\n"
    s += f" {lk_mean:8.3f},  {lk_var:8.3f},  {lk_std:8.3f}"
    print(s)


# Define checkpoints and load model
checkpoint = oodd.models.Checkpoint(path=args.model_dir)
checkpoint.load(device=device)
datamodule = checkpoint.datamodule
model = checkpoint.model
model.eval()
criterion = oodd.losses.ELBO()
rich.print(datamodule)

# Add additional datasets to evaluation
TRAIN_DATASET_KEY = list(datamodule.train_datasets.keys())[0]
LOGGER.info("Train dataset %s", TRAIN_DATASET_KEY)

MAIN_DATASET_NAME = datamodule.primary_val_name.strip("Binarized").strip("Quantized").strip("Dequantized")
LOGGER.info("Main dataset %s", MAIN_DATASET_NAME)

IN_DIST_DATASET = MAIN_DATASET_NAME + " test"
TRAIN_DATASET = MAIN_DATASET_NAME + " train"
LOGGER.info("Main in-distribution dataset %s", IN_DIST_DATASET)
if MAIN_DATASET_NAME in ["FashionMNIST", "MNIST"]:
    extra_val = dict(
        # notMNISTQuantized=dict(split='validation'),
        # Omniglot28x28Quantized=dict(split='validation'),
        # Omniglot28x28InvertedQuantized=dict(split='validation'),
        # SmallNORB28x28Quantized=dict(split='validation'),
        # SmallNORB28x28InvertedQuantized=dict(split='validation'),
        # KMNISTQuantized=dict(split='validation'),  # Effectively quantized
    )
    # extra_test = {TRAIN_DATASET_KEY: dict(split="train", dynamic=False)}
    extra_test = {}
elif MAIN_DATASET_NAME in ["CIFAR10", "SVHN"]:
    extra_val = dict(
        # CIFAR10DequantizedGrey=dict(split='test', preprocess='deterministic'),
        # CIFAR100Quantized=dict(split='test'),  # bad
        # CelebAQuantized=dict(split='test'),
        # STL10Dequantized=dict(split='train'),  # bad
        # Food101Quantized=dict(split='test'),   # 食物
        # Flowers102Quantized=dict(split='test'),  # 花
        # Places365Quantized=dict(split='test'),   # 场景
        # LSUNQuantized=dict(split='test'),
        # FakeDataQuantized=dict(split='test'),   # 随机数
        # LFWPeopleQuantized=dict(split='train'),  # 人脸
        # SUN397Quantized=dict(split='test'),   # 景色
        # RenderedSST2Quantized=dict(split='train'),  # 文本转成的图像
        # PCAMQuantized=dict(split='test'),  # 医学淋巴图像
        # EuroSATQuantized=dict(split='test'),  # 卫星图像
        # GTSRBQuantized=dict(split='train'),  # 交通信号灯
        # DTDQuantized=dict(split='train', dynamic=True),  # 纹理


    )
    # extra_test = {TRAIN_DATASET_KEY: dict(split="train", dynamic=True)}
    extra_test = {TRAIN_DATASET_KEY: dict(split="test", dynamic=False)}
else:
    raise ValueError(f"Unknown main dataset name {MAIN_DATASET_NAME}")

datamodule.add_datasets(val_datasets=extra_val, test_datasets=extra_test)
datamodule.data_workers = 4
datamodule.batch_size = args.batch_size
datamodule.test_batch_size = args.batch_size
LOGGER.info("%s", datamodule)


n_test_batches = get_lengths(datamodule.val_datasets) + get_lengths(datamodule.test_datasets)
for name, loader in datamodule.val_datasets.items():
    print(f'dataset:{name}-->len:{len(loader)}')
# dataset:CIFAR10Dequantized-->len:10000
# dataset:SVHNDequantized-->len:531131
# dataset:CelebAQuantized-->len:19962
# dataset:Food101Quantized-->len:25250
# dataset:Flowers102Quantized-->len:6149
# dataset:Places365Quantized-->len:36500
# dataset:FakeDataQuantized-->len:10000
# dataset:LFWPeopleQuantized-->len:9525
# dataset:RenderedSST2Quantized-->len:6920
# dataset:PCAMQuantized-->len:32768
# dataset:EuroSATQuantized-->len:27000
# dataset:GTSRBQuantized-->len:26640
# N_EQUAL_EXAMPLES_CAP:6000


dataloaders = {(k + " test", v) for k, v in datamodule.val_loaders.items()}
dataloaders |= {(k + " train", v) for k, v in datamodule.test_loaders.items()}

with torch.no_grad():
    for dataset, dataloader in dataloaders:
        dataset = dataset.replace("Binarized", "").replace("Quantized", "").replace("Dequantized", "")
        print(f"Evaluating {dataset}")

        model.eval()

        iterator = tqdm(enumerate(dataloader), smoothing=0.9, total=len(dataloader), leave=False)
        in_shape = datamodule.train_dataset.datasets[0].size[0]
        iter_n = 0
        for idx, (x, _) in iterator:

            # save x
            fig_x_save_path = os.path.join(args.save_dir, f'{dataset}_{idx}_x.png')
            # plt.imshow(x[0].permute(1, 2, 0), cmap='gray')
            plt.imshow(x[0].permute(1, 2, 0))
            plt.axis("off")
            plt.savefig(fig_x_save_path, bbox_inches='tight', pad_inches=0.0)
            plt.show()
            plt.close()


            x = x.to(device)
            n = min(x.size(0), show_n_figs)
            p_x_means = [[]for i in range(model.n_latents)]



            # generation
            p_z_samples = model.prior.sample(torch.Size([args.n_eval_samples])).to(device)
            sample_latents = [None] * (model.n_latents - 1) + [p_z_samples]
            with torch.no_grad():
                likelihood_data, stage_datas, skip_likelihood = model.sample_from_prior(
                    n_prior_samples=args.n_eval_samples, forced_latent=sample_latents
                )
                # p_x_samples = likelihood_data.samples.view(args.n_eval_samples, *in_shape)
                p_x_mean = likelihood_data.mean.view(args.n_eval_samples, *in_shape)
                # comparison = torch.cat([p_x_samples, p_x_mean])
                comparison = torch.cat([p_x_mean])
                comparison = comparison.permute(0, 2, 3, 1)  # [B, H, W, C]
                fig, ax = plot_gallery(comparison.cpu().numpy(), ncols=args.n_eval_samples // 10)
                plt.xticks([])
                plt.yticks([])
                plt.tight_layout()
                print(f'save_dir: {args.save_dir}')
                fig.savefig(os.path.join(args.save_dir, f"generated_samples.pdf"), bbox_inches='tight', pad_inches=0.1)
                plt.show()
                plt.close()

            break
        break



            #
            # for skip_latens in range(model.n_latents):
            #     decode_from_p = get_decode_from_p(model.n_latents, k=skip_latens)
            #     likelihood_data, stage_datas, skip_likelihood = model(x, n_posterior_samples=1, decode_from_p=decode_from_p, use_mode=decode_from_p)
            #     p_x_means[skip_latens] = likelihood_data.mean[: args.batch_size].view(args.batch_size, *in_shape)  # Reshape zeroth "sample"
            #
            #     save_each = p_x_means[skip_latens][0].permute(1, 2, 0).cpu().numpy()
            #     fig_recon_save_path = os.path.join(args.save_dir, f'{dataset}_{idx}_recon_z{skip_latens}.png')
            #     plt.imshow(save_each, cmap='gray')
            #     plt.axis("off")
            #     plt.savefig(fig_recon_save_path, bbox_inches='tight', pad_inches=0.0)
            #     plt.close()
            # # if dataset == 'KMNIST test':
            # #     chosen = [2,  3,  9,  16,  20,  21,  23,  28]
            # # if dataset == 'notMNIST test':
            # #     chosen = [2,  6,  9,  18,  22,  23, 29, 32]
            # if dataset == 'FashionMNIST test':
            #     chosen = [0, 1, 2, 3, 4, 5,  8,  9,  19,  30]
            # if dataset == 'MNIST test':
            #     chosen = [11,23,30,33,43,44,80,82,85,94]
            #
            # # if dataset == 'KMNIST test' or dataset == 'notMNIST test':
            # if dataset == 'FashionMNIST test' or dataset == 'MNIST test':
            #     com = x[chosen[0]]
            #     pxms = [[]for i in range(model.n_latents)]
            #     for skip_latens in range(model.n_latents):
            #         pxms[skip_latens] = p_x_means[skip_latens][chosen[0]]
            #
            #     for chosen_n in chosen[1:]:
            #         com = torch.cat([com, x[chosen_n]])
            #         for skip_latens in range(model.n_latents):
            #             pxms[skip_latens] = torch.cat([pxms[skip_latens], p_x_means[skip_latens][chosen_n]])
            #
            #     com = com.unsqueeze(1)
            #     for skip_latens in range(model.n_latents):
            #         com = torch.cat([com, pxms[skip_latens].unsqueeze(1)])
            #
            #     com = com.permute(0, 2, 3, 1)  # [B, H, W, C]
            #     fig, ax = plot_gallery(com.cpu().numpy(), ncols=10, cmap='gray')
            #
            #     plt.xticks([])
            #     plt.yticks([])
            #     plt.tight_layout()
            #     fig.savefig(os.path.join(args.save_dir, f"1111chosen_{dataset}"), bbox_inches='tight', pad_inches=0.0)
            #     plt.show()
            #     plt.close()
            #
            #
            # comparison = x[:n]
            # for skip_latens in range(model.n_latents):
            #     comparison = torch.cat([comparison , p_x_means[skip_latens][:n]])
            # comparison = comparison.permute(0, 2, 3, 1)  # [B, H, W, C]
            # if comparison.shape[-1] == 1:
            #     fig, ax = plot_gallery(comparison.cpu().numpy(), ncols=n, cmap='gray')
            # else:
            #     fig, ax = plot_gallery(comparison.cpu().numpy(), ncols=n)
            #
            #
            # plt.xticks([])
            # plt.yticks([])
            # plt.tight_layout()
            # fig_save_path = os.path.join(args.save_dir, f"{dataset}_{iter_n}_recon__n_latents_{args.n_latents_skip}")
            # fig.savefig(fig_save_path)
            #
            # # img = cv2.imread(fig_save_path, cv2.IMREAD_COLOR)
            # # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            # # plt.imshow(gray, cmap="gray")
            # # plt.axis('off')
            # # fig.savefig(fig_save_path)
            # plt.show()
            #
            # # # save single figs
            # # for single_fig in comparison:
            # #     plt.imshow(single_fig)
            # #     single_path = os.path.join(args.save_dir, f'{idx}_recon_single')
            # #     plt.savefig()
            #
            #
            # # plt.close()
            # iter_n += 1
            # if iter_n > iter_n_show:
            #     break
            #
            #
