"""MPNN-based predictor.

This class follows the same structure as other predictors (ESM2/ESMC/ESM3)
and implements the BasePredictor interface. MPNN-specific feature preparation
can be wired to functions in `benchmark/BindingGYM/baselines/protein_mpnn/protein_mpnn_utils.py`.

Note: The referenced utils module in this repository appears incomplete. This
predictor avoids hard imports to keep the package importable. Wire the real
model/featurization in `setup_model` and `prepare_features` when ready.
"""

from __future__ import annotations

import logging
import json
import copy
from typing import Any, Dict, Optional, List, Tuple

from peft import LoraConfig, get_peft_model
import os
import torch
import torch.nn as nn
import lightning.pytorch as pl

from omegaconf import DictConfig
from haipr.models.protein_mpnn_utils import ProteinMPNN
from haipr.predictor import BasePredictor
from haipr.models.module import HAIPRModule
from haipr.data import HAIPRData
from haipr.utils import loss_funcs
from torch.utils.data import Subset, DataLoader, Dataset

logger = logging.getLogger(__name__)


class MPNNDataset(Dataset):
    def __init__(self, sequences: List[str], labels: List[float], batch_size: int):
        self.sequences = sequences
        self.labels = labels
        self.batch_size = batch_size

    def __getitem__(self, idx):
        return self.sequences[idx]

    def __len__(self):
        return len(self.sequences)


