from abc import ABC, abstractmethod
from functools import partial

import jax.numpy as jnp
from jax import random, pmap, local_device_count

from torch.utils.data import Dataset


class BaseSampler(Dataset):
    def __init__(self, batch_size, rng_key=random.PRNGKey(1234)):
        self.batch_size = batch_size
        self.key = rng_key
        self.num_devices = local_device_count()

    def __getitem__(self, index):
        "Generate one batch of data"
        self.key, subkey = random.split(self.key)
        keys = random.split(subkey, self.num_devices)
        batch = self.data_generation(keys)
        return batch

    def data_generation(self, key):
        raise NotImplementedError("Subclasses should implement this!")


class UniformSampler(BaseSampler):
    def __init__(self, dom, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.dom = dom
        self.dim = dom.shape[0]

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        batch = random.uniform(
            key,
            shape=(self.batch_size, self.dim),
            minval=self.dom[:, 0],
            maxval=self.dom[:, 1],
        )

        return batch
    
class MeshSampler(BaseSampler):
    """
    Samples points in equidistant manner from a meshgrid defined by the given domain and resolution.
    """
    
    def __init__(self, dom, res, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.dom = dom
        self.res = res
        self.dim = dom.shape[0]

        # Create meshgrid points
        grids = [
            jnp.linspace(self.dom[d, 0], self.dom[d, 1], self.res[d])
            for d in range(self.dim)
        ]
        mesh = jnp.meshgrid(*grids, indexing='ij')
        self.coords = jnp.stack([m.flatten() for m in mesh], axis=-1)

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        #batch = self.coords
        #return batch
        grids = [
            jnp.linspace(self.dom[d, 0], self.dom[d, 1], self.res[d])
            for d in range(self.dim)
        ]
        mesh = jnp.meshgrid(*grids, indexing='ij')
        return jnp.stack([m.flatten() for m in mesh], axis=-1)


class SphereSampler(BaseSampler):

    """
    Generates uniform samples from a 3D unit sphere centered at the origin
    and return phi, theta in spherical coordinates
    """

    def __init__(self, temporal_dom, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.temporal_dom = temporal_dom

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        batch = random.normal(key, shape=(self.batch_size, 3))
        norm = jnp.linalg.norm(batch, axis=1, keepdims=True)
        batch = batch / norm

        # phi = jnp.arctan2(batch[:, 0:1], xyz_batch[:, 1:2])
        # theta = jnp.arccos(batch[:, 2:3])
        # batch = jnp.concatenate([phi, theta], axis=1)

        if self.temporal_dom is not None:
            key, _ = random.split(key)

            t_batch = random.uniform(
                key,
                shape=(self.batch_size, 1),
                minval=self.temporal_dom[0],
                maxval=self.temporal_dom[1],
            )

            batch = jnp.concatenate([t_batch, batch], axis=1)

        return batch


class SpaceSampler(BaseSampler):
    def __init__(self, coords, batch_size, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.coords = coords

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        idx = random.choice(key, self.coords.shape[0], shape=(self.batch_size,))
        batch = self.coords[idx, :]

        return batch


class TimeSpaceSampler(BaseSampler):
    def __init__(
        self, temporal_dom, spatial_coords, batch_size, rng_key=random.PRNGKey(1234)
    ):
        super().__init__(batch_size, rng_key)

        self.temporal_dom = temporal_dom
        self.spatial_coords = spatial_coords

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        key1, key2 = random.split(key)

        temporal_batch = random.uniform(
            key1,
            shape=(self.batch_size, 1),
            minval=self.temporal_dom[0],
            maxval=self.temporal_dom[1],
        )

        spatial_idx = random.choice(
            key2, self.spatial_coords.shape[0], shape=(self.batch_size,)
        )
        spatial_batch = self.spatial_coords[spatial_idx, :]
        batch = jnp.concatenate([temporal_batch, spatial_batch], axis=1)

        return batch
