from typing import Optional, Union, Tuple

import numpy as np
import torch

from torch import Tensor
from torch.utils.data import Dataset

class Dataset(Dataset):
    """Dataset for handling input, output, and optional parameters."""

    def __init__(
        self,
        input: Union[np.ndarray, Tensor],
        output: Union[np.ndarray, Tensor],
        params: Optional[Union[np.ndarray, Tensor]] = None,
    ) -> None:
        """Initialize Dataset

        Args:
            input (Union[np.ndarray, Tensor]): Data input
            output (Union[np.ndarray, Tensor]): Data output
            params (Optional[Union[np.ndarray, Tensor]]): Additional parameters. Defaults to None.
        """
        self.x: Tensor = self._to_tensor(input)
        self.y: Tensor = self._to_tensor(output)
        self.r: Optional[Tensor] = self._to_tensor(params) if params is not None else None

    def _to_tensor(self, data: Union[np.ndarray, Tensor]) -> Tensor:
        """Convert data to a Tensor if it is not already one."""
        return data if isinstance(data, Tensor) else torch.tensor(data).float()

    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(self.x)

    def __getitem__(self, index: int) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        """Retrieve a sample from the dataset at the given index."""
        if self.r is not None:
            return self.x[index], self.y[index], self.r[index]
        return self.x[index], self.y[index]