"""
This module contains the integration methods.

"""

import tensorflow as tf
import functools

class RandomIntegrator():
    """
    Random integration using Monte Carlo with uniform distribution.
    
    Given a domain and a number N of integration points this class
    performs Monte Carlo integration over the domain with uniformly
    sampled integration points.

    Parameters
    ----------
    domain 
        A domain instance. Needs to provide the methods 
        domain.random_integration_points(<int>) and
        domain.measure().
    N : int
        Number of integration points.

    Returns
    -------
    the result of the integration, a tensor of shape ().

    Raises
    ------
    ValueError
        If the argument N is not a positive integer.
    """
    def __init__(self, domain, N):
        if not type(N) == int:
            raise TypeError(
                f'[Constructor of RandomIntegrator:] N must be positive'
                f' integer, not of type {type(N)}'
                )
        if N < 1:
            raise ValueError(
                f'[Constructor of RandomIntegrator:] N must be positive.'
            )
        self._domain = domain
        self._N = N
        self._measure = domain.measure()

    def __call__(self, R):
        """
        Here the integration happens.

        Parameters
        ----------
        R : function
            Must map (K,dim) to (K,1) tensors, where K and dim must 
            be any positive integers.
        Returns
        -------
        The result of the integration. A tensor of shape ().
        
        """
        x = self._domain.random_integration_points(self._N)
        return self._measure * tf.math.reduce_mean(R(x))
        

class RandomDeterministicIntegrator():
    """
    Random integration in parameters, deterministic in physical domain.
    
    Given a ProductDomain, an int N_param and an int N_phys for the
    number of parameters and physical points respectively the 
    __call__ method of the integrator uses N_param * N_phys points
    for the integration. 

    Parameters
    ----------
    prod_domain 
        A ProductDomain instance. Needs to provide the methods 
        domain.rand_det_integration_points(<int>,<int>) and
        domain.measure().
    N_param : int
        Number of parameter integration points.
    N_phys : int
        Number of physical integration points.

    Returns
    -------
    the result of the integration, a tensor of shape ().

    Raises
    ------
    ValueError
        If the argument either N_param or n_phys are not 
        positive integers.
    """
    def __init__(self, prod_domain, N_param, N_phys):
        if not type(N_param) == int or N_param < 1:
            raise ValueError(
                f'[Constructor RandomDeterministicIntegrator:] N_param '
                f'needs to be a positive integer.'
            )

        if not type(N_phys) == int or N_phys < 1:
            raise ValueError(
                f'[Constructor RandomDeterministicIntegrator:] N_phys '
                f'needs to be a positive integer'
            )

        self._evaluation_points = prod_domain.rand_det_integration_points(
            N_param=N_param, 
            N_phys=N_phys
            )

        self._measure = prod_domain.measure()

    def __call__(self, R):
        """
        Here the integration happens.

        Parameters
        ----------
        R : function
            Maps (K,dim) to (K,1) tensors, where K and dim must 
            be any positive integers.
        Returns
        -------
        The result of the integration. A tensor of shape ().
        
        """
        x = self._evaluation_points
        return self._measure * tf.math.reduce_mean(R(x))


class RandomRandomIntegrator():
    """
    Random integration in parameters, random in physical domain.
    
    Works just as the RandomDeterministicIntegrator, just with
    randomly drawn points in both physical and parameter space. 

    Parameters
    ----------
    prod_domain 
        A ProductDomain instance. Needs to provide the methods 
        domain.rand_det_integration_points(<int>,<int>) and
        domain.measure().
    N_param : int
        Number of parameter integration points.
    N_phys : int
        Number of physical integration points.

    Returns
    -------
    the result of the integration, a tensor of shape ().

    Raises
    ------
    ValueError
        If the argument either N_param or n_phys are not 
        positive integers.
    """
    def __init__(self, prod_domain, N_param, N_phys):
        if not type(N_param) == int or N_param < 1:
            raise ValueError(
                f'[Constructor RandomRandomIntegrator:] N_param '
                f'needs to be a positive integer'
            )

        if not type(N_phys) == int or N_phys < 1:
            raise ValueError(
                f'[Constructor RandomRandomIntegrator:] N_phys '
                f'needs to be a positive integer'
            )

        self._evaluation_points = prod_domain.rand_rand_integration_points(
            N_param=N_param, 
            N_phys=N_phys
            )

        self._measure = prod_domain.measure()

    def __call__(self, R):
        x = self._evaluation_points
        return self._measure * tf.math.reduce_mean(R(x))  


def accuracy(precision = 1e-3, max_iter = 5000):
    """
    Repeats a function (with stochasticity) in output for more precision.

    This is a decorator that is meant to be used with functions that
    involve computing stochastic quantities, e.g., expectations. It
    evaluates the given function up to max_iter times and takes the
    mean of these function evaluations. The computation stops if as
    a stopping criterion the increment of the cumulative means is 
    below the specified precision.

    Parameters
    ----------
    precision : float
        precision to determine when to stop accumulating function
        evaluations.
    max_iter : int
        upper limit for the function evaluations.
    
    """
    def decorator_accuracy(integrator):
        @functools.wraps(integrator)
        def wrapper_accuracy(*args, **kwargs):
            values      = [] # store values of single random_integrator() call
            means       = [] # store means of the above calls
            differences = [] # store differences of means

            values.append(integrator(*args, **kwargs))
            means.append(values[0])
            mean_differences = 100. # to not exit while loop early

            while(mean_differences > precision and len(values) < max_iter):
                values.append(integrator(*args, **kwargs))
                means.append(tf.reduce_mean(tf.convert_to_tensor(values)))
                differences.append(tf.abs(means[-1] - means[-2]))
                mean_differences = tf.reduce_mean(tf.convert_to_tensor(differences))
                
            return means[-1]
        return wrapper_accuracy
    return decorator_accuracy