from collections.abc import Callable
from typing import Literal

from imle.noise import SumOfGammaNoiseDistribution
from imle.target import BaseTargetDistribution, TargetDistribution
from imle.wrapper import imle
from pathos.multiprocessing import ProcessPool
import torch
from torch import Tensor
from torch_geometric.data import Batch, Data

from karger import karger_stein_repeated
from ._imle_karger_config import IMLEKargerConfig
from ._target_distributions import KargerTargetDistribution


class IMLEKarger:
    """
    Runs Karger-Stein multiple times and estimates gradients using I-MLE [1].

    The steering weights for each run of Karger-Stein are obtained by adding noise to the input edge predictions and
    inverting them:
    `steering_weights = 1 - (edge_predictions + noise)`.
    The number of noise samples (and therefore the number of sets of steering weights on which Karger-Stein is run)
    is determined by the `num_noise_samples` configuration parameter passed to `initialise_imle()`.

    Note that the number of Karger-Stein runs per noise sample can separately be adjusted using the
    `num_karger_runs_per_noise_sample` configuration parameter.
    The cut returned for each noise sample is the best cut found during the `num_karger_runs_per_noise_sample` runs.
    This means Karger-Stein is run `num_noise_samples * num_karger_runs_per_noise_sample` times in total during the
    forward pass.
    Gradient estimation requires another `num_noise_samples * num_karger_runs_per_noise_sample` runs of Karger-Stein.

    Parameters:

    - `edge_predictions`: The inverse of the steering weights, before the noise is added.
                          Size `[1, graph.num_edges]`.
    - `graph`: The graph to run Karger-Stein on.

    Returns the cuts found by Karger-Stein, size `[num_noise_samples, graph.num_edges]`.

    [1] Niepert et al., "Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions", NeurIPS, 2021
    """

    _num_karger_runs_per_noise_sample: int
    _imle_karger_stein: Callable[[Tensor, Batch], Tensor]

    def __init__(self, config: IMLEKargerConfig, device: str):
        self._num_karger_runs_per_noise_sample = config.num_karger_runs_per_noise_sample

        self._imle_karger_stein = imle(
            self._karger_stein_batch,
            target_distribution=_create_target_distribution(config.target_distribution),
            noise_distribution=SumOfGammaNoiseDistribution(
                k=config.sog_noise_k,
                nb_iterations=config.sog_noise_iterations,
                device=torch.device(device)
            ),
            nb_samples=config.num_noise_samples,
            input_noise_temperature=config.input_noise_temperature,
            target_noise_temperature=config.target_noise_temperature,
        )

    def __call__(self, edge_predictions: Tensor, graphs: Batch) -> Tensor:
        return self._imle_karger_stein(edge_predictions, graphs)

    def _karger_stein_batch(self, edge_predictions: Tensor, graphs: Batch) -> Tensor:
        """
        Runs Karger-Stein multiple times using a different set of steering weights each time, without estimating
        gradients.

        Parameters:

        - `edge_predictions`: The inverse of the steering weights to use for each run
                            (i.e. `steering_weights = 1 - edge_predictions`).
                            Size `[num_noise_samples, graph.num_edges]`.
        - `graph`: The graph to run Karger-Stein on.

        Returns the cuts found by Karger-Stein, size `[num_noise_samples, graph.num_edges]`.
        """
        steering_weights = 1 - edge_predictions

        steering_weights_list = [steering_weights[:, graphs.edge_index_batch == i] for i in range(graphs.num_graphs)]
        graphs_list = [graphs.get_example(i) for i in range(graphs.num_graphs)]

        # run karger-stein on each graph in parallel
        # pathos.multiprocessing.ProcessPool is much quicker here than multiprocessing.Pool from the standard library
        with ProcessPool() as pool:
            cuts = pool.map(self._karger_stein_single_graph, steering_weights_list, graphs_list)

        return torch.cat(cuts, dim=1)

    def _karger_stein_single_graph(self, steering_weights: Tensor, graph: Data):
        cut = [
            karger_stein_repeated(graph, 2, self._num_karger_runs_per_noise_sample, weights, mode="train")
            for weights in list(steering_weights)
        ]
        return torch.stack(cut)


def _create_target_distribution(target_distribution: Literal["general_purpose", "karger"]) -> BaseTargetDistribution:
    print("I-MLE target distribution:", target_distribution)

    if target_distribution == "general_purpose":
        # TODO: make parameters configurable
        return TargetDistribution(alpha=1.0, beta=1.0)
    elif target_distribution == "karger":
        return KargerTargetDistribution()
    else:
        raise ValueError(
            f'I-MLE target_distribution must be either "general_purpose" or "karger", but was {target_distribution}'
        )
