import torch
from torchmetrics import Accuracy, MetricCollection

from src.metrics.segmentation_diff import SegmentationAbsoluteDiff
from src.metrics.segmentation_iou import SegmentationIOU
from src.metrics.sparsity import Sparsity
from src.tasks.base_task import BaseTaskStrategy


class BinaryClassificationStrategy(BaseTaskStrategy):
    def get_default_loss(self):
        return torch.nn.BCEWithLogitsLoss()

    def process_outputs(self, logits):
        logits = logits.squeeze()
        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).int()
        return preds

    def configure_metrics(self, patch_size, *args, **kwargs):
        metrics = {
            "train_acc": Accuracy(task="binary"),
            "val_acc": Accuracy(task="binary"),
            "test_acc": Accuracy(task="binary"),
            "test_sparsity": Sparsity(),
        }
        if patch_size is not None:
            metrics.update({
                "test_segmentation_diff": SegmentationAbsoluteDiff(patch_size=patch_size),
                "test_segmentation_iou": SegmentationIOU(patch_size=patch_size, threshold=0.5),
            })

        return MetricCollection(metrics)
