import logging
from omegaconf import DictConfig
import torch
import torch.nn as nn
import numpy as np
from transformers import (
    AutoTokenizer,
    EsmModel,
)
from peft import LoraConfig, TaskType, get_peft_model
from typing import Any, Dict, Optional, List

from haipr.models.module import HAIPRModule
from haipr.predictor import BasePredictor
from haipr.utils import loss_funcs
import lightning.pytorch as pl
from haipr.data import HAIPRData
from tqdm import tqdm

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class ESM2Predictor(HAIPRModule):
    def __init__(
        self,
        model_name_or_path: str = "esm2_t6_8M_UR50D",
        num_classes: int = 0,
        prediction_head: Optional[nn.Module] = None,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 1,
        loss: str = "mse",
        **kwargs,
    ):
        logger.info(
            f"Initializing ESM2Predictor with model_name: {model_name_or_path}")
        criterion = loss_funcs.get(loss)
        if criterion is None:
            logger.warning(
                f"Loss function '{loss}' not found. Defaulting to MSELoss.")
            criterion = nn.MSELoss()

        model = EsmModel.from_pretrained(f"facebook/{model_name_or_path}")
        self.tokenizer = AutoTokenizer.from_pretrained(
            f"facebook/{model_name_or_path}")

        # Initialize HAIPRModule
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )

        if prediction_head is None:
            raise ValueError("No prediction head provided")
        else:
            self.prediction_head = prediction_head

        self.num_classes = num_classes
        self.model_name = model_name_or_path
        self.batch_size = batch_size
        self.data = None  # Initialize data attribute
        self.pdb_path = None  # ESM2 typically doesn't use PDB directly like ESM3
        self.signature = None  # defines the input / output signature for the model
        self._dataparallel_initialized = False  # Track if DataParallel has been set up

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        """Initialize PEFT adapters for the model."""
        target_modules = [
            "query",
            "key",
            "value",
            "EsmSelfOutput.dense",
            "EsmIntermediate.dense",
            "EsmOutput.dense",
        ]
        peft_config_instance = LoraConfig(
            use_dora=(lora_config_dict.get(
                "peft_type", "lora").lower() == "dora"),
            task_type=TaskType.FEATURE_EXTRACTION,
            r=lora_config_dict.get("rank", 8),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get(
                "dropout", 0.0
            ),  # Default to 0.0 if not specified
            target_modules=target_modules,
            bias=lora_config_dict.get("bias", "none"),
            # fan_in_fan_out is relevant for LoRA layer init, True for classification-like tasks
            fan_in_fan_out=lora_config_dict.get(
                "fan_in_fan_out", (self.num_classes >= 2)
            ),
            # modules_to_save = lora_config_dict.get("modules_to_save", None) # if you need to train other modules
        )
        self.model = get_peft_model(self.model, peft_config_instance)

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        self.data = data
        self.cfg = cfg

    def forward(self, batch):
        """Forward pass of the model."""
        if "inputs" in batch:
            # New format from FeatureTensorDataset
            input_ids = batch["inputs"]["input_ids"]
            attention_mask = batch["inputs"]["attention_mask"]
        else:
            raise ValueError("Invalid batch format, needs 'inputs' key")

        # Inputs are expected to be tokenized sequences (input_ids, attention_mask)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
        embeddings = outputs.last_hidden_state
        if embeddings.dtype == torch.bfloat16:
            embeddings = embeddings.float()

        pred_out = self.prediction_head(embeddings.mean(dim=1))

        if (
            self.num_classes == 0 and len(pred_out.shape) == 1
        ):  # Regression, single target
            pred_out = pred_out.unsqueeze(1)
        elif (
            self.num_classes > 0 and len(pred_out.shape) == 1
        ):  # Should not happen with CrossEntropyLoss expecting [B, C]
            logger.warning(
                f"Unexpected pred_out shape for classification: {pred_out.shape}"
            )

        return pred_out

    def prepare_batch_features(self, batch_items: list[dict]):
        """Prepare sequences for ESM2 model."""
        sequences = [item["sequence"] for item in batch_items]
        # Handle case where labels might not be present (during prediction)
        raw_labels = [item.get("labels", 0.0) for item in batch_items]

        # Extract sample IDs if present for DDP compatibility
        sample_ids = None
        if batch_items and "sample_id" in batch_items[0]:
            sample_ids = torch.tensor(
                [item["sample_id"] for item in batch_items], dtype=torch.long
            )

        tokenized_output = self.tokenizer(
            sequences,
            padding="longest",  # Pad to the longest sequence in the batch
            return_tensors="pt",
        )

        input_ids = tokenized_output["input_ids"]
        attention_mask = tokenized_output["attention_mask"]

        # Collate labels
        if self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(raw_labels, dtype=torch.float32)
            if (
                len(collated_labels.shape) == 1
            ):  # Ensure 2D for regression [batch_size, num_targets]
                collated_labels = collated_labels.unsqueeze(1)
        else:  # Classification
            collated_labels = torch.tensor(raw_labels, dtype=torch.long)
            # For CrossEntropyLoss, labels should be [batch_size]

        logger.debug(
            f"Input IDs shape: {input_ids.shape}, Device: {input_ids.device}")
        logger.debug(
            f"Attention Mask shape: {attention_mask.shape}, Device: {attention_mask.device}"
        )
        logger.debug(
            f"Collated labels shape: {collated_labels.shape}, Device: {collated_labels.device}"
        )

        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

        # Add sample IDs if present for DDP compatibility
        if sample_ids is not None:
            inputs["sample_id"] = sample_ids

        return {"inputs": inputs, "labels": collated_labels}

    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, torch.Tensor]:
        """Pre-compute all features for given indices. Returns dict of tensors ready for TensorDataset."""
        logger.info(f"Preparing training features for {len(indices)} samples")

        # Get all sequences and labels at once
        sequences = dataset.data["sequence"].iloc[indices].tolist()
        labels = dataset.data[dataset.label_col].iloc[indices].values

        # Tokenize all sequences in one call
        tokenized_output = self.tokenizer(
            sequences,
            padding="longest",  # Pad to the longest sequence in the batch
            return_tensors="pt",
        )

        # Convert to tensors and move to device
        input_ids = tokenized_output["input_ids"].to(self.device)
        attention_mask = tokenized_output["attention_mask"].to(self.device)

        # Prepare labels
        if self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(
                labels, dtype=torch.float32).to(self.device)
            # Ensure 2D for regression [batch_size, num_targets]
            if len(collated_labels.shape) == 1:
                collated_labels = collated_labels.unsqueeze(1)
        else:  # Classification
            collated_labels = torch.tensor(
                labels, dtype=torch.long).to(self.device)

        logger.info(
            f"Prepared features: input_ids {input_ids.shape}, attention_mask {attention_mask.shape}, labels {collated_labels.shape}"
        )

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": collated_labels,
        }

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices,
        val_indices,
        trainer_instance: Optional[pl.Trainer] = None,
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """Fit the model using PyTorch Lightning trainer_instance."""
        if trainer_instance is None:
            trainer_instance = pl.Trainer(
                max_epochs=self.hparams.get("num_epochs", 3))
        self.data = dataset

        # Prepare features once for all data (train + val)
        all_indices = np.concatenate([train_indices, val_indices])
        features_dict = self.prepare_training_features(dataset, all_indices)

        # Create DataLoaders using base class method
        logger.info(
            f"Creating DataLoaders with train_indices: {len(train_indices)} | val_indices: {len(val_indices)}")
        train_loader, val_loader = self._create_dataloaders(
            features_dict=features_dict,
            labels=features_dict["labels"],
            train_indices=np.arange(len(train_indices)),
            val_indices=np.arange(
                len(train_indices), len(all_indices)
            ),  # Relative indices
            batch_size=self.batch_size,
            shuffle_train=True,
        )

        trainer_instance.fit(self, train_loader, val_loader)
        metrics = self.best_val_metrics
        predictions = self.best_val_predictions

        if len(predictions["preds"]) != len(val_indices):
            raise ValueError(
                f"Number of predictions ({len(predictions['preds'])}) does not\
                    match number of samples in validation set ({len(val_indices)})"
            )
        else:
            pred_dict = {
                "indices": (
                    val_indices.tolist()
                    if hasattr(val_indices, "tolist")
                    else list(val_indices)
                ),
                "predictions": predictions["preds"].tolist(),
                "true_values": predictions["labels"].tolist(),
            }
            if "probs" in predictions:
                pred_dict["probabilities"] = predictions["probs"].tolist()
        logger.debug(f"Metrics: {metrics}")
        return {"metrics": metrics, "predictions": pred_dict}

    def save_model(self, save_dir: str) -> str:
        """Save the model (for checkpoint restoration during training).
        Returns the path to the saved model.
        """
        save_path = save_dir + "model.pt"
        torch.save(self.model.state_dict(), save_path)
        return save_path
