import math
from typing import Final

from pathos.multiprocessing import ProcessPool
import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Batch, Data
from typing_extensions import override

from karger import karger_stein_repeated
from util import sum_of_edge_weights


MAX_KARGER_STEIN_ATTEMPTS: Final = 20


class KargerDirectGradientLoss(Module):
    """
    Calculates the gradient at the input to this function directly, without requiring a ground truth minimum k-cut.

    This is done by running Karger's algorithm based on the input until two cuts of different sizes are found, and
    using the difference of the vector representations of the two cuts as gradient ("bad cut - good cut").
    The intuition behind this is that the edges in the bad cut should receive lower scores, so they are assigned a
    positive gradient.
    The edges in the good cut should receive higher scores, so they are assigned a negative gradient.

    The gradient is backpropagated after calculating it, meaning that `.backward()` does not need to be called on the
    output.
    In fact, the loss this returns is always a constant 0, and calling `.backward()` on it does nothing.
    Even though this is not actually a loss function, this setup means that this can be used as a drop-in replacement
    for a loss function.
    """

    @override
    def forward(self, input: Tensor, graph: Batch) -> Tensor:
        steering_weights = 1 - torch.sigmoid(input)
        steering_weights_list = [steering_weights[graph.edge_index_batch == i] for i in range(graph.num_graphs)]
        graphs_list = [graph.get_example(i) for i in range(graph.num_graphs)]

        with ProcessPool() as pool:
            gradients = pool.map(_get_gradient, steering_weights_list, graphs_list)

        gradients = torch.cat(gradients)

        # TODO should probably use a custom backward implementation on this class instead. can use i-mle as inspiration
        input.backward(gradients)

        return torch.zeros((), requires_grad=True)


def _get_gradient(steering_weights: Tensor, graph: Data) -> Tensor:
    cut_a = karger_stein_repeated(graph, 2, num_runs=1, steering_weights=steering_weights, mode="train")
    cut_b = cut_a

    attempts = 0
    while math.isclose(sum_of_edge_weights(graph, cut_b), sum_of_edge_weights(graph, cut_a), rel_tol=1e-6):
        # limit number of attempts for finding a different cut. on some graphs, karger-stein always finds the same cut
        if attempts >= MAX_KARGER_STEIN_ATTEMPTS:
            return torch.zeros_like(cut_a)

        cut_b = karger_stein_repeated(graph, 2, num_runs=1, steering_weights=steering_weights, mode="train")
        attempts += 1

    if sum_of_edge_weights(graph, cut_a) > sum_of_edge_weights(graph, cut_b):
        gradient = cut_a - cut_b
    else:
        gradient = cut_b - cut_a

    return gradient
