import os
import sys
from glob import glob

import yaml

sys.path.append("..")
from codes.run_eval import run_eval
from codes.run_train import run_train
from common.experiment_manager import ExperimentManager

# Get the repository root directory (2 levels up from this file)
repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
cfg_file = os.path.join(repo_root, "config.yaml")
with open(cfg_file) as f:
    config = yaml.safe_load(f)

# Resolve base_dir to absolute path relative to repository root
if not os.path.isabs(config["path"]["base_dir"]):
    config["path"]["base_dir"] = os.path.abspath(
        os.path.join(repo_root, config["path"]["base_dir"])
    )

class ClassificationExperimentManager(ExperimentManager):
    """
    Experiment manager for PCG classification training.

    This class manages the entire training pipeline including experiment configuration,
    training execution, hyperparameter updates, and evaluation. It extends the base
    ExperimentManager to provide classification-specific functionality.

    Attributes:
        exe_mode: Execution mode identifier used to fetch configuration from config.yaml
    """

    exe_mode = "clf_exp01"

    def _fetch_config_file(self, exp_id: str):
        """
        Fetch the experiment configuration file path for a given experiment ID.

        Args:
            exp_id: Experiment ID as a string. The file path is constructed as
                exp_yamls/exp{exp_id//100:02d}s/exp{exp_id:04d}.yaml

        Returns:
            str: Path to the experiment configuration YAML file
        """

        exp_config_file = os.path.join(
            # "./resources",
            # f"exp{exp_id//100:02d}s",
            # f"exp{exp_id:04d}.yaml"
            "./exp_yamls",
            f"exp{exp_id//100:02d}s",
            f"exp{exp_id:04d}.yaml"            
        )

        return exp_config_file

    def _run_train(self, params, save_loc):
        """
        Execute the training process with specified parameters.

        This method configures dataset versions, batch sizes, and finetune targets
        based on the experiment configuration. For synthetic datasets, it automatically
        triggers iterative training with adaptive hyperparameter adjustment if the
        initial result doesn't meet quality thresholds.

        Args:
            params: Namespace object containing all training parameters including:
                - dataset: Dataset name (e.g., "syn" for synthetic data)
                - target_dx: Target diagnosis for classification
                - batch_size: Batch size (can be "per_model" for auto-selection)
                - modelname: Name of the model architecture
                - finetune_target: Optional path to pretrained model for finetuning
            save_loc: Directory path where training results will be saved

        Returns:
            tuple: (best_result, save_dir) where:
                - best_result: Dictionary containing best validation metrics
                - save_dir: Directory path where the trained model is saved
        """
        if params.dataset == "syn":
            params.dataset_ver_norm = \
                config["experiment"]\
                    [self.exe_mode]["syn_dataset_ver"]["Normal"]
            params.dataset_ver_dx = \
                config["experiment"]\
                    [self.exe_mode]["syn_dataset_ver"][params.target_dx]            
            params.data_lim = 10000
            params.val_data_lim = 2500
        
        if params.batch_size == "per_model":
            params.batch_size = \
                config["experiment"]\
                    [self.exe_mode]["model_to_batchsize"][params.modelname]

        if params.finetune_target is not None:
            params = self._update_finetune_target(params)
        best_result, save_dir = run_train(params, save_loc)

        # Rerun if matching condition.
        # training dataset is syn, and best_result > 0.5
        if params.dataset == "syn":
            best_result, save_dir = self._train_loop(
                best_result["loss"], params, save_dir, save_loc)

        return best_result, save_dir
    
    def _train_loop(self, best_loss, params, save_dir, save_loc):
        """
        Iteratively retrain the model with adjusted hyperparameters until convergence.

        This method is used for synthetic datasets to ensure high-quality training.
        It adaptively reduces the learning rate and adjusts data augmentation
        parameters if the validation loss exceeds 0.5. The loop runs up to 5
        iterations or until the loss falls below the threshold.

        Args:
            best_loss: Best validation loss from the previous training run
            params: Namespace object containing training parameters. Must have
                'rerun' attribute to enable retraining loop
            save_dir: Directory where the previous model was saved
            save_loc: Root directory for saving results

        Returns:
            tuple: (best_result, save_dir) where:
                - best_result: Dictionary with updated best validation metrics
                - save_dir: Directory path of the final trained model
        """
        if not hasattr(params, "rerun"):
            return best_loss, save_dir

        for _ in range(5):
            params.learning_rate = params.learning_rate / 2

            params.aug_mask_ratio = self._update_aug_params(
                params.aug_mask_ratio, False)
            params.max_shift_ratio = self._update_aug_params(
                params.max_shift_ratio, False)
            params.flip_rate = self._update_aug_params(
                params.flip_rate, False)
            params.breathing_scale = self._update_aug_params(
                params.breathing_scale, True)
            params.scale_ratio = self._update_aug_params(
                params.scale_ratio, True)
            params.stretch_ratio = self._update_aug_params(
                params.stretch_ratio, True)

            best_result, save_dir = run_train(params, save_loc)
            best_loss = best_result["loss"]
            if best_loss < 0.5:
                break
        return best_result, save_dir
    
    def _update_aug_params(self, param_val, zero_to_one: bool):
        """
        Update data augmentation parameter value adaptively.

        This method adjusts augmentation parameters during retraining to improve
        model generalization. The update strategy depends on the parameter's
        expected value range.

        Args:
            param_val: Current value of the augmentation parameter
            zero_to_one: Boolean flag indicating parameter value range:
                - True: Parameter is in [0, 1] range (e.g., scale_ratio, stretch_ratio)
                  Always multiply by 0.5 to reduce augmentation strength
                - False: Parameter can exceed 1 (e.g., mask_ratio, shift_ratio)
                  Multiply by 2 if <1, multiply by 0.5 if ≥1 to oscillate around 1

        Returns:
            float: Updated parameter value
        """
        if not zero_to_one:
            if param_val < 1:
                param_val *= 2.
            else:
                param_val *= 0.5
        else:
            param_val *= 0.5
        return param_val

    def _update_finetune_target(self, params):
        """
        Resolve and update the finetune target path to the actual model checkpoint.

        This method locates the specific model checkpoint directory within the
        finetune_target path by navigating through the multirun/train/seed directory
        structure and selecting the latest checkpoint.

        Args:
            params: Namespace object containing:
                - finetune_target: Base path to pretrained model directory
                - seed: Random seed used in the pretrained model training
                - ft_target_ep: (Optional) Specific epoch to load for finetuning

        Returns:
            Namespace: Updated params object with finetune_target pointing to
                the actual checkpoint directory

        Raises:
            NotImplementedError: If ft_target_ep is specified (not yet supported)
        """
        finetune_target = os.path.join(
            params.finetune_target, 
            "multirun",
            "train",
            f"seed{params.seed:04d}",
        )
        if "ft_target_ep" in params:
            # finetune_target = glob(finetune_target + "/*")[-1]
            raise NotImplementedError
        else:
            finetune_target = glob(finetune_target + "/*")[-1]

        params.finetune_target = finetune_target

        return params

    def _run_eval(self, eval_target, device, dump_loc, multiseed_run):
        """
        Execute evaluation on a trained model.

        This is a wrapper method that delegates to the run_eval function from
        codes.run_eval module.

        Args:
            eval_target: Path to the trained model checkpoint directory
            device: Device string for model inference (e.g., "cuda:0", "cpu")
            dump_loc: Directory path where evaluation results will be saved
            multiseed_run: Boolean indicating if this is part of a multi-seed
                experiment (affects result formatting and storage)

        Returns:
            tuple: (val_result, test_result) dictionaries containing metrics
                for validation and test sets respectively
        """
        return run_eval(eval_target, device, dump_loc, multiseed_run)

if __name__ == "__main__":

    from argparse import ArgumentParser

    parser = ArgumentParser()

    parser.add_argument(
        '--exp', 
        default=0
    )
    parser.add_argument(
        '--device', 
        default="cuda:0"
    )
    parser.add_argument(
        '--debug', 
        action="store_true"
    )
    parser.add_argument(
        '--multirun', 
        action="store_true"
    )    
    args = parser.parse_args()

    print(args)

    executer = ClassificationExperimentManager(
        int(args.exp), 
        args.device,
        debug=args.debug
    )
    executer.main(not args.multirun)
