import logging
from typing import List
import hydra
import mlflow
from omegaconf import DictConfig
from haipr.utils.resolvers import register_resolvers
from haipr.utils.results_logger import ResultsLogger
from haipr.train import HAIPRTrainer
from haipr.inference import HAIPRInference
from haipr.data import HAIPRData
from haipr.optimize import HAIPROptimizer
import torch
import gc

logger = logging.getLogger(__name__)


def _parse_stages(stages) -> List[str]:
    if stages is None:
        return ["features", "train", "inference"]
    if isinstance(stages, str):
        if stages.lower() == "all":
            return ["features", "train", "inference"]
        return [s.strip() for s in stages.split("-")]
    if isinstance(stages, (list, tuple)):
        return list(stages)
    return ["features", "train", "inference"]


def _run_self_test():
    """
    Run a self-test to check if the pipeline is working.
    """
    logger.info("Running self-test...")
    pass


@hydra.main(version_base=None, config_path="conf", config_name="haipr")
def main(cfg: DictConfig):
    register_resolvers()
    logger.info("Starting HAIPR Pipeline")
    stages = _parse_stages(cfg.get("stages"))
    logger.info(f"Running stages: {stages}")
    logger.info("=" * 80)
    torch.set_float32_matmul_precision(cfg.trainer.matmul_precision)

    # Handle test stage (no MLflow tracking)
    if "test" in stages:
        logger.info("=" * 80)
        logger.info("Running self-test...")
        logger.info("=" * 80)
        _run_self_test()

    # Handle features stage (no MLflow tracking)
    if "features" in stages:
        logger.info("=" * 80)
        logger.info("Preparing features...")
        logger.info("=" * 80)
        data = HAIPRData(cfg)
        if not data._check_features_ready() or cfg.data.recompute_features:
            logger.info(f"Generating features: {data.cache_key}")
            logger.info(cfg.embedder)
            data.prepare_features()
        del data
        gc.collect()

    # Check if any MLflow-tracked stages are present
    mlflow_stages = [s for s in stages if s in ["optimize", "train", "inference"]]

    if mlflow_stages:
        # Set up MLflow tracking
        mlflow.set_tracking_uri(cfg.mlflow.tracking_uri)
        mlflow.set_experiment(cfg.mlflow.experiment_name)

        # Create parent orchestration run for all MLflow-tracked stages
        if cfg.pipeline_name:
            orchestration_run_name = cfg.pipeline_name
        else:
            orchestration_run_name = f"{cfg.benchmark.name}"
        with mlflow.start_run(
            run_name="Pipline-" + orchestration_run_name
        ) as parent_run:
            mlflow.set_tag("orchestration", "True")
            mlflow.set_tag("stages", "-".join(mlflow_stages))
            mlflow.set_tag("benchmark", cfg.benchmark.name)

            # Log config to parent run
            results_logger = ResultsLogger(cfg=cfg, run=parent_run)
            results_logger.log_config(cfg)
            logger.info(f"Created orchestration run: {parent_run.info.run_id}")

            # Run MLflow-tracked stages as nested runs
            if "optimize" in stages:
                logger.info("=" * 80)
                logger.info("Running optimization...")
                logger.info("=" * 80)
                optimizer = HAIPROptimizer(cfg)
                optimizer.is_nested = True  # Make optimizer runs nested
                optimizer.optimize()
                optimization_params = optimizer.get_optimized_params()

                # Merge optimized parameters into config for subsequent stages
                logger.info(
                    f"Merging optimized parameters into config: {optimization_params}"
                )
                cfg = optimizer.update_trial_config(cfg, optimization_params)  # type: ignore

                # Log the optimized parameters to the parent run
                mlflow.log_params(
                    {f"optimized_{k}": v for k, v in optimization_params.items()}
                )

                del optimizer
                gc.collect()

            if "train" in stages:
                # Ensure log_models is True when training after optimization
                if "optimize" in stages:
                    logger.info(
                        "Setting log_models to True for training after optimization..."
                    )
                    cfg.mlflow.log_models = True
                    # reset after optimization
                    logger.info("Resetting subsampling threshold and trainer subsampling to 0 for training after optimization")
                    cfg.data.subsample_threshold = 0 
                    cfg.trainer.subsample_train = 0
                logger.info("=" * 80)
                logger.info("Running training...")
                logger.info("=" * 80)
                trainer = HAIPRTrainer(cfg)
                trainer.is_nested = True  # Make trainer runs nested
                trainer.setup_data(cfg)
                trainer.tune(cfg)
                cfg.models_from_parent_run = trainer.parent_run_id
                del trainer
                gc.collect()

            if "inference" in stages:
                logger.info("=" * 80)
                logger.info("Running inference...")
                logger.info("=" * 80)
                # Inference creates its own run in _setup_design_mode
                # We'll tag it with the parent orchestration run_id for linking
                orchestration_run_id = parent_run.info.run_id
                inference = HAIPRInference(cfg)
                # Tag inference run with orchestration parent
                if hasattr(inference, "mlflow_run") and inference.mlflow_run:
                    mlflow.set_tag("orchestration_parent_run_id", orchestration_run_id)
                inference.run()
                del inference
                gc.collect()

    logger.info("=== HAIPR Pipeline completed successfully ===")


if __name__ == "__main__":
    register_resolvers()  # need to call here so when main gets called the they are already registered
    main()
