from pathos.multiprocessing import ProcessPool
import torch
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import BackwardCFunction
from torch.nn import Module
from torch_geometric.data import Batch, Data
from typing_extensions import override

from decoders import Christofides, Noisy, TSPDecoder
from util import sum_of_edge_weights


# TODO: make these configurable
NUM_SAMPLES = 3
DECODER: TSPDecoder = Noisy(Christofides(), scale_factor=1)


class ReinforceLoss(Module):
    @override
    def forward(self, input: Tensor, graph: Batch) -> Tensor:
        return _ReinforceFunction.apply(input, graph)


# see https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
class _ReinforceFunction(Function):
    @override
    @staticmethod
    def forward(ctx: BackwardCFunction, input: Tensor, graph: Batch) -> Tensor:
        input_cpu = input.cpu()
        graph_cpu = graph.clone().cpu()
        input_list = [input_cpu[graph_cpu.edge_index_batch == i] for i in range(graph_cpu.num_graphs)]
        graphs_list = [graph_cpu.get_example(i) for i in range(graph_cpu.num_graphs)]

        with ProcessPool() as pool:
            gradients = pool.map(_get_gradient, input_list, graphs_list)

        gradients = torch.cat(gradients).to(input.device)
        ctx.save_for_backward(gradients)

        return torch.zeros((), device=input.device)

    @override
    @staticmethod
    def backward(ctx: BackwardCFunction, _grad_output: Tensor) -> tuple[Tensor, None]:
        gradients, = ctx.saved_tensors
        # `gradients` is the gradient for `input`, and `None` is the gradient for `graph` (see parameters of forward())
        return gradients, None


@torch.no_grad()
def _get_gradient(steering_weights: Tensor, graph: Data) -> Tensor:
    """
    Returns: gradient estimated using REINFORCE
    """
    tsp_tour = DECODER(graph, steering_weights)
    tsp_tour_length = sum_of_edge_weights(graph, tsp_tour)

    mean_tsp_tour = torch.stack([DECODER(graph, steering_weights) for _ in range(NUM_SAMPLES)]).mean(dim=0)

    # TODO check if the sign is correct
    # no minus sign, because the inputs are inverted inside christofides_steered
    return tsp_tour_length * (tsp_tour - mean_tsp_tour)
