import logging
from mlflow.tracking.fluent import log_params
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from esm.models.esmc import ESMC, ESMCOutput
from typing import Any, Dict, Optional, List
from haipr.models.module import HAIPRModule
import lightning.pytorch as pl
from torch.utils.data import Subset, DataLoader, Dataset
from haipr.data import HAIPRData
from haipr.predictor import BasePredictor
from haipr.utils import loss_funcs
from peft import LoraConfig, get_peft_model
from tqdm import tqdm

logging.basicConfig()
logger = logging.getLogger(__name__)
# Changed to INFO for more visibility during development
logger.setLevel(logging.INFO)


class ESMCDataset(Dataset):
    def __init__(self, sequences: List[str]):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]


class ESMCPredictor(HAIPRModule):
    """ESMC-specific predictor implementation using PyTorch Lightning."""

    def __init__(
        self,
        model_name: str = "esmc_300m",  # Default open license ESMC model
        num_classes: int = 0,
        prediction_head: Optional[nn.Module] = None,
        learning_rate: float = 0.001,
        weight_decay: float = 0.05,
        batch_size: int = 1,
        loss_fn: str = "huber", 
        **kwargs,
    ):
        self.num_workers = kwargs.get("num_workers", 0)
        model = ESMC.from_pretrained(model_name)
        criterion = loss_funcs.get(loss_fn, nn.MSELoss())

        logger.debug(f"Initializing HAIPRModule with {model}")
        super().__init__(  # HAIPRModule init
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
        )

        if prediction_head is not None:
            self.prediction_head = prediction_head
        else:
            raise ValueError("No prediction head provided")

        self.model_name = model_name
        self.batch_size = batch_size
        self.data = None  

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        # Custom implementation since esm3 like models are not supported by default
        target_modules = []
        for name, module in self.model.named_modules():
            if name.endswith("attn.layernorm_qkv.1"):
                target_modules.append(name)
            elif name.endswith("attn.out_proj"):
                target_modules.append(name)
            elif name.endswith("ffn.1"):
                target_modules.append(name)
            elif name.endswith("ffn.3"):
                target_modules.append(name)

        lora_config = LoraConfig(
            r=lora_config_dict.get("rank", 2),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get("dropout", 0.0),
            target_modules=target_modules,
            bias=lora_config_dict.get("bias", "none"),
        )

        # module has model set
        self.model = get_peft_model(self.model, lora_config)

    def setup_model(self, data: HAIPRData, cfg: DictConfig):
        self.data = data
        self.wt_score = data.get_labels()[0]
        self.pdb_path = data.pdb
        self.cfg = cfg

    def fit_model(
        self,
        dataset: HAIPRData,
        train_indices,
        val_indices,
        trainer_instance: Optional[
            pl.Trainer
        ] = None,  # Made optional for API consistency
        cfg: DictConfig = None,
    ) -> Dict[str, Any]:
        """Fit the model using PyTorch Lightning trainer_instance."""
        if trainer_instance is None:
            # Default trainer if none provided
            trainer_instance = pl.Trainer(max_epochs=2)  # default

        self.data = dataset  # Store dataset reference

        # Prepare features once for all data (train + val)
        all_indices = np.concatenate([train_indices, val_indices])
        features_dict = self.prepare_training_features(dataset, all_indices)

        # Create DataLoaders using base class method
        train_loader, val_loader = self._create_dataloaders(
            features_dict=features_dict,
            labels=features_dict["labels"],
            train_indices=np.arange(
                len(train_indices)
            ),  # Relative indices within features_dict
            val_indices=np.arange(
                len(train_indices), len(all_indices)
            ),  # Relative indices
            batch_size=self.batch_size,
            shuffle_train=True,
        )
        trainer_instance.fit(self, train_loader, val_loader)

        predictions = self.best_val_predictions
        metrics = self.best_val_metrics
        if hasattr(self, "trainer") and getattr(self.trainer, "fast_dev_run", False):
            # Skip strict checks in fast_dev_run and set val_indices to the number of predictions acutally made
            val_indices = val_indices[: len(predictions["preds"])]
        else:
            if len(predictions["preds"]) != len(val_indices):
                raise ValueError(
                    f"Number of predictions ({len(predictions['preds'])}) does not match "
                    f"number of samples in validation set ({len(val_indices)}). "
                )

        # Prepare prediction dictionary
        pred_dict = {
            "indices": val_indices,
            "predictions": predictions["preds"].tolist(),
            "true_values": predictions["labels"].tolist(),
        }

        if "probs" in predictions:
            pred_dict["probabilities"] = predictions["probs"].tolist()

        logger.debug(f"Metrics: {metrics}")
        return {"metrics": metrics, "predictions": pred_dict}

    def forward(self, batch, **kwargs):
        """Forward pass of the model."""
        # logger.debug(f"Forward pass with kwargs: {kwargs}")
        # ESMC output is ESMCOutput(sequence_logits: torch.Tensor, embeddings: torch.Tensor | None, hidden_states: torch.Tensor | None)
        # embeddings are [batch_size, seq_len, d_model]
        out: ESMCOutput = self.model(sequence_tokens=batch["inputs"]["sequence_tokens"])

        # Use the prediction head if it was set
        if hasattr(self, "prediction_head") and self.prediction_head is not None:
            # Assuming the head takes the mean of embeddings
            embeddings = out.embeddings
            if embeddings is None:
                raise ValueError(
                    "ESMC embeddings are None, cannot use prediction_head that relies on embeddings."
                )

            # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
            if embeddings.dtype == torch.bfloat16:
                embeddings = embeddings.float()

            pred_out = self.prediction_head(embeddings.mean(dim=1))
        elif hasattr(self.model, "sequence_head"):
            embeddings = out.embeddings
            if embeddings is None:
                raise ValueError(
                    "ESMC embeddings are None, cannot use sequence_head that relies on embeddings for regression/classification."
                )

            # Convert embeddings to Float32 to avoid dtype mismatch with prediction head
            if embeddings.dtype == torch.bfloat16:
                embeddings = embeddings.float()

            pred_out = self.prediction_head(embeddings.mean(dim=1))

        else:
            logger.warning("No Head Defined ")
            return out.sequence_logits

        # Add dimension if single output target
        if len(pred_out.shape) == 1:
            pred_out = pred_out.unsqueeze(1)
        return pred_out

    def prepare_batch_features(self, batch_items: list[dict]):
        """
        Prepare sequences for ESMC model.
        ESMC.forward expects `sequence_tokens`.
        """
        sequences = [item["sequence"] for item in batch_items]
        if "labels" in batch_items[0]:
            raw_labels = [item["labels"] for item in batch_items]
        else:
            raw_labels = None

        # Extract sample IDs if present for DDP compatibility
        sample_ids = None
        if "sample_id" in batch_items[0]:
            sample_ids = torch.tensor(
                [item["sample_id"] for item in batch_items], dtype=torch.long
            )

        if not hasattr(self.model, "tokenizer") or self.model.tokenizer is None:
            raise ValueError(
                "ESMC model does not have a tokenizer. Cannot prepare features."
            )

        if isinstance(self.model, torch.nn.DataParallel):
            tokenized_sequences = self.model.module._tokenize(sequences)
        else:
            tokenized_sequences = self.model._tokenize(sequences)

        # Collate labels
        if raw_labels is None:
            collated_labels = None
        elif self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(raw_labels, dtype=torch.float32)
        else:  # Classification
            collated_labels = torch.tensor(raw_labels, dtype=torch.long)

        # Ensure labels for regression are 2D [batch_size, num_targets]
        if (
            raw_labels is not None
            and self.num_classes == 0
            and len(collated_labels.shape) == 1
        ):
            collated_labels = collated_labels.unsqueeze(1)

        inputs_for_model = {
            "sequence_tokens": tokenized_sequences,
        }

        result = {
            "inputs": inputs_for_model,
            "labels": collated_labels,
        }

        # Add sample IDs if present for DDP compatibility
        if sample_ids is not None:
            result["sample_id"] = sample_ids

        return result

    def prepare_training_features(
        self, dataset: HAIPRData, indices: np.ndarray
    ) -> Dict[str, torch.Tensor]:
        """Pre-compute all features for given indices. Returns dict of tensors ready for TensorDataset."""
        logger.info(f"Preparing training features for {len(indices)} samples")

        # Get all sequences and labels at once
        sequences = dataset.data["sequence"].iloc[indices].tolist()
        labels = dataset.data[dataset.label_col].iloc[indices].values

        # Extract sample IDs if present for DDP compatibility
        sample_ids = None
        if "sample_id" in dataset.data.columns:
            sample_ids = torch.tensor(
                dataset.data["sample_id"].iloc[indices].values, dtype=torch.long
            )

        # Ensure the model has its tokenizer attached
        if not hasattr(self.model, "tokenizer") or self.model.tokenizer is None:
            raise ValueError(
                "ESMC model does not have a tokenizer. Cannot prepare features."
            )

        # Tokenize all sequences at once
        if isinstance(self.model, torch.nn.DataParallel):
            tokenized_sequences = self.model.module._tokenize(sequences)
        else:
            tokenized_sequences = self.model._tokenize(sequences)

        # Collate labels
        if self.num_classes == 0:  # Regression
            collated_labels = torch.tensor(labels, dtype=torch.float32).to(self.device)
            if (
                len(collated_labels.shape) == 1
            ):  # Ensure 2D for regression [batch_size, num_targets]
                collated_labels = collated_labels.unsqueeze(1)
        else:  # Classification
            collated_labels = torch.tensor(labels, dtype=torch.long).to(self.device)

        # Prepare inputs for model
        inputs_for_model = {
            "sequence_tokens": tokenized_sequences.to(
                self.device
            ),  # Ensure inputs are on model device
        }

        result = {
            "sequence_tokens": inputs_for_model["sequence_tokens"],
            "labels": collated_labels,
        }

        # Add sample IDs if present for DDP compatibility
        if sample_ids is not None:
            result["sample_id"] = sample_ids.to(self.device)

        logger.info(
            f"Prepared features: sequence_tokens {result['sequence_tokens'].shape}, labels {result['labels'].shape}"
        )
        return result

    def load_model(self, model_path: str) -> None:
        """Load a trained model/adapter from checkpoint."""
        try:
            loaded_model = ESMCPredictor.load_from_checkpoint(
                model_path, model=self.model
            )
            self.model = loaded_model.model
            logger.info(f"ESMC model loaded from checkpoint: {model_path}")
        except Exception as e:
            logger.error(f"Failed to load model from {model_path}: {e}")
            raise

    def save_model(self, model_path: str, adapter_only: bool = True):
        if adapter_only:
            self.model.save_adapter(model_path)
        else:
            self.model.save_pretrained(model_path)

    def _predict_perplexities(
        self,
        sequences: List[str],
        batch_size: int = 1,
    ) -> Dict[str, Any]:
        # Convert single sequence to list for consistent processing
        if isinstance(sequences, str):
            sequences = [sequences]

        self.model.eval()
        probabilities = []
        log_probabilities = []
        perplexities = []

        # Create dataset and dataloader
        dataset = ESMCDataset([{"sequence": seq} for seq in sequences])

        with torch.no_grad():
            for item_batch in tqdm(
                DataLoader(
                    dataset,
                    batch_size=batch_size,
                    collate_fn=self.prepare_features,
                    shuffle=False,
                ),
                desc="Computing sequence perplexities",
            ):
                # Move to device
                inputs = {
                    k: v.to(self.model.device) for k, v in item_batch["inputs"].items()
                }
                tokens = inputs["sequence_tokens"]

                # Get model outputs
                outputs = self.model(sequence_tokens=tokens)
                sequence_logits = (
                    outputs.sequence_logits
                )  # [batch_size, seq_len, vocab_size]
                log_probs = F.log_softmax(sequence_logits, dim=-1)
                # pull out the log probs for the actual tokens
                token_log_probs = log_probs.gather(
                    dim=-1, index=tokens.unsqueeze(-1)
                ).squeeze(-1)

                mask = tokens.ne(self.model.tokenizer.pad_token_id).float()
                masked_token_log_probs = token_log_probs * mask
                seq_log_probs = masked_token_log_probs.sum(dim=-1)
                seq_probs = torch.exp(seq_log_probs)
                seq_perplexities = torch.exp(-seq_log_probs / tokens.size(1))
                log_probabilities.extend(seq_log_probs.cpu().tolist())
                probabilities.extend(seq_probs.cpu().tolist())
                perplexities.extend(seq_perplexities.cpu().tolist())

        # Prepare return dictionary
        result = {
            "sequences": sequences,
            "perplexities": perplexities,
            "log_probabilities": log_probabilities,
            "probabilities": probabilities,
        }

        return result

    def _load_artifacts(self, context):
        """Load PyTorch model from artifacts."""
        if "model" in context.artifacts:
            checkpoint = torch.load(context.artifacts["model"])
            self.model.load_state_dict(checkpoint)
            self.model.eval()

    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None = None
    ) -> Dict[str, Any]:
        """Make predictions on given sequences."""
        if params is not None and params.get("perplexities", False):
            batch_size = params.get("batch_size", 1)
            return self._predict_perplexities(sequences, batch_size)
        else:
            return super().predict_sequences(sequences, params)

    def save_model(self, save_dir: str) -> str:
        """Save the model (for checkpoint restoration during training).
        Returns the path to the saved model.
        """
        save_path = save_dir + "model.pt"
        torch.save(self.model.state_dict(), save_path)
        return save_path