from typing import List
import hashlib
import numpy as np
import torch
from omegaconf import DictConfig
import logging
import hydra

logger = logging.getLogger("haipr.cache")


class EmbeddingManager:
    """
    Centralized embedding manager for inference pipeline.
    Handles embedding computation and caching for multiple evaluators.
    """

    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.embedder_instance = None
        self.embedder_config = None
        self.cache: dict[str, np.ndarray] = {}
        # set in Gb depending on device RAM
        self.max_cache_size = cfg.embedding_cache_size
        self.instance_id = id(self)
        logger.info(f"Created CacheManager instance {self.instance_id}")

        # Initialize embedder if available in config
        if hasattr(cfg, "embedder"):
            self.embedder_config = cfg.embedder
            self._initialize_embedder()

    def _initialize_embedder(self):
        """Initialize the embedder instance from configuration."""
        if not self.embedder_config:
            logger.warning("No embedder configuration found.")
            return

        if self.embedder_config.name != "protenc":
            raise NotImplementedError(
                f"Embedder '{self.embedder_config.name}' not supported. Only 'protenc' is implemented."
            )

        try:
            import protenc

            model_name = self.embedder_config.model
            batch_size = getattr(self.embedder_config, "batch_size", 32)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            data_parallel = getattr(self.embedder_config, "ddp", False)

            logger.info(f"Initializing protenc embedder: {model_name}")

            self.embedder_instance = protenc.get_encoder(
                model_name,
                device=device,
                batch_size=batch_size,
                data_parallel=data_parallel,
            )

            logger.info("Protenc embedder initialized successfully")

        except Exception as e:
            logger.error(f"Failed to initialize protenc embedder: {e}")
            raise ValueError(f"Could not initialize protenc embedder: {e}")

    def _get_cache_key(self, sequences: List[str]) -> str:
        """Generate a cache key for embeddings based on sequences and embedder config."""
        sorted_sequences = sorted(sequences)
        logger.debug(
            f"Generating cache key for sequences s[:10]: {[s[:10] for s in sorted_sequences[:3]]}... (showing first 3)"
        )
        sequences_str = "|".join(sorted_sequences)
        if self.embedder_config:
            config_items = [
                ("model", getattr(self.embedder_config, "model", "")),
                (
                    "average_sequence",
                    getattr(self.embedder_config, "average_sequence", True),
                ),
            ]
            config_str = str(sorted(config_items))
        else:
            config_str = ""
        combined = f"{sequences_str}#{config_str}"
        return hashlib.md5(combined.encode()).hexdigest()

    def _manage_cache_size(self):
        """Remove oldest entries if cache exceeds maximum size."""
        if len(self.cache) > self.max_cache_size:
            # Remove oldest entries (FIFO)
            keys_to_remove = list(self.cache.keys())[
                : len(self.cache) - self.max_cache_size + 1
            ]
            for key in keys_to_remove:
                del self.cache[key]
            logger.debug(
                f"Removed {len(keys_to_remove)} old cache entries to manage memory"
            )

    def get_embeddings(self, sequences: List[str]) -> np.ndarray:
        """
        Get embeddings for sequences, using cache if available.

        Args:
            sequences: List of protein sequences to embed

        Returns:
            np.ndarray: Embedded features for the sequences
        """
        if not self.embedder_instance:
            raise RuntimeError(
                "No embedder available. Please ensure embedder is configured."
            )

        # Check cache first
        cache_key = self._get_cache_key(sequences)
        logger.debug(
            f"CacheManager {self.instance_id}: Cache key for {len(sequences)} sequences: {cache_key[:16]}..."
        )

        if cache_key in self.cache:
            logger.debug(
                f"CacheManager {self.instance_id}: CACHE HIT - Loading embeddings from cache for {len(sequences)} sequences"
            )
            return self.cache[cache_key]

        logger.debug(
            f"CacheManager {self.instance_id}: CACHE MISS - Computing embeddings for {len(sequences)} sequences"
        )

        try:
            # Use embedder to generate features
            embeddings_list = []
            average_embeddings = getattr(
                self.embedder_config, "average_sequence", True)
            if self.embedder_config.model.startswith("esm2"):
                sequences = [s.replace("|", "<sep>") for s in sequences]

            for embed_output in self.embedder_instance(
                sequences,
                average_sequence=average_embeddings,
                return_format="numpy",
            ):
                embeddings_list.append(embed_output)

            # Stack all embeddings into a single array
            features = np.vstack(embeddings_list)
            logger.debug(f"Generated features with shape: {features.shape}")

            # Cache the results
            self.cache[cache_key] = features
            self._manage_cache_size()
            logger.debug(
                f"CacheManager {self.instance_id}: CACHED embeddings for {len(sequences)} sequences (total cache entries: {len(self.cache)})"
            )

            return features

        except Exception as e:
            logger.error(f"Error generating embeddings: {e}")
            raise RuntimeError(f"Failed to generate embeddings: {e}")

    def clear_cache(self):
        """Clear the embedding cache."""
        cache_size = len(self.cache)
        self.cache.clear()
        logger.info(f"Cleared embedding cache ({cache_size} entries)")

    def get_cache_stats(self):
        """Get statistics about the embedding cache."""
        total_entries = len(self.cache)
        total_memory_gb = sum(arr.nbytes for arr in self.cache.values()) / (
            1024 * 1024 * 1024
        )
        return {
            "total_entries": total_entries,
            "total_memory_gb": total_memory_gb,
            "max_cache_size": self.max_cache_size,
        }
