import logging
import os
from pathlib import Path
from typing import List
import hydra
from omegaconf import DictConfig, OmegaConf, OmegaConf as OC
from haipr.utils.resolvers import register_resolvers
from haipr.train import HAIPRTrainer
from haipr.inference import HAIPRInference
from haipr.data import HAIPRData
from haipr.optimize import HAIPROptimizer
import gc

logger = logging.getLogger(__name__)


def _parse_stages(stages) -> List[str]:
    if stages is None:
        return ["features", "train", "inference"]
    if isinstance(stages, str):
        if stages.lower() == "all":
            return ["features", "train", "inference"]
        return [s.strip() for s in stages.split("-")]
    if isinstance(stages, (list, tuple)):
        return list(stages)
    return ["features", "train", "inference"]

def _run_self_test():
    """
    Run a self-test to check if the pipeline is working.
    """
    logger.info("Running self-test...")
    pass

@hydra.main(version_base=None, config_path="conf", config_name="haipr")
def main(cfg: DictConfig):
    register_resolvers()
    logger.info("Starting HAIPR Pipeline")
    stages = _parse_stages(cfg.get("stages"))
    if "inference" in stages:
        logger.warning("Setting log_models to True for inference...")
        cfg.mlflow.log_models = True
    logger.info(f"Running stages: {stages}")
    logger.info("="*80)
    if "test" in stages:
        logger.info("="*80)
        logger.info("Running self-test...")
        logger.info("="*80)
        _run_self_test()
    if "features" in stages:
        logger.info("="*80)
        logger.info("Preparing features...")
        logger.info("="*80)
        data = HAIPRData(cfg)
        if not data._check_features_ready():
            logger.info(f"Generating features: {data.cache_key}")
            logger.info(cfg.embedder)
            data.prepare_features()
        del data
        gc.collect()
    if "optimize" in stages:
        logger.info("="*80)
        logger.info("Running optimization...")
        logger.info("="*80)
        optimizer = HAIPROptimizer(cfg)
        optimizer.optimize()
        optimization_params = optimizer.get_optimized_params()
        del optimizer
        gc.collect()
    if "train" in stages:
        cfg.data.subsample_threshold = 0 # 
        logger.info("="*80)
        logger.info("Running training...")
        logger.info("="*80)
        trainer = HAIPRTrainer(cfg)
        trainer.setup_data(cfg)
        trainer.tune(cfg)
        cfg.models_from_parent_run = trainer.parent_run_id
        del trainer
        gc.collect()
    if "inference" in stages:
        logger.info("="*80)
        logger.info("Running inference...")
        logger.info("="*80)
        inference = HAIPRInference(cfg)
        inference.run()

    logger.info("=== HAIPR Pipeline completed successfully ===")


if __name__ == "__main__":
    register_resolvers() # need to call here so when main gets called the they are already registered
    main()
