"""
   Experiment class for running experiments with various models and metrics.
"""
import logging
import os
import pickle
import time
import json
import pandas as pd

from module.model import RsetWrapper
from module.utils import Hparams, NumpyEncoder
from module.datasets import DatasetLoader
from module.metrics import Metrics

class Experiment:
    """ Experiment class for running experiments with various models and metrics. """
    def __init__(self, dataset: DatasetLoader=None, params: Hparams=Hparams(),
                    metrics: Metrics=None):
        """ Constructor for the Experiment class.

        Args:
            dataset (DatasetLoader, optional): Dataset loader object. Defaults to None.
            params (Hparams, optional): Hyperparameters object. Defaults to Hparams().
            metrics (Metrics, optional): Metrics object. Defaults to None.
        """
        self.dataset = dataset
        self.params = params

        self.model = None if getattr(self.params, 'model_class', None) is None \
                    else self.params.model_class(**self.params.model_params)

        self.metrics = metrics
        self.result = {}
        self.fold_idx = None

    #region Experiment pipeline
    def tuning(self, nested_cv, fold):
        #TODO this need to be generalized
        """ Function to tune model parameters based on the given nested cross-validation and fold.

        Args:
            nested_cv (sklearn.model_selection): Nested cross-validation object.
            fold (int): The current fold number.
        """
        logger = logging.getLogger("experiment.tuning")
        logger.debug("Parameter tuning")

        if issubclass(self.params.model_class, RsetWrapper):
            with open("out/rset_params/results.pkl", "rb") as pickle_file: #TODO Fix path
                result = pickle.load(pickle_file)
            result["fold"] = result["fold"].astype(int)
            params = result[(result["dataset"] == self.dataset.dataset_name) \
                                                    & (result["fold"] == fold)]
            if len(params) == 1:
                logger.info("Using existing parameters")
                depth_budget = params['depth_budget'].values[0]
                lamb = params['regularization'].values[0]
            else:
                result_param, _ = self.model.tune(nested_cv, self.dataset.feat_name)
                depth_budget = result_param['depth_budget']
                lamb = result_param['regularization']
                new_row = {
                    "dataset": self.dataset.dataset_name,
                    "fold": fold,
                    "depth_budget": depth_budget,
                    "regularization": lamb,
                }
                result = pd.concat([result, pd.DataFrame([new_row])], ignore_index=True)
                with open("out/rset_params/results.pkl", "wb") as pickle_file:
                    pickle.dump(result, pickle_file)
            self.params.model_params['config']['depth_budget'] = int(depth_budget)
            self.params.model_params['config']['regularization'] = float(lamb)
        logger.info("Model Params: %s", self.params.model_params)

    def train(self, X, y):
        """ Function to train the model on the given data.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
        """
        self.model = self.params.model_class(**self.params.model_params)
        start_time = time.time()
        self.model.fit(X, y)
        end_time = time.time()
        self.result["train_time"] = end_time - start_time

    def training(self, X, y):
        """ Wrapper function to call train, process the data and log

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
        """
        logger = logging.getLogger("experiment.training")
        if self.params.encoder is not None:
            X = self.params.encoder.transform(X)
        logger.info("Training model with data shape: %s", X.shape)
        self.train(X, y)
        logger.info("Train time: %s", self.result["train_time"])

        if issubclass(self.params.model_class, RsetWrapper):
            self.model.find_special_tree(self.params.rng)
            logger.info("Number of trees in the Rashomon set: %s", self.model.ntrees())


    def selecting(self, X, y, sensitive_features=None):
        """ Function to select the best tree from the Rashomon set.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            sensitive_features (np.ndarray): Sensitive features for fairness (optional).
        """
        logger = logging.getLogger("experiment.selecting")
        assert issubclass(self.params.model_class, RsetWrapper), "Only Rashomon set can be selected"


        if self.dataset.fairness_mode is not None:
            metrics_func = self.metrics.fairness_score_func.copy()
            keys_to_delete = [metric_name for metric_name in metrics_func.keys() \
                                         if metric_name in self.model.best_tree]

            for key in keys_to_delete:
                del metrics_func[key]

            if len(metrics_func) != 0: #TODO Support encoders
                logger.info("Selecting fair tree: %s", metrics_func.keys())
                self.metrics.sensitive_features = sensitive_features
                self.model.select_fair_tree(X, y, self.params.rng, metrics_func)
        for key, value in self.model.best_tree.items():
            logger.info("Best Tree for %s: %s with score: %s",
                                key, value, self.model.best_tree_score[key])
        if "default" not in self.model.best_tree:
            logger.info("Selecting default tree")
            X_select = self.params.encoder.transform(X) if self.params.encoder is not None else X
            self.model.select_tree(X_select, y, 'default', self.params.rng)

        if "kantch" not in self.model.best_tree:
            logger.info("Selecting kantch tree")
            X_adv_select = self.metrics.kantch_example_gen(
                        self.model.get_optimal_tree(self.params.encoder), X, y)
            X_adv_select = self.params.encoder.transform(X_adv_select) \
                            if self.params.encoder is not None else X_adv_select
            self.model.select_tree(X_adv_select, y, 'kantch', self.params.rng)



    def testing(self, X, y, X_test, y_test, X_test_sensitive=None):
        """ Function to test the model on the given test data.

        Args:
            X (np.ndarray): Feature matrix.
            y (np.ndarray): Target vector.
            X_test (np.ndarray): Test feature matrix.
            y_test (np.ndarray): Test target vector.
            X_test_sensitive (np.ndarray, optional): test set sensitive features. Defaults to None.
        """
        logger = logging.getLogger("experiment.testing")
        logger.info("Testing model")

        if self.params.reset_results:
            train_time = self.result.get("train_time", None)
            self.result = {"train_time": train_time}

        self.metrics.eval_model_init(self.model, self.params.encoder)
        self.metrics.metric_data_init(X, y, X_test, y_test, X_test_sensitive)
        result = self.metrics.evaluation()

        self.result.update(result)
        logger.info("Result: %s", json.dumps(self.result, indent=4, cls=NumpyEncoder))

    def cross_validate(self, train=True, test=True, fold=None) -> None:
        """ Function to perform cross-validation on the dataset.

        Args:
            train (bool, optional): Train the model. Defaults to True.
            test (bool, optional): Test the model. Defaults to True.
            fold (_type_, optional): Choose a specific fold to run. Defaults to None.
        """
        logger = logging.getLogger("experiment.cross_validate")
        folds_data = self.dataset.kfold_normalized_generator(select_size=0.1) \
                if self.params.selection else self.dataset.kfold_normalized_generator(select_size=0)
        for data in folds_data:
            X_select_sensitive, X_test_sensitive = None, None
            if self.params.selection:
                if self.dataset.fairness_mode is not None:
                    fold_idx, nested_cv, (X, y, _), \
                        (X_select, y_select, X_select_sensitive), \
                        (X_test, y_test, X_test_sensitive) = data
                else:
                    fold_idx, nested_cv, (X, y), (X_select, y_select), (X_test, y_test) = data
                    # X_train_sensitive = None
            else:
                #TODO add fairness mode
                fold_idx, nested_cv, (X, y), (X_select, y_select), (X_test, y_test) = data


            if fold is not None and fold_idx != fold:
                continue

            self.fold_idx = fold_idx
            logger.info("Fold %s", fold_idx)

            if not self.params.retrain:
                trained = self.load_cross_val_model()
                if not trained:
                    self.params.retrain = True
            else:
                self.clean_cross_val_model()

            ### Tuning
            if self.params.tune and self.params.retrain:
                self.tuning(nested_cv, fold)

            ### Preprocessing
            if self.params.encoder is not None:
                self.params.encoder.fit(X, y, self.dataset.feat_name)

            ### Training
            if train and self.params.retrain:
                self.training(X, y)
                self.save_cross_val_model()
            else:
                logger.info("Model parameters: %s", self.params.model_params)
                logger.info("Model fitted time: %s", self.result["train_time"])
                if issubclass(self.params.model_class, RsetWrapper):
                    logger.info("Number of trees in rset: %s", self.model.ntrees())

            ### Selection
            if self.params.selection and issubclass(self.params.model_class, RsetWrapper):
                if self.dataset.fairness_mode is None:
                    X_select_sensitive = None
                self.selecting(X_select, y_select, X_select_sensitive)
                self.save_cross_val_model()

            if test:
                self.testing(X, y, X_test, y_test, X_test_sensitive)

            self.save_cross_val_model()
    #endregion

    ###
    #region I/O Save/Load functions
    ###
    def get_state(self) -> dict:
        """ Function to get the current state of the experiment.

        Returns:
            dict: A dictionary containing the current state of the experiment.
        """
        return {
            "model": self.model,
            "param_state": self.params.get_state(),
            "result": self.result
        }

    def set_state(self, model, param_state, result) -> None:
        """ Function to set the current state of the experiment.

        Args:
            model (model_wrapper): Model wrapper object.
            param_state (hparams_state): Hyperparameters state.
            result (dict): Result dictionary.
        """
        self.params.set_state(param_state)
        self.model = model
        self.result = result

    def filename(self) -> str:
        """ Function to get the filename for saving the experiment state.

        Returns:
            str: The filename for saving the experiment state.
        """
        return f"{self.params.model_name}_{self.dataset.dataset_name}_fold_{self.fold_idx}.pkl"

    def param_path(self) -> str:
        """ Function to get the parameter file path.
        Returns:
            str: The parameter file path.
        """
        return self.file_path("param_dir")

    def model_path(self):
        """ Function to get the model file path.
        Returns:
            str: The model file path.
        """
        return self.file_path("model_dir")

    def result_path(self) -> str:
        """ Function to get the result file path.

        Returns:
            str: The result file path.
        """
        return self.file_path("result_dir")

    def file_path(self, path_param: str) -> str:
        """

        Args:
            path_param (str): _description_

        Returns:
            str: _description_
        """
        return os.path.join(
                self.params.io_params["output_dir"],
                self.params.io_params[path_param],
                self.filename()
            )

    def save_cross_val_model(self) -> None:
        """ Function to save the cross-validation model. """
        logger = logging.getLogger("experiment.save_cross_val_model")
        logger.info("Saving cross-validation model")
        state = self.get_state()
        with open(self.model_path(), 'wb') as file:
            pickle.dump(state["model"], file)
        with open(self.param_path(), 'wb') as file:
            pickle.dump(state["param_state"], file)
        with open(self.result_path(), 'wb') as file:
            pickle.dump(state["result"], file)

    def load_experiment(self, model_file, param_file, result_file) -> None:
        """ Function to load the experiment state from the given files.

        Args:
            model_file (str): Model file path.
            param_file (str): Parameter file path.
            result_file (str): Result file path.
        """
        logger = logging.getLogger("experiment.load_experiment")
        logger.info("Loading experiment")
        with open(model_file, 'rb') as file:
            model_state = pickle.load(file)
        with open(param_file, 'rb') as file:
            param_state = pickle.load(file)
        with open(result_file, 'rb') as file:
            result_state = pickle.load(file)
        self.set_state(model_state, param_state, result_state)

    def load_cross_val_model(self) -> bool:
        """ Function to load the cross-validation model.

        Returns:
            bool: True if the model was loaded successfully, False otherwise.
        """
        logger = logging.getLogger("experiment.load_cross_val_model")
        logger.info("Loading cross-validation model")
        try:
            with open(self.model_path(), 'rb') as file:
                model_state = pickle.load(file)
            with open(self.param_path(), 'rb') as file:
                param_state = pickle.load(file)
            with open(self.result_path(), 'rb') as file:
                result_state = pickle.load(file)
        except FileNotFoundError:
            logger.info("Model not found")
            return False
        self.set_state(model_state, param_state, result_state)
        return True

    def clean_cross_val_model(self) -> None:
        """ Function to clean the cross-validation model files. """
        for filename in [self.file_path("model_dir"),
                         self.file_path("param_dir"),
                         self.file_path("result_dir")]:
            self.clean_file(filename)

    def clean_file(self, filename: str) -> None:
        """ Function to clean a file if it exists.

        Args:
            filename (str): The name of the file to be cleaned.
        """
        logger = logging.getLogger("experiment.clean_file")
        logger.info("Cleaning experiment file %s", filename)
        if os.path.exists(filename):
            os.remove(filename)

    #endregion
