import argparse
import json
import logging
import os
from time import time

import importlib
import numpy as np
import torch
import matplotlib.pyplot as plt

from pathlib import Path

import wandb
from pythae.data.preprocessors import DataProcessor
from pythae.models import AutoModel
from pythae.pipelines import GenerationPipeline
from pythae.trainers import BaseTrainerConfig
from TTUR.fid import compute_fid
from improved_gan.inception_score.model import get_inception_score
import shutil


logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)

PATH = os.path.dirname(os.path.abspath(__file__))

ap = argparse.ArgumentParser()

ap.add_argument(
    "--models_path",
    help="The path to a model to generate from",
    required=True,
)
ap.add_argument(
    "--sampler_name",
    help="the name of the sampler to use",
    choices=[
        "normal_sampler",
        "gmm_sampler",
        "vamp_sampler",
        "rhvae_sampler",
        "two_stage_sampler",
        "hypersphere_sampler",
        "maf_sampler",
        "iaf_sampler",
        "pixelcnn_sampler",
        "man_sampler"
    ],
    default="normal_sampler"
)
ap.add_argument(
    "--sampler_config",
    help="path to model config file (expected json file)",
    default=None,
)
ap.add_argument(
    "--num_samples",
    type=int,
    default=10
)
ap.add_argument(
    "--use_wandb",
    help="whether to log the metrics in wandb",
    action="store_true",
)
ap.add_argument(
    "--wandb_project",
    help="wandb project name",
    default="generation_reb",
)
ap.add_argument(
    "--wandb_entity",
    help="wandb entity name",
    default="benchmark_team",
)

args = ap.parse_args()


