"""
This module contains implementations of the computational domains.

The classes provided are: Hyperrectangle, UnitSquare, UnitInterval and
ProductDomain. 

"""
import tensorflow as tf

class Hyperrectangle():
    """
    A hyperrectangle, i.e., a cartesian product of intervals.

    A hyperrectangle in d dimensions is specified through d intervals and
    the hyperrectangle then is given as the Cartesian product of the 
    intervals. The d-intervals can be specified passing a left bound and
    a right bound in any format that can be converted to two tensorflow
    tensors that are of the same length after flattening. 

    Parameters
    ----------
    l_bounds
    Lower bound of the intervals. An object convertable to a (d,) tensor. 

    r_bounds
    Upper bound of the intervals. An object convertable to a (d,) tensor. 

    Attributes
    ----------
    _l_bounds
        The tensor of shape (d,) that stores the left interval ends.
    _r_bounds
        The tensor of shape (d,) that stores the right interval ends.
    _dimension : int
        The dimension of the hyperrectangle. 

    Raises
    ------
    TypeError
        TypeError is raised if 'l_bounds' and 'r_bounds' are not of the
        same length after flattened to a one-dimensional tensor.

    ValueError
        ValueError is raised if 'l_bounds' is not smaller than 'r_bounds.
    
    """
    def __init__(self, l_bounds, r_bounds):
        self._l_bounds = tf.reshape(
            tf.convert_to_tensor(l_bounds, dtype=tf.float32),
            shape = [-1],
            )

        self._r_bounds = tf.reshape(
            tf.convert_to_tensor(r_bounds, dtype=tf.float32),
            shape = [-1],
            )

        if len(self._l_bounds) != len(self._r_bounds):
            raise TypeError("[In constructor of Hyperrectangle]: l_bounds"
                    " and r_bounds must be of same length after flattening.")

        if not tf.math.reduce_all(self._l_bounds < self._r_bounds):
            raise ValueError("[In constructor of Hyperrectangle]: The"
                    " l_bounds must be smaller than the r_bounds.")

        self._dimension = len(self._l_bounds)

        
    def measure(self):
        """
        Returns the measure of the domain.

        Returns
        -------
        float
            The measure of the domain.
        
        """
        return tf.reduce_prod(self._r_bounds - self._l_bounds)
    
    def random_integration_points(self, N):
        """
        Returns uniformly distributed, randomly drawn points from the domain.

        Returns `N` uniformly drawn points from the domain. The drawn points
        are of shape (N, dimension), where dimension is the attribute
        self._dimension of the domain.

        Parameters
        ----------
        N : int
            a positive integer specifing the number of randomly drawn points.

        Returns
        -------
        x
            A tensor of shape (N, self._dimension) of randomly drawn points.

        Raises
        ------
        ValueError
            ValueError is raised when sth else than a natural number is 
            given as the argument N.
        
        """
        if not type(N) == int or N <= 0:
            raise ValueError("A positive integer N must be specified.")
     
        x = tf.random.uniform(
            shape = (N, self._dimension), 
            minval = tf.broadcast_to(
                self._l_bounds, 
                shape=(N, self._dimension),
                ), 
            maxval = tf.broadcast_to(
                self._r_bounds, 
                shape=(N, self._dimension),
                ),
            )
        
        return x


class UnitSquare(Hyperrectangle):
    """
    The unit square [0,1]^d in d dimensions

    Parameters
    ----------
    dimension : int
        The dimension of the unit square
    
    """
    def __init__(self, dimension):
        if not type(dimension) == int or dimension < 1:
            raise ValueError(
                "[In Constructor UnitSquare]: dimension needs to be a " 
                "positive integer"
                )
        l_bound = tf.zeros(dimension)
        r_bound = tf.ones(dimension)
        super().__init__(l_bound, r_bound)

