import logging
from mlflow.tracking.fluent import log_params
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from esm.models.esmc import ESMC, ESMCOutput
from typing import Any, Dict, Optional, List
from haipr.models.module import HAIPRModule
import lightning.pytorch as pl
from torch.utils.data import Subset, DataLoader, Dataset
from haipr.data import HAIPRData
from haipr.predictor import BasePredictor
from haipr.utils import loss_funcs
from peft import LoraConfig, get_peft_model
from tqdm import tqdm

logging.basicConfig()
logger = logging.getLogger(__name__)
# Changed to INFO for more visibility during development
logger.setLevel(logging.INFO)


class ESMCDataset(Dataset):
    def __init__(self, sequences: List[str]):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]


class ESMCPredictor(HAIPRModule, BasePredictor):
    """ESMC-specific predictor implementation using PyTorch Lightning."""

    def __init__(
        self,
        model_name: str = "facebook/esm2_t6_8M_UR50D",  # Default ESMC model
        num_classes: int = 0,
        prediction_head: Optional[nn.Module] = None,
        learning_rate: float = 0.001,
        weight_decay: float = 0.05,
        batch_size: int = 1,
        loss_fn: str = "huber",  # Added loss_fn
        **kwargs,
    ):
        logger.info(f"Initializing ESMCPredictor with kwargs: {kwargs}")
        logger.info(
            f"Initializing ESMCPredictor with model_name: {model_name}")
        logger.info(
            f"Initializing ESMCPredictor with num_classes: {num_classes}")
        logger.info(
            f"Initializing ESMCPredictor with prediction_head: {prediction_head}"
        )
        logger.info(
            f"Initializing ESMCPredictor with learning_rate: {learning_rate}")
        logger.info(
            f"Initializing ESMCPredictor with weight_decay: {weight_decay}")
        logger.info(
            f"Initializing ESMCPredictor with batch_size: {batch_size}")

        # Initialize base model
        model = ESMC.from_pretrained(model_name)

        criterion = loss_funcs.get(loss_fn, nn.MSELoss())  # Use loss_fn

        # Initialize HAIPRModule with the model
        logger.debug("Initializing HAIPRModule with model")
        super().__init__(  # HAIPRModule init
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
        )

        if prediction_head is not None:
            self.prediction_head = prediction_head
        else:
            raise ValueError("No prediction head provided")

        # Store additional attributes
        self.model_name = model_name
        # self.prediction_head = prediction_head # Stored in model or self.custom_prediction_head
        self.batch_size = batch_size
        self.data = None  # Initialize data attribute

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        # Custom implementation since esm3 like models are not supported by default
        target_modules = []
        for name, module in self.model.named_modules():
            if name.endswith("attn.layernorm_qkv.1"):
                target_modules.append(name)
            elif name.endswith("attn.out_proj"):
                target_modules.append(name)
            elif name.endswith("ffn.1"):
                target_modules.append(name)
            elif name.endswith("ffn.3"):
                target_modules.append(name)

        lora_config = LoraConfig(
            r=lora_config_dict.get("rank", 2),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get("dropout", 0.0),
            target_modules=target_modules,
            bias=lora_config_dict.get("bias", "none"),
        )
        self.model = get_peft_model(self.model, lora_config)

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        self.data = data
        self.wt_score = data.get_labels()[0]
        self.pdb_path = data.pdb
        self.cfg = cfg

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices,
        val_indices,
        trainer_instance: Optional[
            pl.Trainer
        ] = None,  # Made optional for API consistency
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """Fit the model using PyTorch Lightning trainer_instance."""
        if trainer_instance is None:
            # Default trainer if none provided
            trainer_instance = pl.Trainer(max_epochs=2)  # default

        self.data = dataset  # Store dataset reference

        # Create data loaders
        train_loader = DataLoader(
            Subset(dataset, train_indices),
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.prepare_features,
        )
        val_loader = DataLoader(
            Subset(dataset, val_indices),
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.prepare_features,
        )
        trainer_instance.fit(self, train_loader, val_loader)

        predictions = self.best_val_predictions
        metrics = self.best_val_metrics
        if hasattr(self, "trainer") and getattr(self.trainer, "fast_dev_run", False):
            # Skip strict checks in fast_dev_run and set val_indices to the number of predictions acutally made
            val_indices = val_indices[: len(predictions["preds"])]
        else:
            if len(predictions["preds"]) != len(val_indices):
                raise ValueError(
                    f"Number of predictions ({len(predictions['preds'])}) does not match "
                    f"number of samples in validation set ({len(val_indices)}). "
                    "This can happen when sanity_check sets best_val_predictions "
                    "and actual training is not making improvements."
                )

        # Prepare prediction dictionary
        pred_dict = {
            "indices": val_indices,
            "predictions": predictions["preds"].tolist(),
            "true_values": predictions["labels"].tolist(),
        }

        if "probs" in predictions:
            pred_dict["probabilities"] = predictions["probs"].tolist()

        logger.debug(f"Metrics: {metrics}")
        return {"metrics": metrics, "predictions": pred_dict}

    def forward(self, batch, **kwargs):
        """Forward pass of the model."""
        # logger.debug(f"Forward pass with kwargs: {kwargs}")
        # ESMC output is ESMCOutput(sequence_logits: torch.Tensor, embeddings: torch.Tensor | None, hidden_states: torch.Tensor | None)
        # embeddings are [batch_size, seq_len, d_model]
        out: ESMCOutput = self.model(
            sequence_tokens=batch["inputs"]["sequence_tokens"])

        # Use the prediction head if it was set
        if hasattr(self, "prediction_head") and self.prediction_head is not None:
            # Assuming the head takes the mean of embeddings
            embeddings = out.embeddings
            if embeddings is None:
                raise ValueError(
                    "ESMC embeddings are None, cannot use prediction_head that relies on embeddings."
                )

            # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
            if embeddings.dtype == torch.bfloat16:
                embeddings = embeddings.float()

            pred_out = self.prediction_head(embeddings.mean(dim=1))
        elif hasattr(self.model, "sequence_head"):
            embeddings = out.embeddings
            if embeddings is None:
                raise ValueError(
                    "ESMC embeddings are None, cannot use sequence_head that relies on embeddings for regression/classification."
                )

            # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
            if embeddings.dtype == torch.bfloat16:
                embeddings = embeddings.float()

            pred_out = self.prediction_head(embeddings.mean(dim=1))

        else:
            logger.warning("No Head Defined ")
            return out.sequence_logits

        # Add dimension if single output target
        if len(pred_out.shape) == 1:
            pred_out = pred_out.unsqueeze(1)
        return pred_out

    def prepare_features(self, batch_items: list[dict]):
        """
        Prepare sequences for ESMC model.
        ESMC.forward expects `sequence_tokens`.
        """
        #  uses __getitem__ of HAIPRData
        sequences = [item["sequence"] for item in batch_items]
        if "labels" in batch_items[0]:
            raw_labels = [item["labels"] for item in batch_items]
        else:
            raw_labels = None

        # Extract sample IDs if present for DDP compatibility
        sample_ids = None
        if "sample_id" in batch_items[0]:
            sample_ids = torch.tensor(
                [item["sample_id"] for item in batch_items], dtype=torch.long
            )

        # ESMC uses its own tokenizer internally via model.encode or _tokenize
        # We need to get tokenized sequences.
        # model.encode takes a single ESMProtein. We need to batch.
        # model._tokenize takes a list of sequences.

        # Ensure the model has its tokenizer attached, which from_pretrained should handle.
        if not hasattr(self.model, "tokenizer") or self.model.tokenizer is None:
            raise ValueError(
                "ESMC model does not have a tokenizer. Cannot prepare features."
            )

        # Tokenize sequences
        # The `_tokenize` method in ESMC handles padding and special tokens.
        # check if model is wrapped in DataParallel
        if isinstance(self.model, torch.nn.DataParallel):
            tokenized_sequences = self.model.module._tokenize(sequences)
        else:
            tokenized_sequences = self.model._tokenize(sequences)

        # Collate labels
        if raw_labels is None:
            collated_labels = None
        elif self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(raw_labels, dtype=torch.float32).to(
                self.device  # Ensure labels are on the same device as model
            )
        else:  # Classification
            collated_labels = torch.tensor(
                raw_labels, dtype=torch.long).to(self.device)

        # Ensure labels for regression are 2D [batch_size, num_targets]
        if (
            raw_labels is not None
            and self.num_classes == 0
            and len(collated_labels.shape) == 1
        ):
            collated_labels = collated_labels.unsqueeze(1)

        inputs_for_model = {
            "sequence_tokens": tokenized_sequences.to(
                self.device
            ),  # Ensure inputs are on model device
        }

        result = {
            "inputs": inputs_for_model,
            "labels": collated_labels,
        }

        # Add sample IDs if present for DDP compatibility
        if sample_ids is not None:
            result["sample_id"] = sample_ids.to(self.device)

        return result

    def load_model(self, model_path: str) -> None:
        """Load a trained model/adapter from checkpoint."""
        try:
            loaded_model = ESMCPredictor.load_from_checkpoint(
                model_path, model=self.model
            )
            self.model = loaded_model.model
            self.learning_rate = loaded_model.learning_rate
            # Potentially copy other relevant attributes
            logger.info(f"ESMC model loaded from checkpoint: {model_path}")
        except Exception as e:
            logger.error(f"Failed to load model from {model_path}: {e}")
            raise

    def save_model(self, model_path: str, adapter_only: bool = True):
        if adapter_only:
            self.model.save_adapter(model_path)
        else:
            self.model.save_pretrained(model_path)

    def _predict_perplexities(
        self,
        sequences: List[str],
        batch_size: int = 1,
    ) -> Dict[str, Any]:
        """
        Compute the overall sequence probability of given sequences.

        This method uses the ESMC model's language modeling capabilities to compute
        the likelihood of protein sequences.

        Args:
            sequences: List of protein sequence strings to evaluate
            batch_size: Batch size for processing sequences

        Returns:
            Dictionary containing:
                - sequences: Input sequences
                - probabilities: Sequence-level probabilities
                - log_probabilities: Sequence-level log probabilities
                - perplexities: Perplexity scores (lower is better)

        Example:
            ```python
            # Single sequence evaluation
            predictor = ESMCPredictor(...)
            result = predictor.sequence_probability(["MALWMRLLPLLALLALWGPDPAAA"])

            # Multiple sequences
            sequences = [
                "MALWMRLLPLLALLALWGPDPAAA",
                "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
            ]
            result = predictor.sequence_probability(sequences, batch_size=2)

            # Access results
            probs = result["probabilities"]        # Regular probabilities
            log_probs = result["log_probabilities"]  # Log probabilities
            perplexities = result["perplexities"]    # Perplexity scores

            # Use for filtering (typically done via ESMCProbabilityEvaluator)
            threshold = 0.001
            valid_sequences = [seq for i, seq in enumerate(sequences)
                             if probs[i] >= threshold]
            ```
        """
        # Convert single sequence to list for consistent processing
        if isinstance(sequences, str):
            sequences = [sequences]

        self.model.eval()
        probabilities = []
        log_probabilities = []
        perplexities = []

        # Create dataset and dataloader
        dataset = ESMCDataset([{"sequence": seq} for seq in sequences])

        with torch.no_grad():
            for item_batch in tqdm(
                DataLoader(
                    dataset,
                    batch_size=batch_size,
                    collate_fn=self.prepare_features,
                    shuffle=False,
                ),
                desc="Computing sequence perplexities",
            ):
                # Move to device
                inputs = {
                    k: v.to(self.model.device) for k, v in item_batch["inputs"].items()
                }
                tokens = inputs["sequence_tokens"]

                # Get model outputs
                outputs = self.model(sequence_tokens=tokens)
                sequence_logits = (
                    outputs.sequence_logits
                )  # [batch_size, seq_len, vocab_size]
                log_probs = F.log_softmax(sequence_logits, dim=-1)
                # pull out the log probs for the actual tokens
                token_log_probs = log_probs.gather(
                    dim=-1, index=tokens.unsqueeze(-1)
                ).squeeze(-1)

                mask = tokens.ne(self.model.tokenizer.pad_token_id).float()
                masked_token_log_probs = token_log_probs * mask
                seq_log_probs = masked_token_log_probs.sum(dim=-1)
                seq_probs = torch.exp(seq_log_probs)
                seq_perplexities = torch.exp(-seq_log_probs / tokens.size(1))
                log_probabilities.extend(seq_log_probs.cpu().tolist())
                probabilities.extend(seq_probs.cpu().tolist())
                perplexities.extend(seq_perplexities.cpu().tolist())

        # Prepare return dictionary
        result = {
            "sequences": sequences,
            "perplexities": perplexities,
            "log_probabilities": log_probabilities,
            "probabilities": probabilities,
        }

        return result

    def predict(
        self, sequences: List[str], batch_size: int = 1, perplexities=False
    ) -> Dict[str, Any]:
        """Make predictions on given data."""
        if perplexities:
            return self._predict_perplexities(sequences, batch_size)

        # Use the generic predict method from HAIPRModule which uses forward()
        return super().predict(sequences, batch_size)
