import os

import torch

import torch.nn as nn

import torch.nn.functional as F

import pytorch_lightning as pl

from torch.utils.data import DataLoader, TensorDataset

from typing import Any, Dict, Optional

import torchmetrics



from src.downstream_dataset import DownstreamDataModule

from src.downstream import PoolingHead

from src.ema import EMAHelper, resolve_dtype





class EMACallback(pl.Callback):

    def __init__(self, ema_config: Dict[str, Any]):

        super().__init__()

        self.decay = float(ema_config.get("decay", 0.9999))

        self.update_interval = int(ema_config.get("update_interval", 1))

        if self.update_interval < 1:

            raise ValueError("EMA update_interval must be >= 1")

        self.use_ema_for_validation = bool(ema_config.get("use_ema_for_validation", True))

        self.use_ema_for_checkpoint = bool(ema_config.get("use_ema_for_checkpoint", False))

        if self.use_ema_for_checkpoint and not self.use_ema_for_validation:

            raise ValueError("EMA use_ema_for_checkpoint requires use_ema_for_validation=true")

        store_on_cpu = bool(ema_config.get("store_on_cpu", False))

        dtype = resolve_dtype(ema_config.get("dtype"))

        device = torch.device("cpu") if store_on_cpu else None



        self.ema = EMAHelper(decay=self.decay, device=device, dtype=dtype)

        self._ema_swapped = False

        self._last_update_step = -1



    def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

        if not self.ema.shadow:

            self.ema.register(pl_module)



    def on_train_batch_end(

        self,

        trainer: pl.Trainer,

        pl_module: pl.LightningModule,

        outputs,

        batch,

        batch_idx: int,

    ) -> None:

        if trainer.sanity_checking:

            return

        global_step = trainer.global_step

        if global_step == 0:

            return

        if global_step <= self._last_update_step:

            return

        if global_step % self.update_interval != 0:

            return

        self.ema.update(pl_module)

        self._last_update_step = global_step

        pl_module.log("ema/num_updates", self.ema.num_updates, on_step=True, prog_bar=False, logger=True)



    def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

        if trainer.sanity_checking or not self.use_ema_for_validation:

            return

        if not self.ema.shadow:

            return

        self.ema.store(pl_module)

        self.ema.copy_to(pl_module)

        self._ema_swapped = True



    def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

        if not self._ema_swapped:

            return

        self.ema.restore(pl_module)

        self._ema_swapped = False



    def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]) -> None:

        if not self.use_ema_for_checkpoint:

            return

        if not self.ema.shadow:

            return

        state_dict = checkpoint.get("state_dict")

        if not state_dict:

            return

        for name, param in pl_module.named_parameters():

            if name not in self.ema.shadow:

                continue

            ema_tensor = self.ema.shadow[name].to(device=param.device, dtype=param.dtype)

            state_dict[name] = ema_tensor.detach().clone()



    def state_dict(self) -> Dict[str, torch.Tensor]:

        return self.ema.state_dict()



    def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

        self.ema.load_state_dict(state_dict)





class EMACheckpointCallback(pl.Callback):

    def __init__(self, ema_callback: EMACallback, checkpoint_dir: str, checkpoint_name: str = "ema_last.ckpt"):

        super().__init__()

        self.ema_callback = ema_callback

        self.checkpoint_dir = checkpoint_dir

        self.checkpoint_name = checkpoint_name

        self._last_saved_step: Optional[int] = None



    def _save_ema_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

        if trainer.sanity_checking or not trainer.is_global_zero:

            return

        if not self.ema_callback.ema.shadow:

            return

        global_step = trainer.global_step

        if self._last_saved_step == global_step:

            return

        os.makedirs(self.checkpoint_dir, exist_ok=True)

        path = os.path.join(self.checkpoint_dir, self.checkpoint_name)

        ema = self.ema_callback.ema









        ema_state_dict = {k: v.clone() for k, v in ema.shadow.items()}

        checkpoint = {

            "epoch": trainer.current_epoch,

            "global_step": global_step,

            "pytorch-lightning_version": pl.__version__,

            "state_dict": ema_state_dict,

            "loops": trainer.fit_loop.state_dict() if hasattr(trainer.fit_loop, "state_dict") else {},

            "hyper_parameters": pl_module.hparams if hasattr(pl_module, "hparams") else {},





        }

        torch.save(checkpoint, path)

        self._last_saved_step = global_step



    def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

        self._save_ema_checkpoint(trainer, pl_module)



    def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

        self._save_ema_checkpoint(trainer, pl_module)





