import pandas as pd
import numpy as np
from datetime import datetime
import xgboost as xgb
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    mean_absolute_error,
    mean_squared_error,
)
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
from sklearn.preprocessing import LabelEncoder
from models import DeepCAE, StackedCAE, JointVAE, ConvAE, StandardAE, TransformerAE, PCA

import os
import logging
import json
import argparse


MODEL_CLASSES = [PCA, DeepCAE, StackedCAE, JointVAE, ConvAE, StandardAE, TransformerAE]
MODEL_NAMES = [model.__name__ for model in MODEL_CLASSES]


def get_predictor_performance(data_path: str = "artifacts/data") -> dict:
    """
    For all datasets in data_path, train an XGBoost classifier an assess the performance of the model on the raw data.
    """

    # Load names of all available datasets.
    dataset_names = get_folders(data_path)
    logging.info(f"Dataset names: {dataset_names}")
    if "MNIST" in dataset_names:
        dataset_names.remove(
            "MNIST"
        )  # MNIST will need special treatment using its dedicated dataloader
    if "Thermography" in dataset_names:
        dataset_names.remove(
            "Thermography"
        )

    results = {}

    for dataset_name in dataset_names:
        # First load the config from the original data directory.
        with open(f"artifacts/data/{dataset_name}/conf.json", "r") as config_file:
            config = json.load(config_file)

        # Load and split the dataset.
        if "embeddings" in data_path:
            data = pd.read_csv(f"{data_path}/{dataset_name}/embeddings.csv")
        elif "data" in data_path:
            data = pd.read_csv(f"{data_path}/{dataset_name}/processed.csv")
        else:
            logging.warn(
                f"The data path {data_path} was not detected as location for embeddings and also not for original data."
            )
        data.rename(columns=replace_special_characters, inplace=True)
        # logging.info(f"Data head: \n{data.head()}")

        # Separate the label as per config
        target_names = config["target"]
        for i, target_name in enumerate(target_names):
            target_names[i] = target_name.replace("<", "&lt;").replace(">", "&gt;")
        labels = data[target_names]

        # Drop one label column if there are two to avoid multicollinearity
        if labels.shape[1] == 1:
            logging.info(
                f"There is only one label for {dataset_name}, we continue as normal."
            )
        elif labels.shape[1] == 2:
            logging.info(
                f"There are two labels {target_names} for {dataset_name}. We have to drop one of them to avoid multicollinearity."
            )
            # This is fine for Adult, BlastChar and ChurnModelling. For Support2 we also just predict if dead or not.
            labels.drop(labels.columns[-1], axis=1, inplace=True)
        elif labels.shape[1] > 2:
            # This applies only for the Students dataset.
            logging.info(
                f"There are more than two labels {target_names}, we will use a MultiOutputClassifier in training for this."
            )

        data.drop(target_names, axis=1, inplace=True)

        # Perform train test split for data and label
        mask = np.random.rand(len(data)) < 0.8
        train, test = data[mask], data[~mask]
        train_label, test_label = labels[mask], labels[~mask]

        logging.info(f"Now training predictor for {dataset_name}.")
        predictor = train_xgb_predictor(train, train_label)

        logging.info(f"Now evaluating predictor for {dataset_name}")
        result = get_xgb_predictor_results(predictor, test, test_label)
        results[dataset_name] = result

        logging.info(
            f"Finished original predictor performance training and evaluation for dataset {dataset_name}!\n Result:\n{result}"
        )

    return results


