
"""Dataset for the baseline task."""
from math import pi

import torch

from .utils import StackDataset


def sin(t, f):
    r"""Compute the next function with a given frequency.

    .. math::
        \sin(\pi f x) \cos(\pi f y)

    Args:
        t (torch.Tensor): Input tensor.
        f (float): Frequency.

    Returns:
        torch.Tensor: Output of the function.
    """
    return (torch.sin(t[..., 0] * pi * f) * torch.cos(t[..., 1] * pi * f)).unsqueeze(-1)


def datasets(device, nx, type, f=None):
    """Create the datasets for the baseline task.

    Args:
        device (torch.device): Device to use.
        nx (int): Number of points in the grid.
        type (str): Type of the dataset. Can be "sin" or "random".
        f (float): Frequency of the function.

    Returns:
        tuple(torch.utils.data.Dataset): The train and validation
            datasets.
    """
    gen = torch.Generator()
    gen.manual_seed(0)
    train_cx = torch.randn(10000, nx, 2, generator=gen).to(device)
    val_cx = torch.randn(1000, nx, 2, generator=gen).to(device)

    if type == "sin":
        train_cy = sin(train_cx, f)
        val_cy = sin(val_cx, f)
    else:
        train_cy = torch.randn(10000, nx, 1, generator=gen).to(device)
        val_cy = torch.randn(1000, nx, 1, generator=gen).to(device)

    train_tx = train_cx.clone()
    train_ty = train_cy.clone()
    val_tx = val_cx.clone()
    val_ty = val_cy.clone()

    train_inputs = StackDataset(train_cx, train_cy, train_tx)
    train_dataset = StackDataset(train_inputs, train_ty)
    val_inputs = StackDataset(val_cx, val_cy, val_tx)
    val_dataset = StackDataset(val_inputs, val_ty)

    return train_dataset, val_dataset
