from typing import Tuple, List

import torch
from pydgn.training.callback.metric import (
    MulticlassClassification,
    Classification,
    MulticlassAccuracy,
)
from sklearn.metrics import f1_score
from torch import softmax


class MulticlassClassification(MulticlassClassification):
    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns output[0] as predictions and dataset targets.
        Squeezes the first dimension of output and targets to get
        single vectors.
        Args:
            targets (:class:`torch.Tensor`): ground truth
            outputs (List[:class:`torch.Tensor`]): outputs of the model
        Returns:
            A tuple of tensors (predicted_values, target_values)
        """
        pred = outputs[0]
        targets = outputs[2][0]

        if len(targets.shape) == 2:
            targets = targets.squeeze(dim=1)

        return pred, targets


class MulticlassAccuracy(MulticlassAccuracy):
    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns output[0] as predictions and dataset targets.
        Squeezes the first dimension of output and targets to get
        single vectors.
        Args:
            targets (:class:`torch.Tensor`): ground truth
            outputs (List[:class:`torch.Tensor`]): outputs of the model
        Returns:
            A tuple of tensors (predicted_values, target_values)
        """
        pred = outputs[0]
        correct = self._get_correct(pred)

        targets = outputs[2][0]

        if len(targets.shape) == 2:
            targets = targets.squeeze(dim=1)

        return correct, targets


class MicroF1Score(Classification):
    @property
    def name(self) -> str:
        """
        The name of the loss to be used in configuration files and displayed
        on Tensorboard
        """
        return "Macro F1 Score"

    def __init__(
        self,
        use_as_loss=False,
        reduction="mean",
        accumulate_over_epoch: bool = True,
        force_cpu: bool = True,
        device: str = "cpu",
    ):
        super().__init__(
            use_as_loss=use_as_loss,
            reduction=reduction,
            accumulate_over_epoch=accumulate_over_epoch,
            force_cpu=force_cpu,
            device=device,
        )
        self.metric = f1_score

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns output[0] as predictions and dataset targets.
        Squeezes the first dimension of output and targets to get
        single vectors.
        Args:
            targets (:class:`torch.Tensor`): ground truth
            outputs (List[:class:`torch.Tensor`]): outputs of the model
        Returns:
            A tuple of tensors (predicted_values, target_values)
        """
        pred = outputs[0].argmax(dim=1)

        # override targets because we are taking a subset of the nodes in the graph
        targets = outputs[2][0]

        if len(targets.shape) == 2:
            targets = targets.squeeze(dim=1)

        return pred, targets

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        """
        Applies a regression metric
        (to be subclassed as it is None in this class)
        Args:
            targets (:class:`torch.Tensor`): tensor of ground truth values
            predictions (:class:`torch.Tensor`):
                tensor of predictions of the model
        Returns:
            A tensor with the metric value
        """
        metric = self.metric(targets, predictions, average="macro")
        return metric
