import os
import hydra
import pandas as pd
import mlflow
from mlflow.tracking import MlflowClient
import logging
from datetime import datetime
from typing import Dict, Any, Tuple
from omegaconf import DictConfig
import torch
from haipr.data import HAIPRData
from haipr.sequence_generators.base_generator import BaseSequenceGenerator
from haipr.utils import AA_ALPHABET, AA_ALPHABET_INFERENCE, AA_ALPHABET_WITH_EXTRAS, load_sequences
from haipr.evaluators import EvaluatorFactory
from haipr.utils.results_logger import ResultsLogger
from typing import List
from typing import Optional
import numpy as np
from haipr.embedding_manager import EmbeddingManager
from haipr.models.esmc import ESMCPredictor
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from haipr.utils.resolvers import register_resolvers

logger = logging.getLogger("haipr.inference")


class HAIPRInference:

    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

        # initialize stages and evaluators
        self.stages: List[str] = []
        self.evaluators: Dict[str, Any] = {}
        self.batch_size = int(getattr(self.cfg, "batch_size", 1024))
        # load data,
        # NOTE: important that cfg is same as used for training, e.g focus on/off etc.
        if self.cfg.inference_mode == "design":
            self._setup_design_mode()
        elif self.cfg.inference_mode == "score_sequences":
            self._setup_score_sequences_mode()
        elif self.cfg.inference_mode == "random_search":
            self._setup_random_search_mode()
        else:
            raise ValueError(f"Invalid mode: {self.cfg.inference_mode}")

    def _setup_design_mode(self):
        """
        Setup design mode.
        """
        self.data = HAIPRData(self.cfg)
        self.representative = self.data.get_representative()

        if self.cfg.alphabet_per_position:
            self.aa_per_position = self.cfg.alphabet_per_position
        else:
            # get mutation positions and alphabet per position
            self.aa_per_position = [
                (
                    list(AA_ALPHABET)
                    if i in self.data.mut_pos
                    else [self.representative[i]]
                )
                for i in range(len(self.representative))
            ]

        # Get experiment by name
        if not isinstance(self.cfg.mlflow.experiment_id, str):
            try:
                experiment_id = str(self.cfg.mlflow.experiment_id)
            except Exception:
                raise ValueError(
                    f"Invalid experiment ID: {self.cfg.mlflow.experiment_id}"
                )
        else:
            experiment_id = self.cfg.mlflow.experiment_id
        # Set up MLflow tracking
        mlflow.set_tracking_uri(self.cfg.mlflow.tracking_uri)
        # get experiment by id if provided, otherwise search by name
        if self.cfg.mlflow.experiment_id is not None:
            self.experiment = mlflow.get_experiment(experiment_id)
        else:
            experiments = mlflow.search_experiments(
                filter_string=f"name='{self.cfg.mlflow.experiment_name}'"
            )
            if len(experiments) == 0:
                logger.warning(
                    f"No experiment found with name {self.cfg.mlflow.experiment_name}"
                )
                self.experiment = None
            else:
                # Pick the first matching experiment by name
                self.experiment = experiments[0]
                self.cfg.mlflow.experiment_id = self.experiment.experiment_id

        # Always create an MLflow run and results logger; HAIPRData may be skipped in some modes
        self.mlflow_run = mlflow.start_run(
            experiment_id=(
                self.experiment.experiment_id if self.experiment else None),
            run_name=self.cfg.run_name,
        )
        self.results_logger = ResultsLogger(cfg=self.cfg, run=self.mlflow_run)
        self.results_logger.log_config(self.cfg)

        # Initialize embedding manager only if needed
        # Check if any evaluators will need embeddings (sklearn models)
        self.embedding_manager = None
        if hasattr(self.cfg, "embedder") and self.cfg.embedder is not None:
            # Only initialize if we expect sklearn models that need embeddings
            # This will be checked again when evaluators are created
            self.embedding_manager = EmbeddingManager(self.cfg)

        # Track stage outputs for detailed logging
        self.generation_stage_outputs: List[Dict[str, np.ndarray]] = []

        # Track retry count for naming new runs
        self.retry_count = 0

    def _setup_score_sequences_mode(self):
        """
        Setup score sequences mode.
        """
        # No dataset required; sequences are provided at call-time.
        # Keep shared components (MLflow run, logger, embedding_manager) from __init__.
        self.data = None
        self.representative = None
        self.aa_per_position = None

    def _setup_random_search_mode(self):
        """
        Setup random search mode.
        """
        # Random search needs alphabet per position; use dataset like design mode.
        self.data = HAIPRData(self.cfg)
        self.representative = self.data.get_representative()
        if self.cfg.alphabet_per_position:
            self.aa_per_position = self.cfg.alphabet_per_position
        else:
            self.aa_per_position = [
                (
                    list(AA_ALPHABET)
                    if i in self.data.mut_pos
                    else [self.representative[i]]
                )
                for i in range(len(self.representative))
            ]

    def _should_parallelize_evaluators(self) -> bool:
        """
        Determine if evaluators should be parallelized.
        PyTorch evaluators always run sequentially.
        """
        # Don't parallelize if DDP is enabled or only one evaluator
        if getattr(self.cfg, "ddp", False) or len(self.evaluators) <= 1:
            return False

        # Don't parallelize if any evaluator is PyTorch-based
        return not any(evaluator.model_type == "pytorch" for evaluator in self.evaluators.values())

    def _evaluate_single_evaluator(
        self, evaluator_key: str, evaluator: Any, sequences: List[str], precomputed_embeddings: np.ndarray = None
    ) -> Tuple[str, np.ndarray]:
        """Evaluate a single evaluator on sequences."""
        try:
            if evaluator.model_type == "sklearn" and precomputed_embeddings is not None:
                # Use pre-computed embeddings for sklearn evaluators
                scores = evaluator.predict_with_embeddings(
                    precomputed_embeddings)
            else:
                # Use regular predict for PyTorch evaluators or if embeddings failed
                scores = evaluator.predict(
                    sequences, batch_size=self.batch_size)
            return evaluator_key, scores
        except Exception as e:
            logger.error(f"Error in evaluator {evaluator_key}: {e}")
            return evaluator_key, np.zeros(len(sequences))

    def _parallel_evaluate_evaluators(
        self, sequences: List[str], precomputed_embeddings: np.ndarray = None
    ) -> Dict[str, np.ndarray]:
        """Evaluate non-PyTorch evaluators in parallel, PyTorch evaluators sequentially."""
        results = {}

        # Run PyTorch evaluators sequentially
        for key, evaluator in self.evaluators.items():
            if evaluator.model_type == "pytorch":
                results[key] = evaluator.predict(
                    sequences, batch_size=self.batch_size)

        # Run non-PyTorch evaluators in parallel
        non_pytorch = {k: v for k, v in self.evaluators.items()
                       if v.model_type != "pytorch"}
        if non_pytorch:
            with ThreadPoolExecutor(max_workers=min(len(non_pytorch), multiprocessing.cpu_count())) as executor:
                futures = {executor.submit(self._evaluate_single_evaluator, k, v, sequences, precomputed_embeddings): k
                           for k, v in non_pytorch.items()}
                for future in futures:
                    try:
                        k, scores = future.result()
                        results[k] = scores
                    except Exception as e:
                        logger.error(
                            f"Failed evaluator {futures[future]}: {e}")
                        results[futures[future]] = np.zeros(len(sequences))

        return results

    def log_stage_outputs(self, stage_outputs: Dict[str, np.ndarray]):
        """
        Log aggregate statistics for each stage to MLflow.
        This provides insights into how each evaluator stage is performing.
        """
        if not self.results_logger:
            return

        stage_metrics = {}
        for stage_name, scores in stage_outputs.items():
            scores_array = np.array(scores)
            # Log aggregate statistics for each stage
            stage_metrics.update(
                {
                    f"batch_{stage_name}_min": np.min(scores_array),
                    f"batch_{stage_name}_max": np.max(scores_array),
                    f"batch_{stage_name}_mean": np.mean(scores_array),
                    f"batch_{stage_name}_std": np.std(scores_array),
                    f"batch_{stage_name}_median": np.median(scores_array),
                }
            )

        # Store stage outputs for generation-level aggregation
        self.generation_stage_outputs.append(stage_outputs)

        # Log the stage metrics
        self.results_logger.log_metrics(stage_metrics)

        # Log raw stage outputs for each generation
        if hasattr(self, 'generator') and hasattr(self.generator, 'generation_count'):
            generation_num = self.generator.generation_count
            for stage_name, scores in stage_outputs.items():
                # Convert numpy array to list for JSON serialization
                scores_list = scores.tolist() if hasattr(scores, 'tolist') else scores
                mlflow.log_dict(
                    {stage_name: scores_list},
                    f"generation_{generation_num}_stage_outputs.json"
                )
        else:
            # Fallback for cases where generator is not available yet
            for stage_name, scores in stage_outputs.items():
                # Convert numpy array to list for JSON serialization
                scores_list = scores.tolist() if hasattr(scores, 'tolist') else scores
                mlflow.log_dict(
                    {stage_name: scores_list},
                    f"pre_generation_stage_outputs.json"
                )

    def load_predictors(self, model_ids: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Load predictors from MLflow using search_logged_models API.

        Supports three modes:
        1. Specific model IDs (if model_ids provided)
        2. Models from specific parent run (if models_from_parent_run configured)
        3. All models from experiment (default)

        Args:
            model_ids: Optional list of specific model IDs to load

        Returns:
            Dict[str, Any]: Dictionary mapping stage names to loaded models
        """
        logger.info(
            f"Loading predictors from experiment: {self.cfg.mlflow.experiment_name}"
        )
        if self.experiment is None:
            raise ValueError(
                f"Experiment {self.cfg.mlflow.experiment_name} not found")

        # Determine which runs to search for models
        target_run_ids = None

        if hasattr(self.cfg, "models_from_parent_run") and self.cfg.models_from_parent_run:
            logger.info(
                f"Filtering models by parent_run_id: {self.cfg.models_from_parent_run}")

            # First, find runs that match the parent_run_id criteria
            run_filter_conditions = []
            run_filter_conditions.append("status = 'FINISHED'")
            run_filter_conditions.append(
                f"tags.benchmark = '{self.cfg.benchmark.name}'")
            run_filter_conditions.append(
                f"tags.parent_run_id = '{self.cfg.models_from_parent_run}'")

            run_filter_string = " AND ".join(run_filter_conditions)
            logger.info(f"Searching runs with filter: {run_filter_string}")

            try:
                runs_df = mlflow.search_runs(
                    experiment_ids=[self.experiment.experiment_id],
                    filter_string=run_filter_string,
                    output_format="pandas"
                )

                if len(runs_df) == 0:
                    raise ValueError(
                        f"No finished runs found in experiment {self.cfg.mlflow.experiment_name} "
                        f"matching parent_run_id: {self.cfg.models_from_parent_run}"
                    )

                target_run_ids = runs_df['run_id'].tolist()
                logger.info(
                    f"Found {len(target_run_ids)} runs matching parent_run_id criteria")

            except Exception as e:
                logger.error(f"Failed to search runs: {e}")
                raise ValueError(
                    f"No runs found matching parent_run_id criteria: {e}")
        else:
            logger.info(
                f"No models_from_parent_run specified, using all models from experiment {self.cfg.mlflow.experiment_name}")

        # Build filter string for logged models
        model_filter_conditions = []

        # Add benchmark filter (this should work on logged models)
        model_filter_conditions.append(
            f"tags.benchmark = '{self.cfg.benchmark.name}'")

        # Add model ID filter if specific models requested
        if model_ids:
            # Convert model_ids to quoted strings for IN clause
            quoted_ids = [f"'{model_id}'" for model_id in model_ids]
            model_filter_conditions.append(
                f"model_id IN ({','.join(quoted_ids)})")
            logger.info(f"Filtering models by specific IDs: {model_ids}")

        # Combine all conditions
        model_filter_string = " AND ".join(model_filter_conditions)

        logger.info(
            f"Searching logged models with filter: {model_filter_string}")

        # Search for logged models using the new API
        try:
            logged_models = mlflow.search_logged_models(
                experiment_ids=[self.experiment.experiment_id],
                filter_string=model_filter_string,
                output_format="list"
            )
        except Exception as e:
            logger.error(f"Failed to search logged models: {e}")
            raise ValueError(f"No logged models found matching criteria: {e}")

        if len(logged_models) == 0:
            raise ValueError(
                f"No logged models found in experiment {self.cfg.mlflow.experiment_name} "
                f"matching filter: {model_filter_string}"
            )

        logger.info(f"Found {len(logged_models)} logged models")

        # Filter logged models by target run IDs if specified
        if target_run_ids is not None:
            filtered_logged_models = [
                model for model in logged_models
                if model.source_run_id in target_run_ids
            ]
            logger.info(
                f"Filtered to {len(filtered_logged_models)} models from target runs")
            logged_models = filtered_logged_models

        # Load models from logged models
        predictors = {}

        for logged_model in logged_models:
            try:
                # Get run details to access tags and parameters
                run_info = mlflow.get_run(logged_model.source_run_id)
                tags = run_info.data.tags

                # Skip parent and inference runs
                if tags.get("is_parent") == "True":
                    logger.info(
                        f"Skipping parent run {logged_model.source_run_id}")
                    continue
                if tags.get("is_inference") == "True":
                    logger.info(
                        f"Skipping inference run {logged_model.source_run_id}")
                    continue

                # Extract model type and task from tags
                model_type = tags.get("model_type", "unknown")
                task_type = tags.get("task", "regression")

                # Determine stage name based on task type
                if task_type == "classification":
                    stage_name = "filter"
                elif task_type == "regression":
                    stage_name = "score"
                else:
                    stage_name = "unknown"

                # Load the model
                try:
                    model_uri = f"models:/{logged_model.model_id}"
                    logger.info(
                        f"Loading model {logged_model.name} from run {logged_model.source_run_id} using model_id {logged_model.model_id}")

                    # Load model based on type
                    if model_type == "sklearn":
                        model = mlflow.sklearn.load_model(model_uri)
                    elif model_type == "pytorch":
                        model = mlflow.pytorch.load_model(model_uri)
                    else:
                        logger.warning(
                            f"Unknown model type {model_type}, skipping")
                        continue

                except Exception as e:
                    logger.warning(
                        f"Failed to load model {logged_model.name}: {e}")
                    continue

                # Store with a unique key
                predictor_key = f"{stage_name}_{model_type}_{logged_model.source_run_id[:8]}"
                predictors[predictor_key] = {
                    "name": logged_model.name,
                    "model": model,
                    "run_id": logged_model.source_run_id,
                    "model_type": model_type,
                    "task_type": task_type,
                    "stage_name": stage_name,
                    "run_info": run_info,
                    "model_id": logged_model.model_id,
                }
                self.stages.append(stage_name)

                logger.info(
                    f"Loaded {model_type} model {logged_model.name} from run {logged_model.source_run_id[:8]} as {predictor_key}"
                )

            except Exception as e:
                logger.warning(
                    f"Failed to process logged model {logged_model.model_id}: {e}")
                continue

        logger.info(f"Successfully loaded {len(predictors)} predictors")
        if len(predictors) == 0:
            raise ValueError(
                f"No predictors loaded from experiment {self.cfg.mlflow.experiment_name}")
        return predictors

    def create_evaluators(
        self,
        predictors: Dict[str, Any],
        default_filters: List[str] = ["filter_esmc_prob"],
    ) -> None:
        """
        Create evaluators from loaded predictors and optional physics configurations.

        Args:
            predictors: Dictionary of loaded predictors from MLflow
        """
        # Get physics evaluator configurations from config if available
        physics_configs = getattr(self.cfg, "physics_evaluators", None)

        # Create evaluators using the factory, passing config and embedding manager
        self.evaluators = EvaluatorFactory.create_evaluators_from_predictors(
            predictors=predictors,
            physics_configs=physics_configs,
            cfg=self.cfg,
            embedding_manager=self.embedding_manager,
        )
        self.has_sklearn_evaluator = any(
            evaluator.model_type == "sklearn" for evaluator in self.evaluators.values())

        if self.cfg.filter_by_perplexity and hasattr(self, "data") and self.data is not None:

            # threshold should be median from the training data
            # since we know that this definitely expresses and is correlated with fitness
            esmc_model = ESMCPredictor(model_name="esmc_300m", num_classes=0)
            esmc_model.model.eval()
            logger.info("Computing perplexity threshold for ESMC model")
            if self.cfg.filter_by_perplexity == "min":
                threshold_func = np.min
            elif self.cfg.filter_by_perplexity == "median":
                threshold_func = np.median
            else:
                raise ValueError(
                    f"Unknown filter_by_perplexity: {self.cfg.filter_by_perplexity}"
                )

            with torch.no_grad():
                prob_threshold = threshold_func(
                    esmc_model.predict(
                        sequences=self.data.get_sequences(),
                        batch_size=self.batch_size,
                        perplexities=True,
                    )["perplexities"]
                )
            logger.info(
                f"Perplexity threshold for ESMC model: {prob_threshold:.6f}")
            logger.info(f"Creating ESMC filter evaluator")
            self.evaluators["filter_esmc_prob"] = EvaluatorFactory.create_ml_evaluator(
                name="filter_esmc_prob",
                model=esmc_model,
                model_type="pytorch",
                task_type="filter",
                is_seq_prob_evaluator=True,
                prob_threshold=prob_threshold,
            )

        logger.info(f"Created {len(self.evaluators)} evaluators")

    def create_generator(self) -> BaseSequenceGenerator:
        """Create sequence generator with loaded evaluators."""
        # Create alphabet per position based on config
        if self.cfg.alphabet_per_position == "train":
            # Use the training alphabet from the data
            aa_per_position = [
                (
                    list(self.data.pos_wise_alphabet[i])
                    if i < len(self.data.pos_wise_alphabet)
                    # positions not in training data
                    else list(AA_ALPHABET_INFERENCE)
                )
                for i in range(len(self.data.representative))
            ]
        elif self.cfg.alphabet_per_position == "full_per_mut_pos":
            aa_per_position = [
                (
                    list(AA_ALPHABET_INFERENCE)
                    if len(self.data.pos_wise_alphabet[i]) > 1
                    else list(self.data.pos_wise_alphabet[i])
                )
                for i in range(len(self.data.representative))
            ]
        elif self.cfg.alphabet_per_position == "de-novo":
            # Initialize aa_per_position as a list of the current representative sequence
            aa_per_position = list(self.data.representative)
            sep_token = self.cfg.data.separator_token
            mut_pos = self.data.mut_pos

            # Find all separator token indices
            sep_indices = [i for i, c in enumerate(
                self.data.representative) if c == sep_token]
            seq_len = len(self.data.representative)

            # Define chain boundaries as (start, end) tuples (exclusive of end)
            chain_boundaries = []
            prev = 0
            for idx in sep_indices:
                chain_boundaries.append((prev, idx))
                prev = idx + 1
            chain_boundaries.append((prev, seq_len))

            # For each chain, if any mut_pos falls within, set only those positions to AA_ALPHABET_INFERENCE
            for start, end in chain_boundaries:
                # Only set positions that are not separator tokens
                chain_positions = [pos for pos in range(
                    start, end) if self.data.representative[pos] != sep_token]
                if any(pos in mut_pos for pos in chain_positions):
                    for pos in chain_positions:
                        aa_per_position[pos] = list(AA_ALPHABET_INFERENCE)

        else:
            raise ValueError(
                f"Unknown alphabet_per_position: {self.cfg.alphabet_per_position}")
        self.generator = hydra.utils.instantiate(self.cfg.generator)

        # Pass batch_score method as the fitness callback
        self.generator.setup_generator(
            self.data, aa_per_position, self.batch_score)

        # Set up generation-level metrics logging for the generator
        def generation_metrics_logger(metrics, step=None):
            # Add per-stage generation metrics if we have collected stage outputs
            if self.generation_stage_outputs:
                # Aggregate all stage outputs for this generation
                generation_stage_metrics = self._compute_generation_stage_metrics()
                metrics.update(generation_stage_metrics)
                # Clear stage outputs for next generation
                self.generation_stage_outputs = []

            # Filter out non-scalar values that can't be logged to MLflow
            filtered_metrics = {}
            for key, value in metrics.items():
                # Check for Python scalars
                if isinstance(value, (int, float, str, bool)):
                    filtered_metrics[key] = value
                # Check for numpy scalars
                elif hasattr(value, 'item') and hasattr(value, 'dtype'):
                    # This is a numpy scalar, convert to Python scalar
                    try:
                        filtered_metrics[key] = value.item()
                    except (ValueError, AttributeError):
                        logger.debug(
                            f"Could not convert numpy scalar {key}: {type(value)}")
                elif isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], (int, float)):
                    # Convert lists of numbers to comma-separated strings
                    filtered_metrics[f"{key}_list"] = ",".join(map(str, value))
                else:
                    logger.debug(
                        f"Skipping non-scalar metric {key}: {type(value)}")

            self.results_logger.log_metrics(filtered_metrics, step=step)

        self.generator.set_metrics_logger(generation_metrics_logger)

        # Set the new run callback for trajectory restarts
        self.generator.set_new_run_callback(self.create_new_mlflow_run)

        return self.generator

    def _log_final_results(
        self,
        all_sequences: List[str],
        all_fitnesses: List[float],
    ):
        """Log final results as a FASTA file with fitness scores as sequence IDs."""

        # Create FASTA content with fitness scores as sequence IDs
        fasta_lines = []

        # Combine sequences and fitness, sort by fitness (best first)
        sequence_fitness_pairs = list(zip(all_sequences, all_fitnesses))
        sequence_fitness_pairs.sort(key=lambda x: x[1], reverse=True)

        for i, (sequence, fitness) in enumerate(sequence_fitness_pairs):
            # Use fitness score as sequence ID
            fasta_lines.append(f">fitness_{fitness:.6f}_rank_{i + 1}")
            fasta_lines.append(sequence)

        # Write FASTA file
        fasta_content = "\n".join(fasta_lines)
        temp_path = f"/tmp/generated_sequences_{self.mlflow_run.info.run_id}.fasta"

        with open(temp_path, "w") as f:
            f.write(fasta_content)

        # Log as text artifact
        mlflow.log_artifact(temp_path, "results")

    def _count_mutations(self, sequence: str) -> int:
        """Count mutations relative to the representative sequence."""
        if len(sequence) != len(self.representative):
            return len(sequence)  # Can't compare if lengths differ

        mutations = sum(
            1 for i, (a, b) in enumerate(zip(sequence, self.representative)) if a != b
        )
        return mutations

    def _compute_generation_stage_metrics(self) -> Dict[str, float]:
        """Compute aggregate metrics across all stage outputs for the current generation."""
        if not self.generation_stage_outputs:
            return {}

        generation_metrics = {}

        # Collect all stage names
        all_stage_names: set[str] = set()
        for stage_output in self.generation_stage_outputs:
            all_stage_names.update(stage_output.keys())

        # For each stage, aggregate across all batches in this generation
        for stage_name in all_stage_names:
            all_scores: list[float] = []
            for stage_output in self.generation_stage_outputs:
                if stage_name in stage_output:
                    all_scores.extend(stage_output[stage_name])

            if all_scores:
                scores_array = np.array(all_scores)
                generation_metrics.update(
                    {
                        f"generation_{stage_name}_min": np.min(scores_array),
                        f"generation_{stage_name}_max": np.max(scores_array),
                        f"generation_{stage_name}_mean": np.mean(scores_array),
                        f"generation_{stage_name}_std": np.std(scores_array),
                        f"generation_{stage_name}_median": np.median(scores_array),
                    }
                )

        return generation_metrics

    def create_new_mlflow_run(self):
        """Create a new MLflow run for trajectory restart"""
        self.retry_count += 1

        # End the current run if it exists
        if hasattr(self, 'mlflow_run') and self.mlflow_run is not None:
            mlflow.end_run()
            logger.info(f"Ended previous MLflow run")

        # Create new run with retry suffix
        run_name = f"{self.cfg.run_name}_retry_{self.retry_count}"
        logger.info(f"Creating new MLflow run: {run_name}")

        self.mlflow_run = mlflow.start_run(
            experiment_id=(
                self.experiment.experiment_id if self.experiment else None),
            run_name=run_name,
        )

        # Set inference tag
        mlflow.set_tag("is_inference", "True")
        mlflow.set_tag("trajectory_restart", "True")
        mlflow.set_tag("retry_count", str(self.retry_count))

        # Create new results logger with the new run
        self.results_logger = ResultsLogger(cfg=self.cfg, run=self.mlflow_run)
        self.results_logger.log_config(self.cfg)

        # Log retry information
        self.results_logger.log_metrics({
            "retry_count": self.retry_count,
            "trajectory_restart": True
        })

        logger.info(
            f"Successfully created new MLflow run: {self.mlflow_run.info.run_id}")

    def build_fitness_from_stage_outputs(self, stage_outputs):
        """
        Compute final scores from stage outputs.
        Uses the last score stage value for sequences that pass all filters.
        Returns scores array of same length as input sequences.
        """
        # Get number of sequences from first stage output
        num_sequences = len(next(iter(stage_outputs.values())))

        # Initialize mask for sequences that pass all filters, all pass initially, for no-filter case
        sequence_mask = np.ones(num_sequences, dtype=bool)

        # Apply all filters
        for stage, scores in stage_outputs.items():
            if stage.startswith("filter"):
                sequence_mask &= scores.astype(bool)
        if self.results_logger is not None:
            step_value = getattr(
                getattr(self, "generator", None), "generation_count", 0)
            self.results_logger.log_metrics(
                {
                    "filtered_sequences": np.sum(sequence_mask),
                    "total_sequences": num_sequences,
                },
                step=step_value,
            )

        # collect all score stage outputs
        score_stages = [
            stage for stage in self.evaluators.keys() if stage.startswith("score")
        ]
        if not score_stages:
            logger.warning("no score stages found, using last stage")
            # get last stage output
            last_stage, last_stage_output = list(stage_outputs.items())[-1]
            logger.debug(f"last stage: {last_stage_output}")
            return last_stage_output

        # stack all score stage outputs, ensuring 1d shape for each
        score_arrays = []
        for stage in score_stages:
            scores = stage_outputs[stage]
            if len(scores.shape) > 1:
                scores = scores[:, 0]
            score_arrays.append(scores)
        # compute mean across all score stages
        mean_scores = np.mean(np.stack(score_arrays, axis=0), axis=0)

        # initialize final scores with -inf, ensuring all sequences have a score
        final_scores = np.full(num_sequences, -np.inf, dtype=float)
        # set mean scores for sequences that passed all filters
        final_scores[sequence_mask] = mean_scores[sequence_mask]

        logger.info(
            f"Final scores range [{np.min(final_scores):.3f}, {np.max(final_scores):.3f}], "
            f"{np.sum(sequence_mask)}/{len(sequence_mask)} sequences passed filters"
        )

        return final_scores

    def score_sequences(self, sequences: List[str] | str = None):
        """ Simply score all provided sequences
        """
        sequences = load_sequences(sequences)
        # mark this run as inference
        mlflow.set_tag("is_inference", "True")
        self.create_evaluators(self.load_predictors(self.cfg.model_ids))
        # score sequences
        final_scores, stage_outputs = self.batch_score(sequences)
        self.results_logger.log_sequences(
            sequences=sequences,
            fitness=final_scores,
            step=0,
        )  # create csv
        self._log_final_results(sequences, final_scores)  # creates fasta file
        self.shutdown()

    def batch_score(
        self, sequences: List[str], evaluator_type: str = "all"
    ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
        """Score sequences using specified evaluators in the order defined by stages"""
        logger.debug(f"Batch scoring {len(sequences)} sequences")

        if evaluator_type == "all":
            current_sequences = sequences
            sequence_mask = np.ones(len(sequences), dtype=bool)

            # Dictionary to store stage outputs
            stage_outputs = {}

            # Apply all filter evaluators in order, updating the sequence mask and current_sequences
            for key, evaluator in self.evaluators.items():
                if evaluator.task_type == "filter":
                    # Call the actual predict method for the filter
                    stage_mask = evaluator.predict(
                        current_sequences, batch_size=self.batch_size
                    )
                    stage_outputs[key] = stage_mask
                    sequence_mask[sequence_mask] = stage_mask
                    current_sequences = [
                        seq for i, seq in enumerate(current_sequences) if stage_mask[i]
                    ]
            # check sequences are not empty
            if len(current_sequences) == 0:
                logger.warning("No sequences left after filtering")
                return np.zeros(len(sequences)), {}

            # Pre-compute embeddings for sklearn evaluators if needed
            precomputed_embeddings = None
            sklearn_evaluators = {
                k: v for k, v in self.evaluators.items() if v.model_type == "sklearn"}
            if sklearn_evaluators and self.embedding_manager:
                logger.info(
                    f"Pre-computing embeddings for {len(current_sequences)} sequences")
                try:
                    precomputed_embeddings = self.embedding_manager.get_embeddings(
                        current_sequences)
                    logger.info(
                        f"Successfully computed embeddings with shape: {precomputed_embeddings.shape}")
                except Exception as e:
                    logger.error(f"Failed to compute embeddings: {e}")
                    # Fall back to individual embedding computation
                    precomputed_embeddings = None

            # Evaluate evaluators (PyTorch always sequential, others may be parallel)
            if self._should_parallelize_evaluators():
                stage_outputs = self._parallel_evaluate_evaluators(
                    current_sequences, precomputed_embeddings=precomputed_embeddings)
            else:
                # Sequential evaluation for all evaluators
                stage_outputs = {}
                for key, evaluator in self.evaluators.items():
                    if evaluator.model_type == "sklearn" and precomputed_embeddings is not None:
                        # Use pre-computed embeddings for sklearn evaluators
                        scores = evaluator.predict_with_embeddings(
                            precomputed_embeddings)
                    else:
                        # Use regular predict for PyTorch evaluators or if embeddings failed
                        scores = evaluator.predict(
                            current_sequences, batch_size=self.batch_size)
                    stage_outputs[key] = scores

                    if evaluator.task_type == "filter":
                        sequence_mask[sequence_mask] = scores
                        current_sequences = [seq for i, seq in enumerate(
                            current_sequences) if scores[i]]

            self.log_stage_outputs(stage_outputs)
            # Compute final scores from stage outputs
            final_scores = self.build_fitness_from_stage_outputs(stage_outputs)

            if self.results_logger is not None:
                step_value = getattr(
                    getattr(self, "generator", None), "generation_count", 0)
                self.results_logger.log_sequences(
                    sequences=current_sequences,
                    fitness=final_scores,
                    step=step_value,
                )

            # clear cache after scoring
            if self.embedding_manager is not None:
                self.embedding_manager.clear_cache()
                logger.debug("Cleared embedding cache")

            return final_scores, stage_outputs

        elif evaluator_type in self.evaluators:
            scores = self.evaluators[evaluator_type].predict(
                sequences, batch_size=self.batch_size
            )
            return scores, {evaluator_type: scores}
        else:
            raise ValueError(f"Unknown evaluator type: {evaluator_type}")

    def run(self):
        """Main inference pipeline using pygad's native run method."""
        mlflow.set_tag("is_inference", "True")
        # Load predictors - pass model_ids from config if available
        model_ids = getattr(self.cfg, "model_ids", None)
        predictors = self.load_predictors(model_ids)

        # Create evaluators from predictors
        self.create_evaluators(predictors)

        # Create generator with batch_score as fitness callback
        generator = self.create_generator()

        # Run the genetic algorithm - it will handle everything internally
        logger.info("Starting genetic algorithm optimization")
        best_sequence, best_fitness = generator.run()

        # Get all solutions generated during evolution
        all_sequences, all_fitnesses = generator.get_all_solutions()

        logger.info("Optimization completed!")
        logger.info(f"Generated {len(all_sequences)} total sequences")
        logger.info(f"Best sequence: {best_sequence}")
        logger.info(f"Best fitness: {best_fitness:.4f}")
        logger.info(
            f"Training data fitness range: [{np.min(self.data.get_labels()):.3f}, {np.max(self.data.get_labels()):.3f}]"
        )

        # Log final metrics using ResultsLogger
        final_metrics = {
            "final_best_fitness": best_fitness,
            "total_sequences_generated": len(all_sequences),
            "total_generations": generator.generation_count,
        }
        self.results_logger.log_metrics(final_metrics)

        # Log comprehensive fitness statistics
        if all_fitnesses:
            fitness_stats = {
                "overall_min_fitness": np.min(all_fitnesses),
                "overall_max_fitness": np.max(all_fitnesses),
                "overall_mean_fitness": np.mean(all_fitnesses),
                "overall_std_fitness": np.std(all_fitnesses),
                "overall_median_fitness": np.median(all_fitnesses),
            }
            self.results_logger.log_metrics(fitness_stats)

        # Log the best sequences and their fitness as artifacts
        self._log_final_results(all_sequences, all_fitnesses)

    def shutdown(self):
        """Shutdown the inference engine."""
        self.schedule_shutdown = True
        if hasattr(self.generator, "shutdown"):
            self.generator.shutdown()
        if hasattr(self.embedding_manager, "shutdown"):
            self.embedding_manager.shutdown()
        if hasattr(self.results_logger, "shutdown"):
            self.results_logger.shutdown()

    def _run_random_search(self, num_sequences: int):
        """Run random search. no generator needed"""
        # create mlflow run
        self.mlflow_run = mlflow.start_run(
            experiment_id=self.experiment.experiment_id,
            run_name=f"random_search_{self.cfg.run_name}",
        )
        mlflow.set_tag("is_inference", "True")

        self.results_logger = ResultsLogger(cfg=self.cfg, run=self.mlflow_run)
        # Load predictors - pass model_ids from config if available
        model_ids = getattr(self.cfg, "model_ids", None)
        predictors = self.load_predictors(model_ids)
        self.create_evaluators(predictors)

        all_sequences = []
        for i in range(num_sequences):
            # Generate sequence by sampling from alphabet per position
            sequence_list = []
            for alphabet in self.aa_per_position:
                sequence_list.append(np.random.choice(list(alphabet)))
            sequence = "".join(sequence_list)
            all_sequences.append(sequence)

        # Score sequence and log results
        fitness, _ = self.batch_score(all_sequences)
        self._log_final_results(all_sequences, fitness)


@hydra.main(version_base=None, config_path="conf", config_name="inference")
def main(cfg: DictConfig):
    """Main entry point for inference."""
    # Register custom Hydra resolvers
    register_resolvers()

    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    inference = HAIPRInference(cfg)

    if cfg.inference_mode == "random_search":
        inference._run_random_search(cfg.num_sequences)
    elif cfg.inference_mode == "score_sequences":
        inference.score_sequences(cfg.sequences)
    else:
        try:
            inference.run()

        except KeyboardInterrupt:
            logger.warning(
                "KeyboardInterrupt received. Shutting down inference engine gracefully..."
            )
            if hasattr(inference, "shutdown") and callable(getattr(inference, "shutdown")):
                try:
                    inference.shutdown()
                except Exception as shutdown_exc:
                    logger.error(f"Error during shutdown: {shutdown_exc}")
            else:
                logger.info("No shutdown method found on inference engine.")
        except Exception as e:
            logger.error(f"Exception during inference: {e}")
            if hasattr(inference, "shutdown") and callable(getattr(inference, "shutdown")):
                try:
                    inference.shutdown()
                except Exception as shutdown_exc:
                    logger.error(f"Error during shutdown: {shutdown_exc}")


if __name__ == "__main__":
    main()
