import random
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
from einops import repeat
from torch import Tensor
from torch.utils.data import Dataset


@dataclass
class Batch:
    """A batch of samples.

    ---
    Parameters:
        x: Cities coordinates.
            Shape of [batch_size, n_cities, (x, y)].
        s: Solutions.
            Shape of [batch_size, n_cities].
        m: Masks.
            Shape of [batch_size, n_cities].
        o: Origins.
            Shape of [batch_size,].
        e: Destinations.
            Shape of [batch_sise,].
        l: Labels.
            Shape of [batch_sise,].
        n: The number of unpadded cities.
    """

    x: Tensor
    s: Tensor
    m: Tensor
    o: Tensor
    e: Tensor
    l: Tensor
    n: int

    def to(self, device: torch.device) -> "Batch":
        return Batch(
            self.x.to(device),
            self.s.to(device),
            self.m.to(device),
            self.o.to(device),
            self.e.to(device),
            self.l.to(device),
            self.n,
        )

    def pad(self, multiple: int = 128) -> "Batch":
        """Fix https://github.com/pytorch/pytorch/issues/153799"""
        batch_size, n_cities = self.s.shape
        device = self.x.device
        to_add = multiple - (n_cities % multiple)

        x = torch.zeros((batch_size, n_cities + to_add, 2), dtype=self.x.dtype, device=device)
        s = torch.zeros((batch_size, n_cities + to_add), dtype=self.s.dtype, device=device)
        m = torch.zeros((batch_size, n_cities + to_add), dtype=self.m.dtype, device=device)

        x[:, :n_cities] = self.x
        s[:, :n_cities] = self.s
        m[:, :n_cities] = self.m

        return Batch(x, s, m, self.o, self.e, self.l, n_cities)

    @classmethod
    def from_samples(cls, batch: list[tuple[Tensor, Tensor]], device: torch.device) -> "Batch":
        x = [b[0] for b in batch]
        s = [b[1] for b in batch]

        x = torch.stack(x).to(device)
        s = torch.stack(s).to(device)

        batch_size, n_cities = s.shape
        a = torch.arange(batch_size)

        # Reorder according to the solutions.
        x = torch.vmap(lambda x, s: x[s])(x, s)
        s = repeat(torch.arange(n_cities, device=device), "c -> b c", b=batch_size)

        # Randomly choose a destination node.
        o = torch.zeros((batch_size,), dtype=torch.long, device=device)
        e = torch.randint(
            size=(batch_size,), low=3, high=n_cities + 1, dtype=torch.long, device=device
        )
        e = e % n_cities
        l = o + 1

        # Build the mask. I propagate the `True` values up to the origin by using the cummax in the
        # reverse order. Also handle the special case when o == e (in that case the mask must be a
        # full ones).
        m = torch.zeros((batch_size, n_cities), dtype=torch.bool, device=device)
        m[a, e] = True
        m = torch.flip(m, dims=(1,))
        m = torch.cummax(m, dim=1).values
        m = torch.flip(m, dims=(1,))
        m = torch.masked_fill(m, (e == o)[:, None], True)

        return cls(x, s, m, o, e, l, n_cities).pad()


class TSPDataset(Dataset):
    def __init__(self, name: str, cities: Tensor, solutions: Tensor):
        self.name = name
        self.cities = cities
        self.solutions = solutions

    def __len__(self):
        return len(self.cities)

    def __getitem__(self, i: int) -> tuple[Tensor, Tensor]:
        x = self.cities[i]
        s = self.solutions[i]
        s = TSPDataset.data_aug(s)
        return x, s

    @staticmethod
    def data_aug(s: Tensor) -> Tensor:
        """Randomly roll and flip the solution.

        Doesn't change the overall solution, but the model will see a different initial city and a
        different order of visited cities.

        ---
        Args:
            s: Ordered cities to visit in the solution.
                Shape of [n_cities,].

        ---
        Returns:
            The randomly augmented solution.
                Shape of [n_cities,].
        """
        (n_cities,) = s.shape
        s = torch.roll(s, shifts=random.randint(0, n_cities - 1), dims=-1)
        s = torch.flip(s, dims=(-1,)) if random.random() > 0.5 else s
        return s

    @classmethod
    def from_npz(cls, filepath: Path) -> "TSPDataset":
        data = np.load(filepath)
        cities = data["coords"][:, :-1]
        solutions = data["solutions"]
        cities = torch.tensor(cities, dtype=torch.float32)
        solutions = torch.tensor(solutions, dtype=torch.long)
        return cls(filepath.stem, cities, solutions)
