'''
Variance estimation
'''
import os
import sys
sys.path.append('.')
path = os.path.dirname(sys.argv[0])

import numpy as np

from hvbll.toy_functions import *
from hvbll.basic import cal_total_variance, cal_Va_Vm_from_data


if __name__ == '__main__':

    N_SEED = 10
    noise_level = 0.1
    noise_level_slope = 1.0
    noise_level_omega = 2 * np.pi

    # (dim_input, num_samples)
    list_setting = [
        (1, 20), (1, 100), (1, 1000),
        (10, 50), (10, 1000), (10, 10000),
        (100, 100), (100, 1000), (100, 10000),
    ]
        

    f = open(os.path.join(path, 'variance-multi-dimension.dat'), 'w')
    f.write('Variables= %9s %9s %19s %19s %19s %19s %19s %19s\n' % 
            ('i_func', 'dim_input', 'num_samples', 'E_noise_real', 'E_noise', 'V_total', 'V_mean', 'V_noise'))

    for i_function in range(4):

        print()
        print('>>> Function %d' % i_function)
        print()

        for dim_input, num_samples in list_setting:

            V_total = []
            E_noise = []
            V_mean = []
            V_noise = []

            if i_function in [1, 3]:    # sine noise
                if dim_input == 10:
                    dim_input = 2
                elif dim_input == 100:
                    dim_input = 4
            
            for seed in range(N_SEED):

                if i_function == 0:
                    dataset = ToyFn_Lin_Noise_Lin(num_samples, dim_input=dim_input, seed=seed,
                                    noise_level=noise_level, noise_level_slope=noise_level_slope)
                elif i_function == 1:
                    dataset = ToyFn_Lin_Noise_Sin(num_samples, dim_input=dim_input, seed=seed,
                                    noise_level=noise_level, noise_level_omega=noise_level_omega)
                elif i_function == 2:
                    dataset = ToyFn_Sin_Noise_Lin(num_samples, dim_input=dim_input, seed=seed,
                                    noise_level=noise_level, noise_level_slope=noise_level_slope)
                elif i_function == 3:
                    dataset = ToyFn_Sin_Noise_Sin(num_samples, dim_input=dim_input, seed=seed,
                                    noise_level=noise_level, noise_level_omega=noise_level_omega)

                #* Variance estimation
                _V_total = cal_total_variance(dataset.Y_cpu)
                _E_noise, _V_mean, _V_noise = cal_Va_Vm_from_data(dataset.X_cpu, dataset.Y_cpu,
                                                n_neighbor=5, ratio_neighbor=0.001)

                V_total.append(_V_total)
                E_noise.append(_E_noise)
                V_mean.append(_V_mean)
                V_noise.append(_V_noise)

            E_noise_real = dataset.get_average_aleatoric_uncertainty(num_points=int(1E6))[0]

            text = '%20d %9d %19d %19.6e %19.6e %19.6e %19.6e %19.6e' % \
                    (i_function, dim_input, num_samples, 
                    E_noise_real, np.mean(E_noise), np.mean(V_total), np.mean(V_mean), np.mean(V_noise))
            
            f.write(text+' \n')
            print(text)
        
    f.close()
    
