import logging
import torch
import torch.nn as nn
import numpy as np
from omegaconf import DictConfig
from typing import Dict, Any, List, Optional, Union
import lightning.pytorch as pl
from torch.utils.data import DataLoader, Subset, Dataset

from haipr.predictor import BasePredictor
from haipr.models.module import HAIPRModule
from haipr.data import HAIPRData
from haipr.utils import loss_funcs

logger = logging.getLogger(__name__)


class HAIPR(HAIPRModule, BasePredictor):
    """
    High-Throughput Affinity Prediction Model
    
    A generic predictor that can be configured to use different underlying models
    and prediction heads for protein affinity prediction tasks.
    """

    def __init__(
        self,
        cfg: DictConfig,
        model: Optional[nn.Module] = None,
        prediction_head: Optional[nn.Module] = None,
        num_classes: int = 0,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 1,
        loss_fn: str = "mse",
        **kwargs,
    ):
        """
        Initialize the HAIPR predictor.
        
        Args:
            cfg: Configuration dictionary
            model: PyTorch model (if None, will be created from config)
            prediction_head: Prediction head module (required)
            num_classes: Number of classes (0 for regression)
            learning_rate: Learning rate for optimization
            weight_decay: Weight decay for optimization
            batch_size: Batch size for training/inference
            loss_fn: Loss function name
            **kwargs: Additional arguments passed to HAIPRModule
        """
        self.cfg = cfg
        
        # Initialize model if not provided
        if model is None:
            model = self._create_model_from_config(cfg)
        
        # Initialize loss function
        criterion = loss_funcs.get(loss_fn, nn.MSELoss())
        
        # Initialize HAIPRModule
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            batch_size=batch_size,
            **kwargs,
        )
        
        # Store prediction head
        if prediction_head is None:
            raise ValueError("prediction_head is required for HAIPR")
        self.prediction_head = prediction_head
        
        # Store additional attributes
        self.data = None
        self.model_name = getattr(cfg.model, 'name', 'haipr')
        self.loss_fn = loss_fn
        
        logger.info(f"Initialized HAIPR predictor with {self.model_name}")

    def _create_model_from_config(self, cfg: DictConfig) -> nn.Module:
        pass


    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: np.ndarray,
        val_indices: np.ndarray,
        trainer_instance: Optional[pl.Trainer] = None,
        cfg: Optional[DictConfig] = None,
    ) -> Dict[str, Any]:
        """
        Fit the model using PyTorch Lightning trainer.
        
        Args:
            dataset: HAIPRData instance
            train_indices: Training set indices
            val_indices: Validation set indices
            trainer_instance: PyTorch Lightning trainer (optional)
            cfg: Configuration dictionary (optional)
            
        Returns:
            Dictionary containing metrics and predictions
        """
        if trainer_instance is None:
            trainer_instance = pl.Trainer(max_epochs=10)
        
        self.data = dataset
        
        # Prepare features once for all data
        all_indices = np.concatenate([train_indices, val_indices])
        features_dict = self.prepare_training_features(dataset, all_indices)
        
        # Convert to tensors if needed
        if isinstance(features_dict, dict):
            for k, v in features_dict.items():
                if isinstance(v, np.ndarray):
                    features_dict[k] = torch.tensor(v)
                elif isinstance(v, list):
                    features_dict[k] = torch.tensor(v)
        
        # Create DataLoaders
        train_loader, val_loader = self._create_dataloaders(
            features_dict=features_dict,
            labels=features_dict["labels"],
            train_indices=np.arange(len(train_indices)),
            val_indices=np.arange(len(train_indices), len(all_indices)),
            batch_size=self.batch_size,
            shuffle_train=True,
        )
        
        # Train the model
        trainer_instance.fit(self, train_loader, val_loader)
        
        # Return results
        predictions = self.best_val_predictions
        metrics = self.best_val_metrics
        
        return {
            "metrics": metrics,
            "predictions": predictions,
        }

    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, torch.Tensor]:
        """
        Prepare features for training.
        
        Args:
            dataset: HAIPRData instance
            indices: Indices to prepare features for
            
        Returns:
            Dictionary of features ready for training
        """
        # Get sequences and labels for the specified indices
        sequences = []
        labels = []
        for i in indices:
            item = dataset[i]
            if isinstance(item, dict):
                # For sequence-based features
                sequences.append(item["sequence"])
                labels.append(item["labels"])
            else:
                # For embedded features, item is a tuple (features, labels)
                features, label = item
                sequences.append(features)  # Use features directly for embedded type
                labels.append(label)
        
        # Prepare batch features
        if isinstance(sequences[0], str):
            # Sequence-based features
            batch_items = [{"sequence": seq} for seq in sequences]
            features = self.prepare_batch_features(batch_items)
        else:
            # Embedded features - sequences is actually a list of feature tensors
            features = {"embeddings": torch.stack(sequences)}
        
        # Add labels
        features["labels"] = torch.stack(labels) if isinstance(labels[0], torch.Tensor) else torch.tensor(labels, dtype=torch.float32)
        
        return features

    def prepare_batch_features(self, batch_items: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Prepare features for a single batch during inference.
        
        Args:
            batch_items: List of items containing sequences
            
        Returns:
            Dictionary of features ready for model prediction
        """
        
        return {
            "features": embeddings,
        }

    def load_model(self, model: str) -> None:
        """
        Load the model from a checkpoint.
        
        Args:
            model: Path to the model checkpoint
        """
        checkpoint = torch.load(model, map_location=self.device)
        self.load_state_dict(checkpoint["state_dict"])
        logger.info(f"Loaded model from {model}")

    def _load_artifacts(self, context) -> None:
        """
        Load model-specific artifacts from MLflow context.
        
        Args:
            context: MLflow context containing artifacts
        """
        # Load model weights if available
        if "model" in context.artifacts:
            model_path = context.artifacts["model"]
            self.load_model(model_path)
        
        logger.info("Loaded HAIPR artifacts from MLflow context")

    def predict_sequences(
        self, sequences: List[str], batch_size: int = 1, **kwargs
    ) -> Union[np.ndarray, Dict[str, Any]]:
        """
        Make predictions on sequences.
        
        Args:
            sequences: List of protein sequences
            batch_size: Batch size for processing
            **kwargs: Additional arguments
            
        Returns:
            Predictions as numpy array or dictionary
        """
        self.model.eval()
        
        predictions_list = []
        
        with torch.no_grad():
            for i in range(0, len(sequences), batch_size):
                batch_sequences = sequences[i:i + batch_size]
                batch_items = [{"sequence": seq} for seq in batch_sequences]
                
                # Prepare features
                features = self.prepare_batch_features(batch_items)
                
                # Forward pass through base model
                base_output = self.model(features["embeddings"])
                
                # Forward pass through prediction head
                predictions = self.prediction_head(base_output)
                
                predictions_list.append(predictions.cpu())
        
        # Concatenate all predictions
        predictions_tensor = torch.cat(predictions_list, dim=0)
        predictions_np = predictions_tensor.numpy()
        
        # Return appropriate format based on task type
        if self.num_classes == 0:  # Regression
            return predictions_np.flatten()
        else:  # Classification
            return {
                "predictions": predictions_np,
                "probabilities": torch.softmax(predictions_tensor, dim=-1).numpy(),
            }