# -*- coding: utf-8 -*-
"""
Created on Fri Oct 15 21:39:04 2021

@author: grego
"""
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math

from polynomial_dist import MixtureModel
from sampling_funcs import test_sampling




def n_func(eps, k, delta):
    '''
    Function which returns the theoretical number of samples needed for
    'eps' (epsilon) accuracy with probability at least (1-delta) for a truncation
    of size 'k' according to Theorem 3 from the paper.    
    '''
    
    return (8*pow(k, 2) / pow(eps, 2)) * math.log(2*pow(k, 2) / delta)


if __name__ == '__main__':
    '''
    Set the 'SIZE' constant below to be the maximum degree of the trigonometic
    functions used in the basis (k=2*SIZE + 1), i.e., if 'SIZE = 2', then
    the last two basis functions will be 'cos(2*\pi*s), sin(2*\pi*s)'.
    
    Set the 'REPS' constant below to be the number of times the sampling
    experiment should be repeated for each value of 'N' in 'N_count'.
    
    Set the 'N_count' list below to be the values of 'N' that should be tested.
    The value of 'N' is the number of samples used to compute the matrix as
    described in Algorithm 1 of the paper.
    
    '''
    
    ##### SET THESE PARAMETERS
    SIZE = 4
    REPS = 40
    N_count = [100, 1000, 2000, 4000, 8000, 16000, 24000, 32000, 50000]
    
    #####
    
    model = MixtureModel(lambda s: s**2)
    
    results_list = []
    
    for N in N_count:
        print('N:', N)
        for i in range(REPS):
            
            results_list.append([N, test_sampling(SIZE, N)])
    
    result_df = pd.DataFrame(results_list)
    result_df.columns = ['N', 'Empirical Epsilon']
    result_df = result_df.groupby('N').mean()
    
    eps_grid = np.arange(0.55, 0.05, -0.01)
    N_grid = n_func(eps_grid, 5, 0.8)
    
    
    
    plt.plot(N_grid, eps_grid, label = 'Theoretical Samples Needed')
    plt.plot(result_df, label = 'Observed Epsilon at Fixed Samples')
    
    plt.legend()
    plt.xlabel("Number of Samples (n)")
    plt.ylabel("Infinity Norm Error (epsilon)")