import torch
from torchmetrics import Accuracy, MetricCollection

from src.metrics.segmentation_diff import SegmentationAbsoluteDiff
from src.metrics.segmentation_iou import SegmentationIOU
from src.metrics.sparsity import Sparsity
from src.tasks.base_task import BaseTaskStrategy


class MulticlassClassificationStrategy(BaseTaskStrategy):
    def get_default_loss(self, **kwargs):
        return torch.nn.CrossEntropyLoss()

    def process_outputs(self, logits):
        preds = logits.argmax(dim=1)
        return preds

    def configure_metrics(self, patch_size, num_classes):
        metrics = {
            "train_acc": Accuracy(task="multiclass", num_classes=num_classes),
            "val_acc": Accuracy(task="multiclass", num_classes=num_classes),
            "test_acc": Accuracy(task="multiclass", num_classes=num_classes),
            "test_sparsity": Sparsity(),
        }
        if patch_size is not None:
            metrics.update({
                "test_segmentation_diff": SegmentationAbsoluteDiff(patch_size=patch_size),
                "test_segmentation_iou": SegmentationIOU(patch_size=patch_size, threshold=0.5),
            })
        return MetricCollection(metrics)
