import torch
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP

from .model import Model


def schedule(base: int, n_cities: int) -> tuple[list[int], list[int]]:
    steps = []
    while sum(steps) < n_cities:
        steps.append(base)
        # base *= 2

    steps = list(reversed(steps))
    steps[0] -= sum(steps) - n_cities
    steps[-1] -= 1  # Do not predict the last trivial step.

    sizes = [n_cities]
    for s in steps[:-1]:
        sizes.append(sizes[-1] - s)

    # Add 1 to the problem sizes to take into account the additional origin node that has been
    # lastly selected from previous iteration.
    sizes = [min(n_cities, s + 1) for s in sizes]

    return steps, sizes


def partial_solve(x: Tensor, o: Tensor, e: Tensor, model: Model, n_steps: int) -> Tensor:
    batch_size, n_cities, _ = x.shape
    device = x.device

    solution = []
    m = torch.ones((batch_size, n_cities), dtype=torch.bool, device=device)
    b = torch.arange(batch_size, device=device)

    for _ in range(n_steps):
        l = model(x, m, o, e)
        m[b, o] = False
        m[b, e] = True
        o = l.argmax(dim=1)
        solution.append(o)

    return torch.stack(solution, dim=1)


@torch.compile(dynamic=True)
def solve_(x: Tensor, model: Model, base: int = 32) -> Tensor:
    batch_size, n_cities, _ = x.shape
    device = x.device

    o = e = torch.zeros((batch_size,), dtype=torch.long, device=device)
    m = torch.ones((batch_size, n_cities), dtype=torch.bool, device=device)
    b = torch.arange(batch_size, device=device)

    solutions = [o[:, None]]
    for n_steps, pb_size in zip(*schedule(base, n_cities)):
        # Moves 'True' indices to front.
        dec = torch.argsort(m, dim=1, stable=True, descending=True)
        # Places indices back to their original position.
        enc = torch.argsort(dec, dim=1, descending=False)

        # Generate the next `n_steps` part of the solution.
        sol = partial_solve(
            torch.vmap(lambda x, d: x[d])(x, dec[:, :pb_size]),
            enc[b, o],
            enc[b, e],
            model,
            n_steps,
        )

        # Project the partial solution into the original indices.
        sol = torch.vmap(lambda s, d: d[s])(sol, dec)
        solutions.append(sol)

        m[b, o] = False  # Mask previous starting node.
        o = sol[:, -1]  # The new starting node is the last step of the solution.

        update = lambda m, s: torch.index_fill(m, 0, s, torch.tensor(False))
        m = torch.vmap(update)(m, sol[:, :-1])  # Do not mask the new starting node.

    return torch.concat(solutions, dim=1)


@torch.compile(dynamic=True)
def solve(x: Tensor, model: Model | DDP) -> Tensor:
    batch_size, n_cities, _ = x.shape
    device = x.device

    x = x.clone()  # `x` will be modified in-place.

    o = e = torch.zeros((batch_size,), dtype=torch.long, device=device)
    m = torch.ones((batch_size, n_cities), dtype=torch.bool, device=device)
    b = torch.arange(batch_size, device=device)
    c = torch.stack([torch.arange(n_cities, device=device) for _ in range(batch_size)])

    l = model(x, m, o, e)
    o = l.argmax(dim=1)

    solutions = [c[b, e], c[b, o]]
    for _ in range(n_cities - 2):
        l = model(x, m, o, e)

        # The origin cities `o` are now useless. We move last elements of the batch to the origin
        # cities, and cut the arrays.
        x[b, o] = x[b, -1]
        c[b, o] = c[b, -1]
        l[b, o] = l[b, -1]
        x = x[:, :-1]
        c = c[:, :-1]
        l = l[:, :-1]
        m = m[:, :-1]

        o = l.argmax(dim=1)
        solutions.append(c[b, o])

    return torch.stack(solutions, dim=1)
