"""SSL Finetuner for classification tasks."""
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
import torch.nn.functional as F
from torchmetrics import Accuracy, F1Score, AUROC, ConfusionMatrix, CohenKappa


class SSLFineTuner(LightningModule):
    """SSL Finetuner for classification with pretrained backbone."""
    
    def __init__(
        self,
        backbones,
        use_which_backbone: str = "all",
        config=None,
        in_features: int = 256,
        num_classes: int = 2,
        epochs: int = 10,
        dropout: float = 0.0,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
        scheduler_type: str = "cosine",
        decay_epochs=10,
        gamma: float = 0.1,
        final_lr: float = 1e-5,
        use_mean_pool: bool = False,
        total_training_steps: int = None,
        finetune_backbone: bool = False,
        *args, **kwargs
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.weight_decay = weight_decay
        self.scheduler_type = scheduler_type
        self.decay_epochs = decay_epochs
        self.gamma = gamma
        self.epochs = epochs
        self.final_lr = final_lr
        self.use_mean_pool = use_mean_pool
        self.total_training_steps = total_training_steps
        self.finetune_backbone = finetune_backbone

        # Backbone
        if isinstance(backbones, nn.ModuleDict):
            self.backbones = backbones
        else:
            self.backbones = nn.ModuleDict(backbones)
        self.use_which_backbone = use_which_backbone
        self.backbone = self.backbones[self.use_which_backbone]

        # Freeze backbone for linear probing
        for p in self.backbone.parameters():
            p.requires_grad = self.finetune_backbone
        if not self.finetune_backbone:
            self.backbone.eval()
        else:
            print("[INFO] Full finetuning mode: backbone parameters are TRAINABLE")

        # Linear classifier
        final_in_features = getattr(self.backbone, "out_dim", in_features)
        self.linear_layer = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(final_in_features, num_classes)
        )

        # Metrics
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes, average="micro")
        self.train_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
        self.train_auc = AUROC(task="multiclass", num_classes=num_classes, average="macro")

        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes, average="micro")
        self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
        self.val_auc = AUROC(task="multiclass", num_classes=num_classes, average="macro")
        self.val_kappa = CohenKappa(task="multiclass", num_classes=num_classes, weights="quadratic")

        self.test_acc = Accuracy(task="multiclass", num_classes=num_classes, average="micro")
        self.test_f1 = F1Score(task="multiclass", num_classes=num_classes, average="macro")
        self.test_auc = AUROC(task="multiclass", num_classes=num_classes, average="macro")
        self.test_kappa = CohenKappa(task="multiclass", num_classes=num_classes, weights="quadratic")

        # Per-class metrics
        self.val_acc_c = Accuracy(task="multiclass", num_classes=num_classes, average=None)
        self.val_f1_c = F1Score(task="multiclass", num_classes=num_classes, average=None)
        self.val_auc_c = AUROC(task="multiclass", num_classes=num_classes, average=None)
        self.val_cm = ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize=None)

        self.test_acc_c = Accuracy(task="multiclass", num_classes=num_classes, average=None)
        self.test_f1_c = F1Score(task="multiclass", num_classes=num_classes, average=None)
        self.test_auc_c = AUROC(task="multiclass", num_classes=num_classes, average=None)
        self.test_cm = ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize=None)

        self.class_names = getattr(config, "class_names", [str(i) for i in range(num_classes)])

    def on_train_epoch_start(self) -> None:
        if not self.finetune_backbone:
            self.backbone.eval()

    def training_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        probs = logits.softmax(-1)
        preds = logits.argmax(-1)
        self.train_acc.update(preds, y)
        self.train_f1.update(preds, y)
        self.train_auc.update(probs, y)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
        return loss

    def on_train_epoch_end(self):
        acc = self.train_acc.compute()
        f1 = self.train_f1.compute()
        auc = self.train_auc.compute()
        self.log("train_acc", acc, prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("train_f1", f1, prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("train_auc", auc, on_epoch=True, sync_dist=True)
        self.train_acc.reset()
        self.train_f1.reset()
        self.train_auc.reset()

    def validation_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        probs = logits.softmax(-1)
        preds = logits.argmax(-1)
        self.val_acc.update(preds, y)
        self.val_f1.update(preds, y)
        self.val_auc.update(probs, y)
        self.val_kappa.update(preds, y)
        self.val_acc_c.update(preds, y)
        self.val_f1_c.update(preds, y)
        self.val_auc_c.update(probs, y)
        self.val_cm.update(preds, y)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True)
        return loss

    def on_validation_epoch_end(self):
        acc = self.val_acc.compute()
        f1 = self.val_f1.compute()
        auc = self.val_auc.compute()
        kappa = self.val_kappa.compute()
        self.log("val_acc", acc, prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("val_f1", f1, prog_bar=True, on_epoch=True, sync_dist=True)
        self.log("val_auc", auc, on_epoch=True, sync_dist=True)
        self.log("val_kappa", kappa, prog_bar=True, on_epoch=True, sync_dist=True)

        acc_c = self.val_acc_c.compute()
        f1_c = self.val_f1_c.compute()
        auc_c = self.val_auc_c.compute()
        cm = self.val_cm.compute()
        support = cm.sum(dim=1)
        for i in range(len(acc_c)):
            name = self.class_names[i] if i < len(self.class_names) else str(i)
            self.log(f"val/acc_{name}", acc_c[i], on_epoch=True, sync_dist=True)
            self.log(f"val/f1_{name}", f1_c[i], on_epoch=True, sync_dist=True)
            self.log(f"val/auc_{name}", auc_c[i], on_epoch=True, sync_dist=True)
            self.log(f"val/support_{name}", support[i].float(), on_epoch=True, sync_dist=True)

        for m in [self.val_acc, self.val_f1, self.val_auc, self.val_kappa,
                  self.val_acc_c, self.val_f1_c, self.val_auc_c, self.val_cm]:
            m.reset()

    def test_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        probs = logits.softmax(-1)
        preds = logits.argmax(-1)
        self.test_acc.update(preds, y)
        self.test_f1.update(preds, y)
        self.test_auc.update(probs, y)
        self.test_kappa.update(preds, y)
        self.test_acc_c.update(preds, y)
        self.test_f1_c.update(preds, y)
        self.test_auc_c.update(probs, y)
        self.test_cm.update(preds, y)
        self.log("test_loss", loss, on_epoch=True, sync_dist=True)
        return loss

    def on_test_epoch_end(self):
        acc = self.test_acc.compute()
        f1 = self.test_f1.compute()
        auc = self.test_auc.compute()
        kappa = self.test_kappa.compute()
        self.log("test_acc", acc, on_epoch=True, sync_dist=True)
        self.log("test_f1", f1, on_epoch=True, sync_dist=True)
        self.log("test_auc", auc, on_epoch=True, sync_dist=True)
        self.log("test_kappa", kappa, on_epoch=True, sync_dist=True)

        acc_c = self.test_acc_c.compute()
        f1_c = self.test_f1_c.compute()
        auc_c = self.test_auc_c.compute()
        cm = self.test_cm.compute()
        support = cm.sum(dim=1) if cm is not None else None
        for i in range(len(acc_c)):
            name = self.class_names[i] if i < len(self.class_names) else str(i)
            self.log(f"test/acc_{name}", acc_c[i], on_epoch=True, sync_dist=True)
            self.log(f"test/f1_{name}", f1_c[i], on_epoch=True, sync_dist=True)
            self.log(f"test/auc_{name}", auc_c[i], on_epoch=True, sync_dist=True)
            if support is not None:
                self.log(f"test/support_{name}", support[i].float(), on_epoch=True, sync_dist=True)

        for m in [self.test_acc, self.test_f1, self.test_auc, self.test_kappa,
                  self.test_acc_c, self.test_f1_c, self.test_auc_c, self.test_cm]:
            m.reset()

    def shared_step(self, batch):
        context = torch.no_grad() if not self.finetune_backbone else torch.enable_grad()
        with context:
            psg = batch['psg']
            feats = self._get_features(self.backbone, psg)
        y = batch["label"]
        feats = feats.view(feats.size(0), -1)
        logits = self.linear_layer(feats)
        y = y.squeeze(1).long()
        loss = F.cross_entropy(logits, y)
        return loss, logits, y
    
    def _get_features(self, backbone, x):
        if self.use_mean_pool:
            if hasattr(backbone, 'forward_encoding_mean_pool'):
                return backbone.forward_encoding_mean_pool(x)
            elif hasattr(backbone, 'forward_avg_pool'):
                return backbone.forward_avg_pool(x)
        return backbone(x)

    def configure_optimizers(self):
        if self.finetune_backbone:
            params = list(self.backbone.parameters()) + list(self.linear_layer.parameters())
        else:
            params = self.linear_layer.parameters()
        
        optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)

        if self.total_training_steps is not None and self.total_training_steps > 0:
            if self.scheduler_type == "constant":
                return [optimizer]
            elif self.scheduler_type == "step":
                decay_steps = int(self.total_training_steps * self.decay_epochs / self.epochs)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [decay_steps], gamma=self.gamma)
            elif self.scheduler_type == "cosine":
                warmup_steps = int(0.1 * self.total_training_steps)
                cosine_steps = self.total_training_steps - warmup_steps
                warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
                    optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps
                )
                cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, T_max=cosine_steps, eta_min=self.final_lr
                )
                scheduler = torch.optim.lr_scheduler.SequentialLR(
                    optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_steps]
                )
            return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
        else:
            if self.scheduler_type == "constant":
                return [optimizer]
            elif self.scheduler_type == "cosine":
                warmup_epochs = int(0.1 * self.epochs)
                cosine_epochs = self.epochs - warmup_epochs
                warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
                    optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
                )
                cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, T_max=cosine_epochs, eta_min=self.final_lr
                )
                scheduler = torch.optim.lr_scheduler.SequentialLR(
                    optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs]
                )
            return [optimizer], [scheduler]