def main(args):

    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    model_signature = os.listdir(args.models_path)[0]

    #model_path = os.path.join(args.models_path, model_signature, "final_model")

    model_path = os.path.join(args.models_path)

    # reload the model
    trained_model = AutoModel.load_from_folder(model_path).eval()
    logger.info(f"Successfully reloaded {trained_model.model_name.upper()} model !\n")

    train_data = None
    eval_data = None

    if trained_model.model_config.input_dim == (1, 28, 28):
        dataset = 'mnist'

    elif trained_model.model_config.input_dim == (3, 32, 32):
        dataset = 'cifar10'

    elif trained_model.model_config.input_dim == (3, 64, 64):
        dataset = 'celeba'

    # Get training and eval data if need in the sampler
    if args.sampler_name not in ["normal_sampler", "rhvae_sampler", "hypersphere_sampler", "vamp_sampler"]:

        try:
            logger.info(f"\nLoading {dataset} data...\n")
            train_data = (
                np.load(os.path.join(PATH, f"data/{dataset}", "train_data.npz"))[
                    "data"
                ]
                / 255.0
            )
            eval_data = (
                np.load(os.path.join(PATH, f"data/{dataset}", "eval_data.npz"))["data"]
                / 255.0
            )
        except Exception as e:
            raise FileNotFoundError(
                f"Unable to load the data from 'data/{dataset}' folder. Please check that both a "
                "'train_data.npz' and 'eval_data.npz' are present in the folder.\n Data must be "
                " under the key 'data', in the range [0-255] and shaped with channel in first "
                "position\n"
                f"Exception raised: {type(e)} with message: " + str(e)
            ) from e

        logger.info("Successfully loaded data !\n")
        logger.info("------------------------------------------------------------")
        logger.info("Dataset \t \t Shape \t \t \t Range")
        logger.info(
            f"{dataset.upper()} train data: \t {train_data.shape} \t [{train_data.min()}-{train_data.max()}] "
        )
        logger.info(
            f"{dataset.upper()} eval data: \t {eval_data.shape} \t [{eval_data.min()}-{eval_data.max()}]"
        )
        logger.info("------------------------------------------------------------\n")

        data_input_dim = tuple(train_data.shape[1:])

    if args.sampler_name == "normal_sampler":
        from pythae.samplers import NormalSamplerConfig

        if args.sampler_config is not None:
            sampler_config = NormalSamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = NormalSamplerConfig()

    elif args.sampler_name == "gmm_sampler":
        from pythae.samplers import GaussianMixtureSamplerConfig


        if args.sampler_config is not None:
            sampler_config = GaussianMixtureSamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = GaussianMixtureSamplerConfig()

    elif args.sampler_name == "vamp_sampler":
        from pythae.samplers import VAMPSamplerConfig


        if args.sampler_config is not None:
            sampler_config = VAMPSamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = VAMPSamplerConfig()

    elif args.sampler_name == "rhvae_sampler":
        from pythae.samplers import RHVAESamplerConfig


        if args.sampler_config is not None:
            sampler_config = RHVAESamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = RHVAESamplerConfig()

    elif args.sampler_name == "two_stage_sampler":
        from pythae.samplers import TwoStageVAESamplerConfig


        if args.sampler_config is not None:
            sampler_config = TwoStageVAESamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = TwoStageVAESamplerConfig()


    elif args.sampler_name == "hypersphere_sampler":
        from pythae.samplers import HypersphereUniformSamplerConfig


        if args.sampler_config is not None:
            sampler_config = HypersphereUniformSamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = HypersphereUniformSamplerConfig()


    elif args.sampler_name == "maf_sampler":
        from pythae.samplers import MAFSamplerConfig


        if args.sampler_config is not None:
            sampler_config = MAFSamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = MAFSamplerConfig()

    elif args.sampler_name == "iaf_sampler":
        from pythae.samplers import IAFSamplerConfig


        if args.sampler_config is not None:
            sampler_config = IAFSamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = IAFSamplerConfig()

    elif args.sampler_name == "pixelcnn_sampler":
        from pythae.samplers import PixelCNNSamplerConfig


        if args.sampler_config is not None:
            sampler_config = PixelCNNSamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = PixelCNNSamplerConfig()

    elif args.sampler_name == "man_sampler":
        from pythae.samplers import UnifManVAESamplerConfig

        if args.sampler_config is not None:
            sampler_config = UnifManVAESamplerConfig.from_json_file(args.sampler_config)

        else:
            sampler_config = UnifManVAESamplerConfig(n_medoids=1000, n_lf=5, eps_lf=0.005, lbd=1)

    extended_sampler_config_dict = sampler_config.to_dict()

    extended_sampler_config_dict["model_path"] = model_path
    extended_sampler_config_dict["num_sampler"] = args.num_samples

    logger.info(f"Sampler config: {sampler_config}\n")

    pipe = GenerationPipeline(model=trained_model, sampler_config=sampler_config)
    generated_samples = pipe(
        num_samples=args.num_samples,
        output_dir=None,
        return_gen=True,
       # batch_size=50,
    train_data=train_data,
    eval_data=eval_data,
    training_config=BaseTrainerConfig(num_epochs=200, output_dir=os.path.join("dummy_output_dir", f"{int(os.environ['SLURM_ARRAY_TASK_ID'])}"))
    )

    print(generated_samples.min(), generated_samples.max())


    generated_samples = 255.0 * torch.movedim(generated_samples, 1, 3).cpu().detach().numpy()

    if generated_samples.shape[-1] == 1:
        generated_samples = np.repeat(generated_samples, repeats=3, axis=-1)

    # mimics saving/reloading
    #generated_samples = generated_samples.astype("uint8")
    generated_samples = generated_samples.astype(np.float32)

    is_open_ai = get_inception_score([generated_samples[i] for i in range(generated_samples.shape[0])])




    eval_data = (
                torch.movedim(
                    torch.tensor(
                        np.load(os.path.join(PATH, f"data/{dataset}", "eval_data.npz"))["data"]* 1.
                        )
                    , 1, 3)
                ).detach().numpy()    
    
    test_data = (
                torch.movedim(
                    torch.tensor(
                        np.load(os.path.join(PATH, f"data/{dataset}", "test_data.npz"))["data"]* 1.
                        )
                    , 1, 3)
                ).detach().numpy()



    eval_fid = compute_fid(
        gen_data=generated_samples,
        ref_data=eval_data,#f"/gpfsscratch/rech/wlr/uhw48em/data/{dataset}/eval_folder/",
        inception_path='.'
        )

    test_fid = compute_fid(
        gen_data=generated_samples,
        ref_data=test_data,#f"/gpfsscratch/rech/wlr/uhw48em/data/{dataset}/test_folder/",
        inception_path='.'
        )
    print("----------without save----------")
    print(f"fid vs eval : {eval_fid}")
    print("----------without save----------")
    print(f"fid vs test : {test_fid}")
    print(f"IS openai: {is_open_ai}")
    print("----------without save----------")

    if args.use_wandb:
        
        if importlib.util.find_spec("wandb") is not None:
            
            wandb.init(project=args.wandb_project, entity=args.wandb_entity)
            wandb.config.update(
                {
                    "sampler_config": extended_sampler_config_dict,
                    "model_config": trained_model.model_config.to_dict()
                }
            )

        else:
            raise ModuleNotFoundError(
                "`wandb` package must be installed. Run `pip install wandb`"
            )

        # logging some final images

        n_im_to_log = min(100, args.num_samples)

        imgs_to_log = []
        line_img = []


        column_names = [str(i) for i in range(min(10, args.num_samples))]

        idx = torch.randperm(args.num_samples)

        for i in range(n_im_to_log):
            img = generated_samples[idx[i]]# plt.imread(os.path.join(output_dir, imgs_names[i]))
            #imgs_to_log.append(img)
            line_img.append(wandb.Image(img))

            if len(line_img) == 10:
                imgs_to_log.append(line_img)
                line_img = []

        sampling_table = wandb.Table(data=imgs_to_log, columns=column_names)
        wandb.log(
            {
                "sampling_data": sampling_table,
                "test/is_mean": is_open_ai[0],
                "test/is_std": is_open_ai[1],
                "eval/fid": eval_fid,
                "test/fid": test_fid,
                })

if __name__ == "__main__":

    main(args)
