import os
import pandas as pd
from .constants import N_TRIALS


class ResultManager:
    def __init__(
        self,
        results_file_path: str = os.path.join(
            os.path.dirname(__file__), "..", "results.csv"
        ),
    ):
        self.results_file_path = results_file_path
        if not os.path.exists(results_file_path):
            self.results_df = pd.DataFrame(
                columns=["dataset", "n_trees", "feature", "mse", "trial_index"]
            )
        else:
            self.results_df = pd.read_csv(results_file_path)

        self.check_and_clean()

    def check_results_exist(self, dataset: str, n_trees: int, feature: str):
        """
        Check if the results are already in the file and clean the file if necessary.

        Args:
            dataset (str): The dataset name.
            n_trees (int): The number of trees.
            feature (str): The feature name.

        Returns:
        """
        return (
            len(
                self.results_df.query(
                    f"dataset == '{dataset}' and n_trees == {n_trees} and feature == '{feature}'"
                )
            )
            == N_TRIALS
        )

    def clean_existing_results(self, dataset: str, n_trees: int, feature: str):
        """
        Clean the results for a given dataset, n_trees, and feature.

        Args:
            dataset (str): The dataset name.
            n_trees (int): The number of trees.
            feature (str): The feature name.
        """
        self.results_df = self.results_df.query(
            f"dataset != '{dataset}' or n_trees != {n_trees} or feature != '{feature}'"
        )

    def check_and_clean(self):
        """
        Check if the results are already in the file and clean the file if necessary.
        """
        unique_datasets = self.results_df["dataset"].unique()
        unique_n_trees = self.results_df["n_trees"].unique()
        unique_features = self.results_df["feature"].unique()
        for dataset in unique_datasets:
            for n_trees in unique_n_trees:
                for feature in unique_features:
                    if not self.check_results_exist(dataset, n_trees, feature):
                        # Remove the incomplete results
                        self.clean_existing_results(dataset, n_trees, feature)

    def get_results(self):
        return self.results_df

    def push_results(
        self, dataset: str, n_trees: int, feature: str, values: list[float]
    ):
        """
        Push the results to the file.

        Args:
            dataset (str): The dataset name.
            n_trees (int): The number of trees.
            feature (str): The feature name.
            values (list[float]): The values to push.
        """
        self.results_df = pd.concat(
            [
                self.results_df,
                pd.DataFrame(
                    {
                        "dataset": [dataset] * len(values),
                        "n_trees": [n_trees] * len(values),
                        "feature": [feature] * len(values),
                        "mse": values,
                        "trial_index": list(range(len(values))),
                    }
                ),
            ]
        )
        self.save_results()

    def save_results(self):
        self.results_df.to_csv(self.results_file_path, index=False)


if __name__ == "__main__":
    result_manager = ResultManager()
    print(result_manager.get_results())