class TestEveryEpochCallback(pl.Callback):

    """
    Runs test evaluation after each validation epoch to log test MCC every epoch.
    Aggregates ALL batches of the test dataset before computing MCC.
    """

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

        if trainer.sanity_checking:

            return





        datamodule = trainer.datamodule

        if datamodule is None:

            return





        if not hasattr(datamodule, '_test_setup_done') or not datamodule._test_setup_done:

            datamodule.setup('test')

            datamodule._test_setup_done = True



        test_dataloader = datamodule.test_dataloader()

        if test_dataloader is None:

            return





        pl_module.eval()







        if hasattr(pl_module, 'DNAChunker') and hasattr(pl_module.DNAChunker, 'backbone'):

            backbone = pl_module.DNAChunker.backbone

            if hasattr(backbone, 'downsampler'):

                backbone.downsampler.hard_inference = False



        test_outputs = []



        with torch.no_grad():

            for batch in test_dataloader:

                input_ids, labels = batch

                input_ids = input_ids.to(pl_module.device)

                labels = labels.to(pl_module.device)



                logits = pl_module(input_ids)

                test_outputs.append({

                    'logits': logits.detach(),

                    'labels': labels.detach()

                })





        gathered_outputs = pl_module.all_gather(test_outputs)



        all_logits = []

        all_labels = []



        if isinstance(gathered_outputs, list) and len(gathered_outputs) > 0 and isinstance(gathered_outputs[0], list):



            for rank_output in gathered_outputs:

                for batch_output in rank_output:

                    logits_t = batch_output['logits']

                    labels_t = batch_output['labels']

                    if logits_t.dim() > 2:

                        logits_t = logits_t.reshape(-1, logits_t.size(-1))

                    if labels_t.dim() > 1:

                        labels_t = labels_t.reshape(-1)

                    all_logits.append(logits_t)

                    all_labels.append(labels_t)

        else:



            for batch_output in gathered_outputs:

                logits_t = batch_output['logits']

                labels_t = batch_output['labels']

                if logits_t.dim() > 2:

                    logits_t = logits_t.reshape(-1, logits_t.size(-1))

                if labels_t.dim() > 1:

                    labels_t = labels_t.reshape(-1)

                all_logits.append(logits_t)

                all_labels.append(labels_t)



        if trainer.is_global_zero:

            from sklearn.metrics import matthews_corrcoef



            logits = torch.cat(all_logits).cpu()

            labels = torch.cat(all_labels).cpu()

            preds = logits.argmax(dim=-1)



            preds = preds.reshape(-1)

            labels = labels.reshape(-1)



            mcc_value = matthews_corrcoef(labels.numpy(), preds.numpy())





            if trainer.logger:

                trainer.logger.log_metrics({"test/mcc_epoch": mcc_value}, step=trainer.global_step)

            print(f"[Test] Epoch {trainer.current_epoch} - MCC: {mcc_value:.4f}")





        if hasattr(pl_module, 'DNAChunker') and hasattr(pl_module.DNAChunker, 'backbone'):

            backbone = pl_module.DNAChunker.backbone

            if hasattr(backbone, 'downsampler'):

                backbone.downsampler.hard_inference = True



        pl_module.train()





