import yaml
import os
import argparse
import pandas as pd
import json
import logging
import multiprocessing as mp
import torch
import queue
from datetime import datetime
import traceback

from models.train_utils.train import train_model
from models import DeepCAE, StackedCAE, JointVAE, ConvAE, StandardAE, TransformerAE, PCA


ALL_MODELS = [JointVAE, PCA, TransformerAE, DeepCAE, StackedCAE, ConvAE, StandardAE]
NAME_TO_MODEL_CLASS = {model_class.__name__: model_class for model_class in ALL_MODELS}
DATASET_NAMES = [
    "Abalone", 
    "Adult", 
    "AirQuality", 
    "BankMarketing", 
    "BlastChar", 
    "CaliforniaHousing",
    "ChurnModelling",
    "Parkinsons",
    "Shoppers",
    "Students",
    "Support2",
    "TeaRetail",
    # "Thermography",
    "Walmart"
]

def benchmark(
    data_path: str = "artifacts/data",
    models: list = ALL_MODELS,
    datasets: list = DATASET_NAMES,
    max_epochs: int = 800,
    num_trainings: int = 3,
    dim_reduction: float = 0.505,
) -> pd.DataFrame:
    """
    Benchmark all models on all datasets and store results as experiments.

    Parameters:
    ---------------
    model_configs_path: str
        Path to the model configs that contain the hyperparameters.
    data_path: str
        Path to the datasets, where each dataset is contained in its own
        folder in raw and processed.
    """

    # Load names of all available datasets.
    dataset_names = set(get_folders(data_path)).intersection(set(datasets))
    logging.info(f"Dataset names: {dataset_names}")
    if "MNIST" in dataset_names:
        dataset_names.remove(
            "MNIST"
        )  # MNIST will need special treatment using its dedicated dataloader
    # if "ChurnModelling" in dataset_names:
    #     dataset_names.remove("ChurnModelling")
    if "Thermography" in dataset_names:
        dataset_names.remove("Thermography")

    model_names = [mdl.__name__ for mdl in models]
    logging.info(f"Model names: {model_names}")
    model_configs = get_configs(model_names, dataset_names, max_epochs)
    logging.info("Model configs: ", model_configs)

    # Filter the dataset names for datasets that actually exist.

    dataset_names = set(dataset_names).intersection(set(model_configs))
    logging.info(f"Updated dataset names: {dataset_names}")

    # Create queues for multiprocessing
    job_queue = mp.Queue()
    result_queue = mp.Queue()

    # Populate queue
    num_tasks = 0
    for dataset_name in dataset_names:
        for model_class in models:
            for _ in range(num_trainings):
                job_queue.put((model_class, dataset_name))
                num_tasks += 1

    # Start worker processes
    processes = []
    num_devices = torch.cuda.device_count()
    if not num_devices:
        # Then train on CPU
        num_devices = os.cpu_count()
        logging.info(
            f"Detected NO GPU devices - will distribute training jobs across {num_devices} discovered CPU kernels!"
        )
    else:
        logging.info(
            f"Detected {num_devices} GPU devices - will distribute training jobs!"
        )

    for device_id in range(num_devices):
        p = mp.Process(
            target=worker,
            args=(
                device_id,
                job_queue,
                result_queue,
                model_configs,
                dim_reduction,
            ),
        )
        p.start()
        processes.append(p)

    # Collect results from result_queue
    results_list = []
    completed_tasks_counter = 0
    while completed_tasks_counter < num_tasks:
        result = result_queue.get()
        results_list.append(result)
        completed_tasks_counter += 1

    for p in processes:
        p.join()

    results = pd.concat(results_list, ignore_index=True)
    return results


# Function to handle the training process on a given GPU
def worker(
    device_id: int,
    job_queue: mp.Queue,
    result_queue: mp.Queue,
    model_configs: dict,
    dim_reduction: float,
) -> None:
    while True:
        try:
            model_class, dataset_name = job_queue.get_nowait()
        except queue.Empty:
            break
        logging.info(
            f"Device {device_id} is now training {model_class.__name__} on {dataset_name}..."
        )
        try:
            result = train_model(
                model_class,
                dataset_name,
                model_configs[dataset_name],
                dim_reduction,
                device_id,
                early_stopping=True,
                patience=30,
                min_delta=0.998,
            )
            result_queue.put(result)
        except Exception as e:
            traceback.print_exc()