# todo: falls a bit out of the logic
class Interval(Hyperrectangle):
    def __init__(self, a, b):
        super().__init__(a, b)
        self._a = a
        self._b = b
    
    def deterministic_integration_points(self, N):
        points = 1./(N-1) * tf.range(start=0., limit=N, dtype=float)
        points = (self._b - self._a) * points + self._a
        return tf.reshape(points, shape=(N,1))


class UnitInterval(UnitSquare):
    """
    The unit interval [0,1].
    
    """
    def __init__(self):
        super().__init__(1)

    def deterministic_integration_points(self, N):
        """
        Returns N equi-distant integration points in [0,1].

        Parameters
        ----------
        N : unsigned int
            The number of integration points. A positive integer
        
        Returns
        -------
        A tensor of shape (N,1) of equi-distant integration points.

        Raises
        ------
        ValueError
            If the argument N is not a positive integer.
        
        """
        if not type(N) == int or N < 1:
            raise ValueError(
                "[In UnitInterval.deterministic_integration_points]: N " 
                "needs to be a positive integer."
                )
        
        points = 1./(N-1) * tf.range(start=0., limit=N, dtype=float)
        return tf.reshape(points, shape=(N,1))


class ProductDomain():
    def __init__(self, domain_para, domain_phys):
        self._domain_para = domain_para
        self._domain_phys = domain_phys
    
    def measure(self):
        return self._domain_para.measure() * self._domain_phys.measure()

    def rand_rand_integration_points(self, N_param, N_phys):
        """
        Returns product integration points drawing points from factors.

        Draws N_param random points from the parameter domain and N_phys
        points from the physical domain and returns N_param*N_phys 
        integration points.
        """
        param = self._domain_para.random_integration_points(N_param)
        phys  = self._domain_phys.random_integration_points(N_phys)

        return self.generate_parametric_points(param, phys)

    def rand_det_integration_points(self, N_param, N_phys):
        """
        Returns product integration points drawing deterministic-random.

        Draws N_param random points from the parameter domain and N_phys
        deterministic points from the physical domain and returns 
        N_param*N_phys integration points.
        """
        param = self._domain_para.random_integration_points(N_param)
        phys  = self._domain_phys.deterministic_integration_points(N_phys)
        
        return self.generate_parametric_points(param, phys)

    def parameter_specific_integration_points(self, param, N_phys):
        """"""
        phys  = self._domain_phys.deterministic_integration_points(N_phys)
        
        return self.generate_parametric_points(param, phys)

    @staticmethod
    def generate_parametric_points(param, phys):
        """
        Given parameters and physical points, yields points from product.
        
        Given N_param points of shape (N_param, d_param) and N_phys points
        of shape (N_phys, d_phys) the function combines these points into
        (N_param * N_phys, d_param + d_phys) points. The way these points
        are listed is as follows:

        [[p_1, x_1],
        [p_1, x_2],
        .
        .
        .
        [p_1, x_N_phys],
        [p_2, x_1],
        .
        .
        .
        [p_N_param, x_N_phys]]

        Parameters
        ----------
        param : Tensor
            Tensor of shape (N_param, d_param) where N_param is number of
            parameters and d_param is the dimension of the parameter space.
        phys : Tensor
            Tensor of shape (N_phys, d_phys) where N_phys is number of
            physical points and d_phys is the dimension of the domain.

        Returns
        -------
        A tensor of shape (N_phys*N_param, d_param+d_phys)
            
        """
        k = len(param)
        n = len(phys)
        d_param = len(param[0,:])
        d_phys = len(phys[0,:])

        param_broad = tf.reshape(
            param,
            shape=(k,1,d_param)
        )

        param_broad = tf.reshape(
            tf.broadcast_to(param_broad, shape=(k,n,d_param)),
            shape = (n*k, d_param)
        )

        phys_broad = tf.reshape(
            tf.broadcast_to(phys, shape=(k,n,d_phys)),
            shape=(n*k,d_phys),
        )

        return tf.concat((param_broad, phys_broad), axis=1)
