import torch.nn as nn
import torch
from omegaconf import DictConfig, OmegaConf
from haipr.models.module import HAIPRModule
from haipr.data import HAIPRData
from haipr.utils import loss_funcs
from typing import Dict, Any, List, Optional
import numpy as np
import logging
import lightning.pytorch as pl

logger = logging.getLogger(__name__)


class MLP(nn.Module):
    """
    A simple Multi-layer Perceptron model.
    """

    def __init__(
        self,
        input_dim,
        output_dim,
        num_layers,
        hidden_dim=1024,
        dropout=0.1,
        activation=nn.ReLU,
        layer_norm=True,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.layer_norm = layer_norm
        # Instantiate the activation function
        self.activation = activation() if isinstance(activation, type) else activation
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        # add layer normalization, putting all values between 0 and 1
        self.layers.append(nn.LayerNorm(hidden_dim))

        for _ in range(num_layers - 1):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(self.activation)

        self.layers.append(nn.Linear(hidden_dim, output_dim))

    def normalize(self, x):
        # normalize the input horizontally
        # (all values in the same row go to 0 mean and 1 std)
        return (x - x.mean(dim=1, keepdim=True)) / x.std(dim=1, keepdim=True)

    def forward(self, x):
        if self.layer_norm:
            # normalize the input horizontally
            # (all values in the same row go to 0 mean and 1 std)
            x = self.normalize(x)

        for layer in self.layers:
            x = layer(x)
        return x


class MLPPredictor(HAIPRModule):
    """
    A predictor class for Multi-layer Perceptron (MLP) models.

    This class encapsulates the training, prediction, and evaluation
    of MLP models using PyTorch Lightning.
    """

    def __init__(
        self,
        input_dim: Optional[int] = None,
        output_dim: int = 1,
        num_layers: int = 2,
        hidden_dim: int = 1024,
        dropout: float = 0.1,
        num_classes: int = 0,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 32,
        loss: str = "mse",
        **kwargs,
    ):
        """
        Initialize the MLPPredictor.

        Args:
            input_dim: Input dimension (will be inferred from data if None)
            output_dim: Output dimension
            num_layers: Number of hidden layers
            hidden_dim: Hidden dimension size
            dropout: Dropout rate
            num_classes: Number of classes (0 for regression)
            learning_rate: Learning rate
            weight_decay: Weight decay
            batch_size: Batch size
            loss: Loss function name
            **kwargs: Additional keyword arguments.
        """
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        
        # Initialize MLP model (will be set up properly in setup_model)
        mlp_model = MLP(
            input_dim=input_dim or 1280,  # Temporary, will be updated in setup_model
            output_dim=output_dim,
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            dropout=dropout,
        )
        
        # Get loss function
        criterion = loss_funcs.get(loss, nn.MSELoss())
        if criterion is None:
            logger.warning(f"Loss function '{loss}' not found. Defaulting to MSELoss.")
            criterion = nn.MSELoss()
        
        # Initialize HAIPRModule
        super().__init__(
            model=mlp_model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            batch_size=batch_size,
            **kwargs,
        )
        
        self.embedder_instance = None
        self.embedder_config = None
        self.embedding_manager = None
        logger.info(f"Initialized MLP predictor")

    @classmethod
    def from_params(cls, params: DictConfig):
        return cls(**OmegaConf.to_container(params, resolve=True))

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        """
        Sets up the data and embedder for the predictor.

        Args:
            data (HAIPRData): The dataset to be used.
            cfg (DictConfig): Configuration containing embedder information.
        """
        self.data = data
        self.cfg = cfg
        
        if self.data.features_loaded and self.input_dim is None:
            # Infer input_dim from data
            X, _ = self.data[0:1]
            if isinstance(X, torch.Tensor):
                self.input_dim = int(X.shape[-1])
            elif isinstance(X, np.ndarray):
                self.input_dim = int(X.shape[-1] if len(X.shape) > 1 else len(X))
            else:
                # Fallback for other types
                self.input_dim = len(X) if hasattr(X, '__len__') else None
            logger.info(f"Inferred input_dim={self.input_dim} from data")
            
            # Reinitialize model with correct input_dim
            if self.input_dim is not None:
                self.model = MLP(
                    input_dim=self.input_dim,
                    output_dim=self.output_dim,
                    num_layers=self.num_layers,
                    hidden_dim=self.hidden_dim,
                    dropout=self.dropout,
                ).to(self.device)
                logger.info("MLP model reinitialized with inferred input_dim")

        if cfg and hasattr(cfg, "embedder"):
            self.embedder_config = cfg.embedder
            if not hasattr(self, "embedding_manager") or self.embedding_manager is None:
                self._initialize_embedder()

        logger.info("Data setup for MLPPredictor complete.")

    def _initialize_embedder(self):
        """Initialize the embedder instance from configuration."""
        if not self.embedder_config:
            logger.warning(
                "No embedder configuration found. Cannot initialize embedder."
            )
            return

        if self.embedder_config.name != "protenc":
            raise NotImplementedError(
                f"Embedder '{self.embedder_config.name}' not supported. Only 'protenc' is implemented."
            )

        try:
            import protenc

            model_name = self.embedder_config.model
            batch_size = getattr(self.embedder_config, "batch_size", 32)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            data_parallel = getattr(self.embedder_config, "data_parallel", False)

            logger.info(f"Initializing protenc embedder: {model_name}")

            self.embedder_instance = protenc.get_encoder(
                model_name,
                device=device,
                batch_size=batch_size,
                data_parallel=data_parallel,
            )

            logger.info("Protenc embedder initialized successfully")

        except Exception as e:
            logger.error(f"Failed to initialize protenc embedder: {e}")
            raise ValueError(f"Could not initialize protenc embedder: {e}")

    def load_model(self, model: str):
        """Load the model from checkpoint."""
        self.model.load_state_dict(torch.load(model, map_location=self.device))
        logger.info(f"Loaded MLP model from {model}")

    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, torch.Tensor]:
        """Get pre-computed embeddings from HAIPRData for training."""
        X, y = dataset[indices]
        if isinstance(X, torch.Tensor):
            features = X.to(self.device)
        else:
            features = torch.tensor(X, dtype=torch.float32).to(self.device)
        
        if isinstance(y, torch.Tensor):
            labels = y.to(self.device)
        else:
            labels = torch.tensor(y, dtype=torch.float32).to(self.device)
        
        # Ensure labels are 2D for regression
        if self.num_classes == 0 and len(labels.shape) == 1:
            labels = labels.unsqueeze(1)

        return {"features": features, "labels": labels}

    def prepare_batch_features(
        self, batch_items: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """Prepare features for a single batch during inference."""
        sequences = [item["sequence"] for item in batch_items]
        embeddings = self._get_embeddings_for_sequences(sequences)
        features = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
        return {"inputs": {"features": features}}

    def _get_embeddings_for_sequences(self, sequences: List[str]) -> np.ndarray:
        """Get embeddings for sequences using available embedding source."""
        if hasattr(self, "embedding_manager") and self.embedding_manager is not None:
            logger.debug(f"Using EmbeddingManager for {len(sequences)} sequences")
            return self.embedding_manager.get_embeddings(sequences)

        if not hasattr(self, "embedder_instance") or self.embedder_instance is None:
            if hasattr(self, "embedder_config") and self.embedder_config:
                self._initialize_embedder()
            else:
                raise RuntimeError(
                    "No embedding source available. Need either embedding_manager "
                    "(from inference) or embedder_config (for local embedder)."
                )

        logger.debug(f"Using local embedder for {len(sequences)} sequences")
        embeddings_list = []
        average_embeddings = getattr(self.embedder_config, "average_sequence", True)

        if self.embedder_instance is not None:
            for embed_output in self.embedder_instance(
                sequences,
                average_sequence=average_embeddings,
                return_format="numpy",
            ):
                embeddings_list.append(embed_output)

        return np.vstack(embeddings_list)

    def forward(self, batch: Dict[str, Any]):
        """Forward pass of the model."""
        if "inputs" in batch:
            features = batch["inputs"]["features"]
        else:
            # Fallback for direct feature input
            features = batch.get("features", batch)
        
        return self.model(features)

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: Any,
        val_indices: Any,
        trainer_instance: Optional[pl.Trainer] = None,
        cfg: Optional[DictConfig] = None,
    ) -> Dict[str, Any]:
        """Fit MLP model using pre-computed embeddings from HAIPRData."""
        if trainer_instance is None:
            trainer_instance = pl.Trainer(max_epochs=10)
        
        self.data = dataset

        # Prepare features once for all data (train + val)
        all_indices = np.concatenate([train_indices, val_indices])
        features_dict = self.prepare_training_features(dataset, all_indices)

        # Create DataLoaders using base class method
        train_loader, val_loader = self._create_dataloaders(
            features_dict=features_dict,
            labels=features_dict["labels"],
            train_indices=np.arange(len(train_indices)),
            val_indices=np.arange(len(train_indices), len(all_indices)),
            batch_size=self.batch_size,
            shuffle_train=True,
        )

        trainer_instance.fit(self, train_loader, val_loader)
        
        predictions = self.best_val_predictions
        metrics = self.best_val_metrics

        if predictions is None:
            raise RuntimeError("No predictions available after training")

        pred_dict = {
            "indices": (
                val_indices.tolist()
                if hasattr(val_indices, "tolist")
                else list(val_indices)
            ),
            "predictions": predictions["preds"].tolist(),
            "true_values": predictions["labels"].tolist(),
        }
        if "probs" in predictions:
            pred_dict["probabilities"] = predictions["probs"].tolist()

        return {"metrics": metrics, "predictions": pred_dict}

    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None = None
    ) -> Dict[str, Any]:
        """Make predictions on sequences (pyfunc interface)."""
        # Use parent class method which handles batching and inference
        return super().predict_sequences(sequences, params)

    def load_context(self, context):
        """Load MLP model from MLflow artifacts."""
        # Call parent's load_context first
        super().load_context(context)
        
        if hasattr(self, "cfg") and hasattr(self.cfg, "embedder"):
            self.embedder_config = self.cfg.embedder
            logger.info("Set embedder config from loaded config")

    def save_model(self, save_dir: str) -> str:
        """Save the model (for checkpoint restoration during training)."""
        import os
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, "model.pt")
        torch.save(self.model.state_dict(), save_path)
        logger.info(f"Saved MLP model to {save_path}")
        return save_path

