from typing import Literal

import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Batch
from typing_extensions import override

from util import sum_of_edge_weights
from ._hamming_loss import hamming_loss
from ._imle_karger import IMLEKarger
from ._imle_karger_config import IMLEKargerConfig


class IMLEKargerLoss(Module):
    """
    This loss executes the Karger-Stein algorithm, then calculates the loss based on the output of Karger-Stein.
    There are two options for the loss applied to this output cut:
    `"supervised"` compares the cut with a ground truth minimum cut using the Hamming distance, and
    `"self_supervised"` simply computes the size of the cut.
    The gradient is estimated using I-MLE [1].

    The `follow_batch` parameter to the DataLoader must include `"edge_index"` for this loss to work.
    This is required because it allows reconstructing which edge (and therefore which part of the input) belongs to
    which graph.

    Parameters:

    - `config`: Configuration for I-MLE and the Karger-Stein algorithm.
    - `mode`: Determines how to calculate the loss based on the output of Karger-Stein.
              Must be `"supervised"` or `"self_supervised"`.

    [1] Niepert et al., "Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions", NeurIPS, 2021
    """

    _imle_karger: IMLEKarger
    _mode: Literal["supervised", "self_supervised"]

    @override
    def __init__(self, config: IMLEKargerConfig, mode: Literal["supervised", "self_supervised"], device: str):
        super().__init__()

        self._imle_karger = IMLEKarger(config, device)
        self._mode = mode

    @override
    def forward(self, input: Tensor, graphs: Batch) -> Tensor:
        input_sigmoid = torch.sigmoid(input)

        cuts = self._imle_karger(input_sigmoid.unsqueeze(0), graphs).squeeze()

        if self._mode == "supervised":
            return hamming_loss(cuts, graphs.y)
        elif self._mode == "self_supervised":
            # sum over all edges where combined_cuts is 1 (elements of combined_cuts are 0 or 1)
            # TODO handle the case where num_noise_samples > 1
            return sum_of_edge_weights(graphs, cuts)
        else:
            raise ValueError(f'mode must be either "supervised" or "self_supervised", but was {self._mode}')
