import argparse
import json
import logging
import os
from time import time

import importlib
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from pathlib import Path

import wandb
from pythae.data.preprocessors import DataProcessor
from pythae.models import AutoModel
from pythae.pipelines import GenerationPipeline
from pythae.trainers import BaseTrainerConfig
from TTUR.fid import compute_fid_wo_paths


logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)

PATH = os.path.dirname(os.path.abspath(__file__))

ap = argparse.ArgumentParser()

ap.add_argument(
    "--models_path",
    help="The path to a model to generate from",
    required=True,
)
ap.add_argument(
    "--use_wandb",
    help="whether to log the metrics in wandb",
    action="store_true",
)
ap.add_argument(
    "--wandb_project",
    help="wandb project name",
    default="latent_dim_sensi_interpolations",
)
ap.add_argument(
    "--wandb_entity",
    help="wandb entity name",
    default="benchmark_team",
)

args = ap.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"


def main(args):

    model_signature = os.listdir(args.models_path)[0]
    model_path = os.path.join(args.models_path, model_signature, "final_model")

    # reload the model
    trained_model = AutoModel.load_from_folder(model_path).to(device)
    logger.info(f"Successfully reloaded {trained_model.model_name.upper()} model !\n")

    train_data = None
    eval_data = None

    if trained_model.model_config.input_dim == (1, 28, 28):
        dataset = 'mnist'

        inter_idx = {
            '0': [0, 1],
            '1': [88, 2],
            '2': [15, 32],
            '3': [555, 432],
            '4': [12, 100],
            '5': [132, 1156],
            '6': [818, 2527],
            '7': [201, 3072],
            '8': [656, 555],
            '9': [1221, 1320]

        }

    elif trained_model.model_config.input_dim == (3, 32, 32):
        dataset = 'cifar10'

        inter_idx = {
            '0': [0, 1],
            '1': [88, 2],
            '2': [15, 32],
            '3': [555, 432],
            '4': [12, 100],
            '5': [132, 1156],
            '6': [818, 2527],
            '7': [201, 3072],
            '8': [656, 555],
            '9': [1221, 1320]

        }

    elif trained_model.model_config.input_dim == (3, 64, 64):
        dataset = 'celeba'

        inter_idx = {
            '0': [0, 1],
            '1': [88, 2],
            '2': [15, 32],
            '3': [76, 9865],
            '4': [1019, 1106],
            '5': [132, 1156],
            '6': [818, 2527],
            '7': [201, 3072],
            '8': [656, 555],
            '9': [1221, 1320]
        }

    try:
        test_data = (
            np.load(os.path.join(PATH, f"data/{dataset}", "test_data.npz"))["data"]
            / 255.0
        )
        
    except Exception as e:
        raise FileNotFoundError(
            f"Unable to load the data from 'data/{dataset}' folder. Please check that both a "
            "'train_data.npz' and 'eval_data.npz' are present in the folder.\n Data must be "
            " under the key 'data', in the range [0-255] and shaped with channel in first "
            "position\n"
            f"Exception raised: {type(e)} with message: " + str(e)
        ) from e

    logger.info("Successfully loaded data !\n")
    logger.info("------------------------------------------------------------")
    logger.info("Dataset \t \t Shape \t \t \t Range")
    logger.info(
        f"{dataset.upper()} test data: \t {test_data.shape} \t [{test_data.min()}-{test_data.max()}]"
    )
    logger.info("------------------------------------------------------------\n")

    dataset_type = (
        "DoubleBatchDataset"
        if trained_model.model_name == "FactorVAE"
        else "BaseDataset"
    )

    data_processor = DataProcessor()
    test_data = data_processor.process_data(test_data).to(device)
    test_dataset = data_processor.to_dataset(test_data, dataset_type=dataset_type)
    test_loader = DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False
    )

    test_inter = []

    for i in range(10):
        start_idx, end_idx = inter_idx[str(i)]


        with torch.no_grad():
            #for _, inputs in enumerate(test_loader):
            start = trained_model.encoder(test_loader.dataset.data[start_idx].unsqueeze(0)).embedding
            #start = trained_model({'data': test_loader.dataset.data[start_idx].unsqueeze(0)}).z
            
            end = trained_model.encoder(test_loader.dataset.data[end_idx].unsqueeze(0)).embedding
            #end = trained_model({'data': test_loader.dataset.data[end_idx].unsqueeze(0)}).z           


        for t in torch.linspace(0, 1, 10):
            z_t = start * (1 - t) + t * end
            test_inter.append(trained_model.decoder(z_t).reconstruction)

    test_recon = torch.cat(test_inter)

    if test_recon.shape[-1] == 1:
        test_recon = np.repeat(test_recon, repeats=3, axis=-1)


    if args.use_wandb:
        
        if importlib.util.find_spec("wandb") is not None:
            
            wandb.init(project=args.wandb_project, entity=args.wandb_entity)
            wandb.config.update(
                {
                    "model_path": model_path,
                    "model_config": trained_model.model_config.to_dict()
                }
            )

        else:
            raise ModuleNotFoundError(
                "`wandb` package must be installed. Run `pip install wandb`"
            )

        # logging some final images

        n_im_to_log = 100

        imgs_to_log = []
        line_img = []


        column_names = [str(i) for i in range(10)]
        for i in range(n_im_to_log):
           # plt.imread(os.path.join(output_dir, imgs_names[i]))
            #imgs_to_log.append(img)

            img = test_recon[i]
            line_img.append(wandb.Image(img))

            if len(line_img) == 10:
                imgs_to_log.append(line_img)
                line_img = []

        sampling_table = wandb.Table(data=imgs_to_log, columns=column_names)
        wandb.log(
            {
                "test_interpolations": sampling_table,
                })

if __name__ == "__main__":

    main(args)
