# -*- coding: utf-8 -*-
"""
Created on Fri Feb 19 14:11:44 2021

@author: grego
"""

import numpy as np
import math
from math import sin 
from math import cos


PI = math.pi

 
    
def recursive_integrate(poly_deg, trig_deg):
    '''
    Returns the value of the integral on [-1,1] of a monomial x^d times 
    (sin(n*\pi*x) + cos(n*\pi*x)).

    Parameters
    ----------
    poly_deg : int
        degree of polynomial (d).
    trig_deg : int
        Multiplier in trigonometric function (n).

    Returns
    -------
    int
        Number equal to the above integral. Note that the integral of x^d
        with either the sine or cosine will be zero depending on if x^d 
        is an even or odd function.

    '''
    
    if poly_deg == 0:
        return 0
    
    
    # If d is even, then return \int_{-1}^1 x^d cos(n*\pi*x) dx 
    if poly_deg % 2 == 0:
        
        return ((2 * sin(trig_deg * PI))/(trig_deg*PI) 
                - poly_deg/(trig_deg*PI) * recursive_integrate(poly_deg - 1, trig_deg))
    
    # If d is odd, then return \int_{-1}^1 x^d sin(n*\pi*x) dx
    if poly_deg % 2 == 1:
        return ((-2 * cos(trig_deg * PI))/(trig_deg*PI) 
                + poly_deg/(trig_deg*PI) * recursive_integrate(poly_deg - 1, trig_deg))
    
    

class PolynomialDist:
    '''test'''
    
    pdf = None
    cdf = None
        
    # Must be even degree
    def __init__(self, degree, seed = None):
        '''
        This class generates a random polynomial g(x) that is a valid 
        probability distribution on [-1,1].  The probability density function
        is stored as the field "pdf" and the cumulitive density function is "cdf".

        Parameters
        ----------
        degree : int
            The maximum degree of the polynomial.
        seed : TYPE, int
            Sets the random seed of numpy. The default is None.

        Returns
        -------
        None.

        '''
        
        if seed:
            np.random.seed(seed)

        coefficients = np.random.uniform(0, 1, degree//2)
        offsets = np.random.uniform(0, 1, degree//2)
            
        poly = np.poly1d([0])
        for i in range(degree//2):
            
            temp_poly = np.poly1d([offsets[i]], r=True)
            poly = poly + coefficients[i] * temp_poly**(2*(i+1))
        
        int_poly = np.polyint(poly)
                
        int_poly = int_poly / (int_poly(1) - int_poly(-1))
        self.cdf = int_poly - int_poly(-1)
        self.pdf = int_poly.deriv()
        
                
    def compute_phi_int(self, length):
        '''
        Computes the integral of PolynomialDist.pdf with each trigonometric 
        function in a vector of the following form:
            
        [cos(1*\pi*s) , sin(1*\pi*s),...,cos(length*\pi*s) , sin(length*\pi*s)]

        Note that this vector does not include the constant function cos(0),
        since the integral with pdf will always equal 1.
    
        Parameters
        ----------
        length : int
            The highest trigonmetric degree in the vector to be integrated.

        Returns
        -------
        numpy.array
            Numpy array of size '2*length' representing the integral of \phi_k
            with the pdf.

        '''
        
        poly_arr = np.flip(np.array(self.pdf))
        
        trig_int_arr = np.zeros((2*length, poly_arr.shape[0] - 1))
        for i in range(2*length):
            # Don't include constant polynomial
            for j in range(poly_arr.shape[0] - 1):
                
                # Check if cos and odd x^p or sin and even x^p
                if i % 2 != j % 2:
                    trig_int_arr[i,j] = 0       
                    continue
                    
                trig_int_arr[i,j] = recursive_integrate(j+1, (i//2 + 1))
            
        return trig_int_arr @ poly_arr[1:]
    
    def sample(self, precision = 8):
        '''
        Samples from this polynomial probability distribution by inverse transform
        sampling.  Precision controls the number of iterations of binary search
        to approximately solve the inverse transform.
        '''
        
        rand_u = np.random.uniform(0,1)
        left_lim = -1
        right_lim = 1
        for i in range(8):
            
            avg = (left_lim + right_lim) / 2
            if self.cdf(avg) < rand_u:
                left_lim = avg
            else:
                right_lim = avg
                
        return (left_lim + right_lim) / 2
            
        
        
    
class MixtureModel:
    
    w1_func = None
    w2_func = None
    poly_1 = None    
    poly_2 = None
    
    def __init__(self, w_func = lambda s: s**2, seed = None):
        '''
        This model constructs a probability distribution as a weighted 
        mixture of two polynomial distributions given by PolynomialDist
        objects. Described more in Appendix C of the paper.

        Parameters
        ----------
        w_func : function(float) -> (0,1), optional
            Do not change this from the default value, as it is only kept for
            compatability with the written code.  Originally intended to allow 
            specification of aribtrary mixtures, but doing so will break the
            computation of the exact Z-matrix.  Not tested.
            
        seed : int, optional
            Sets the seed of the Numpy random generator before generating
            the two component PolynomialDist objects. The default is None.

        Returns
        -------
        None.

        '''
    
        if seed:
            np.random.seed(seed)     
    
        self.w1_func = w_func
        self.w2_func = lambda s: 1 - w_func(s)
        
        self.poly_1 = PolynomialDist(4)
        self.poly_2 = PolynomialDist(4)
        
    
    def compute_Zmat(self, size):
        '''
        Returns the exact trunctated Z-matrix.
        
        Computes the coefficients of the truncated Z-matrix by integrating
        the probability density functions of each PolynomialDist object
        with each entry of the truncated basis vector \phi_k on [-1,1] and 
        returns the outer product of the integrated vectors.

        Parameters
        ----------
        size : int
            Maximum degree of the trigonometric functions in the series.

        Returns
        -------
        numpy.array
            Exact Z-matrix of dimension (2*size+1)x(2*size+1).

        '''
        
        Zmat = np.zeros((2*size + 1, 2*size + 1))
        Zmat[0, 0] = 1/2
        
        # Hardcode for now
        w1_poly = np.poly1d([1, 0, 0])
        w2_poly = 1 - w1_poly
        
        w1_phi = np.zeros((2*size + 1, 1))
        w2_phi = np.zeros((2*size + 1, 1))
        
        w1_phi[0] = 2/3
        
        for j in range(1, 2*size + 1):
            if j % 2 == 1:
                #sine of s^2 is 0
                w1_phi[j] = 0
                continue
            w1_phi[j] = recursive_integrate(2, j / 2)
        
        
        w2_phi = -1*w1_phi # 1- s^2
        w2_phi[0] = 4/3
                
        Zmat[1:, :] += np.outer(self.poly_1.compute_phi_int(size), w1_phi) \
                    + np.outer(self.poly_2.compute_phi_int(size), w2_phi)
        
        return Zmat
    
    def sample(self, start_s, precision = 8):
        '''
        Samples from the represented conditional probability distribution conditioned
        on s_0 = start_s by inverse transform sampling. The 'precision' parameter 
        is passed on to the sample function of the PolynomialDist object.
        '''
        
        poly_dist = np.random.binomial(1, 1 - start_s**2) # Poly_dist = 1 with probab 1 - start_s**2
        
        if poly_dist == 0:
            return self.poly_1.sample(precision)
        
        else:
            return self.poly_2.sample(precision)
            
                
    
   
    
    
    
    
    
    
    
    
    
    
    
    