def get_folders(directory):
    folders = []
    for item in os.listdir(directory):
        if os.path.isdir(os.path.join(directory, item)):
            folders.append(item)
    return folders


def get_configs(
    model_names: list[str],
    dataset_names: list[str],
    max_epochs: int = None,
    base_path: str = "artifacts/tuning/",
) -> dict:
    """
    Get the configs for a given set of model names and dataset names.
    """

    amount = 0.505
    model_arguments = dict()
    for m_name in model_names:
        model_path = f"models/architectures/{m_name}/"
        try:
            with open(model_path + "conf.json", "r") as fp:
                model_arguments[m_name] = json.load(fp)
        except:
            logging.info(f"Folder does not exist: {model_path}")

    model_configs = dict()
    for d_name in dataset_names:
        model_configs[d_name] = dict()
        for m_name in model_names:
            file_name = f"{m_name}_{d_name}_{amount}_best.json"
            file_path = base_path + file_name
            try:
                with open(file_path, "r") as config_file:
                    config_dict = json.load(config_file)

                clean_config_dict = dict()
                for key, value in config_dict.items():
                    if "config_" in key:
                        clean_config_dict[key[len("config_") :]] = value
                    else:
                        clean_config_dict[key] = value

                model_configs[d_name][m_name] = model_arguments[m_name]["static"]
                for key in model_arguments[m_name]["tunable"].keys():
                    optimal_value = clean_config_dict[key]
                    model_configs[d_name][m_name][key] = optimal_value
                if max_epochs:
                    model_configs[d_name][m_name]["epochs"] = max_epochs
            except:
                logging.info(f"File does not exist: {file_path}")
    return model_configs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dim_reduction", type=float, default=0.505)
    parser.add_argument("--data_path", type=str, default="artifacts/data")
    parser.add_argument(
        "--models",
        type=str,
        nargs="*",
        metavar="M",
        default=ALL_MODELS,
        help="List of model names.",
    )
    parser.add_argument(
        "--datasets",
        type=str,
        nargs="*",
        metavar="D",
        default=DATASET_NAMES,
        help="List of dataset names.",
    )
    parser.add_argument("--max_epochs", type=int, default=800)
    parser.add_argument(
        "--num_trainings",
        type=int,
        default=3,
        help="Number of training runs to do for each model and dataset combination.",
    )

    args = parser.parse_args()

    # Decode the model class names to model classes
    print(args.models, type(args.models))
    if args.models != ALL_MODELS and args.models and type(args.models[0]) == str:
        models = [NAME_TO_MODEL_CLASS[model] for model in args.models]
    else:
        models = args.models

    start_time = datetime.now()
    results = benchmark(
        data_path=args.data_path,
        models=models,
        datasets=args.datasets,
        max_epochs=args.max_epochs,
        num_trainings=args.num_trainings,
        dim_reduction=args.dim_reduction,
    )

    base_result_path = f"artifacts/results/{datetime.today().strftime('%Y-%m-%d')}/"
    os.makedirs(base_result_path, exist_ok=True)
    result_file_name = "results"
    model_path = os.path.join(base_result_path, result_file_name)
    while os.path.isfile(model_path):
        path_components = model_path.split("_")
        if path_components[-1].endswith("results"):
            path_components.append("1")
        else:
            path_components[-1] = str(int(path_components[-1]) + 1)
        model_path = "_".join(path_components)
    result_file_name = result_file_name + ".csv"
    full_path = os.path.join(base_result_path, result_file_name)
    results.to_csv(full_path)

    end_time = datetime.now()
    runtime = end_time - start_time
    logging.info(
        f"Finished benchmarking!\nTrained {results.shape[0]} models after a duration of {runtime}!"
    )
