"""Datasets and loaders for time-series experiments."""

from __future__ import annotations

from typing import Sequence, Tuple

import torch
from torch.utils.data import DataLoader, Dataset


class MixtureDecayDataset(Dataset):
    """
    Synthetic dataset where the target is a mixture of exponentials with noise.
    """

    def __init__(
        self,
        T: int,
        num_samples: int,
        lambdas: Sequence[float],
        coeffs: Sequence[float],
        noise_std: float = 0.01,
        seed: int = 0,
    ):
        if len(lambdas) != len(coeffs):
            raise ValueError("lambdas and coeffs must have the same length.")
        if T <= 0 or num_samples <= 0:
            raise ValueError("T and num_samples must be positive.")

        self.T = T
        self.num_samples = num_samples
        self._generator = torch.Generator().manual_seed(seed)

        lambda_tensor = torch.tensor(lambdas, dtype=torch.float64)
        coeff_tensor = torch.tensor(coeffs, dtype=torch.float64)
        t = torch.arange(T, dtype=torch.float64)

        decay_terms = torch.pow(lambda_tensor.unsqueeze(1), t.unsqueeze(0))
        base_trace = (coeff_tensor.unsqueeze(1) * decay_terms).sum(dim=0)  # (T,)

        noise = torch.randn(num_samples, T, generator=self._generator, dtype=torch.float64) * noise_std
        targets = base_trace.unsqueeze(0) + noise

        self.inputs = torch.ones(num_samples, T, 1, dtype=torch.float64)
        self.targets = targets.unsqueeze(-1)  # (num_samples, T, 1)

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.inputs[idx], self.targets[idx]


def make_decay_dataloaders(
    T: int = 60,
    lambdas=(0.9, 0.5),
    coeffs=(1.0, 0.7),
    train_size: int = 2048,
    val_size: int = 512,
    batch_size: int = 32,
    noise_std: float = 0.01,
    seed: int = 0,
) -> tuple[DataLoader, DataLoader]:
    """
    Build train/val dataloaders for the synthetic decay mixture task.
    """
    train_ds = MixtureDecayDataset(
        T=T,
        num_samples=train_size,
        lambdas=lambdas,
        coeffs=coeffs,
        noise_std=noise_std,
        seed=seed,
    )
    val_ds = MixtureDecayDataset(
        T=T,
        num_samples=val_size,
        lambdas=lambdas,
        coeffs=coeffs,
        noise_std=noise_std,
        seed=seed + 1,
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader
