import math
from typing import Final

from pathos.multiprocessing import ProcessPool
import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Batch, Data
from typing_extensions import override

from decoders import Insertion, TSPDecoder
from util import sum_of_edge_weights


MIN_ATTEMPTS: Final = 10
MAX_ATTEMPTS: Final = 20
TSP_HEURISTIC: Final[TSPDecoder] = Insertion(mode="random")


# TODO There is probably a way to do this with less duplicated code


class TSPDirectGradientLoss(Module):
    """
    Calculates the gradient at the input to this function directly, without requiring a ground truth TSP tour.

    TODO

    The gradient is backpropagated after calculating it, meaning that `.backward()` does not need to be called on the
    output.
    In fact, the loss this returns is always a constant 0, and calling `.backward()` on it does nothing.
    Even though this is not actually a loss function, this setup means that this can be used as a drop-in replacement
    for a loss function.
    """

    @override
    def forward(self, input: Tensor, graph: Batch) -> Tensor:
        input_list = [input[graph.edge_index_batch == i] for i in range(graph.num_graphs)]
        graphs_list = [graph.get_example(i) for i in range(graph.num_graphs)]

        with ProcessPool() as pool:
            gradients = pool.map(_get_gradient, input_list, graphs_list)

        gradients = torch.cat(gradients)

        # TODO should probably use a custom backward implementation on this class instead. can use i-mle as inspiration
        input.backward(gradients)

        return torch.zeros((), requires_grad=True)


def _get_gradient(steering_weights: Tensor, graph: Data) -> Tensor:
    tsp_tours: list[Tensor] = []
    tour_lengths: list[Tensor] = []

    for _ in range(MIN_ATTEMPTS):
        tsp_tour = TSP_HEURISTIC(graph, steering_weights)
        tsp_tours.append(tsp_tour)
        tour_lengths.append(sum_of_edge_weights(graph, tsp_tour))

    attempts = MIN_ATTEMPTS
    while math.isclose(min(tour_lengths), max(tour_lengths), rel_tol=1e-6):
        # limit number of attempts for finding a different TSP tour
        if attempts >= MAX_ATTEMPTS:
            return torch.zeros_like(tsp_tours[0])

        tsp_tour = TSP_HEURISTIC(graph, steering_weights)
        tsp_tours.append(tsp_tour)
        tour_lengths.append(sum_of_edge_weights(graph, tsp_tour))

        attempts += 1

    tour_lengths_tensor = torch.stack(tour_lengths)
    winning_tour_length, winning_index = torch.min(tour_lengths_tensor, dim=0)
    winning_tour = tsp_tours[winning_index]

    gradient = torch.zeros_like(tsp_tours[0])

    for losing_tour, losing_tour_length in zip(tsp_tours, tour_lengths):
        if not math.isclose(losing_tour_length, winning_tour_length, rel_tol=1e-6):
            scaling_factor = losing_tour_length / winning_tour_length - 1
            gradient += scaling_factor * (losing_tour - winning_tour)

    return gradient