class MPNNPredictor(HAIPRModule, BasePredictor):
    """MPNN-specific predictor implementation using the HAIPR Lightning module.

    The backbone is set to a no-op by default to avoid hard dependency on the
    ProteinMPNN implementation. Replace it in `setup_model` when available.
    """

    def __init__(
        self,
        mpnn,
        name: str = "MPNNPredictor",
        model_name: str = "protein_mpnn",
        model_name_or_path: str = "<path_to_mpnn_model>",
        num_classes: int = 0,
        prediction_head: Optional[nn.Module] = None,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 1,
        loss_fn: str = "mse",
        **kwargs: Any,
    ) -> None:
        # Loss selection consistent with other predictors
        criterion = loss_funcs.get(loss_fn)
        if criterion is None:
            logger.warning(
                "Unknown loss_fn '%s'. Falling back to MSELoss for MPNN.", loss_fn
            )
            criterion = nn.MSELoss()
        # Store simple configuration values (safe before base init)
        self.model_name = model_name
        self.model_name_or_path = model_name_or_path
        self.batch_size = batch_size

        # Build backbone
        model = ProteinMPNN(**mpnn)
        # Optionally load weights if a valid path was provided
        if (
            isinstance(model_name_or_path, str)
            and model_name_or_path
            and model_name_or_path not in {"???", "<path_to_mpnn_model>"}
            and os.path.exists(model_name_or_path)
        ):
            try:
                state = torch.load(model_name_or_path, map_location="gpu")
                # Support checkpoints saved as {"state_dict": ...}
                if isinstance(state, dict) and "state_dict" in state:
                    state = state["state_dict"]
                missing, unexpected = model.load_state_dict(
                    state, strict=False)
                if missing or unexpected:
                    logger.info(
                        f"Loaded MPNN weights with missing={len(missing)}, unexpected={len(unexpected)}"
                    )
            except Exception as e:
                logger.warning(
                    f"Could not load MPNN weights from {model_name_or_path}: {e}"
                )
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )
        if prediction_head is None:
            raise ValueError("Need Prediction Head")
        self.prediction_head = prediction_head

        # Runtime state
        self.data = None
        self.cfg = None
        self.pdb_path = None
        self._pdb_template = None

        logger.info(
            "Initialized MPNNPredictor(model_name=%s, num_classes=%d, batch_size=%d)",
            model_name,
            num_classes,
            batch_size,
        )

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig) -> None:
        """Initialize PEFT adapters if supported by the MPNN backbone."""
        # all linnear layers
        target_modules = ["W_e", "W_v", "W_out"]

        peft_config = LoraConfig(
            target_modules=target_modules,
            r=lora_config_dict.get("rank", 8),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get("dropout", 0.0),
        )
        self.model = get_peft_model(self.model, peft_config)

    def setup_model(self, data: HAIPRData, cfg: DictConfig) -> None:
        """Attach dataset/cfg and optionally initialize the real MPNN backbone.

        Replace the identity backbone with a real ProteinMPNN model here when
        the implementation is available.
        """
        from haipr.models.protein_mpnn_utils import (
            parse_PDB,
        )

        self.data = data
        self.cfg = cfg
        # Accept various naming schemes; do not hard-fail on name

        # Save PDB path if available
        self.pdb_path = None
        if hasattr(data, "pdb") and data.pdb:
            self.pdb_path = data.pdb
        elif (
            cfg is not None
            and hasattr(cfg, "benchmark")
            and getattr(cfg.benchmark, "pdb", None)
        ):
            self.pdb_path = cfg.benchmark.pdb

        if self.pdb_path is None:
            logger.warning(
                "No PDB path found in data.cfg; MPNN requires structure to featurize."
            )
        else:
            # Cache parsed PDB template
            try:
                ca_only = getattr(cfg.model, "ca_only", False)
                if self.data.focus:
                    input_chain_list = self.data.variable_chains
                else:
                    input_chain_list = None
                pdb_list = parse_PDB(
                    self.pdb_path, input_chain_list=input_chain_list, ca_only=ca_only
                )
                if not pdb_list:
                    raise ValueError("parse_PDB returned empty list")
                self._pdb_template = pdb_list[0]
            except Exception as e:
                logger.error(f"Failed to parse PDB at {self.pdb_path}: {e}")
                self._pdb_template = None

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices: Any,
        val_indices: Any,
        trainer_instance: Any = None,
        cfg: DictConfig = DictConfig({}),
    ) -> Dict[str, Any]:
        """Train the MPNN predictor on the provided dataset.

        Returns a dict with keys: metrics, predictions.
        """
        # Store references
        self.data = dataset
        self.cfg = cfg

        if trainer_instance is None:
            trainer_instance = pl.Trainer(
                max_epochs=self.hparams.get("num_epochs", 3))

        train_loader = DataLoader(
            Subset(dataset, train_indices),
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.prepare_features,
        )
        val_loader = DataLoader(
            Subset(dataset, val_indices),
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.prepare_features,
        )

        trainer_instance.fit(self, train_loader, val_loader)
        predictions = self.best_val_predictions
        metrics = self.best_val_metrics
        if predictions is None:
            return {"metrics": metrics or {}, "predictions": {}}
        if len(predictions["preds"]) != len(val_indices):
            raise ValueError(
                f"Number of predictions ({len(predictions['preds'])}) does not match validation set size ({len(val_indices)})"
            )
        pred_dict = {
            "indices": (
                val_indices.tolist()
                if hasattr(val_indices, "tolist")
                else list(val_indices)
            ),
            "predictions": predictions["preds"].tolist(),
            "true_values": predictions["labels"].tolist(),
        }
        if "probs" in predictions:
            pred_dict["probabilities"] = predictions["probs"].tolist()
        return {"metrics": metrics, "predictions": pred_dict}

    def load_model(self, model: str) -> None:
        """Load a trained model checkpoint using Lightning."""
        try:
            loaded_module = type(self).load_from_checkpoint(
                model,
                model=self.model,
                criterion=self.criterion,
                model_name=self.model_name,
                num_classes=self.num_classes,
            )
            self.model = loaded_module.model
            self.learning_rate = loaded_module.learning_rate
            logger.info(f"MPNN model loaded from checkpoint: {model}")
        except Exception as e:
            logger.error(
                f"Failed to load MPNN model from {model}: {e}", exc_info=True)
            raise

    def predict(self, sequences: List[str], batch_size: int = 1) -> Dict[str, Any]:
        """Predict labels/logits for given sequences using the current PDB structure."""
        if self.pdb_path is None:
            raise ValueError(
                "MPNN predict requires a PDB path set in setup_model.")

        # Wrap sequences into dicts for prepare_features
        def _predict_collate(seq_batch: List[str]):
            return self.prepare_features(
                [{"sequence": s, "labels": 0.0} for s in seq_batch]
            )

        predict_loader = DataLoader(
            sequences,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=_predict_collate,
            num_workers=0,
        )
        self.model.eval()
        preds_list: List[torch.Tensor] = []
        with torch.no_grad():
            for batch in predict_loader:
                out = self.forward(batch)
                preds_list.append(out.detach().cpu())
        preds = torch.cat(preds_list, dim=0)
        result: Dict[str, Any] = {"predictions": preds.numpy()}
        if self.num_classes > 0:
            result["probabilities"] = torch.softmax(preds, dim=-1).numpy()
        return result

    # --- Forward/inference helpers ---
    def forward(self, batch: Dict[str, Any]) -> Any:  # type: ignore[override]
        """Forward pass: run ProteinMPNN then aggregate per-sequence score and map via head."""
        X = batch["X"]
        S = batch["S"]
        mask = batch["mask"]
        chain_M = batch["chain_M"]
        residue_idx = batch["residue_idx"]
        chain_encoding_all = batch["chain_encoding_all"]
        randn = batch["randn"]

        log_probs, logits, embeds = self.model(
            X=X,
            S=S,
            mask=mask,
            chain_M=chain_M,
            residue_idx=residue_idx,
            chain_encoding_all=chain_encoding_all,
            randn=randn,
        )  # [B, L, 21]

        # Aggregate token embeddings [B,L,H] -> per-sequence [B,H]
        # Use masked mean over positions to avoid padding influence
        pos_mask = (mask * chain_M).float()  # [B,L]
        denom = pos_mask.sum(dim=1, keepdim=True).clamp_min(1.0)  # [B,1]
        seq_embed = (embeds * pos_mask.unsqueeze(-1)
                     ).sum(dim=1) / denom  # [B,H]

        # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
        if embeds.dtype == torch.bfloat16:
            embeds = embeds.float()

        pred = self.prediction_head(embeds.mean(dim=1))  # [B, out_dim]
        if self.num_classes == 0 and pred.ndim == 1:
            pred = pred.unsqueeze(1)
        return pred

    # --- Feature preparation ---
    def prepare_features(self, batch_items: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Prepare MPNN features (inputs, labels) from raw HAIPRData items.

        - Parses PDB once and clones per item
        - Overwrites chain sequences with mutated sequences from items
        - Uses tied_featurize to get tensors for ProteinMPNN
        - honors focus on/off settings and selects chains from pdb accordingly.
        """
        from haipr.models.protein_mpnn_utils import parse_PDB, tied_featurize

        if self._pdb_template is None:
            if self.pdb_path is None:
                raise ValueError(
                    "PDB path not set; cannot featurize for MPNN.")
            if self.data.focus:
                input_chain_list = self.data.variable_chains
            else:
                input_chain_list = None
            pdb_list = parse_PDB(
                self.pdb_path, input_chain_list=input_chain_list, ca_only=False
            )
            if not pdb_list:
                raise ValueError("parse_PDB returned empty list")
            self._pdb_template = pdb_list[0]

        # Helper to map sequence string to chain dict
        def split_sequence(seq_str: str) -> Dict[str, str]:
            # 1. JSON mapping {"A": "...", "B": "..."}
            try:
                if isinstance(seq_str, str) and seq_str.strip().startswith("{"):
                    data = json.loads(seq_str)
                    return {str(k): str(v) for k, v in data.items()}
                elif isinstance(seq_str, dict):  # from data
                    return seq_str
            except Exception:
                pass
            # 2. Separator-based split using configured token if available
            sep = None
            if self.data is not None and hasattr(self.data.config, "data"):
                sep = getattr(self.data.config.data, "separator_token", "|")
            sep = sep or "|"
            if isinstance(seq_str, str) and sep in seq_str:
                parts = seq_str.split(sep)
                # Map in alphabetical chain order starting from available chains
                chain_keys = [
                    k[-1]
                    for k in self._pdb_template.keys()
                    if k.startswith("seq_chain_")
                ]
                chain_keys = sorted(chain_keys)
                mapping = {}
                for i, p in enumerate(parts):
                    if i < len(chain_keys):
                        mapping[chain_keys[i]] = p
                return mapping
            # 3) Single chain -> assign to first available chain
            chain_keys = [
                k[-1] for k in self._pdb_template.keys() if k.startswith("seq_chain_")
            ]
            chain_keys = sorted(chain_keys)
            first = chain_keys[0] if chain_keys else "A"
            return {first: seq_str}

        # Build batch of PDB dicts with overridden sequences
        batch_pdb_entries: List[Dict[str, Any]] = []
        for item in batch_items:
            entry = copy.deepcopy(self._pdb_template)
            chain_map = split_sequence(item["sequence"])  # letter->seq
            # Override chain sequences present in template, aligning to coords length
            for letter, seq in chain_map.items():
                key_seq = f"seq_chain_{letter}"
                key_coords = f"coords_chain_{letter}"
                if key_seq in entry:
                    # determine chain length from coords
                    if key_coords in entry:
                        coords_dict = entry[key_coords]
                        if f"CA_chain_{letter}" in coords_dict:
                            coords_len = len(coords_dict[f"CA_chain_{letter}"])
                        elif f"N_chain_{letter}" in coords_dict:
                            coords_len = len(coords_dict[f"N_chain_{letter}"])
                        else:
                            coords_len = len(seq)
                    else:
                        coords_len = len(seq)
                    # sanitize and align sequence to coords length
                    seq_sanitized = "".join(
                        [c if c != "-" else "X" for c in str(seq)])
                    if len(seq_sanitized) > coords_len:
                        seq_sanitized = seq_sanitized[:coords_len]
                    elif len(seq_sanitized) < coords_len:
                        seq_sanitized = seq_sanitized + (
                            "X" * (coords_len - len(seq_sanitized))
                        )
                    entry[key_seq] = seq_sanitized
            # Rebuild concatenated seq
            chain_keys = [k[-1]
                          for k in entry.keys() if k.startswith("seq_chain_")]
            chain_keys = sorted(chain_keys)
            entry["seq"] = "".join(
                entry[f"seq_chain_{l}"] for l in chain_keys if f"seq_chain_{l}" in entry
            )
            batch_pdb_entries.append(entry)

        # Generate features
        (
            X_out,
            S,
            mask,
            lengths,
            chain_M,
            chain_encoding_all,
            _letter_list,
            _visible,
            _masked,
            _masked_len,
            chain_M_pos,
            omit_AA_mask,
            residue_idx,
            dihedral_mask,
            tied_pos_list,
            pssm_coef_all,
            pssm_bias_all,
            pssm_log_odds_all,
            bias_by_res_all,
            tied_beta,
        ) = tied_featurize(
            batch_pdb_entries,
            device=self.device,
            chain_dict=None,
            fixed_position_dict=None,
            omit_AA_dict=None,
            tied_positions_dict=None,
            pssm_dict=None,
            bias_by_res_dict=None,
            ca_only=False,
        )

        # Randomness for decoding order
        randn = torch.randn_like(mask)

        # Collate labels and coerce to [B,1] for regression
        raw_labels = [item["labels"] for item in batch_items]
        if self.num_classes == 0:
            labels = torch.as_tensor(
                raw_labels, dtype=torch.float32, device=self.device
            )
            # If labels are nested like [[y1],[y2]] or [y1,y2], reduce to first column
            if labels.ndim == 2 and labels.size(1) != 1:
                labels = labels[:, 0]
            if labels.ndim == 1:
                labels = labels.unsqueeze(1)
        else:
            labels = torch.as_tensor(
                raw_labels, dtype=torch.long, device=self.device)

        batch = {
            "X": X_out.to(self.device),
            "S": S.to(self.device),
            "mask": mask.to(self.device),
            "chain_M": chain_M.to(self.device),
            "residue_idx": residue_idx.to(self.device),
            "chain_encoding_all": chain_encoding_all.to(self.device),
            "randn": randn.to(self.device),
            "labels": labels,
        }
        # Pass through sample_id if present
        if "sample_id" in batch_items[0]:
            batch["sample_id"] = torch.tensor(
                [item["sample_id"] for item in batch_items],
                dtype=torch.long,
                device=self.device,
            )
        return batch
