from typing import Iterable, Optional, Sequence

import dataclasses
import itertools

import dm_env

import numpy as np

import jax
import jax.numpy as jnp

import haiku as hk


def _as_iterable(bound, max_repeat=None):
    if not isinstance(bound, Iterable):
        bound = itertools.repeat(bound, max_repeat)
    return bound


def uniform_centers(minimums, maximums, centers_per_dim: int):
    centers = np.meshgrid(*[
        np.linspace(lb, ub, centers_per_dim, endpoint=True)
        for lb, ub in zip(minimums, maximums)
    ])
    centers = np.stack([x.flatten() for x in centers], axis=0)

    return centers


def _scale_values(a, center):
    a = np.subtract(a, center)
    return np.sign(a) * np.sqrt(np.abs(a))


def clustered_centers(minimums, maximums, cluster_center, centers_per_dim):
    minimums = _scale_values(minimums, cluster_center)
    maximums = _scale_values(maximums, cluster_center)

    centers = uniform_centers(minimums, maximums, centers_per_dim)
    centers = centers.T
    centers = centers * np.linalg.norm(centers, axis=-1)[..., None]
    centers = centers + cluster_center

    return centers.T


def uniform_scales(minimums, maximums, scale: float):
    span = np.subtract(maximums, minimums)
    return span * scale


@dataclasses.dataclass
class RBFEncoder:
    """ A feature encoding using radial basis functions.
    """
    centers: np.ndarray
    scales: np.ndarray
    normalized: bool

    def apply(self, inputs):
        diff = (inputs[..., None] - self.centers) / self.scales[..., None]
        neg_dist = -jnp.sum(diff**2, axis=-2)
        if self.normalized:
            return jax.nn.softmax(neg_dist)
        else:
            return jnp.exp(neg_dist)


def linear_rbf(inputs,
               env: dm_env.Environment,
               centers_per_dim: int,
               scale: float,
               normalized: bool,
               with_bias: bool,
               fixed_centers: bool = True,
               fixed_scales: bool = True,
               cluster_center: Optional[Sequence[float]] = None):
    obs_spec = env.observation_spec()
    assert len(obs_spec.shape) == 1, "Only rank 1 observations are supported."

    minimums = _as_iterable(obs_spec.minimum, obs_spec.shape[0])
    maximums = _as_iterable(obs_spec.maximum, obs_spec.shape[0])

    if cluster_center:
        centers = clustered_centers(minimums, maximums, cluster_center, centers_per_dim)
    else:
        centers = uniform_centers(minimums, maximums, centers_per_dim)

    if not fixed_centers:
        centers = hk.get_parameter(
            "centers", centers.shape, centers.dtype, init=lambda s, t: centers)

    scales = uniform_scales(minimums, maximums, scale)
    if not fixed_scales:
        scales = hk.get_parameter(
            "scales", scales.shape, scales.dtype, init=lambda s, t: scales)

    encoder = RBFEncoder(
        centers=centers,
        scales=scales,
        normalized=normalized,
    )
    out = encoder.apply(inputs)
    return hk.Linear(output_size=1, with_bias=with_bias)(out)
