# This script is to extract the results from a set of stored checkpoints.
import os
import joblib
import math
import argparse
import pandas as pd
import json
import logging
import multiprocessing as mp
import torch
import torch.nn as nn
import queue
from datetime import datetime
import traceback


from benchmarking.utils import Experiment
from models.data_utils.util import get_device
from models.data_utils import get_dataloaders_for_tabular
from models.architectures.PCA.training import dataloader_to_numpy
from models import DeepCAE, StackedCAE, JointVAE, ConvAE, StandardAE, TransformerAE, PCA


ALL_MODELS = [JointVAE, PCA, TransformerAE, DeepCAE, StackedCAE, ConvAE, StandardAE]
NAME_TO_MODEL_CLASS = {model_class.__name__: model_class for model_class in ALL_MODELS}


def get_results(
    data_path: str = "artifacts/data",
    models: list = ALL_MODELS,
    max_epochs: int = 800,
    num_trainings: int = 3,
    dim_reduction: float = 0.505,
) -> pd.DataFrame:
    """
    Retreive all results from model checkpoints on all datasets and store results as experiments.

    Parameters:
    ---------------
    model_configs_path: str
        Path to the model configs that contain the hyperparameters.
    data_path: str
        Path to the datasets, where each dataset is contained in its own
        folder in raw and processed.
    """

    # Load names of all available datasets.
    dataset_names = get_folders(data_path)
    logging.info(f"Dataset names: {dataset_names}")
    if "MNIST" in dataset_names:
        dataset_names.remove(
            "MNIST"
        )  # MNIST will need special treatment using its dedicated dataloader
    # if "ChurnModelling" in dataset_names:
    #     dataset_names.remove("ChurnModelling")
    if "Thermography" in dataset_names:
        dataset_names.remove("Thermography")
    # dataset_names = ["Walmart"]

    model_names = [mdl.__name__ for mdl in models]
    logging.info(f"Model names: {model_names}")
    model_configs = get_configs(model_names, dataset_names, max_epochs)
    logging.info("Model configs: ", model_configs)

    # Filter the dataset names for datasets that actually exist.
    dataset_names = set(dataset_names).intersection(set(model_configs))
    logging.info(f"Updated dataset names: {dataset_names}")

    # Create queues for multiprocessing
    job_queue = mp.Queue()
    result_queue = mp.Queue()

    # Populate queue
    num_tasks = 0
    for dataset_name in dataset_names:
        for model_class in models:
            # Get the first num_trainings checkpoints.
            checkpoint_base_path = (
                f"artifacts/benchmarking/{model_class.__name__}/{dataset_name}"
            )
            directories = os.listdir(checkpoint_base_path)[:num_trainings]
            for directory in directories:
                job_queue.put((model_class, dataset_name, directory))
                num_tasks += 1

    # Start worker processes
    processes = []
    num_devices = torch.cuda.device_count()
    if not num_devices:
        # Then train on CPU
        num_devices = os.cpu_count()
        logging.info(
            f"Detected NO GPU devices - will distribute training jobs across {num_devices} discovered CPU kernels!"
        )
    else:
        logging.info(
            f"Detected {num_devices} GPU devices - will distribute training jobs!"
        )

    for device_id in range(num_devices):
        p = mp.Process(
            target=worker,
            args=(
                device_id,
                job_queue,
                result_queue,
                model_configs,
                dim_reduction,
            ),
        )
        p.start()
        processes.append(p)

    # Collect results from result_queue
    results_list = []
    completed_tasks_counter = 0
    while completed_tasks_counter < num_tasks:
        result = result_queue.get()
        results_list.append(result)
        completed_tasks_counter += 1

    for p in processes:
        p.join()

    results = pd.concat(results_list, ignore_index=True)
    return results


# Function to handle the training process on a given GPU
def worker(
    device_id: int,
    job_queue: mp.Queue,
    result_queue: mp.Queue,
    model_configs: dict,
    dim_reduction: float,
) -> None:
    while True:
        try:
            model_class, dataset_name, directory = job_queue.get_nowait()
        except queue.Empty:
            break
        logging.info(
            f"Device {device_id} is now getting results from {model_class.__name__} on {dataset_name} for checkpoint in {directory}..."
        )
        try:
            result = get_result_from_checkpoint(
                model_class,
                dataset_name,
                directory,
                model_configs[dataset_name],
                dim_reduction,
                device_id,
            )
            result_queue.put(result)
        except Exception as e:
            traceback.print_exc()


