import logging
from omegaconf import DictConfig
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    EsmModel,
)
from peft import LoraConfig, TaskType, get_peft_model
from typing import Any, Dict, Optional

from haipr.models.module import HAIPRModule
from haipr.predictor import BasePredictor
from haipr.utils import loss_funcs
import lightning.pytorch as pl
from torch.utils.data import Subset, DataLoader
from haipr.data import HAIPRData
from tqdm import tqdm

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class ESM2Dataset(torch.utils.data.Dataset):
    def __init__(self, features):
        self.input_ids = features["input_ids"]
        self.attention_mask = features["attention_mask"]
        self.labels = features["labels"]

    def __len__(self):
        return self.input_ids.shape[0]

    def __getitem__(self, idx):
        return (
            self.input_ids[idx],
            self.attention_mask[idx],
            self.labels[idx],
        )


class ESM2Predictor(HAIPRModule, BasePredictor):
    def __init__(
        self,
        model_name_or_path: str = "esm2_t6_8M_UR50D",
        num_classes: int = 0,
        prediction_head: Optional[nn.Module] = None,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        batch_size: int = 1,
        loss: str = "mse",
        **kwargs,
    ):
        logger.info(
            f"Initializing ESM2Predictor with model_name: {model_name_or_path}")
        criterion = loss_funcs.get(loss)
        if criterion is None:
            logger.warning(
                f"Loss function '{loss}' not found. Defaulting to MSELoss."
            )
            criterion = nn.MSELoss()

        model = EsmModel.from_pretrained(f"facebook/{model_name_or_path}")
        self.tokenizer = AutoTokenizer.from_pretrained(
            f"facebook/{model_name_or_path}")

        # Initialize HAIPRModule
        super().__init__(
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            **kwargs,
        )

        if prediction_head is None:
            raise ValueError("No prediction head provided")
        else:
            self.prediction_head = prediction_head

        self.num_classes = num_classes
        self.model_name = model_name_or_path
        self.batch_size = batch_size
        self.data = None  # Initialize data attribute
        self.pdb_path = None  # ESM2 typically doesn't use PDB directly like ESM3
        self._dataparallel_initialized = False  # Track if DataParallel has been set up

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        """Initialize PEFT adapters for the model."""
        target_modules = [
            "query",
            "key",
            "value",
            "EsmSelfOutput.dense",
            "EsmIntermediate.dense",
            "EsmOutput.dense",
        ]
        # task_type helps LoraConfig determine default target_modules if not provided,
        # and can influence other behaviors. FEATURE_EXTRACTION is a safe default.
        peft_config_instance = LoraConfig(
            use_dora=(lora_config_dict.get(
                "peft_type", "lora").lower() == "dora"),
            task_type=TaskType.FEATURE_EXTRACTION,
            r=lora_config_dict.get("rank", 8),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get(
                "dropout", 0.0
            ),  # Default to 0.0 if not specified
            target_modules=target_modules,
            bias=lora_config_dict.get("bias", "none"),
            # fan_in_fan_out is relevant for LoRA layer init, True for classification-like tasks
            fan_in_fan_out=lora_config_dict.get(
                "fan_in_fan_out", (self.num_classes >= 2)
            ),
            # modules_to_save = lora_config_dict.get("modules_to_save", None) # if you need to train other modules
        )
        self.model = get_peft_model(self.model, peft_config_instance)

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        self.data = data
        self.cfg = cfg

    def forward(self, batch):
        """Forward pass of the model."""
        # Inputs are expected to be tokenized sequences (input_ids, attention_mask)
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
        )

        # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
        embeddings = outputs.last_hidden_state
        if embeddings.dtype == torch.bfloat16:
            embeddings = embeddings.float()

        pred_out = self.prediction_head(embeddings.mean(dim=1))

        # Ensure output shape is consistent (e.g., for regression [batch_size, 1])
        if (
            self.num_classes == 0 and len(pred_out.shape) == 1
        ):  # Regression, single target
            pred_out = pred_out.unsqueeze(1)
        elif (
            self.num_classes > 0 and len(pred_out.shape) == 1
        ):  # Should not happen with CrossEntropyLoss expecting [B, C]
            logger.warning(
                f"Unexpected pred_out shape for classification: {pred_out.shape}"
            )

        return pred_out

    def prepare_features(self, batch_items: list[dict]):
        """Prepare sequences for ESM2 model."""
        sequences = [item["sequence"] for item in batch_items]
        raw_labels = [item["labels"] for item in batch_items]

        # Extract sample IDs if present for DDP compatibility
        sample_ids = None
        if "sample_id" in batch_items[0]:
            sample_ids = torch.tensor(
                [item["sample_id"] for item in batch_items], dtype=torch.long
            )

        # Tokenize sequences
        # `padding=True` and `truncation=True` are common defaults.
        # `return_tensors="pt"` returns PyTorch tensors.
        # Max length can be crucial; check model's max length or set a reasonable one.
        # self.tokenizer.model_max_length could be a good default.
        tokenized_output = self.tokenizer(
            sequences,
            padding="longest",  # Pad to the longest sequence in the batch
            truncation=True,  # Truncate sequences longer than model max length
            return_tensors="pt",
            max_length=(
                self.tokenizer.model_max_length
                if hasattr(self.tokenizer, "model_max_length")
                and self.tokenizer.model_max_length
                else 512
            ),
        )

        input_ids = tokenized_output["input_ids"].to(self.device)
        attention_mask = tokenized_output["attention_mask"].to(self.device)

        # Collate labels
        if self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(raw_labels, dtype=torch.float32).to(
                self.device
            )
            if (
                len(collated_labels.shape) == 1
            ):  # Ensure 2D for regression [batch_size, num_targets]
                collated_labels = collated_labels.unsqueeze(1)
        else:  # Classification
            collated_labels = torch.tensor(
                raw_labels, dtype=torch.long).to(self.device)
            # For CrossEntropyLoss, labels should be [batch_size]

        logger.debug(
            f"Input IDs shape: {input_ids.shape}, Device: {input_ids.device}")
        logger.debug(
            f"Attention Mask shape: {attention_mask.shape}, Device: {attention_mask.device}"
        )
        logger.debug(
            f"Collated labels shape: {collated_labels.shape}, Device: {collated_labels.device}"
        )

        result = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": collated_labels,
        }

        # Add sample IDs if present for DDP compatibility
        if sample_ids is not None:
            result["sample_id"] = sample_ids.to(self.device)

        return result

    # method to prepare all features at once and cache them
    def prepare_features_all(self):
        """Prepare and cache all features for ESM2 model, similar to ESM3."""
        if self.data is None:
            raise ValueError(
                "self.data must be set before calling prepare_features_all. Use setup_data()."
            )
        if hasattr(self, "_cached_features") and self._cached_features is not None:
            return self._cached_features
        cache_key = self.data._get_cache_key()
        cached = self.data._load_from_cache(cache_key)
        if cached is not None:
            logger.info(f"Loaded features from cache: {cache_key}")
            self._cached_features = cached
            return cached
        logger.info(
            "No cached features found. Preparing features for all data.")
        all_sequences = [item["sequence"] for item in self.data]
        all_labels = [item["labels"] for item in self.data]
        tokenized_output = self.tokenizer(
            all_sequences,
            padding="longest",
            truncation=True,
            return_tensors="pt",
            max_length=(
                self.tokenizer.model_max_length
                if hasattr(self.tokenizer, "model_max_length")
                and self.tokenizer.model_max_length
                else 512
            ),
        )
        input_ids = tokenized_output["input_ids"]
        attention_mask = tokenized_output["attention_mask"]
        if self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(all_labels, dtype=torch.float32)
            if len(collated_labels.shape) == 1:
                collated_labels = collated_labels.unsqueeze(1)
        else:  # Classification
            collated_labels = torch.tensor(all_labels, dtype=torch.long)
        features = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": collated_labels,
        }
        self.data.cache_features(features)
        self._cached_features = features
        return features

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices,
        val_indices,
        trainer_instance: Optional[pl.Trainer] = None,
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """Fit the model using PyTorch Lightning trainer_instance."""
        if trainer_instance is None:
            trainer_instance = pl.Trainer(
                max_epochs=self.hparams.get("num_epochs", 3))
        self.data = dataset
        # Prepare and cache all features once
        # features = self.prepare_features_all()
        # torch_data = self.ESM2Dataset(features)
        train_loader = DataLoader(
            Subset(dataset, train_indices),
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.prepare_features,
        )
        val_loader = DataLoader(
            Subset(dataset, val_indices),
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.prepare_features,
        )
        trainer_instance.fit(self, train_loader, val_loader)
        predictions = self.best_val_predictions
        metrics = self.best_val_metrics
        if len(predictions["preds"]) != len(val_indices):
            raise ValueError(
                f"Number of predictions ({len(predictions['preds'])}) does not\
                    match number of samples in validation set ({len(val_indices)})\
                        This can happen with sanity_check setting best_val_predictions\
                            and actual training is not making improvements"
            )
        else:
            pred_dict = {
                "indices": (
                    val_indices.tolist()
                    if hasattr(val_indices, "tolist")
                    else list(val_indices)
                ),
                "predictions": predictions["preds"].tolist(),
                "true_values": predictions["labels"].tolist(),
            }
            if "probs" in predictions:
                pred_dict["probabilities"] = predictions["probs"].tolist()
        logger.debug(f"Metrics: {metrics}")
        return {"metrics": metrics, "predictions": pred_dict}

    def load_model(self, model_path: str) -> None:
        """Load a trained model checkpoint."""
        try:
            # HAIPRModule's load_from_checkpoint is a class method
            # It returns a new instance of the module. We need to update self.
            loaded_module = type(self).load_from_checkpoint(
                model_path,
                model=self.model,  # Pass current model structure
                criterion=self.criterion,  # Pass current criterion
                # Ensure all necessary hparams for __init__ are available or passed
                model_name=self.model_name,
                num_classes=self.num_classes,
                # Potentially other __init__ args if they are stored in hparams
                # and needed by load_from_checkpoint's __init__ call.
            )
            self.model = loaded_module.model
            self.learning_rate = loaded_module.learning_rate  # Update LR if stored
            # Update other relevant attributes from the loaded_module if needed
            # self.hparams.update(loaded_module.hparams) # Update all hparams
            logger.info(f"ESM2 model loaded from checkpoint: {model_path}")
        except Exception as e:
            logger.error(
                f"Failed to load ESM2 model from {model_path}: {e}", exc_info=True
            )
            raise

    def predict_collate_fn(self, batch_items):
        """Collate function for prediction DataLoader that can be pickled."""
        # Simplified version of prepare_features for prediction (labels might be missing)
        tokenized_output = self.tokenizer(
            batch_items,
            padding="longest",
            truncation=True,
            return_tensors="pt",
            max_length=(
                self.tokenizer.model_max_length
                if hasattr(self.tokenizer, "model_max_length")
                and self.tokenizer.model_max_length
                else 512
            ),
        )
        return {  # Only return model inputs
            "input_ids": tokenized_output["input_ids"].to(self.device),
            "attention_mask": tokenized_output["attention_mask"].to(self.device),
        }

    def predict(self, sequences: list[str], batch_size: int = 1) -> Dict[str, Any]:
        """Make predictions on given data."""
        # NOTE: if sequences have separator token, remove it
        # TODO: make this more robust (predictor should have smth. like alphabet)
        sequences = [s.replace("|", "") for s in sequences]

        # Use the generic predict method from HAIPRModule which uses forward()
        return super().predict(sequences, batch_size)
