import numpy as np
import torch

from torch.utils.data import Dataset

class Dataset(Dataset):
    """Dataset for handling input, output, and optional parameters."""

    def __init__(
        self,
        input: np.ndarray | torch.Tensor,
        output: np.ndarray | torch.Tensor,
        params: np.ndarray | torch.Tensor | None = None,
    ) -> None:
        """Initialize Dataset

        Args:
            input: Data input
            output: Data output
            params: Additional parameters. Defaults to None.
        """
        self.x: torch.Tensor = self._to_tensor(input)
        self.y: torch.Tensor = self._to_tensor(output)
        self.r: torch.Tensor | None = self._to_tensor(params) if params is not None else None

    def _to_tensor(self, data: np.ndarray | torch.Tensor) -> torch.Tensor:
        """Convert data to a Tensor if it is not already one."""
        return data if isinstance(data, torch.Tensor) else torch.tensor(data).float()

    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(self.x)

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Retrieve a sample from the dataset at the given index."""
        if self.r is None:
            return self.x[index], self.y[index]
        else:
            return self.x[index], self.y[index], self.r[index]