def get_result_from_checkpoint(
    model_class,
    dataset_name: str,
    directory: str,
    config: dict,
    dim_reduction: float,
    device_id: int = None,
):
    device = get_device(device_id)

    # Get data first, before model size definition
    _, test_loader = get_dataloaders_for_tabular(
        batch_size=64,
        path_to_data=f"artifacts/data/{dataset_name}/processed.csv",
        device=device,
    )

    input_dim = next(iter(test_loader)).size()[1]
    model_class_name = model_class.__name__

    if model_class == PCA:
        # Special treatment for PCA
        hidden_dim = round(input_dim * dim_reduction)

        experiment = Experiment(
            {"hidden_dim": hidden_dim},
            model_class_name,
            f"Infer from date: {datetime.now()}",
            "",
            dataset_name,
        )
        model = model_class(hidden_dim)
        checkpoint_base_path = (
            f"artifacts/benchmarking/{model_class.__name__}/{dataset_name}/{directory}"
        )
        logging.info(checkpoint_base_path)
        files = os.listdir(checkpoint_base_path)
        pkl_file = next((file for file in files if file.endswith(".pkl")), None)
        assert (
            pkl_file != None
        ), f"No model Checkpoint found for dataset {dataset_name} and model {model_class}"
        pkl_file_path = os.path.join(checkpoint_base_path, pkl_file)

        model = joblib.load(pkl_file_path)
        data = dataloader_to_numpy(test_loader)

        # Perform transformation
        transformed_reduced = model.fit_transform(data)

        # Reconstruct data.
        recon = model.inverse_transform(transformed_reduced)

        # Measure reconstruction quality.
        test_loss = nn.MSELoss()(torch.Tensor(recon), torch.Tensor(data)).item()
        logging.info(
            f"Kernel PCA result retrieval completed! The test loss is {test_loss}."
        )

        experiment.train_loss_mse.append(0.0)
        experiment.test_loss_mse.append(test_loss)

        del model
        # Now extract the final metrics from the experiment and return them
        return pd.DataFrame(experiment.flatten(), index=[0])
    else:
        if "models" in config:
            config = config["models"]
        if model_class_name in config:
            config = config[model_class_name]
        assert "epochs" in config, "Config has not attr epochs"

        # Init model
        hidden_spec = config.get("hidden_spec", None)
        if hidden_spec and len(hidden_spec) > 0:
            hidden_spec[-1] = round(input_dim * dim_reduction)
            if len(hidden_spec) == 2:
                # For the MultiLayer experiments comparing DeepCAE and StackedCAE.
                hidden_spec[-2] = round((input_dim + hidden_spec[-1]) / 2)

        model_kwargs = {
            "input_dim": input_dim,
        }

        if hidden_spec:
            model_kwargs["hidden_spec"] = hidden_spec
        else:
            # If there is no hidden spec, it has to be JointVAE and
            # there is a hidden_dim parameter instead.
            model_kwargs["hidden_dim"] = round(input_dim * dim_reduction)

        if channel_spec := config.get("channel_spec", None):
            model_kwargs["channel_spec"] = channel_spec

        if latent_spec := config.get("latent_spec", None):
            model_kwargs["latent_spec"] = latent_spec
        elif model_class == JointVAE:
            # Per default, we use no discrete variables.
            # We only use the same number of hidden continuous features as hidden_dim
            model_kwargs["latent_spec"] = {"cont": round(input_dim * dim_reduction)}

        model = model_class(**model_kwargs)
        checkpoint_base_path = (
            f"artifacts/benchmarking/{model_class.__name__}/{dataset_name}/{directory}"
        )
        logging.info(checkpoint_base_path)
        files = os.listdir(checkpoint_base_path)
        pth_file = next((file for file in files if file.endswith(".pth")), None)
        assert (
            pth_file != None
        ), f"No model Checkpoint found for dataset {dataset_name} and model {model_class}"
        pth_file_path = os.path.join(checkpoint_base_path, pth_file)
        model.load_state_dict(
            torch.load(pth_file_path, map_location=torch.device("cpu"))
        )
        model.to(device)
        model.eval()

        test_loss = 0.0
        num_test_batches = math.ceil(len(test_loader.dataset) / test_loader.batch_size)
        for batch in test_loader:
            batch = batch.to(device)

            if model_class == ConvAE:
                batch_size = len(batch)
                batch = batch.view(batch_size, 1, input_dim)

            if model_class == JointVAE or model_class == TransformerAE:
                batch = batch.unsqueeze(1)

            with torch.no_grad():
                recon, _ = model(batch)
                if model_class == JointVAE:
                    loss = nn.MSELoss()(
                        recon.view(-1, input_dim), batch.view(-1, input_dim)
                    )
                else:
                    loss = nn.MSELoss()(recon, batch)
                test_loss += loss

        experiment = Experiment(
            config,
            model_class_name,
            f"Infer from date: {datetime.now()}",
            model,
            dataset_name,
        )

        experiment.train_loss_mse.append(0.0)
        experiment.test_loss_mse.append(test_loss.item() / num_test_batches)

        del model
        # Now extract the final metrics from the experiment and return them
        return pd.DataFrame(experiment.flatten(), index=[0])


