from lightning import LightningModule

from src.tasks.registry import TASK_REGISTRY


class BaseModule(LightningModule):
    def __init__(self, patch_size: int = None, num_classes: int = None, loss_fn=None, *args: any, **kwargs: any):
        super().__init__(*args, **kwargs)
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.output_size = 1 if num_classes <= 2 else num_classes
        self.task = "binary" if num_classes <= 2 else "multiclass"
        strategy_cls = TASK_REGISTRY[self.task]
        self.strategy = strategy_cls()
        self.loss_fn = loss_fn if loss_fn else self.strategy.get_default_loss()
        self.metrics = self.strategy.configure_metrics(patch_size=self.patch_size, num_classes=num_classes)

    def _shared_step(self, batch, stage):
        output = self(batch['x'])
        logits = output['logits']
        loss = self._compute_loss(logits, batch['y'])
        self.log(f"{stage}/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        preds = self.strategy.process_outputs(logits)
        return {"loss": loss, "preds": preds} | batch | output

    def _compute_loss(self, logits, labels):
        if self.task == "binary":
            return self.loss_fn(logits.squeeze(), labels.float())
        return self.loss_fn(logits, labels)

    def training_step(self, batch, batch_idx):
        step = self._shared_step(batch, "train")
        self.metrics.train_acc(step["preds"], batch["y"])
        self.log('train/acc', self.metrics.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return step["loss"]

    def validation_step(self, batch, batch_idx):
        step = self._shared_step(batch, "val")
        self.metrics.val_acc(step["preds"], batch["y"])
        self.log('val/acc', self.metrics.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        return step

    def test_step(self, batch, batch_idx):
        step = self._shared_step(batch, "test")
        self.metrics.test_acc(step["preds"], batch["y"])
        self.log('test/acc', self.metrics.test_acc, on_step=False, on_epoch=True)
        if step.get('segmentation', None) is not None and step.get('importance', None) is not None:
            self.metrics.test_segmentation_iou(step['importance'], step['segmentation'])
            self.log('test/segmentation_iou', self.metrics.test_segmentation_iou, on_step=False, on_epoch=True)
            self.metrics.test_segmentation_diff(step['importance'], step['segmentation'])
            self.log('test/segmentation_diff', self.metrics.test_segmentation_diff, on_step=False, on_epoch=True)
        if step.get('importance', None) is not None:
            self.metrics.test_sparsity(step['importance'])
            self.log('test/sparsity', self.metrics.test_sparsity, on_step=False, on_epoch=True)
        return step