# -*- coding: utf-8 -*-
"""
Created on Sat Feb 20 16:48:12 2021

@author: grego
"""

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from math import sin, cos, log
import multiprocessing

from polynomial_dist import PolynomialDist, MixtureModel

PI = 3.14159265359


def phi_vec(s, length):
    '''
    Use this function to create the truncation of '\phi', a vector of the form:
        [cos(0), sin(1*\pi*s), cos(1*\pi*s), sin(2*\pi*s), cos(2*\pi*s),...]

    Parameters
    ----------
    s : float
        Point that the vector '\phi' should be evaluated at.
    length : int
        Highest degree of the trigonometric polynomials in the truncated vector.
        The size of the returned vector will be '2*length + 1'.

    Returns
    -------
    phi : np.array
        Truncated vector \phi evaluated at the point 's'.

    '''
    
    phi = np.zeros(2*length + 1)
    
    for i in range(2*length + 1):
        
        if i % 2 == 1:
            phi[i] = sin((i//2 + 1) * PI * s)
            
        else:
            phi[i] = cos(i//2 * PI * s)
            
    return phi



def test_sampling(size, N):
    '''
    This function will generate a random MixtureModel and compute its exact 
    Z-matrix of dimensions (2*size+1)x(2*size+1).  It will then estimate the 
    Z-matrix using Algorithm 1 in the paper with 'N' samples and evaluate the error.

    Parameters
    ----------
    size : int
        Highest degree of the trigonometric polynomials in the truncated series.
        The computed Z-matrix will have dimension (2*size+1)x(2*size+1).
    N : int
        Number of samples when estimating the Z-matrix.

    Returns
    -------
    float
        Returns the infinity-norm of the difference of the exact truncated Z-matrix
        and the estimated truncated Z-matrix.

    '''


    model = MixtureModel(lambda s: s**2)
    Zhmat = np.zeros((2*size + 1, 2*size + 1))
    
    for j in range(N):
        
        start_s = np.random.uniform(-1, 1)
        next_s = model.sample(start_s)
        
        Zhmat += (2/N) * np.outer(phi_vec(next_s, size), phi_vec(start_s, size))
        
        
    Zmat = model.compute_Zmat(size)
    Zhmat[0,0] = 1/2
    
    return np.absolute(Zmat - Zhmat).sum(axis=1).max()


def sample_Z(model, size, N):
    '''
    Implements Algorithm 1 of the paper to estimate the truncated Z-matrix
    of a mixture model.

    Parameters
    ----------
    model : MixtureModel
        Represents the conditional probability distribution to be estimated.
    size : int
        Highest degree of the trigonometric polynomials in the truncated series.
        The computed Z-matrix will have dimension (2*size+1)x(2*size+1).
    N : int
        Number of samples used to estimate the matrix.

    Returns
    -------
    Zhmat : np.array
        Estimated matrix (referred to as \hat{Z} in the paper).

    '''

    Zhmat = np.zeros((2*size + 1, 2*size + 1))
    
    for j in range(N):
        
        start_s = np.random.uniform(-1, 1)
        next_s = model.sample(start_s)
        
        Zhmat += (2/N) * np.outer(phi_vec(next_s, size), phi_vec(start_s, size))
        
        
    Zhmat[0,0] = 1/2
    
    return Zhmat

def sample_Zseed(model_seed, size, N):
    
    '''
    This is the same as the 'sample_Z' function, except it accepts an 'int'
    'model_seed' to generate a MixtureModel instead of accepting
    a MixtureModel' directly.
    '''
    
    model = MixtureModel(lambda s: s**2, model_seed)
    Zhmat = np.zeros((2*size + 1, 2*size + 1))
    
    for j in range(N):
        
        start_s = np.random.uniform(-1, 1)
        next_s = model.sample(start_s)
        
        Zhmat += (2/N) * np.outer(phi_vec(next_s, size), phi_vec(start_s, size))
        
        
    Zhmat[0,0] = 1/2
    
    return Zhmat


def parallel_sample_Z(model_seed, size, N, processors):
    '''
    This is a parallelized version of 'sample_Zseed'.  Implements Algorithm 1 
    of the paper to estimate the truncated Z-matrix of a mixture model.

    Parameters
    ----------
    model_seed : int
        Seed used to genearte a MixtureModel which represents the 
        conditional probability distribution to be estimated.
    size : int
        Highest degree of the trigonometric polynomials in the truncated series.
        The computed Z-matrix will have dimension (2*size+1)x(2*size+1).
    N : int
        Number of samples used to estimate the matrix.
    processors : int
        Number of processors to use for the estimation task.

    Returns
    -------
    Zhmat : np.array
        Estimated matrix (referred to as \hat{Z} in the paper).

    '''
    
    N_frac = N // processors
    arg_list = processors * [(model_seed, size, N_frac)]
    
    with multiprocessing.Pool(processors) as pool:
        
        map_result = pool.starmap(sample_Zseed, arg_list)
        return (1/processors) * sum(map_result)
        




    














    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    