import hydra
import mlflow
import logging
from datetime import datetime
from typing import Dict, Any, Tuple
from omegaconf import DictConfig
import torch
from haipr.data import HAIPRData
from haipr.sequence_generators.base_generator import BaseSequenceGenerator
from haipr.utils import AA_ALPHABET, AA_ALPHABET_INFERENCE, load_sequences
from haipr.evaluators import EvaluatorFactory
from haipr.utils.results_logger import ResultsLogger
from typing import List
from typing import Optional
import numpy as np
from haipr.embedding_manager import EmbeddingManager
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from haipr.utils.resolvers import register_resolvers
from tqdm import tqdm
logger = logging.getLogger("haipr.inference")


class HAIPRInference:

    def __init__(self, cfg: DictConfig):
        self.cfg = cfg

        # initialize stages and evaluators
        self.stages: List[str] = []
        self.evaluators: Dict[str, Any] = {}
        self.batch_size = int(getattr(self.cfg, "inference_batch_size", 1024))

        # Setup common components
        self._setup_common_components()

        # Setup MLflow tracking - create parent run
        self._setup_parent_run()

        # Setup mode-specific components
        if self.cfg.inference_mode == "design":
            self._setup_design_mode()
        elif self.cfg.inference_mode == "score_sequences":
            self._setup_score_sequences_mode()
        elif self.cfg.inference_mode == "random_search":
            self._setup_random_search_mode()
        else:
            raise ValueError(f"Invalid mode: {self.cfg.inference_mode}")

    def _setup_parent_run(self):
        """
        Setup parent MLflow run for inference session.
        """
        # Get experiment by name
        if not isinstance(self.cfg.mlflow.experiment_id, str):
            try:
                experiment_id = str(self.cfg.mlflow.experiment_id)
            except Exception:
                raise ValueError(
                    f"Invalid experiment ID: {self.cfg.mlflow.experiment_id}"
                )
        else:
            experiment_id = self.cfg.mlflow.experiment_id

        # Set up MLflow tracking
        mlflow.set_tracking_uri(self.cfg.mlflow.tracking_uri)

        # get experiment by id if provided, otherwise search by name
        if self.cfg.mlflow.experiment_id is not None:
            self.experiment = mlflow.get_experiment(experiment_id)
        else:
            experiments = mlflow.search_experiments(
                filter_string=f"name='{self.cfg.mlflow.experiment_name}'"
            )
            if len(experiments) == 0:
                logger.warning(
                    f"No experiment found with name {self.cfg.mlflow.experiment_name}"
                )
                self.experiment = None
            else:
                # Pick the first matching experiment by name
                self.experiment = experiments[0]
                self.cfg.mlflow.experiment_id = self.experiment.experiment_id

        # Check if there's an active run (e.g., from haipr.py orchestration)
        # If so, create a nested run
        active_run = mlflow.active_run()
        is_nested = active_run is not None

        # Create parent run
        run_name = self.cfg.run_name if self.cfg.run_name else "Inference"
        parent_run_name = f"{run_name}_{self.cfg.inference_mode}"
        logger.info(f"Creating parent MLflow run: {parent_run_name}")

        self.parent_run = mlflow.start_run(
            experiment_id=(
                self.experiment.experiment_id if self.experiment else None),
            run_name=parent_run_name,
            nested=is_nested,
        )

        # Set parent run tags
        mlflow.set_tag("is_parent", "True")
        mlflow.set_tag("is_inference", "True")
        mlflow.set_tag("inference_mode", self.cfg.inference_mode)

        logger.info(
            f"Successfully created parent MLflow run: {self.parent_run.info.run_id}"
        )

        # Create results logger for parent run
        self.results_logger = ResultsLogger(cfg=self.cfg, run=self.parent_run)
        self.results_logger.log_config(self.cfg)

    def _create_trajectory_run(self):
        """
        Create a child trajectory run under the parent run.
        """
        self.trajectory_count += 1
        trajectory_name = f"trajectory_{self.trajectory_count}"
        logger.info(f"Creating trajectory run: {trajectory_name}")

        self.mlflow_run = mlflow.start_run(
            experiment_id=self.parent_run.info.experiment_id,
            run_name=trajectory_name,
            nested=True,
        )

        # Set trajectory run tags
        mlflow.set_tag("is_trajectory", "True")
        mlflow.set_tag("is_inference", "True")
        mlflow.set_tag("parent_run_id", self.parent_run.info.run_id)
        mlflow.set_tag("trajectory_number", str(self.trajectory_count))

        # Create results logger for this trajectory
        self.results_logger = ResultsLogger(cfg=self.cfg, run=self.mlflow_run)
        self.results_logger.log_config(self.cfg)

        logger.info(
            f"Successfully created trajectory run: {self.mlflow_run.info.run_id}"
        )

    def _setup_common_components(self):
        """
        Setup common components shared across all inference modes.
        """
        # Initialize embedding managers per unique embedder config
        self.embedding_managers: Dict[str, EmbeddingManager] = {}

        self.generation_stage_outputs: List[Dict[str, np.ndarray]] = []

        self.retry_count = 0
        self.trajectory_count = 0

    def _setup_data_and_alphabet(self):
        """
        Setup data and alphabet per position for modes that need it.
        """
        self.data = HAIPRData(self.cfg)
        self.representative = self.data.get_representative()

        if isinstance(self.cfg.alphabet_per_position, list):
            self.aa_per_position = self.cfg.alphabet_per_position
        else:
            self.aa_per_position = self._resolve_alphabet_per_position()

    def _setup_design_mode(self):
        """Setup design mode - requires data and alphabet."""
        self._setup_data_and_alphabet()

    def _setup_score_sequences_mode(self):
        """Setup score sequences mode - no data required."""
        self.data = None
        self.representative = None
        self.aa_per_position = None

    def _setup_random_search_mode(self):
        """Setup random search mode - requires data and alphabet."""
        self._setup_data_and_alphabet()

    def _should_parallelize_evaluators(self) -> bool:
        """
        Determine if evaluators should be parallelized.
        All pyfunc models can be parallelized uniformly.
        """
        # Don't parallelize if DDP is enabled or only one evaluator
        if getattr(self.cfg, "inference_ddp", False) or len(self.evaluators) <= 1:
            return False

        # All pyfunc models can be parallelized
        return True

    def _evaluate_single_evaluator(
        self,
        evaluator_key: str,
        evaluator: Any,
        sequences: List[str],
        precomputed_embeddings: np.ndarray | None = None,
    ) -> Tuple[str, np.ndarray]:
        """Evaluate a single evaluator on sequences."""
        try:
            result = evaluator.predict(sequences)
            return evaluator_key, result["predictions"]
        except Exception as e:
            logger.error(f"Error in evaluator {evaluator_key}: {e}")
            return evaluator_key, np.zeros(len(sequences))

    def _parallel_evaluate_evaluators(
        self, sequences: List[str], precomputed_embeddings: np.ndarray | None = None
    ) -> Dict[str, np.ndarray]:
        """Evaluate all evaluators in parallel (unified pyfunc interface)."""
        results = {}

        # Run all evaluators in parallel
        with ThreadPoolExecutor(
            max_workers=min(len(self.evaluators), multiprocessing.cpu_count())
        ) as executor:
            futures = {
                executor.submit(
                    self._evaluate_single_evaluator,
                    k,
                    v,
                    sequences,
                    precomputed_embeddings,
                ): k
                for k, v in self.evaluators.items()
            }
            for future in futures:
                try:
                    k, scores = future.result()
                    results[k] = scores
                except Exception as e:
                    logger.error(f"Failed evaluator {futures[future]}: {e}")
                    results[futures[future]] = np.zeros(len(sequences))

        return results

    def log_stage_outputs(self, stage_outputs: Dict[str, np.ndarray]):
        """
        Log aggregate statistics for each stage to MLflow.
        This provides insights into how each evaluator stage is performing.
        """
        if not self.results_logger:
            return

        stage_metrics = {}
        for stage_name, scores in stage_outputs.items():
            scores_array = np.array(scores)
            # Log aggregate statistics for each stage
            stage_metrics.update(
                {
                    f"batch_{stage_name}_min": np.min(scores_array),
                    f"batch_{stage_name}_max": np.max(scores_array),
                    f"batch_{stage_name}_mean": np.mean(scores_array),
                    f"batch_{stage_name}_std": np.std(scores_array),
                    f"batch_{stage_name}_median": np.median(scores_array),
                }
            )

        # Store stage outputs for generation-level aggregation
        self.generation_stage_outputs.append(stage_outputs)

        # Log the stage metrics
        self.results_logger.log_metrics(stage_metrics)

        # Log raw stage outputs for each generation
        if hasattr(self, "generator") and hasattr(self.generator, "generation_count"):
            generation_num = self.generator.generation_count
            for stage_name, scores in stage_outputs.items():
                # Convert numpy array to list for JSON serialization
                scores_list = scores.tolist() if hasattr(scores, "tolist") else scores
                mlflow.log_dict(
                    {stage_name: scores_list},
                    f"generation_{generation_num}_stage_outputs.json",
                )
        else:
            # Fallback for cases where generator is not available yet
            for stage_name, scores in stage_outputs.items():
                # Convert numpy array to list for JSON serialization
                scores_list = scores.tolist() if hasattr(scores, "tolist") else scores
                mlflow.log_dict(
                    {stage_name: scores_list}, "pre_generation_stage_outputs.json"
                )

    def load_predictors(self, model_ids: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Load predictors from MLflow using search_logged_models API.

        Supports three modes:
        1. Specific model IDs (if model_ids provided)
        2. Models from specific parent run (if models_from_parent_run configured)
        3. All models from experiment (default)

        Args:
            model_ids: Optional list of specific model IDs to load

        Returns:
            Dict[str, Any]: Dictionary mapping stage names to loaded models
        """
        logger.info(
            f"Loading predictors from experiment: {self.cfg.mlflow.experiment_name}"
        )
        if self.experiment is None:
            raise ValueError(
                f"Experiment {self.cfg.mlflow.experiment_name} not found")

        # Determine which runs to search for models
        target_run_ids = None

        if (
            hasattr(self.cfg, "models_from_parent_run")
            and self.cfg.models_from_parent_run
        ):
            logger.info(
                f"Filtering models by parent_run_id: {self.cfg.models_from_parent_run}"
            )

            # First, find runs that match the parent_run_id criteria
            run_filter_conditions = []
            run_filter_conditions.append("status = 'FINISHED'")
            run_filter_conditions.append(
                f"tags.benchmark = '{self.cfg.benchmark.name}'"
            )
            run_filter_conditions.append(
                f"tags.parent_run_id = '{self.cfg.models_from_parent_run}'"
            )

            run_filter_string = " AND ".join(run_filter_conditions)
            logger.info(f"Searching runs with filter: {run_filter_string}")

            try:
                runs_df = mlflow.search_runs(
                    experiment_ids=[self.experiment.experiment_id],
                    filter_string=run_filter_string,
                    output_format="pandas",
                )

                if len(runs_df) == 0:
                    raise ValueError(
                        f"No finished runs found in experiment {self.cfg.mlflow.experiment_name} "
                        f"matching parent_run_id: {self.cfg.models_from_parent_run}"
                    )

                target_run_ids = runs_df["run_id"].tolist()  # type: ignore
                logger.info(
                    f"Found {len(target_run_ids)} runs matching parent_run_id criteria"
                )

            except Exception as e:
                logger.error(f"Failed to search runs: {e}")
                raise ValueError(
                    f"No runs found matching parent_run_id criteria: {e}")
        else:
            logger.info(
                f"No models_from_parent_run specified, using all models from experiment {self.cfg.mlflow.experiment_name}"
            )

        # Build filter string for logged models
        model_filter_conditions = []

        # Add model ID filter if specific models requested
        if model_ids:
            # Convert model_ids to quoted strings for IN clause
            quoted_ids = [f"'{model_id}'" for model_id in model_ids]
            model_filter_conditions.append(
                f"model_id IN ({','.join(quoted_ids)})")
            logger.info(f"Filtering models by specific IDs: {model_ids}")

        # Combine all conditions
        model_filter_string = (
            " AND ".join(
                model_filter_conditions) if model_filter_conditions else None
        )

        logger.info(
            f"Searching logged models with filter: {model_filter_string}")

        # Search for logged models using the new API
        try:
            logged_models = mlflow.search_logged_models(
                experiment_ids=[self.experiment.experiment_id],
                filter_string=model_filter_string,
                output_format="list",
            )
        except Exception as e:
            logger.error(f"Failed to search logged models: {e}")
            raise ValueError(f"No logged models found matching criteria: {e}")

        if len(logged_models) == 0:
            raise ValueError(
                f"No logged models found in experiment {self.cfg.mlflow.experiment_name} "
                f"matching filter: {model_filter_string}"
            )

        logger.info(f"Found {len(logged_models)} logged models")

        # Filter logged models by benchmark and target run IDs
        filtered_logged_models = []
        for model in logged_models:
            try:
                # Get run details to check tags
                run_info = mlflow.get_run(model.source_run_id)
                tags = run_info.data.tags

                # Check benchmark tag
                if tags.get("benchmark") != self.cfg.benchmark.name:
                    logger.debug(
                        f"Skipping model {model.name} from run {model.source_run_id} - benchmark mismatch"
                    )
                    continue

                # Check target run IDs if specified
                if (
                    target_run_ids is not None
                    and model.source_run_id not in target_run_ids
                ):
                    logger.debug(
                        f"Skipping model {model.name} from run {model.source_run_id} - not in target runs"
                    )
                    continue

                filtered_logged_models.append(model)

            except Exception as e:
                logger.warning(
                    f"Failed to get run info for model {model.name}: {e}")
                continue

        logger.info(
            f"Filtered to {len(filtered_logged_models)} models matching benchmark and run criteria"
        )
        logged_models = filtered_logged_models

        # Load models from logged models
        predictors = {}

        for logged_model in logged_models:
            try:
                # Get run details to access tags and parameters
                run_info = mlflow.get_run(logged_model.source_run_id)
                tags = run_info.data.tags

                # Skip parent and inference runs
                if tags.get("is_parent") == "True":
                    logger.info(
                        f"Skipping parent run {logged_model.source_run_id}")
                    continue
                if tags.get("is_inference") == "True":
                    logger.info(
                        f"Skipping inference run {logged_model.source_run_id}")
                    continue

                # Extract model type and task from tags
                model_type = tags.get("model_type", "unknown")
                task_type = tags.get("task", "regression")

                # Determine stage name based on task type
                if task_type == "classification":
                    stage_name = "filter"
                elif task_type == "regression":
                    stage_name = "score"
                else:
                    stage_name = "unknown"

                # Load model as pyfunc (unified interface)
                try:
                    model_uri = f"models:/{logged_model.model_id}"
                    logger.info(f"Loading pyfunc model from {model_uri}")

                    # Load returns BasePredictor instance with loaded context
                    model_config = {
                        "batch_size": self.batch_size,
                    }
                    model = mlflow.pyfunc.load_model(
                        model_uri, model_config=model_config
                    )
                    # Store with unique key
                    predictor_key = f"{stage_name}_{logged_model.source_run_id[:8]}"

                    predictors[predictor_key] = {
                        "name": logged_model.name,
                        "model": model,
                        "run_id": logged_model.source_run_id,
                        "task_type": task_type,
                        "stage_name": stage_name,
                        "run_info": run_info,
                        "model_id": logged_model.model_id,
                    }
                    self.stages.append(stage_name)

                    logger.info(f"Loaded pyfunc model as {predictor_key}")

                except Exception as e:
                    logger.warning(f"Failed to load pyfunc model: {e}")
                    # Try legacy loading
                    model = self._load_legacy_model(logged_model, tags)
                    if model:
                        # Store with unique key
                        predictor_key = f"{stage_name}_{logged_model.source_run_id[:8]}"

                        predictors[predictor_key] = {
                            "name": logged_model.name,
                            "model": model,
                            "run_id": logged_model.source_run_id,
                            "task_type": task_type,
                            "stage_name": stage_name,
                            "run_info": run_info,
                            "model_id": logged_model.model_id,
                        }
                        self.stages.append(stage_name)

                        logger.info(f"Loaded legacy model as {predictor_key}")
                    else:
                        logger.warning(
                            f"Failed to load legacy model: {logged_model.name}"
                        )
                        continue

            except Exception as e:
                logger.warning(
                    f"Failed to process logged model {logged_model.model_id}: {e}"
                )
                continue

        logger.info(f"Successfully loaded {len(predictors)} predictors")
        if len(predictors) == 0:
            raise ValueError(
                f"No predictors loaded from experiment {self.cfg.mlflow.experiment_name}"
            )

        # Group predictors by embedder config and create centralized embedding managers
        self._setup_embedding_managers(predictors)

        return predictors

    def _load_legacy_model(self, logged_model, tags):
        """Fallback loader for legacy sklearn/pytorch models."""
        model_type = tags.get("model_type", "unknown")
        model_uri = f"models:/{logged_model.model_id}"

        try:
            if model_type == "sklearn":
                return mlflow.sklearn.load_model(model_uri)
            elif model_type == "pytorch":
                return mlflow.pytorch.load_model(model_uri)
        except Exception as e:
            logger.error(f"Failed to load legacy model: {e}", stack_info=True)
            return None

    def _setup_embedding_managers(self, predictors: Dict[str, Any]):
        """
        Create centralized embedding managers for each unique embedder config.
        This prevents recomputing the same embeddings across multiple models.
        """
        # Collect unique embedder configs from all predictors
        embedder_configs = {}

        for predictor_key, predictor_info in predictors.items():
            run_info = predictor_info.get("run_info")
            if (
                run_info
                and hasattr(run_info, "data")
                and hasattr(run_info.data, "params")
            ):
                # Extract embedder config from run parameters
                params = run_info.data.params

                # Create embedder config key from relevant parameters
                embedder_key = self._create_embedder_key(params)

                if embedder_key not in embedder_configs:
                    embedder_configs[embedder_key] = {
                        "params": params,
                        "predictors": [],
                    }

                embedder_configs[embedder_key]["predictors"].append(
                    predictor_key)

        # Create EmbeddingManager for each unique embedder config
        for embedder_key, config_info in embedder_configs.items():
            try:
                # Create a config object for this embedder
                embedder_cfg = self._create_embedder_config(
                    config_info["params"])

                if embedder_cfg:
                    embedding_manager = EmbeddingManager(embedder_cfg)
                    self.embedding_managers[embedder_key] = embedding_manager

                    logger.info(
                        f"Created embedding manager for {embedder_key} "
                        f"serving {len(config_info['predictors'])} predictors: "
                        f"{config_info['predictors']}"
                    )
                else:
                    logger.warning(
                        f"No embedder config found for {embedder_key}")

            except Exception as e:
                logger.error(
                    f"Failed to create embedding manager for {embedder_key}: {e}"
                )
                continue

        logger.info(
            f"Created {len(self.embedding_managers)} embedding managers")

    def _create_embedder_key(self, params: Dict[str, Any]) -> str:
        """Create a unique key for embedder configuration."""
        # Extract embedder-related parameters
        embedder_params = {}

        for key, value in params.items():
            if key.startswith("embedder."):
                embedder_params[key] = value

        if not embedder_params:
            return "no_embedder"

        # Create a hash of the embedder config for uniqueness
        import hashlib

        config_str = str(sorted(embedder_params.items()))
        return hashlib.md5(config_str.encode()).hexdigest()[:8]

    def _create_embedder_config(self, params: Dict[str, Any]) -> Optional[DictConfig]:
        """Create embedder config from run parameters."""
        embedder_params = {}

        for key, value in params.items():
            if key.startswith("embedder."):
                # Remove embedder. prefix and convert to nested structure
                nested_key = key[9:]  # Remove "embedder."
                embedder_params[nested_key] = value

        if not embedder_params:
            return None

        # Create OmegaConf config
        from omegaconf import DictConfig

        return DictConfig(embedder_params)

    def create_evaluators(
        self,
        predictors: Dict[str, Any],
    ) -> None:
        """
        Create evaluators from loaded predictors and optional physics configurations.

        Args:
            predictors: Dictionary of loaded predictors from MLflow
        """
        # Get physics evaluator configurations from config if available
        physics_configs = getattr(self.cfg, "physics_evaluators", None)

        # Create evaluators using the factory, passing config and embedding managers
        self.evaluators = EvaluatorFactory.create_evaluators_from_predictors(
            predictors=predictors,
            physics_configs=physics_configs,
            cfg=self.cfg,
            embedding_managers=self.embedding_managers,
        )
        logger.info(f"Created {len(self.evaluators)} evaluators")

    def _resolve_alphabet_per_position(self) -> List[str]:
        # Create alphabet per position based on config
        if self.cfg.alphabet_per_position == "train":
            # Use the training alphabet from the data
            aa_per_position = [
                (
                    list(self.data.pos_wise_alphabet[i])
                    if i < len(self.data.pos_wise_alphabet)
                    # positions not in training data
                    else list(AA_ALPHABET_INFERENCE)
                )
                for i in range(len(self.data.representative))
            ]
            return list(aa_per_position)
        elif self.cfg.alphabet_per_position == "full_per_mut_pos":
            aa_per_position = [
                (
                    list(AA_ALPHABET_INFERENCE)
                    if len(self.data.pos_wise_alphabet[i]) > 1
                    else list(self.data.pos_wise_alphabet[i])
                )
                for i in range(len(self.data.representative))
            ]
            return list(aa_per_position)
        elif self.cfg.alphabet_per_position == "de-novo":
            # Initialize aa_per_position as a list of the current representative sequence
            aa_per_position = list(self.data.representative)
            sep_token = self.cfg.data.chain_break_token
            mut_pos = self.data.mut_pos

            # Find all chain break token indices
            sep_indices = [
                i for i, c in enumerate(self.data.representative) if c == sep_token
            ]
            seq_len = len(self.data.representative)

            # Define chain boundaries as (start, end) tuples (exclusive of end)
            chain_boundaries = []
            prev = 0
            for idx in sep_indices:
                chain_boundaries.append((prev, idx))
                prev = idx + 1
            chain_boundaries.append((prev, seq_len))

            # For each chain, if any mut_pos falls within, set only those positions to AA_ALPHABET_INFERENCE
            for start, end in chain_boundaries:
                # Only set positions t
                chain_positions = [
                    pos
                    for pos in range(start, end)
                    if self.data.representative[pos] != sep_token
                ]
                if any(pos in mut_pos for pos in chain_positions):
                    for pos in chain_positions:
                        aa_per_position[pos] = list(AA_ALPHABET_INFERENCE)
            return list(aa_per_position)

        else:
            raise ValueError(
                f"Unknown alphabet_per_position: {self.cfg.alphabet_per_position}"
            )

    def create_generator(self) -> BaseSequenceGenerator:
        """Create sequence generator with loaded evaluators."""
        self.generator = hydra.utils.instantiate(self.cfg.generator)
        aa_per_position = self._resolve_alphabet_per_position()

        # Pass batch_score method as the fitness callback
        self.generator.setup_generator(
            self.data, aa_per_position, self.batch_score)

        # Create callback for computing generation-level stage metrics
        def stage_outputs_callback():
            """Callback to compute and return generation-level stage metrics."""
            if not self.generation_stage_outputs:
                return {}
            generation_stage_metrics = self._compute_generation_stage_metrics()
            # Clear stage outputs after computing metrics
            self.generation_stage_outputs = []
            return generation_stage_metrics

        # Configure logging infrastructure separately
        self.generator.configure_logging(
            results_logger=self.results_logger,
            new_run_callback=self.create_new_mlflow_run,
            stage_outputs_callback=stage_outputs_callback,
        )

        return self.generator

    def _log_final_results(
        self,
        all_sequences: List[str],
        all_fitnesses: List[float],
    ):
        """Log final results as a FASTA file with fitness scores as sequence IDs."""

        # Create FASTA content with fitness scores as sequence IDs
        fasta_lines = []

        # Combine sequences and fitness, sort by fitness (best first)
        sequence_fitness_pairs = list(zip(all_sequences, all_fitnesses))
        sequence_fitness_pairs.sort(key=lambda x: x[1], reverse=True)

        for i, (sequence, fitness) in enumerate(sequence_fitness_pairs):
            # Use fitness score as sequence ID
            fasta_lines.append(f">fitness_{fitness:.6f}_rank_{i + 1}")
            fasta_lines.append(sequence)

        # Write FASTA file
        fasta_content = "\n".join(fasta_lines)
        temp_path = f"/tmp/generated_sequences_{self.mlflow_run.info.run_id}.fasta"

        with open(temp_path, "w") as f:
            f.write(fasta_content)

        # Log as text artifact
        mlflow.log_artifact(temp_path, "results")

    def _count_mutations(self, sequence: str) -> int:
        """Count mutations relative to the representative sequence."""
        if len(sequence) != len(self.representative):
            return len(sequence)  # Can't compare if lengths differ

        mutations = sum(
            1 for i, (a, b) in enumerate(zip(sequence, self.representative)) if a != b
        )
        return mutations

    def _compute_generation_stage_metrics(self) -> Dict[str, float]:
        """Compute aggregate metrics across all stage outputs for the current generation."""
        if not self.generation_stage_outputs:
            return {}

        generation_metrics = {}

        # Collect all stage names
        all_stage_names: set[str] = set()
        for stage_output in self.generation_stage_outputs:
            all_stage_names.update(stage_output.keys())

        # For each stage, aggregate across all batches in this generation
        for stage_name in all_stage_names:
            all_scores: list[float] = []
            for stage_output in self.generation_stage_outputs:
                if stage_name in stage_output:
                    all_scores.extend(stage_output[stage_name])

            if all_scores:
                scores_array = np.array(all_scores)
                generation_metrics.update(
                    {
                        f"generation_{stage_name}_min": np.min(scores_array),
                        f"generation_{stage_name}_max": np.max(scores_array),
                        f"generation_{stage_name}_mean": np.mean(scores_array),
                        f"generation_{stage_name}_std": np.std(scores_array),
                        f"generation_{stage_name}_median": np.median(scores_array),
                    }
                )

        return generation_metrics

    def create_new_mlflow_run(self):
        """Create a new MLflow run for trajectory restart"""
        self.retry_count += 1

        # End the current trajectory run if it exists
        if hasattr(self, "mlflow_run") and self.mlflow_run is not None:
            mlflow.end_run()
            logger.info(f"Ended previous trajectory run")

        # Create new trajectory run
        self._create_trajectory_run()

        # Set additional tags for restart
        mlflow.set_tag("trajectory_restart", "True")
        mlflow.set_tag("retry_count", str(self.retry_count))

        # Log retry information
        self.results_logger.log_metrics(
            {"retry_count": self.retry_count, "trajectory_restart": True}
        )

        logger.info(
            f"Successfully created new trajectory run: {self.mlflow_run.info.run_id}"
        )

    def build_fitness_from_stage_outputs(self, stage_outputs):
        """
        Compute final scores from stage outputs.
        Uses the last score stage value for sequences that pass all filters.
        Returns scores array of same length as input sequences.
        """
        logger.debug(f"Building fitness from stage outputs: {stage_outputs}")
        # Get number of sequences from first stage output
        num_sequences = len(next(iter(stage_outputs.values())))

        # Initialize mask for sequences that pass all filters, all pass initially, for no-filter case
        sequence_mask = np.ones(num_sequences, dtype=bool)

        # Apply all filters
        for stage, scores in stage_outputs.items():
            if stage.startswith("filter"):
                sequence_mask &= scores.astype(bool)
        if self.results_logger is not None:
            step_value = getattr(
                getattr(self, "generator", None), "generation_count", 0
            )
            self.results_logger.log_metrics(
                {
                    "filtered_sequences": np.sum(sequence_mask),
                    "total_sequences": num_sequences,
                },
                step=step_value,
            )

        # collect all score stage outputs
        score_stages = [
            stage for stage in self.evaluators.keys() if stage.startswith("score")
        ]
        if not score_stages:
            logger.warning("no score stages found, using last stage")
            # get last stage output
            last_stage, last_stage_output = list(stage_outputs.items())[-1]
            logger.debug(f"last stage: {last_stage_output}")
            return last_stage_output

        # stack all score stage outputs, ensuring 1d shape for each
        score_arrays = []
        for stage in score_stages:
            scores = stage_outputs[stage]
            if len(scores.shape) > 1:
                scores = scores[:, 0]
            score_arrays.append(scores)
        # compute mean across all score stages
        mean_scores = np.mean(np.stack(score_arrays, axis=0), axis=0)

        # initialize final scores with -inf, ensuring all sequences have a score
        final_scores = np.full(num_sequences, -np.inf, dtype=float)
        # set mean scores for sequences that passed all filters
        final_scores[sequence_mask] = mean_scores[sequence_mask]

        logger.info(
            f"Final scores range [{np.min(final_scores):.3f}, {np.max(final_scores):.3f}], "
            f"{np.sum(sequence_mask)}/{len(sequence_mask)} sequences passed filters"
        )

        return final_scores

    def score_sequences(self, sequences: List[str] | str | None = None):
        """Simply score all provided sequences"""
        if sequences is None:
            raise ValueError(
                "sequences parameter is required for score_sequences mode")

        try:
            sequences = load_sequences(sequences)
        except ValueError as e:
            raise ValueError(f"Failed to load sequences: {e}")

        # Create first trajectory run
        self._create_trajectory_run()

        self.create_evaluators(self.load_predictors(self.cfg.model_ids))
        # score sequences
        final_scores, stage_outputs = self.batch_score(sequences)
        self.results_logger.log_sequences(
            sequences=sequences,
            fitness=final_scores,
            step=0,
        )  # create csv
        self._log_final_results(
            sequences, final_scores.tolist())  # creates fasta file
        self.shutdown()

    def batch_score(
        self, sequences: List[str], evaluator_type: str = "all"
    ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
        """Score sequences using specified evaluators in the order defined by stages"""
        logger.debug(f"Batch scoring {len(sequences)} sequences")

        if evaluator_type == "all":
            current_sequences = sequences
            sequence_mask = np.ones(len(sequences), dtype=bool)

            # Dictionary to store stage outputs
            stage_outputs = {}

            # Apply all filter evaluators in order, updating the sequence mask and current_sequences
            for key, evaluator in self.evaluators.items():
                if evaluator.task_type == "filter":
                    # Call the actual predict method for the filter
                    result = evaluator.predict(current_sequences)
                    stage_mask = result["predictions"]
                    stage_outputs[key] = stage_mask
                    sequence_mask[sequence_mask] = stage_mask
                    current_sequences = [
                        seq for i, seq in enumerate(current_sequences) if stage_mask[i]
                    ]
            # check sequences are not empty
            if len(current_sequences) == 0:
                logger.warning("No sequences left after filtering")
                return np.zeros(len(sequences)), {}

            # Evaluate all evaluators (score and other types)
            if self._should_parallelize_evaluators():
                stage_outputs = self._parallel_evaluate_evaluators(
                    current_sequences)
            else:
                # Sequential evaluation for all evaluators
                stage_outputs = {}
                for key, evaluator in self.evaluators.items():
                    result = evaluator.predict(current_sequences)
                    scores = result["predictions"]
                    stage_outputs[key] = scores

            self.log_stage_outputs(stage_outputs)
            # Compute final scores from stage outputs
            final_scores = self.build_fitness_from_stage_outputs(stage_outputs)

            if self.results_logger is not None:
                step_value = getattr(
                    getattr(self, "generator", None), "generation_count", 0
                )
                self.results_logger.log_sequences(
                    sequences=current_sequences,
                    fitness=final_scores,
                    step=step_value,
                )

            return final_scores, stage_outputs

        # only use specific type of evaluator if provided
        elif evaluator_type in self.evaluators:
            result = self.evaluators[evaluator_type].predict(
                sequences, batch_size=self.batch_size
            )
            scores = result["predictions"]
            return scores, {evaluator_type: scores}
        else:
            raise ValueError(f"Unknown evaluator type: {evaluator_type}")

    def run(self):

        self._create_trajectory_run()
        model_ids = getattr(self.cfg, "model_ids", None)
        predictors = self.load_predictors(model_ids)
        self.create_evaluators(predictors)
        generator = self.create_generator()
        logger.info("Starting Generator")
        all_sequences, all_fitnesses = generator.run_generator()
        best_sequence, best_fitness = generator.get_best_solution()

        logger.info("Optimization completed!")
        logger.info(f"Generated {len(all_sequences)} total sequences")
        logger.info(f"Best sequence: {best_sequence}")
        logger.info(f"Best fitness: {best_fitness:.4f}")
        logger.info(
            f"Training data fitness range: [{np.min(self.data.get_labels()):.3f}, {np.max(self.data.get_labels()):.3f}]"
        )

        final_metrics = {
            "final_best_fitness": best_fitness,
            "total_sequences_generated": len(all_sequences),
        }
        self.results_logger.log_metrics(final_metrics)
        if all_fitnesses:
            fitness_stats = {
                "overall_min_fitness": np.min(all_fitnesses),
                "overall_max_fitness": np.max(all_fitnesses),
                "overall_mean_fitness": np.mean(all_fitnesses),
                "overall_std_fitness": np.std(all_fitnesses),
                "overall_median_fitness": np.median(all_fitnesses),
            }
            self.results_logger.log_metrics(fitness_stats)
        self._log_final_results(all_sequences, all_fitnesses)

    def shutdown(self):
        """Shutdown the inference engine."""
        self.schedule_shutdown = True
        if hasattr(self, "generator") and hasattr(self.generator, "shutdown"):
            self.generator.shutdown()

        if hasattr(self, "embedding_managers"):
            for embedder_key, embedding_manager in self.embedding_managers.items():
                if hasattr(embedding_manager, "shutdown"):
                    embedding_manager.shutdown()
                    logger.info(
                        f"Shutdown embedding manager for {embedder_key}")

        if hasattr(self, "results_logger") and hasattr(self.results_logger, "shutdown"):
            self.results_logger.shutdown()

    def _run_random_search(self, num_sequences: int):
        """Run random search. no generator needed"""
        # Create first trajectory run
        self._create_trajectory_run()

        # Load predictors - pass model_ids from config if available
        model_ids = getattr(self.cfg, "model_ids", None)
        predictors = self.load_predictors(model_ids)
        self.create_evaluators(predictors)

        all_sequences = []
        logger.info(f"Generating {num_sequences} random sequences")
        aa_per_position = self.aa_per_position
        num_positions = len(aa_per_position)

        # Build a 2D numpy array of possible amino acids (pad positions for ragged arrays)
        max_aa_count = max(len(pos) for pos in aa_per_position)
        aa_matrix = np.full((num_positions, max_aa_count), '', dtype='<U1')
        for i, pos in enumerate(aa_per_position):
            aa_matrix[i, :len(pos)] = pos

        # For each position, randomly sample indices
        random_indices = np.stack(
            [np.random.randint(len(aa_per_position[pos]), size=num_sequences)
             for pos in range(num_positions)]
        )  # shape: (num_positions, num_sequences)

        # Gather the amino acids for all sequences at all positions
        # shape: (num_positions, num_sequences)
        sequences_array = aa_matrix[np.arange(
            num_positions)[:, None], random_indices]

        # Combine per-position into sequence strings (across axis 0)
        all_sequences = ["".join(seq) for seq in sequences_array.T]

        # Score sequence and log results
        logger.info("Score Sequences")
        fitness, _ = self.batch_score(all_sequences)
        self._log_final_results(all_sequences, fitness)


@hydra.main(version_base=None, config_path="conf", config_name="inference")
def main(cfg: DictConfig):
    """Main entry point for inference."""
    # Register custom Hydra resolvers
    register_resolvers()

    inference = HAIPRInference(cfg)

    if cfg.inference_mode == "random_search":
        inference._run_random_search(cfg.num_sequences)
    elif cfg.inference_mode == "score_sequences":
        inference.score_sequences(cfg.sequences)
    else:
        try:
            inference.run()

        except KeyboardInterrupt:
            logger.warning(
                "KeyboardInterrupt received. Shutting down inference engine gracefully..."
            )
            if hasattr(inference, "shutdown") and callable(
                getattr(inference, "shutdown")
            ):
                try:
                    inference.shutdown()
                except Exception as shutdown_exc:
                    logger.error(f"Error during shutdown: {shutdown_exc}")
            else:
                logger.info("No shutdown method found on inference engine.")
        except Exception as e:
            logger.error(f"Exception during inference: {e}")
            if hasattr(inference, "shutdown") and callable(
                getattr(inference, "shutdown")
            ):
                try:
                    inference.shutdown()
                except Exception as shutdown_exc:
                    logger.error(f"Error during shutdown: {shutdown_exc}")


if __name__ == "__main__":
    main()
