import os
import sys
from glob import glob

import yaml
import pandas as pd
from scipy.special import softmax

sys.path.append("..")
from codes.run_eval import run_eval
from common.experiment_manager import ExperimentManager
from common.utils import ResultManager

# Get the repository root directory (2 levels up from this file)
repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
cfg_file = os.path.join(repo_root, "config.yaml")
with open(cfg_file) as f:
    config = yaml.safe_load(f)

# Resolve base_dir to absolute path relative to repository root
if not os.path.isabs(config["path"]["base_dir"]):
    config["path"]["base_dir"] = os.path.abspath(
        os.path.join(repo_root, config["path"]["base_dir"])
    )

class ClassificationExperimentEvaluator(ExperimentManager):
    """
    Experiment evaluator for PCG classification models.

    This class manages the evaluation pipeline for trained classification models,
    including multi-seed evaluation, result aggregation, and per-sample analysis
    with demographic information. It extends the base ExperimentManager to provide
    evaluation-specific functionality.

    Attributes:
        exe_mode: Execution mode identifier used to fetch configuration from config.yaml
    """

    exe_mode = "clf_exp01"

    def _fetch_config_file(self, exp_id: str):
        """
        Fetch the experiment configuration file path for a given experiment ID.

        Args:
            exp_id: Experiment ID as a string. The file path is constructed as
                resources/exp{exp_id//100:02d}s/exp{exp_id:04d}.yaml

        Returns:
            str: Path to the experiment configuration YAML file
        """

        exp_config_file = os.path.join(
            "./resources",
            f"exp{exp_id//100:02d}s",
            f"exp{exp_id:04d}.yaml"
        )

        return exp_config_file

    def _run_eval(
        self,
        eval_target,
        device,
        dump_loc,
        multiseed_run,
        overwrite_params
    ):
        """
        Execute evaluation on a trained model with parameter overrides.

        This is a wrapper method that delegates to the run_eval function from
        codes.run_eval module, allowing for parameter customization during evaluation.

        Args:
            eval_target: Path to the trained model checkpoint directory
            device: Device string for model inference (e.g., "cuda:0", "cpu")
            dump_loc: Directory path where evaluation results will be saved
            multiseed_run: Boolean indicating if this is part of a multi-seed
                experiment (affects result formatting and storage)
            overwrite_params: Dictionary of parameters to override from the saved
                model configuration (e.g., seed, dataset, data_lim, load_demos)

        Returns:
            tuple: (val_result, test_result) dictionaries containing evaluation
                metrics and predictions for validation and test sets
        """
        return run_eval(
            eval_target,
            device,
            dump_loc,
            multiseed_run,
            overwrite_params=overwrite_params
        )

    def _fetch_trained_model_loc(self, params):
        """
        Locate the trained model checkpoint directory for evaluation.

        This method navigates through the multirun directory structure to find
        the specific model checkpoint based on the experiment configuration and
        random seed. It selects the latest checkpoint in the seed directory.

        Args:
            params: Namespace object containing:
                - eval_target: Root directory of the trained model
                - seed: Random seed used during model training

        Returns:
            str: Full path to the trained model checkpoint directory

        Raises:
            ValueError: If 'eval_target' attribute is not found in params
        """
        # check if params has a `eval_target` attr.
        if "eval_target" not in params:
            raise ValueError("`eval_target` not found in params.")

        # Fetch trained model location.
        trained_model_root = os.path.join(
            params.eval_target,
            "multirun",
            "train",
            f"seed{params.seed:04d}",
        )
        trained_model_loc = glob(trained_model_root + "/*")[-1]

        return trained_model_loc

    def main(self, single_run=False):
        """
        Execute the main evaluation pipeline across multiple random seeds.

        This method orchestrates the complete evaluation workflow:
        1. Prepares result storage with configured metrics columns
        2. Iterates over multiple random seeds
        3. Locates trained models for each seed
        4. Runs evaluation on both validation and test sets
        5. Aggregates results into a multi-seed summary table
        6. Optionally collects per-sample demographic analysis

        Args:
            single_run: Boolean flag to run evaluation for only the first seed
                (useful for debugging). Default is False (evaluate all seeds)

        Returns:
            str: Path to the saved multi-seed result CSV file
                (ResultTableMultiSeed.csv)
        """
        # Prepare result storer.
        columns = \
            ["seed", "dataset"] + \
            config["experiment"][self.exe_mode]["result_cols"]

        savename = os.path.join(
            self.save_loc, "ResultTableMultiSeed.csv")
        result_manager = ResultManager(
            savename=savename, columns=columns)

        seeds = config["experiment"][self.exe_mode]["seed"]["multiseed"]
        load_demos = self.param_manager.get_params().load_demos
        if load_demos:
            demo_results = []
        for _, seed in enumerate(seeds):
            self.param_manager.update_params({"seed": seed})

            # Run training and store result.
            trained_model_loc = \
                self._fetch_trained_model_loc(
                    self.param_manager.get_params())

            # Eval.
            save_loc_eval = os.path.join(
                self.save_loc, 
                "multirun", 
                "eval", 
                f"seed{seed:04d}"
            )
            os.makedirs(save_loc_eval, exist_ok=True)

            overwrite_params = {
                "seed": seed,
                "data_lim": None,
                "val_data_lim": None,
                "dataset": self.param_manager.get_params().dataset,
                "load_demos": load_demos,
            }
            val_result, test_result = self._run_eval(
                eval_target=trained_model_loc, 
                device=self.device,
                dump_loc=save_loc_eval, 
                multiseed_run=True,
                overwrite_params=overwrite_params
            )
            result_row = self._form_result_row(
                seed, "val", columns, val_result)
            result_manager.add_result(result_row)

            result_row = self._form_result_row(
                seed, "test", columns, test_result)
            result_manager.add_result(result_row)

            if load_demos:
                df_demo_result = self._form_per_sample_result_df(test_result, seed)
                demo_results.append(df_demo_result)
            
            result_manager.save_result(is_temporal=True)

            if single_run:
                break

        result_manager.save_result()
        if load_demos:
            df_demo_result = pd.concat(demo_results)
            df_demo_result.to_csv(
                os.path.join(self.save_loc, "demo_result.csv"))
        return result_manager.savename
    
    def _form_per_sample_result_df(self, result, seed):
        """
        Transform evaluation results into a per-sample DataFrame with demographics.

        This method processes raw evaluation results to create a detailed DataFrame
        containing individual sample predictions, probabilities, and demographic
        information (age, gender). Predictions are converted to probabilities using
        softmax and both class labels and probabilities are included.

        Args:
            result: Dictionary containing evaluation results with keys:
                - 'y_trues': Array of true labels
                - 'y_preds': Array of raw prediction logits (before softmax)
                - 'demographics': List of demographic tuples (age, gender)
            seed: Random seed used for this evaluation run

        Returns:
            pd.DataFrame: DataFrame with columns:
                - seed: Random seed identifier
                - age: Patient age
                - gender: Patient gender
                - y_true: True class label
                - y_preds_class: Predicted class label (argmax of probabilities)
                - y_preds_prob: Probability of positive class (class 1)
        """
        demo_result = []
        y_trues, demos = result["y_trues"], result["demographics"]
        y_preds = softmax(result["y_preds"], axis=1)#.argmax(axis=1)
        y_preds_classes = y_preds.argmax(axis=1)
        y_preds_probs = y_preds[:, 1]
        for y_true, y_preds_class, y_preds_prob, demo in zip(y_trues, y_preds_classes, y_preds_probs, demos):
            _result = {
                "seed": seed,
                "age": demo[0],
                "gender": demo[1],
                "y_true": y_true,
                "y_preds_class": y_preds_class, 
                "y_preds_prob": y_preds_prob
            }
            demo_result.append(_result)
        df_demo_result = pd.DataFrame(demo_result)
        return df_demo_result

if __name__ == "__main__":

    from argparse import ArgumentParser

    parser = ArgumentParser()

    parser.add_argument(
        '--exp', 
        default=0
    )
    parser.add_argument(
        '--device', 
        default="cuda:0"
    )
    parser.add_argument(
        '--debug', 
        action="store_true"
    )
    parser.add_argument(
        '--multirun', 
        action="store_true"
    )    
    args = parser.parse_args()

    print(args)

    executer = ClassificationExperimentEvaluator(
        int(args.exp), 
        args.device,
        debug=args.debug
    )
    executer.main(not args.multirun)
