from typing import Any, Dict, Optional
import torch
from omegaconf import DictConfig
from torch.nn import MSELoss

from haipr.models.module import HAIPRModule
from haipr.predictor import BasePredictor
from haipr.data import HAIPRData

# Import ProteinNPTModel and Alphabet utilities
from haipr.models.ProteinNPT.proteinnpt.proteinnpt.model import ProteinNPTModel
from haipr.models.ProteinNPT.proteinnpt.utils.esm.data import Alphabet


class ProteinNPTPredictor(HAIPRModule, BasePredictor):
    """
    HAIPR Predictor for the ProteinNPT (Neural Process Transformer) model.
    Inherits from HAIPRModule and BasePredictor for full integration.
    """

    def __init__(
        self,
        cfg: Optional[DictConfig] = None,
        criterion=None,
        num_classes: int = 0,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        **kwargs,
    ):
        self.cfg = cfg
        self.data = None
        self.alphabet = None
        # Build alphabet from config
        if self.cfg is not None:
            arch = getattr(self.cfg.model, "arch", "ESM-1b")
            self.alphabet = Alphabet.from_architecture(arch)
            model_args = (
                self.cfg.model.params if hasattr(self.cfg.model, "params") else {}
            )
            model_args["model_type"] = "ProteinNPT"
            import types

            args_ns = types.SimpleNamespace(**model_args)
            model = ProteinNPTModel(args=args_ns, alphabet=self.alphabet)
        else:
            model = None
        # Default to MSELoss for regression if not provided
        if criterion is None:
            criterion = MSELoss()
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )

    def setup_data(self, data: HAIPRData) -> None:
        """Store the HAIPRData and extract config if not already set."""
        self.data = data
        if self.cfg is None:
            self.cfg = data.config
            arch = getattr(self.cfg.model, "arch", "ESM-1b")
            self.alphabet = Alphabet.from_architecture(arch)
            model_args = (
                self.cfg.model.params if hasattr(self.cfg.model, "params") else {}
            )
            model_args["model_type"] = "ProteinNPT"
            import types

            args_ns = types.SimpleNamespace(**model_args)
            self.model = ProteinNPTModel(args=args_ns, alphabet=self.alphabet)

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: Any,
        val_indices: Any,
        trainer_instance: Optional[object] = None,
    ) -> Dict[str, Any]:
        """
        Train the ProteinNPT model on the given dataset and indices.
        Returns metrics and predictions for the validation set.
        """
        self.setup_data(dataset)
        # TODO: Integrate with Lightning Trainer for actual training
        self.trained = True
        metrics = {}
        predictions = {}
        return {"metrics": metrics, "predictions": predictions}

    def load_model(self, model_path: str) -> None:
        """Load a trained ProteinNPT model from disk."""
        checkpoint = torch.load(model_path, map_location="cpu")
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.trained = True

    def predict(self, data: Any) -> Dict[str, Any]:
        """Make predictions on given data (expects HAIPRData or compatible batch)."""
        if not hasattr(self, "trained") or not self.trained:
            raise RuntimeError("Model must be trained or loaded before prediction.")
        # TODO: Implement batching and data processing for prediction
        results = {}
        return results