def get_folders(directory):
    folders = []
    for item in os.listdir(directory):
        if os.path.isdir(os.path.join(directory, item)):
            folders.append(item)
    return folders


def get_configs(
    model_names: list[str],
    dataset_names: list[str],
    max_epochs: int = None,
    base_path: str = "artifacts/tuning/",
) -> dict:
    """
    Get the configs for a given set of model names and dataset names.
    """

    amount = 0.505
    model_arguments = dict()
    for m_name in model_names:
        model_path = f"models/architectures/{m_name}/"
        try:
            with open(model_path + "conf.json", "r") as fp:
                model_arguments[m_name] = json.load(fp)
        except:
            logging.info(f"Folder does not exist: {model_path}")

    model_configs = dict()
    for d_name in dataset_names:
        model_configs[d_name] = dict()
        for m_name in model_names:
            file_name = f"{m_name}_{d_name}_{amount}_best.json"
            file_path = base_path + file_name
            try:
                with open(file_path, "r") as config_file:
                    config_dict = json.load(config_file)

                clean_config_dict = dict()
                for key, value in config_dict.items():
                    if "config_" in key:
                        clean_config_dict[key[len("config_") :]] = value
                    else:
                        clean_config_dict[key] = value

                model_configs[d_name][m_name] = model_arguments[m_name]["static"]
                for key in model_arguments[m_name]["tunable"].keys():
                    optimal_value = clean_config_dict[key]
                    model_configs[d_name][m_name][key] = optimal_value
                if max_epochs:
                    model_configs[d_name][m_name]["epochs"] = max_epochs
            except:
                logging.info(f"File does not exist: {file_path}")
    return model_configs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dim_reduction", type=float, default=0.505)
    parser.add_argument("--data_path", type=str, default="artifacts/data")
    parser.add_argument(
        "--models",
        type=str,
        nargs="*",
        metavar="M",
        default=ALL_MODELS,
        help="List of model names. By default, we take the first three checkpoints of this model.",
    )
    parser.add_argument("--max_epochs", type=int, default=800)
    parser.add_argument(
        "--num_trainings",
        type=int,
        default=3,
        help="Number of checkpoints to read in for each model and dataset combination.",
    )

    args = parser.parse_args()

    # Decode the model class names to model classes
    print(args.models, type(args.models))
    if args.models != ALL_MODELS and args.models and type(args.models[0]) == str:
        models = [NAME_TO_MODEL_CLASS[model] for model in args.models]
    else:
        models = args.models

    start_time = datetime.now()
    results = get_results(
        data_path=args.data_path,
        models=models,
        max_epochs=args.max_epochs,
        num_trainings=args.num_trainings,
        dim_reduction=args.dim_reduction,
    )

    base_result_path = (
        f"artifacts/results/{datetime.today().strftime('%Y-%m-%d')}-RecoveredResults/"
    )
    os.makedirs(base_result_path, exist_ok=True)
    result_file_name = "results"
    model_path = os.path.join(base_result_path, result_file_name)
    while os.path.isfile(model_path):
        path_components = model_path.split("_")
        if path_components[-1].endswith("results"):
            path_components.append("1")
        else:
            path_components[-1] = str(int(path_components[-1]) + 1)
        model_path = "_".join(path_components)
    result_file_name = result_file_name + ".csv"
    full_path = os.path.join(base_result_path, result_file_name)
    results.to_csv(full_path)

    end_time = datetime.now()
    runtime = end_time - start_time
    logging.info(
        f"Finished getting results!\nRetrieved {results.shape[0]} models after a duration of {runtime}!"
    )
