################################################################################
# spectral/modules/tsp/data.py
#
# 
# 
# 
# 2024
#
# Implements a dataset which just generates TSP grids on the fly.
#
# To use this dataset with a ``DataLoader``, ensure ``batch_size`` is set to
# ``None``, and use the ``batch_size`` argument in this dataset instead. This is
# because it is far more efficient to generate the batch here than allowing
# ``DataLoader`` to do it.

import torch

from typing import Any

Tensor = torch.Tensor

ASYMMETRIC = 0
SYMMETRIC  = 1
EUCLIDEAN  = 2

class TSPRandomGrids(torch.utils.data.Dataset):
  """
  Dataset that generates random grids per batch.
  """

  def __init__(self,
      # Arguments:
      batch_size: int,
      grid_size:  int,
      # Keyword Arguments:
      batch_count:     int   = 1000,
      diag_dist:       float = 1.0,
      euclidean_dim:   int   = 2,
      grid_type:       int   = ASYMMETRIC,
    ):
    """
    Initializes ``TSPRandomGrids``.

    Args:
      batch_size (int):
        The size of batches.
      grid_size (int):
        The size of the grids.
      batch_count (int, optional):
        How many batches to consider as an epoch (including validation).
        Defaults to ``1000``.
      diag_dist (float, optional):
        Distance to set the diagonal to. It shouldn't be possible for the model
        to select diagonal edges anyway, but this value could influence things
        like normalization.
        Defaults to ``1.0``.
      euclidean_dim (int, optional):
        Dimension to use for when ``grid_type`` is ``EUCLIDEAN``.
        Defaults to ``2``.
      grid_type (int, optional):
        The type of grid to generate. Either ``ASYMMETRIC`` (0),
        ``SYMMETRIC`` (1), or ``EUCLIDEAN`` (2).
        Defaults to ``ASYMMETRIC``.
    """
    # Batch size.
    assert isinstance(batch_size, int), \
      "batch_size must be an integer."
    self.batch_size = batch_size
    # Grid size.
    assert isinstance(grid_size, int), \
      "grid_size must be an integer."
    self.grid_size = grid_size
    # Batch count.
    assert isinstance(batch_count, int), \
      "batch_count must be an integer."
    self.batch_count = batch_count
    # Diagonal distance.
    assert isinstance(diag_dist, float), \
      "diag_dist must be a float."
    self.diag_dist = diag_dist
    # Euclidean dimension.
    assert isinstance(euclidean_dim, int), \
      "euclidean_dim must be an integer."
    self.euclidean_dim = euclidean_dim
    # Grid type.
    assert isinstance(grid_type, int), \
      "grid_type must be an integer."
    self.grid_type = grid_type
    assert ASYMMETRIC <= self.grid_type <= EUCLIDEAN, \
      "grid_type needs to be in the enum range [0, 2]."

  @torch.autograd.no_grad()
  def generate_grids(self) -> Tensor:
    """
    Generates a adjacency grids for TSP.

    Returns:
      Tensor:
        An unformatted adjacency grid.
    """
    if self.grid_type == EUCLIDEAN:
      points = torch.rand(
        (self.batch_size, self.grid_size, self.euclidean_dim)
      )
      grid = torch.cdist(points, points)
    else:
      grid = torch.rand(
        (self.batch_size, self.grid_size, self.grid_size)
      )
      if self.grid_type == SYMMETRIC:
        grid = torch.tril(grid) + torch.tril(grid, -1).mT
    grid = grid / torch.max(grid).clamp(1e-12)
    return grid.unsqueeze(1)

  @torch.autograd.no_grad()
  def format_grid(self,
      # Arguments:
      grid: Tensor
    ) -> Tensor:
    """
    Formats a generated grid into the proper form, with diagonal distances and
    masking.

    Args:
      grid (Tensor):
        The grid to format.

    Returns:
      Tensor:
        The formatted grid.
    """
    tri = torch.tril(torch.ones((self.grid_size, self.grid_size)), -1)
    info = tri.mT - tri
    info = torch.repeat_interleave(
      info.view(1, 1, self.grid_size, self.grid_size),
      grid.size(0),
      dim = 0
    )
    grid = torch.abs(info) * grid + (1 - torch.abs(info)) * self.diag_dist
    return torch.concat([grid, info], dim = 1)

  @torch.autograd.no_grad()
  def __getitem__(self,
      # Arguments:
      index: Any
    ) -> Tensor:
    grids = self.format_grid(self.generate_grids())
    return grids, grids

  def __len__(self) -> int:
    return self.batch_count