import hydra
import lightning.pytorch as pl
from torch.nn import (
    MSELoss,
    L1Loss,
    SmoothL1Loss,
    HuberLoss,
    CrossEntropyLoss,
    BCEWithLogitsLoss,
    NLLLoss,
)
from peft import LoraConfig, TaskType, get_peft_model
import torch
import torch.distributed as dist
import logging
from haipr.utils import compute_classification_metrics, compute_regression_metrics
from typing import Dict, Any, List
from omegaconf import DictConfig
from tqdm import tqdm
import numpy as np
logger = logging.getLogger(__name__)


class HAIPRModule(pl.LightningModule):
    """
    Lightning Module for HAIPR models that are based on torch.
    Handles training, validation, and prediction steps.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        criterion,
        num_classes=0,
        learning_rate=1e-4,
        weight_decay=0.01,
        val_chunk_size=100,
        val_chunk_log_interval=200,
        **kwargs,
    ):
        """
        Initialize the Lightning Module.

        Args:
            model: The PyTorch model to train (e.g., ESM3 instance with prediction head)
            criterion: The loss function to use
            num_classes: Number of classes (0 for regression)
            learning_rate: Learning rate for optimization
            weight_decay: Weight decay for optimization
            val_chunk_size: Size of the validation chunk for frequent evaluation
            val_chunk_log_interval: Number of training steps between validation chunk evaluations
            lora_config: Optional PEFT LoRA configuration (as DictConfig from Hydra)
        """
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.val_chunk_size = val_chunk_size
        self.val_chunk_log_interval = val_chunk_log_interval
        self.peft_initialized = False
        # Validation chunk tracking
        self.val_chunk_data = None
        self.training_step_count = 0

        if kwargs.get("peft", None) is not None:
            peft_config = hydra.compose(config_name="train").peft
            peft_config["target_modules"] = kwargs.get("peft")[
                "target_modules"]
            self._initialize_peft_adapters(peft_config)
            self.peft_initialized = True

        self.save_hyperparameters(
            ignore=["model", "criterion", "prediction_head"])
        # Initialize lists to store validation outputs
        self.validation_step_outputs = []

        # Validate criterion is appropriate for task
        if self.num_classes == 0 and not isinstance(
            criterion, (MSELoss, L1Loss, SmoothL1Loss, HuberLoss)
        ):
            raise ValueError(
                "Regression task requires a regression loss function")
        elif self.num_classes > 0 and not isinstance(
            criterion, (CrossEntropyLoss, BCEWithLogitsLoss, NLLLoss)
        ):
            raise ValueError(
                "Classification task requires a classification loss function"
            )

        self.last_val_predictions = None
        self.last_val_metrics = None
        # Initialize empty lists to store predictions and labels
        self.val_preds = []
        self.val_labels = []
        self.val_probs = []
        self.best_val_loss = float("inf")
        self.best_val_metrics = {}
        self.best_val_predictions = None

    def _is_ddp_enabled(self):
        """Check if DDP is enabled and available."""
        return (
            self.trainer is not None
            and hasattr(self.trainer, "is_global_zero")
            and dist.is_available()
            and dist.is_initialized()
        )

    def _gather_predictions(self, tensor_list):
        """Gather predictions from all GPUs in DDP mode."""
        if not self._is_ddp_enabled():
            return tensor_list

        # Gather all tensors from all processes
        gathered_tensors = []
        for tensor in tensor_list:
            # Ensure tensor is on the correct device
            if tensor.device != self.device:
                tensor = tensor.to(self.device)

            # Gather tensor from all processes
            gathered = [torch.zeros_like(tensor)
                        for _ in range(dist.get_world_size())]
            dist.all_gather(gathered, tensor)
            gathered_tensors.extend(gathered)

        return gathered_tensors

    def _gather_single_tensor(self, tensor):
        """Gather a single tensor from all GPUs in DDP mode."""
        if not self._is_ddp_enabled():
            return tensor

        # Ensure tensor is on the correct device
        if tensor.device != self.device:
            tensor = tensor.to(self.device)

        # Gather tensor from all processes
        gathered = [torch.zeros_like(tensor)
                    for _ in range(dist.get_world_size())]
        dist.all_gather(gathered, tensor)

        # Concatenate all gathered tensors
        return torch.cat(gathered, dim=0)

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        """Injects PEFT adapters (LoRA/DoRA) into self.model directly."""
        logger.info(
            f"Initializing PEFT adapters with config: {lora_config_dict}")

        target_modules_conf = lora_config_dict.get("target_modules", None)
        if isinstance(target_modules_conf, str):
            target_modules_list = target_modules_conf.split(",")
        elif isinstance(
            target_modules_conf, (list, DictConfig)
        ):  # OmegaConf list is ListConfig
            # Ensure it's a Python list of strings
            target_modules_list = [str(item) for item in target_modules_conf]
        elif target_modules_conf is None:
            target_modules_list = (
                None  # Let LoraConfig handle default based on model type
            )
        else:
            raise ValueError(
                f"Invalid target_modules type: {type(target_modules_conf)}. Expected str or list/ListConfig."
            )
        logger.info(f"PEFT target modules: {target_modules_list}")

        peft_config_instance = LoraConfig(
            use_dora=(lora_config_dict.get(
                "peft_type", "lora").lower() == "dora"),
            task_type=TaskType.FEATURE_EXTRACTION,
            r=lora_config_dict.get("rank", 8),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get("dropout", 0.1),
            target_modules=target_modules_list,
            bias=lora_config_dict.get("bias", "none"),
            # modules_to_save = lora_config_dict.get("modules_to_save", None) # if you need to train other modules
        )

        self.model = get_peft_model(self.model, peft_config_instance)

        logger.info(
            f"PEFT adapters ({lora_config_dict.get('peft_type', 'lora')}) injected and activated."
        )

        trainable_params = 0
        all_param = 0
        for _, param in self.model.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()

        if all_param > 0:
            trainable_percentage = 100 * trainable_params / all_param
            logger.info(
                f"Trainable params: {trainable_params:,} || All params: {all_param:,} || Trainable %: {trainable_percentage:.2f}%"
            )
        else:
            logger.info("Model has no parameters.")

    def forward(self, batch_for_predictor: Dict[str, Any]):
        """Forward pass of the model.

        Args:
            batch_for_predictor: The full batch dictionary as prepared by the predictor's collate_fn.
                                 This will be passed to the predictor's forward method.

        Returns:
            torch.Tensor: outputs from the predictor's forward method.
        """
        return self.model(batch_for_predictor)

    def training_step(self, batch, batch_idx):
        """Training step."""
        labels = batch["labels"]
        # The 'batch' itself is passed to HAIPRModule.forward,
        # which then passes it to the predictor's forward.
        outputs = self(batch)

        if self.num_classes == 0:  # regression
            labels = labels.float()
            # Ensure consistent shapes for regression
            if len(outputs.shape) == 1:
                outputs = outputs.view(-1, 1)
            if len(labels.shape) == 1:
                labels = labels.view(-1, 1)
        else:  # classification
            labels = labels.long()
            # No need to modify outputs for classification losses

        # Ensure outputs have the same shape as labels to avoid broadcasting issues
        if outputs.shape != labels.shape:
            outputs = outputs.view_as(labels)
        loss = self.criterion(outputs, labels)
        self.log("train_loss", loss, on_step=True,
                 on_epoch=False, prog_bar=True)

        # Increment training step counter and evaluate validation chunk if needed
        self.training_step_count += 1
        if (self.training_step_count % self.val_chunk_log_interval == 0 and
                self.val_chunk_data is not None):
            self._evaluate_validation_chunk()

        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step."""
        labels = batch["labels"]
        # The 'batch' itself is passed to HAIPRModule.forward,
        # which then passes it to the predictor's forward.
        outputs = self(batch)

        if outputs.shape != labels.shape:
            outputs = outputs.view_as(labels)

        if self.num_classes == 0:  # regression
            labels = labels.float()
            loss = self.criterion(outputs, labels)
            # Ensure both predictions and labels are 2D tensors
            preds = outputs.view(-1, 1)  # Make predictions 2D
            labels_for_metric = labels.view(-1, 1)  # Make labels 2D
            probs = None
        else:  # classification
            labels = labels.long()
            loss = self.criterion(outputs, labels)
            preds = torch.argmax(
                outputs, dim=1
            )  # argmax of the logits [0.2, 0.3, 0.5] -> 2
            labels_for_metric = labels
            probs = torch.softmax(
                outputs, dim=1
            )  # probabilities for each class [2, 3, 5] -> [0.2, 0.3, 0.5]
            logger.debug(f"probs: {probs.shape}")
            logger.debug(f"preds: {preds.shape}")

        # Log validation metrics
        self.log(
            "val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
        )

        # Store predictions, labels, and sample_ids for gathering at epoch end
        self.val_preds.append(preds.detach())
        self.val_labels.append(labels_for_metric.detach())
        if probs is not None:
            self.val_probs.append(probs.detach())
        # DDP-safe: gather sample_id if present
        if "sample_id" in batch:
            self.val_sample_ids.append(batch["sample_id"].detach().cpu())

        return {"val_loss": loss}

    def on_validation_epoch_start(self):
        """Reset validation lists at start of validation epoch"""
        self.val_preds = []
        self.val_labels = []
        self.val_probs = []
        self.val_sample_ids = []

    def on_validation_epoch_end(self):
        """Called at the end of validation to compute and log metrics."""
        if not self.val_preds or not self.val_labels:  # Check if lists are empty
            logger.warning(
                "Validation epoch end: No predictions or labels to compute metrics."
            )
            # Log dummy metrics or skip if no data
            metrics_to_log = {"val_loss": float("nan")}
            if self.num_classes == 0:  # Regression
                for metric_name in ["mse", "mae", "r2", "pearson_r", "spearman_r"]:
                    metrics_to_log[f"val_{metric_name}"] = float("nan")
            else:  # Classification
                for metric_name in [
                    "accuracy",
                    "precision",
                    "recall",
                    "f1",
                    "roc_auc",
                    "mcc",
                ]:
                    metrics_to_log[f"val_{metric_name}"] = float("nan")
            self.log_dict(metrics_to_log, on_epoch=True,
                          prog_bar=False, sync_dist=True)
            return

        # Gather predictions from all GPUs if DDP is enabled
        if self._is_ddp_enabled():
            gathered_preds = self._gather_predictions(self.val_preds)
            gathered_labels = self._gather_predictions(self.val_labels)
            if self.val_probs:
                gathered_probs = self._gather_predictions(self.val_probs)
            else:
                gathered_probs = []
            # Gather sample_ids if present
            if self.val_sample_ids:
                gathered_sample_ids = self._gather_predictions(
                    self.val_sample_ids)
            else:
                gathered_sample_ids = []
        else:
            gathered_preds = self.val_preds
            gathered_labels = self.val_labels
            gathered_probs = self.val_probs
            gathered_sample_ids = (
                self.val_sample_ids if hasattr(self, "val_sample_ids") else []
            )

        preds_all = torch.cat(gathered_preds)
        labels_all = torch.cat(gathered_labels)
        sample_ids_all = torch.cat(
            gathered_sample_ids) if gathered_sample_ids else None

        # Fetch val_loss from trainer as it's aggregated across batches/devices
        val_loss_tensor = self.trainer.callback_metrics.get("val_loss")
        val_loss = (
            val_loss_tensor.item() if val_loss_tensor is not None else float("nan")
        )

        if preds_all.dtype == torch.bfloat16:
            preds_all = preds_all.float()
        if labels_all.dtype == torch.bfloat16:
            labels_all = labels_all.float()

        preds_all = preds_all.cpu()
        labels_all = labels_all.cpu()
        if sample_ids_all is not None:
            sample_ids_all = sample_ids_all.cpu().numpy().flatten()

        preds_np = preds_all.numpy()
        labels_np = labels_all.numpy()

        # If sample_ids are present, sort predictions/labels by sample_id
        if sample_ids_all is not None:
            sort_idx = sample_ids_all.argsort()
            preds_np = preds_np[sort_idx]
            labels_np = labels_np[sort_idx]
            sample_ids_all = sample_ids_all[sort_idx]

        self.last_val_predictions = {
            "preds": preds_np,
            "labels": labels_np,
        }
        if sample_ids_all is not None:
            self.last_val_predictions["sample_ids"] = sample_ids_all

        probs_np = None
        if gathered_probs:
            try:
                probs_all = torch.cat(gathered_probs).cpu()
                probs_np = probs_all.numpy()
                self.last_val_predictions["probs"] = probs_np
            except RuntimeError as e:
                logger.warning(
                    f"Could not concatenate validation probabilities: {e}")

        if self._is_ddp_enabled() and not self.trainer.is_global_zero:
            return

        # NOTE: RANK 0 ONLY
        # log learning rate on end of validation epoch
        self.log("learning_rate", self.trainer.optimizers[0].param_groups[0]
                 ["lr"], on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)

        if self.num_classes == 0:  # Regression
            metrics = compute_regression_metrics(labels_np, preds_np)
        else:  # Classification
            metrics = compute_classification_metrics(
                labels_np,
                preds_np,
                probs_np,  # Pass probs_np, which could be None
            )

        # Ensure val_loss is part of metrics logged
        metrics["val_loss"] = val_loss

        metrics_to_log = {
            f"val_{k}" if not k.startswith("val_") else k: v for k, v in metrics.items()
        }
        self.log_dict(metrics_to_log, on_epoch=True,
                      prog_bar=False, sync_dist=True)
        self.last_val_metrics = metrics_to_log

        if self.best_val_loss > val_loss:
            self.best_val_loss = val_loss
            self.best_val_metrics = metrics
            self.best_val_predictions = self.last_val_predictions

    def on_sanity_check_start(self):
        print("on_sanity_check_start")

    def on_sanity_check_end(self):
        """Called at the end of sanity check to compute and log metrics."""
        print("on_sanity_check_end")
        self.best_val_loss = float("inf")
        self.best_val_metrics = {}
        self.best_val_predictions = None
        self.last_val_predictions = None

    def configure_optimizers(self):
        """
        Configure default optimizers and learning rate schedulers with warmup.
        Can be overridden in the child (predictor) class.

        Uses AdamW optimizer, a linear warmup scheduler, and CosineAnnealingLR for learning rate scheduling.
        """
        # TODO: make this configurable through hydra
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
        )

        # Set T_max to a reasonable default (e.g., 50 epochs), can be overridden in child class
        cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=50, eta_min=0
        )

        # Add a warmup scheduler (LinearLR) before CosineAnnealingLR
        # Warmup for 5 epochs by default; can be overridden in child class
        warmup_epochs = 5
        warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
        )

        # Use SequentialLR to combine warmup and cosine annealing
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[warmup_epochs]
        )

        return [optimizer], [scheduler]

    def _create_validation_chunk(self):
        """Create a fixed random chunk from the validation dataloader for frequent evaluation."""
        if self.trainer is None or self.trainer.val_dataloaders is None:
            logger.warning(
                "No validation dataloader available for chunk creation")
            return

        val_dataloader = self.trainer.val_dataloaders
        if isinstance(val_dataloader, list):
            # Use first dataloader if multiple
            val_dataloader = val_dataloader[0]

        # Set random seed for reproducible chunk selection
        torch.manual_seed(42)

        chunk_data = []
        total_samples = 0

        try:
            for batch in val_dataloader:
                chunk_data.append(batch)
                # Estimate batch size from first tensor in batch
                if isinstance(batch, dict) and len(batch) > 0:
                    first_key = next(iter(batch.keys()))
                    if torch.is_tensor(batch[first_key]):
                        batch_size = batch[first_key].size(0)
                    else:
                        batch_size = 1
                else:
                    batch_size = 1

                total_samples += batch_size

                if total_samples >= self.val_chunk_size:
                    break

        except Exception as e:
            logger.warning(f"Error creating validation chunk: {e}")
            return

        if chunk_data:
            self.val_chunk_data = chunk_data
            logger.info(
                f"Created validation chunk with {total_samples} samples from {len(chunk_data)} batches")
        else:
            logger.warning(
                "Failed to create validation chunk - no data collected")

    def _evaluate_validation_chunk(self):
        """Evaluate the validation chunk and log the loss."""
        if self.val_chunk_data is None:
            return

        self.model.eval()
        total_loss = 0.0
        total_samples = 0

        with torch.no_grad():
            for batch in self.val_chunk_data:
                # Move batch to device
                if isinstance(batch, dict):
                    batch = {k: v.to(self.device) if torch.is_tensor(v) else v
                             for k, v in batch.items()}
                else:
                    batch = batch.to(self.device)

                try:
                    labels = batch["labels"]
                    outputs = self(batch)

                    if self.num_classes == 0:  # regression
                        labels = labels.float()
                        # Ensure consistent shapes for regression
                        if len(outputs.shape) == 1:
                            outputs = outputs.view(-1, 1)
                        if len(labels.shape) == 1:
                            labels = labels.view(-1, 1)
                    else:  # classification
                        labels = labels.long()

                    # Ensure outputs have the same shape as labels
                    if outputs.shape != labels.shape:
                        outputs = outputs.view_as(labels)

                    batch_loss = self.criterion(outputs, labels)

                    # Get batch size for weighted average
                    batch_size = labels.size(0)
                    total_loss += batch_loss.item() * batch_size
                    total_samples += batch_size

                except Exception as e:
                    logger.warning(
                        f"Error evaluating validation chunk batch: {e}")
                    continue

        self.model.train()

        if total_samples > 0:
            avg_loss = total_loss / total_samples
            self.log("val_chunk_loss", avg_loss, on_step=True,
                     on_epoch=False, prog_bar=True)
            logger.debug(f"Validation chunk loss: {avg_loss:.4f}")

    def on_train_start(self):
        """Called at the start of training to initialize validation chunk."""
        super().on_train_start() if hasattr(super(), 'on_train_start') else None
        self._create_validation_chunk()

    def on_train_epoch_start(self):
        """Called at the start of each training epoch."""
        super().on_train_epoch_start() if hasattr(
            super(), 'on_train_epoch_start') else None
        # Reset training step counter for consistent interval behavior across epochs
        self.training_step_count = 0

    def predict(self, sequences: List[str], batch_size: int = 1) -> Dict[str, Any]:
        """
        Generic predict method that prepares features once and then processes in batches.

        Args:
            sequences: List of protein sequence strings to predict on
            batch_size: Batch size for processing sequences

        Returns:
            Dictionary containing:
                - predictions: numpy array of predictions
                - probabilities: numpy array of probabilities (for classification tasks)
        """
        self.model.eval()  # Set model to evaluation mode

        # Prepare all features upfront in one go
        sequences_as_dicts = [{"sequence": seq} for seq in sequences]
        prepared_features = self.prepare_features(sequences_as_dicts)

        # Handle different output formats from prepare_features
        if isinstance(prepared_features, dict):
            # Check if it has "inputs" key (ESMC, ESM3 format)
            if "inputs" in prepared_features:
                inputs = prepared_features["inputs"]
                labels = prepared_features.get("labels", None)
            else:
                # Direct format (ESM2 format: {"input_ids": tensor, "attention_mask": tensor, "labels": tensor})
                inputs = {k: v for k, v in prepared_features.items()
                          if k != "labels"}
                labels = prepared_features.get("labels", None)
        else:
            inputs = prepared_features
            labels = None

        # Convert to tensors if needed
        if isinstance(inputs, dict):
            for k, v in inputs.items():
                if isinstance(v, np.ndarray):
                    inputs[k] = torch.tensor(v)
                elif isinstance(v, list):
                    inputs[k] = torch.tensor(v)

        # Create dataset from prepared features
        from torch.utils.data import DataLoader, TensorDataset

        if isinstance(inputs, dict):
            # Handle dict inputs (multiple tensors)
            input_tensors = list(inputs.values())
            dataset = TensorDataset(*input_tensors)
        else:
            # Handle single tensor input
            dataset = TensorDataset(inputs)

        predictions_list = []

        with torch.no_grad():
            for batch in tqdm(DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=False,
            ), desc=f"Predicting {self.model.__class__.__name__}"):

                # Reconstruct batch dict if needed
                if isinstance(inputs, dict):
                    # Handle ESMC/ESM3 format: {"inputs": {...}, "labels": ...}
                    if "inputs" in prepared_features:
                        batch_dict = {"inputs": {
                            k: batch[i] for i, k in enumerate(inputs.keys())}}
                        if labels is not None:
                            batch_dict["labels"] = labels
                    else:
                        # Handle ESM2 format: {"input_ids": tensor, "attention_mask": tensor, "labels": tensor}
                        batch_dict = {k: batch[i]
                                      for i, k in enumerate(inputs.keys())}
                        if labels is not None:
                            batch_dict["labels"] = labels
                else:
                    batch_dict = {"inputs": batch[0]}
                    if labels is not None:
                        batch_dict["labels"] = labels

                # Use the forward method
                pred_out = self.forward(batch_dict)

                # Convert BFloat16 to Float32 if needed
                if pred_out.dtype == torch.bfloat16:
                    pred_out = pred_out.float()

                # Ensure output shape is consistent
                if self.num_classes == 0 and len(pred_out.shape) == 1:
                    pred_out = pred_out.unsqueeze(1)

                predictions_list.append(pred_out.cpu())

        # Concatenate all predictions
        predictions_tensor = torch.cat(predictions_list, dim=0)

        # Convert to numpy
        predictions_np = predictions_tensor.numpy()

        results = {"predictions": predictions_np}

        # Add probabilities for classification tasks
        if self.num_classes > 0:
            probs = torch.softmax(predictions_tensor, dim=-1)
            results["probabilities"] = probs.numpy()

        return results
