from sklearn.ensemble import RandomForestRegressor
from omegaconf import DictConfig, OmegaConf
from haipr.predictor import BasePredictor
from haipr.data import HAIPRData
from haipr.utils import compute_regression_metrics
from typing import Dict, Any, List, Tuple, Union
import numpy as np
import logging
import torch

logger = logging.getLogger(__name__)


class RFPredictor(BasePredictor):
    """
    A predictor class for Random Forest Regression models.

    This class encapsulates the training, prediction, and evaluation
    of Random Forest models, conforming to the BasePredictor interface.
    """

    def __init__(
        self,
        params: DictConfig = None,
        n_estimators: int = 100,
        max_depth: int = None,
        min_samples_split: int = 2,
        min_samples_leaf: int = 1,
        max_features: Union[str, int, float, None] = "sqrt",
        **kwargs,
    ):
        """
        Initialize the RFPredictor.

        Args:
            params (DictConfig): Configuration for Random Forest model parameters.
            **kwargs: Additional keyword arguments.
        """
        super().__init__()
        self.model = RandomForestRegressor(
            n_estimators=n_estimators,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            max_features=max_features,
        )
        logger.info(f"Initialized Random Forest model with parameters: {params}")

    @classmethod
    def from_params(cls, params: DictConfig):
        return cls(**params)

    def setup_model(self, data: HAIPRData = None, cfg: DictConfig = None):
        """
        Sets up the data and embedder for the predictor.

        Args:
            data (HAIPRData): The dataset to be used.
            cfg (DictConfig): Configuration containing embedder information.
        """
        self.embedder_instance = None
        self.embedder_config = None
        self.embedding_manager = None  # Will be set by inference pipeline
        if data is not None:
            self.data = data
            if self.data.features_loaded:
                logger.info(
                    "Features already loaded. Skipping embedder initialization."
                )
                return
        if cfg and hasattr(cfg, "embedder"):
            self.embedder_config = cfg.embedder
            # Only initialize local embedder if no embedding manager is available
            if not hasattr(self, "embedding_manager") or self.embedding_manager is None:
                self._initialize_embedder()
        logger.info("Data setup for RFPredictor complete.")

    def load_model(self, model_path: str):
        pass

    def save_model(self, save_dir: str) -> str:
        """Save the Random Forest model to disk."""
        import pickle
        import os

        os.makedirs(save_dir, exist_ok=True)
        model_path = os.path.join(save_dir, "model.pkl")
        with open(model_path, "wb") as f:
            pickle.dump(self.model, f)
        logger.info(f"Saved Random Forest model to {model_path}")
        return model_path

    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, np.ndarray]:
        """Get pre-computed embeddings from HAIPRData for training."""
        # HAIPRData already has embeddings computed and cached
        # Just extract features and labels for the given indices
        X, y = dataset[indices]
        return {
            "features": X.numpy() if isinstance(X, torch.Tensor) else X,
            "labels": y.numpy() if isinstance(y, torch.Tensor) else y,
        }

    def prepare_batch_features(
        self, batch_items: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """For sklearn models, this is not used during training.
        Only used if someone manually calls it during inference."""
        sequences = [item["sequence"] for item in batch_items]
        embeddings = self._get_embeddings_for_sequences(sequences)
        return {"features": embeddings}

    def prepare_features(self, sequences: List[str]) -> np.ndarray:
        """
        Prepare the features for the Random Forest model using embeddings.
        Uses the embedding manager from inference if available, otherwise fallback to local embedder.

        Args:
            sequences: List of protein sequences to embed.

        Returns:
            np.ndarray: Embedded features for the sequences.
        """
        return self._prepare_features_from_sequences(sequences)

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: List[int],
        val_indices: List[int],
        trainer_instance: Any = None,  
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """Fit Random Forest model using pre-computed embeddings from HAIPRData."""
        if not hasattr(self, "data"):
            self.setup_model(dataset, cfg)

        X_train, y_train = self.data[train_indices]
        X_val, y_val = self.data[val_indices]

        # Convert tensors to numpy if needed
        if isinstance(X_train, torch.Tensor):
            X_train = X_train.numpy()
            y_train = y_train.numpy()
        if isinstance(X_val, torch.Tensor):
            X_val = X_val.numpy()
            y_val = y_val.numpy()

        logger.info(f"Training Random Forest model on {len(train_indices)} samples")
        self.model.fit(X_train, y_train)

        logger.info(f"Evaluating Random Forest model on {len(val_indices)} validation samples")
        y_pred = self.model.predict(X_val)

        metrics = compute_regression_metrics(y_val, y_pred)
        pred_dict = {
            "indices": val_indices,
            "predictions": y_pred.tolist(),
            "true_values": y_val.tolist(),
        }

        return {"metrics": metrics, "predictions": pred_dict}

    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None = None
    ) -> np.ndarray:
        """Make predictions on sequences (pyfunc interface)."""
        if not hasattr(self.model, "estimators_"):
            raise RuntimeError("Model has not been trained yet.")

        logger.info(f"Making predictions with Random Forest model on {len(sequences)} sequences")

        # Get embeddings using the unified method
        embeddings = self._get_embeddings_for_sequences(sequences)
        predictions = self.model.predict(embeddings)

        return predictions

    def load_context(self, context):
        """Load Random Forest model from MLflow artifacts."""
        super().load_context(context)
        
        # Load Random Forest model from pickle
        if "model" in context.artifacts:
            import pickle
            model_path = context.artifacts["model"]
            with open(model_path, "rb") as f:
                self.model = pickle.load(f)
            logger.info(f"Loaded Random Forest model from {model_path}")

    def forward(self, batch: Any) -> Any:
        """
        Forward pass for the Random Forest model.
        For scikit-learn models, this typically means calling predict.

        Args:
            batch (Any): Input data. Expected to be a list of sequences or pre-computed features.

        Returns:
            Any: Predictions from the model.
        """
        # Handle both raw sequences and pre-computed features
        if isinstance(batch, list) and len(batch) > 0 and isinstance(batch[0], str):
            # batch is a list of sequences - use predict_sequences
            return self.predict_sequences(batch, params=None)
        else:
            # batch is pre-computed features - call model directly
            return self.model.predict(batch)