def get_xgb_predictor_results(
    predictor: xgb.XGBModel, test: pd.DataFrame, test_label: pd.DataFrame
) -> dict:
    """
    Return a dict containing the following metrics of the predictor:
    Classification
    - Accuracy
    - Precision
    - Recall
    - F1-Score

    Regression
    - Mean Absolute Error
    - Root Mean Squared Error

    Parameters:
    predictor (xgb.XGBModel): The trained XGBoost model.
    test (pd.DataFrame): The test data.
    test_label (pd.DataFrame): The true labels for the test data.

    Returns:
    dict: A dictionary containing the metrics.
    """

    # Predict the labels
    if isinstance(predictor, (MultiOutputClassifier, MultiOutputRegressor)):
        predictions = predictor.predict(test)
    else:
        predictions = predictor.predict(test)
        if test_label.shape[1] > 1:
            predictions = predictions.reshape(-1, test_label.shape[1])

    metrics = {}

    # Calculate metrics for each label
    # NOTE: Changes to the used metrics have also to be taken over into the plots_and_tables.py script.
    if isinstance(predictor, xgb.sklearn.XGBRegressor):
        mae = []
        rmse = []

        for i in range(test_label.shape[1]):
            y_true = test_label.iloc[:, i]
            y_pred = predictions[:, i] if test_label.shape[1] > 1 else predictions
            mae.append(mean_absolute_error(y_true, y_pred))
            rmse.append(np.sqrt(mean_squared_error(y_true, y_pred)))

        metrics["MAE"] = sum(mae) / len(mae)
        metrics["RMSE"] = sum(rmse) / len(rmse)

    else:
        accuracy = []
        precision = []
        recall = []
        f1 = []

        for i in range(test_label.shape[1]):
            y_true = test_label.iloc[:, i]
            y_pred = predictions[:, i] if test_label.shape[1] > 1 else predictions
            accuracy.append(accuracy_score(y_true, y_pred))
            precision.append(
                precision_score(y_true, y_pred, average="weighted", zero_division=0)
            )
            recall.append(
                recall_score(y_true, y_pred, average="weighted", zero_division=0)
            )
            f1.append(f1_score(y_true, y_pred, average="weighted", zero_division=0))

        metrics["Accuracy"] = sum(accuracy) / len(accuracy)
        metrics["Precision"] = sum(precision) / len(precision)
        metrics["Recall"] = sum(recall) / len(recall)
        metrics["F1-Score"] = sum(f1) / len(f1)

    return metrics


def train_xgb_predictor(
    train_set: pd.DataFrame, train_label: pd.DataFrame
) -> xgb.XGBModel:
    """
    Return a trained XGBoost predictor for a given training set.

    Parameters:
    train_set (pd.DataFrame): The training data.
    train_label (pd.DataFrame): The training labels.

    Returns:
    xgb.XGBModel: A trained XGBoost model.
    """

    # Check the number of label columns
    num_labels = train_label.shape[1]

    # Determine if the problem is regression or classification
    is_regression = False
    if num_labels == 1:
        # Check the cardinality of the label
        unique_values = train_label.iloc[:, 0].nunique()
        if unique_values > 20:  # Assuming a threshold for classification vs regression
            is_regression = True
    else:
        # Check the cardinality of each label
        for col in train_label.columns:
            unique_values = train_label[col].nunique()
            if (
                unique_values > 20
            ):  # Assuming a threshold for classification vs regression
                is_regression = True
                break

    if is_regression:
        # Use XGBRegressor for regression tasks
        if num_labels == 1:
            model = xgb.XGBRegressor()
            model.fit(train_set, train_label.iloc[:, 0])
        else:
            model = MultiOutputClassifier(xgb.XGBRegressor(), n_jobs=-1)
            model.fit(train_set, train_label)
    else:
        # Encode labels if they are categorical
        for col in train_label.columns:
            if train_label[col].dtype == "object":
                le = LabelEncoder()
                train_label[col] = le.fit_transform(train_label[col])

        # Use XGBClassifier for classification tasks
        if num_labels == 1:
            model = xgb.XGBClassifier()
            model.fit(train_set, train_label.iloc[:, 0])
        else:
            model = MultiOutputClassifier(xgb.XGBClassifier(), n_jobs=-1)
            model.fit(train_set, train_label)

    return model


# Helpers
def get_folders(directory: str) -> list:
    "Get the names of the directories inside the given directory."

    folders = []
    for item in os.listdir(directory):
        if os.path.isdir(os.path.join(directory, item)):
            folders.append(item)
    return folders


def replace_special_characters(col_name: str) -> str:
    return col_name.replace("<", "&lt;").replace(">", "&gt;")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="artifacts/data")
    parser.add_argument("--embeddings_path", type=str, default="artifacts/embeddings")
    parser.add_argument("--results-dir", type=str, default="FinalResults")
    parser.add_argument("--dataset-name", nargs="+", help="Dataset name, e.g. TeaRetail", required=False) # This is not used yet, but we could refactor the get_predictor_performance function accordingly.
    parser.add_argument("--model-name", nargs="+", help="Model name, e.g. DeepCAE", required=False)

    args = parser.parse_args()

    start_time = datetime.now()

    downstream_results = {}
    downstream_results["RawData"] = get_predictor_performance(args.data_path)
    if not args.model_name:
        model_names = MODEL_NAMES
    else:
        model_names = args.model_name
    
    for model_name in model_names:
        downstream_results[model_name] = get_predictor_performance(
            args.embeddings_path + f"/{model_name}"
        )

    # Assemble the result dataframe and save it.
    df = pd.DataFrame(downstream_results)
    df.to_csv(f"artifacts/results/{args.results_dir}/downstream_results.csv")

    end_time = datetime.now()
    runtime = end_time - start_time
    logging.info(f"Finished downstream benchmarking after a duration of {runtime}!")
