################################################################################
# spectral/modules/chesspuzzles/data.py
#
# 
# 
# 
# 2024
#
# Dataset handler for the chess puzzles dataset, slightly modified from
# github.com/aks2203/deep-thinking/blob/main/deepthinking/utils/chess_data.py

import torch

from easy_to_hard_data import ChessPuzzleDataset
from typing            import Optional

class FlippedChessPuzzleDataset(ChessPuzzleDataset):
  """
  A dataset where the player to move next is at the bottom of the board.
  """

  def __init__(self,
      # Arguments:
      root:      str,
      # Keyword Arguments:
      train:     bool          = True,
      idx_start: Optional[int] = None,
      idx_end:   Optional[int] = None,
      who_moves: bool          = True,
      download:  bool          = True
    ):
    """
    Initializes ``FlippedChessPuzzleDataset``.

    Args:
      root (str):
        The root directory of the dataset.
      train (bool, optional):
        Whether to use the train portion of the dataset.
        Defaults to ``True``.
      idx_start (int, optional):
        The start index.
        Defaults to ``None``.
      idx_end (int, optional):
        The end index.
        Defaults to ``None``.
      who_moves (bool, optional):
        Whether to get who moves.
        Defaults to ``True``.
      download (bool, optional):
        Whether to download the dataset.
        Defaults to ``True``.
    """
    super(FlippedChessPuzzleDataset, self).__init__(
      root, train, idx_start, idx_end, who_moves, download
    )
    rotate_idx               = (self.who_moves == 1).squeeze()
    rotated_puzzles          = torch.flip(self.puzzles[rotate_idx], [2])
    self.puzzles[rotate_idx] = rotated_puzzles
    rotated_targets          = torch.flip(self.targets[rotate_idx], [1])
    self.targets[rotate_idx] = rotated_targets