class DownstreamValidationCallback(pl.Callback):

    """
    A PyTorch Lightning Callback to perform linear probing evaluation on a downstream task
    (Nucleotide Transformer benchmark) at the end of each validation epoch during pre-training.
    """

    def __init__(self, config: Dict[str, Any]):

        super().__init__()

        self.config = config

        self.dataset_name = config["downstream"].get(

            "dataset_name", "InstaDeepAI/nucleotide_transformer_downstream_tasks"

        )

        self.task_name = config["downstream"].get("task_name", "H3")

        self.pooling_strategy = config["downstream"].get("pooling_strategy", "mean")

        self.probe_epochs = config["downstream"].get("probe_epochs", 5)

        self.probe_lr = config["downstream"].get("probe_lr", 1e-3)

        self.probe_batch_size = config["downstream"].get("probe_batch_size", 64)

        self.max_seq_len = config["downstream"].get("max_seq_len", 1024)





        self.datamodule = None

        self.num_classes = None



    def _setup_datamodule(self):

        """Lazy initialization of the datamodule."""

        if self.datamodule is None:

            self.datamodule = DownstreamDataModule(

                dataset_name=self.dataset_name,

                task_name=self.task_name,

                max_seq_len=self.max_seq_len,

                batch_size=self.probe_batch_size,

                num_workers=self.config["data"].get("num_workers", 4),

            )

            self.datamodule.setup('fit')

            self.num_classes = self.datamodule.num_classes

            print(f"[LinearProbe] Loaded {self.task_name} task with {self.num_classes} classes")



    def _extract_features(self, pl_module: pl.LightningModule, dataloader: DataLoader) -> tuple:

        """Extract pooled features from the frozen backbone."""

        pl_module.model.eval()



        all_features = []

        all_labels = []

        d_model = self.config["model"]["d_model"]



        with torch.no_grad():

            for batch in dataloader:

                input_ids, labels = batch

                input_ids = input_ids.to(pl_module.device)





                outputs = pl_module.model.net(input_ids, return_dict=True)

                hidden_states = outputs.last_hidden_state





                pad_token_id = pl_module.model.net.config.pad_token_id

                attention_mask = (input_ids != pad_token_id)





                if self.pooling_strategy == "first":

                    pooled = hidden_states[:, 0]

                elif self.pooling_strategy == "mean":



                    masked_hidden = hidden_states * attention_mask.unsqueeze(-1).float()

                    summed = masked_hidden.sum(dim=1)

                    num_valid = attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)

                    pooled = summed / num_valid

                elif self.pooling_strategy == "attn":





                    masked_hidden = hidden_states * attention_mask.unsqueeze(-1).float()

                    summed = masked_hidden.sum(dim=1)

                    num_valid = attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)

                    pooled = summed / num_valid

                else:

                    raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}")





                all_features.append(pooled.float().cpu())

                all_labels.append(labels)



        features = torch.cat(all_features, dim=0)

        labels = torch.cat(all_labels, dim=0)

        return features, labels



    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):



        if not self.config["downstream"].get("linear_probe_enabled", False):

            return

        if trainer.sanity_checking:

            return







        if not trainer.is_global_zero:

            return



        print(f"\n--- Running Linear Probe Evaluation: {self.task_name} ---")





        self._setup_datamodule()





        was_training = pl_module.model.training



        try:



            print(f"[LinearProbe] Extracting features...")

            train_features, train_labels = self._extract_features(

                pl_module, self.datamodule.train_dataloader()

            )

            val_features, val_labels = self._extract_features(

                pl_module, self.datamodule.val_dataloader()

            )

            print(f"[LinearProbe] Train features: {train_features.shape}, Val features: {val_features.shape}")





            train_dataset = TensorDataset(train_features, train_labels)

            val_dataset = TensorDataset(val_features, val_labels)



            probe_train_loader = DataLoader(

                train_dataset, batch_size=self.probe_batch_size, shuffle=True

            )

            probe_val_loader = DataLoader(

                val_dataset, batch_size=self.probe_batch_size, shuffle=False

            )





            d_model = self.config["model"]["d_model"]

            classifier = nn.Linear(d_model, self.num_classes, dtype=torch.float32).to(pl_module.device)

            optimizer = torch.optim.Adam(classifier.parameters(), lr=self.probe_lr)

            criterion = nn.CrossEntropyLoss()





            with torch.enable_grad():

                for epoch in range(self.probe_epochs):

                    classifier.train()

                    for batch_x, batch_y in probe_train_loader:



                        batch_x = batch_x.to(pl_module.device, dtype=torch.float32)

                        batch_y = batch_y.to(pl_module.device)



                        optimizer.zero_grad()

                        logits = classifier(batch_x)

                        loss = criterion(logits, batch_y)

                        loss.backward()

                        optimizer.step()





            classifier.eval()

            all_preds = []

            all_targets = []



            with torch.no_grad():

                for batch_x, batch_y in probe_val_loader:

                    batch_x = batch_x.to(pl_module.device, dtype=torch.float32)

                    logits = classifier(batch_x)

                    preds = logits.argmax(dim=-1)

                    all_preds.append(preds.cpu())

                    all_targets.append(batch_y)



            all_preds = torch.cat(all_preds)

            all_targets = torch.cat(all_targets)





            accuracy = (all_preds == all_targets).float().mean().item()





            mcc_metric = torchmetrics.MatthewsCorrCoef(task="multiclass", num_classes=self.num_classes)

            mcc = mcc_metric(all_preds, all_targets).item()







            if trainer.logger:

                trainer.logger.log_metrics({

                    f"val/linear_probe_acc/{self.task_name}": accuracy,

                    f"val/linear_probe_mcc/{self.task_name}": mcc,

                }, step=trainer.global_step)



            print(f"[LinearProbe] {self.task_name} - Accuracy: {accuracy:.4f}, MCC: {mcc:.4f}")



        finally:



            if was_training:

                pl_module.model.train()

