from imle.target import BaseTargetDistribution
from torch import Tensor


class COTargetDistribution(BaseTargetDistribution):
    """
    Target distribution for combinatorial optimisation problems, as defined in section 4 of
    Niepert et al., "Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions", NeurIPS, 2021

    This is currently unused, because it doesn't work well with Karger's algorithm.
    """
    def params(self, theta: Tensor, dy: Tensor) -> Tensor:
        theta_prime = theta.detach().clone()
        index = dy != 0
        theta_prime[index] = -dy[index]
        return theta_prime


class KargerTargetDistribution(BaseTargetDistribution):
    """
    Target distribution for Karger's algorithm.
    Given the parameters for Karger `theta` and the gradient w.r.t. Karger's outputs `dy`, this attempts to reconstruct
    the optimal inputs to Karger's algorithm that have 100% probability of finding the minimum cut.

    This is intended to be used in combination with the Hamming loss and a ground truth cut.
    """
    def params(self, theta: Tensor, dy: Tensor) -> Tensor:
        nb_samples = 1 / dy.abs().max().item()
        return (1 - nb_samples * dy) / 2
