from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional, Sequence, Union

import torch
from torch.utils.data import Dataset


TensorLike = Union[torch.Tensor]


class BaseSampler(Dataset, ABC):
    def __init__(self, batch_size: int, *, seed: int = 1234, device: Optional[torch.device] = None):
        self.batch_size = int(batch_size)
        self.device = device if device is not None else torch.device("cpu")
        self._gen = torch.Generator(device="cpu")
        self._gen.manual_seed(int(seed))

    def __len__(self) -> int:
        return 10**12

    def __getitem__(self, index):
        return self.data_generation()

    @abstractmethod
    def data_generation(self) -> torch.Tensor:
        raise NotImplementedError


class UniformSampler(BaseSampler):
    def __init__(self, dom: torch.Tensor, batch_size: int, *, seed: int = 1234, device: Optional[torch.device] = None):
        super().__init__(batch_size, seed=seed, device=device)
        dom = dom.to(self.device)
        if dom.ndim != 2 or dom.shape[1] != 2:
            raise ValueError("dom must be shape [dim, 2]")
        self.dom = dom
        self.dim = int(dom.shape[0])

    def data_generation(self) -> torch.Tensor:
        lo = self.dom[:, 0]
        hi = self.dom[:, 1]
        u = torch.rand((self.batch_size, self.dim), generator=self._gen, device=self.device)
        return lo + (hi - lo) * u


class MeshSampler(BaseSampler):
    def __init__(
        self,
        dom: torch.Tensor,
        res: Sequence[int],
        batch_size: int,
        *,
        seed: int = 1234,
        device: Optional[torch.device] = None,
    ):
        super().__init__(batch_size, seed=seed, device=device)
        dom = dom.to(self.device)
        if dom.ndim != 2 or dom.shape[1] != 2:
            raise ValueError("dom must be shape [dim, 2]")
        self.dom = dom
        self.dim = int(dom.shape[0])
        self.res = [int(r) for r in res]
        if len(self.res) != self.dim:
            raise ValueError("len(res) must equal dom.shape[0]")

        grids = [torch.linspace(self.dom[d, 0], self.dom[d, 1], self.res[d], device=self.device) for d in range(self.dim)]
        mesh = torch.meshgrid(*grids, indexing="ij")
        self.coords = torch.stack([m.reshape(-1) for m in mesh], dim=-1)

    def data_generation(self) -> torch.Tensor:
        return self.coords


class SphereSampler(BaseSampler):
    def __init__(
        self,
        temporal_dom: Optional[torch.Tensor],
        batch_size: int,
        *,
        seed: int = 1234,
        device: Optional[torch.device] = None,
    ):
        super().__init__(batch_size, seed=seed, device=device)
        self.temporal_dom = temporal_dom.to(self.device) if temporal_dom is not None else None

    def data_generation(self) -> torch.Tensor:
        x = torch.randn((self.batch_size, 3), generator=self._gen, device=self.device)
        x = x / (torch.linalg.norm(x, dim=1, keepdim=True) + 1e-12)

        if self.temporal_dom is None:
            return x

        lo, hi = self.temporal_dom[0], self.temporal_dom[1]
        t = lo + (hi - lo) * torch.rand((self.batch_size, 1), generator=self._gen, device=self.device)
        return torch.cat([t, x], dim=1)


class SpaceSampler(BaseSampler):
    def __init__(self, coords: torch.Tensor, batch_size: int, *, seed: int = 1234, device: Optional[torch.device] = None):
        super().__init__(batch_size, seed=seed, device=device)
        self.coords = coords.to(self.device)

    def data_generation(self) -> torch.Tensor:
        n = self.coords.shape[0]
        idx = torch.randint(0, n, (self.batch_size,), generator=self._gen, device=self.device)
        return self.coords[idx, :]


class TimeSpaceSampler(BaseSampler):
    def __init__(
        self,
        temporal_dom: torch.Tensor,
        spatial_coords: torch.Tensor,
        batch_size: int,
        *,
        seed: int = 1234,
        device: Optional[torch.device] = None,
    ):
        super().__init__(batch_size, seed=seed, device=device)
        self.temporal_dom = temporal_dom.to(self.device)
        self.spatial_coords = spatial_coords.to(self.device)

    def data_generation(self) -> torch.Tensor:
        lo, hi = self.temporal_dom[0], self.temporal_dom[1]
        t = lo + (hi - lo) * torch.rand((self.batch_size, 1), generator=self._gen, device=self.device)

        n = self.spatial_coords.shape[0]
        idx = torch.randint(0, n, (self.batch_size,), generator=self._gen, device=self.device)
        x = self.spatial_coords[idx, :]

        return torch.cat([t, x], dim=1)
