from typing import Tuple

import torch
from torch.utils.data import Dataset

T = torch.Tensor


def gapped_sine(x: torch.Tensor, noise: bool = True) -> torch.Tensor:
    eps = torch.zeros_like(x)
    eps2 = torch.zeros_like(x)
    eps3 = torch.zeros_like(x)
    if noise:
        eps = torch.randn_like(x) * 0.03
        eps2 = torch.randn_like(x) * 0.03
        eps3 = torch.randn_like(x) * 0.03

    return x + eps + torch.sin(4 * (x + eps2)) + torch.sin(13 * (x + eps3))


class GappedSine(Dataset):
    def __init__(self, device: torch.device, test: bool = False) -> None:
        self.name = "gapped-sine"
        if not test:
            # sort so I can set the lowest 25 as the same cluster, etc...
            x_lower, _ = torch.sort(torch.rand(50 * 2) * 0.4)
            x_upper, _ = torch.sort(torch.rand(100) * (1.0 - 0.8) + 0.8)

            self.x = torch.cat((x_lower, x_upper)).unsqueeze(1).to(device)
            self.y = gapped_sine(self.x).squeeze(1).to(device)
        else:
            self.x = torch.rand(400).unsqueeze(1).to(device) * 2 - 0.5
            self.y = gapped_sine(self.x, noise=False).squeeze(1).to(device)

    def normalize(self, params: Tuple[T, ...] = None) -> Tuple[T, T, T, T]:
        if params is None:
            mu, sigma, y_mu, y_sigma = self.x.mean(dim=0), self.y.std(dim=0), self.y.mean(dim=0), self.y.std(dim=0)
        else:
            if len(params) != 4:
                raise ValueError(f"params must be of length 4. got: {len(params)=}")
            mu, sigma, y_mu, y_sigma = params

        self.mu = mu
        self.sigma = sigma
        self.y_mu = y_mu
        self.y_sigma = y_sigma
        self.x = (self.x - mu) / sigma
        self.y = (self.y - y_mu) / y_sigma
        return self.mu, self.sigma, self.y_mu, self.y_sigma

    def __len__(self) -> int:
        return self.x.size(0)

    def __getitem__(self, i: int) -> Tuple[torch.Tensor, ...]:
        return self.x[i], self.y[i]

    def get_feature_ranges(self) -> torch.Tensor:
        return self.x.max(dim=0)[0] - self.x.min(dim=0)[0]

    def get_y_moments(self) -> Tuple[T, T]:
        return self.y_mu.cpu(), self.y_sigma.cpu()

    def get_x_moments(self) -> Tuple[T, T]:
        return self.mu.cpu(), self.sigma.cpu()

    def __str__(self) -> str:
        return self.name
