'''
Toy functions for testing purposes
'''
import sys
sys.path.append('.')

import numpy as np

import torch
from typing import Tuple
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

from hvbll.basic import cal_total_variance, cal_Va_Vm_from_data


class SimpleFnDataset(Dataset):
    '''
    The simple function used for regression experiments.
    
    Parameters
    ----------
    num_samples : int
        Number of samples in the dataset.
        
    dim_input : int
        Dimension of the input.
        
    dim_output : int
        Dimension of the output.
        
    noise_level : float
        Standard deviation of the homoscedastic noise.
        Or the scale of the heteroscedastic noise.
        
    seed : int | None
        Random seed.
        
    gpu_id : int | None
        GPU ID. If None or a negative integer, use CPU.
        
    x_min, x_max : float
        The range of the input data.
        
    Attributes
    ----------    
    X : torch.Tensor (num_samples, dim_input)
        Input data.
        
    Y : torch.Tensor (num_samples, dim_output)
        Output data.
        
    Y_mean : torch.Tensor (num_samples, dim_output)
        Mean output data (noise is zero).
    '''
    def __init__(self, num_samples: int, 
                dim_input=1, dim_output=1,
                noise_level=0.01, seed=None, gpu_id=None,
                x_min=0.0, x_max=1.0) -> None:
        
        self.name = 'SimpleFnDataset'
        
        self.num_samples = num_samples
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.noise_level = noise_level
        self.seed = seed
        self.gpu_id = None if (gpu_id is None or gpu_id < 0) else gpu_id
        
        if seed is not None:
            np.random.seed(seed)
        
        self.x_min = x_min
        self.x_max = x_max
        
        self.X, self.Y = self.get_data()
        
        if torch.cuda.is_available() and self.gpu_id is not None:
            self.X = self.X.to(self.gpu_id)
            self.Y = self.Y.to(self.gpu_id)

    def func_mean(self, x: np.ndarray) -> np.ndarray:
        '''
        The function of the mean of the output.
        '''
        return np.zeros_like(x)
    
    def func_noise(self, x: np.ndarray) -> np.ndarray:
        '''
        The function of the noise level.
        
        Returns
        -------
        noise : ndarray [num_samples, dim_output]
            The noise level of each sample.
        '''
        return self.noise_level * np.ones_like(x)

    def get_data(self):
        '''
        Get the data for the simple function.
        '''
        self.X_cpu = np.random.uniform(self.x_min, self.x_max, (self.num_samples, self.dim_input))
        self.Y_cpu = self.func_mean(self.X_cpu) + self.func_noise(self.X_cpu) * np.random.randn(self.num_samples, self.dim_output)
        
        return torch.tensor(self.X_cpu).float(), torch.tensor(self.Y_cpu).float()

    def get_mean_data(self):
        '''
        Get the mean function dataset, where the noise is zero, the X is the same.
        '''
        self.Y_mean_cpu = self.func_mean(self.X_cpu)
        
        self.Y_mean = torch.tensor(self.Y_mean_cpu).float()
        
        if torch.cuda.is_available() and self.gpu_id is not None:
            self.Y_mean = self.Y_mean.to(self.gpu_id)

    def get_average_aleatoric_uncertainty(self, num_points=1001) -> np.ndarray:
        '''
        Get the average aleatoric uncertainty in the system.
        
        - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
        
        Returns
        -------
        V_a : ndarray [dim_output]
            Average aleatoric uncertainty.
        '''
        if self.dim_input == 1:
            xx = np.linspace(self.x_min, self.x_max, num_points, endpoint=True)[..., None]
            noise_level = self.func_noise(xx)
            V_a = np.mean(noise_level**2, axis=0)
            
        else:
            xx = np.random.uniform(self.x_min, self.x_max, (num_points, self.dim_input))
            noise_level = self.func_noise(xx)
            V_a = np.mean(noise_level**2, axis=0)
            
        return V_a

    def get_samples(self, x: np.ndarray) -> np.ndarray:
        '''
        Get new samples.
        
        Parameters
        ----------
        x : ndarray [num_samples, dim_input]
            Input data.
        
        Returns
        -------
        y : ndarray [num_samples, dim_output]
            Output data.
        '''
        num_samples = x.shape[0]
        return self.func_mean(x) + self.func_noise(x) * np.random.randn(num_samples, self.dim_output)
    
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    def plot_ground_truth_1d(self, x_min=None, x_max=None, num_points=1001):
        '''
        Plot the mean function and noise level.
        '''
        if self.dim_input != 1 or self.dim_output != 1:
            raise Exception('The function is only supported for 1-d input and output.')
        
        x_min = self.x_min if x_min is None else x_min
        x_max = self.x_max if x_max is None else x_max
        
        xx = np.linspace(x_min, x_max, num_points, endpoint=True)[..., None]
        yy = self.func_mean(xx)
        noise_level = self.func_noise(xx)
        
        plt.plot(xx[:,0], yy[:,0], 'k', label='Mean function')
        plt.fill_between(xx[:,0], yy[:,0] - noise_level[:,0], yy[:,0] + noise_level[:,0], 
                            alpha=0.2, color='b', label='Noise level (1 std)')

    @staticmethod
    def cal_total_variance(y: np.ndarray) -> np.ndarray:
        '''
        Calculate the total QoI variance in the system.
        
        - V_t = Var_{x~p(x), e~p(e)}[y] = V_a + V_m
        - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
        - V_m = Var_{x~p(x)}[E_{e~p(e)}[y|x,e]]
        
        Parameters
        ----------
        y : ndarray [num_samples, dim_output]
            Input data.
            
        Returns
        -------
        V_t : ndarray [dim_output]
            Total variance.
        '''
        return cal_total_variance(y)

    @staticmethod
    def cal_variance_of_mean_function(y_mean: np.ndarray) -> np.ndarray:
        '''
        Calculate the variance of the mean function in the system.
        
        - V_m = Var_{x~p(x)}[E_{e~p(e)}[y|x,e]]
        
        Parameters
        ----------
        y_mean : ndarray [num_samples, dim_output]
            Mean output data.
            
        Returns
        -------
        V_m : ndarray [dim_output]
            Variance of the mean function.
        '''
        return cal_total_variance(y_mean)

    @staticmethod
    def cal_Va_Vm_from_data(x: np.ndarray, y: np.ndarray, n_neighbor=3) -> Tuple[np.ndarray, np.ndarray]:
        '''
        Calculate the average aleatoric uncertainty (V_a) 
        and the variance of the mean function (V_m) in the system.
        
        - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
        - V_m = Var_{x~p(x)}[E_{e~p(e)}[y|x,e]]
        
        Parameters
        ----------
        x : ndarray [num_samples, dim_input]
            Input data.
        
        y : ndarray [num_samples, dim_output]
            Noise data.
            
        n_neighbor : int
            Number of neighbors for grouping.
            
        Returns
        -------
        V_a : ndarray [dim_output]
            Average aleatoric uncertainty.
            
        V_m : ndarray [dim_output]
            Variance of the mean function.
        '''
        V_a, V_m, _ = cal_Va_Vm_from_data(x, y, n_neighbor)
        
        return V_a, V_m


class ToyFn_DUE1D(SimpleFnDataset):
    '''
    The homoscedastic function used in DUE 1-d regression experiments.
    
    Based on the implementation presented in:
    
    https://github.com/y0ast/DUE/blob/main/toy_regression.ipynb
    '''
    def __init__(self, num_samples: int, 
                noise_level=0.1, seed=None, gpu_id=None) -> None:
        
        if seed is not None:
            np.random.seed(seed)
            
        self.W = np.random.randn(30, 1) * 10
        self.b = np.random.rand(30, 1) * 2 * np.pi
        
        super(ToyFn_DUE1D, self).__init__(num_samples, 
                dim_input=1, dim_output=1,
                noise_level=noise_level, seed=seed, gpu_id=gpu_id,
                x_min=-1, x_max=1)

        self.name = 'ToyFn_DUE1D'

    def func_mean(self, x: np.ndarray) -> np.ndarray:
        
        y = np.cos(self.W * x.squeeze(-1) + self.b).sum(0)/5.
        
        return y[..., None]
    
    def get_data(self):
        
        self.X_cpu = 0.5 * np.sign(np.random.randn(self.num_samples)) \
            + 0.1 * np.random.randn(self.num_samples).clip(-2, 2)
            
        self.X_cpu = self.X_cpu[..., None]
        
        noise = self.func_noise(self.X_cpu) * np.random.randn(self.num_samples, 1)
        
        self.Y_cpu = self.func_mean(self.X_cpu) + noise

        return torch.tensor(self.X_cpu).float(), torch.tensor(self.Y_cpu).float()


class ToyFn_Lin_Noise_Lin(SimpleFnDataset):
    '''
    The heteroscedastic linear function used in regression experiments.
    The noise level is linearly dependent on the input.
    
    Parameters
    ----------
    noise_level_slope: float
        Slope of the local noise level relative to the input.
    '''
    def __init__(self, num_samples: int, dim_input=1, dim_output=1,
                noise_level=0.05, noise_level_slope = 1.0,
                seed=None, gpu_id=None):
        
        super().__init__(num_samples, dim_input, dim_output, noise_level=noise_level, 
                            seed=seed, gpu_id=gpu_id)

        self.name = 'ToyFn_Lin_Noise_Lin'
        
    def func_mean(self, x: np.ndarray) -> np.ndarray:
        
        y_mean = np.zeros((x.shape[0], self.dim_output))
        
        for i in range(self.dim_input):
            y_mean = y_mean + x[:,i:i+1]
        
        return y_mean / self.dim_input
    
    def func_noise(self, x: np.ndarray) -> np.ndarray:
        
        g = np.zeros((x.shape[0], self.dim_output)) + 0.1
        
        for i in range(self.dim_input):
            g = g + x[:,i:i+1]

        noise_level = self.noise_level * g / self.dim_input

        return noise_level


class ToyFn_Lin_Noise_Sin(SimpleFnDataset):
    '''
    The heteroscedastic linear function used in regression experiments.
    The noise level is sinusoidally dependent on the input.
    
    Parameters
    ----------
    noise_level_omega : float
        Frequency of the sinusoidal noise level.
    '''
    def __init__(self, num_samples: int, dim_input=1, 
                noise_level=0.05, noise_level_omega = 2*np.pi, 
                seed=None, gpu_id=None):
        
        self.noise_level_omega = noise_level_omega
        
        super().__init__(num_samples, dim_input, dim_output=1, noise_level=noise_level, 
                            seed=seed, gpu_id=gpu_id)

        self.name = 'ToyFn_Lin_Noise_Sin'
        
    def func_mean(self, x: np.ndarray) -> np.ndarray:
        
        y_mean = np.zeros((x.shape[0], self.dim_output))
        
        for i in range(self.dim_input):
            y_mean = y_mean + x[:,i:i+1]
        
        return y_mean / self.dim_input
    
    def func_noise(self, x: np.ndarray) -> np.ndarray:
        
        g = np.ones((x.shape[0], self.dim_output))
        
        for i in range(self.dim_input):
            g = g * np.sin(self.noise_level_omega * x[:,i:i+1])

        noise_level = self.noise_level * g
        
        return noise_level


class ToyFn_Sin_Noise_Lin(SimpleFnDataset):
    '''
    The heteroscedastic sinusoidal function used in regression experiments.
    The noise level is linearly dependent on the input.
    
    Parameters
    ----------
    omega : float
        Frequency of the sinusoidal function.
        The input ranges in [0, 1].
    
    noise_level_slope: float
        Slope of the local noise level relative to the input.
    '''
    def __init__(self, num_samples: int, dim_input=1, omega=2*np.pi, 
                noise_level=0.1, noise_level_slope = 1.0, 
                seed=None, gpu_id=None):
        
        self.omega = omega

        super().__init__(num_samples, dim_input, dim_output=1, noise_level=noise_level, 
                            seed=seed, gpu_id=gpu_id)

        self.name = 'ToyFn_Sin_Noise_Lin'
        
    def func_mean(self, x: np.ndarray) -> np.ndarray:
        
        y_mean = np.ones((x.shape[0], self.dim_output))
        
        for i in range(self.dim_input):
            y_mean = y_mean * np.sin(self.omega * x[:,i:i+1])

        return y_mean * 0.5

    def func_noise(self, x: np.ndarray) -> np.ndarray:
        
        g = np.zeros((x.shape[0], self.dim_output)) + 0.1
        
        for i in range(self.dim_input):
            g = g + x[:,i:i+1]

        noise_level = self.noise_level * g / self.dim_input

        return noise_level


class ToyFn_Sin_Noise_Sin(SimpleFnDataset):
    '''
    The heteroscedastic sinusoidal function used in regression experiments.
    The noise level is sinusoidally dependent on the input.
    
    Parameters
    ----------
    omega : float
        Frequency of the sinusoidal function.
        The input ranges in [0, 1].
    
    noise_level_omega : float
        Frequency of the sinusoidal noise level.
    '''
    def __init__(self, num_samples: int, dim_input=1, omega=2*np.pi, 
                noise_level=0.05, noise_level_omega = 2*np.pi, 
                seed=None, gpu_id=None):
        
        self.omega = omega
        self.noise_level_omega = noise_level_omega

        super().__init__(num_samples, dim_input, dim_output=1, noise_level=noise_level, 
                            seed=seed, gpu_id=gpu_id)

        self.name = 'ToyFn_Sin_Noise_Sin'
        
    def func_mean(self, x: np.ndarray) -> np.ndarray:
        
        y_mean = np.ones((x.shape[0], self.dim_output))
        
        for i in range(self.dim_input):
            y_mean = y_mean * np.sin(self.omega * x[:,i:i+1])

        return y_mean * 0.5
    
    def func_noise(self, x: np.ndarray) -> np.ndarray:
        
        g = np.ones((x.shape[0], self.dim_output))
        
        for i in range(self.dim_input):
            g = g * np.sin(self.noise_level_omega * x[:,i:i+1])

        noise_level = self.noise_level * g
        
        return noise_level


class ToyFn_xSin_Noise_Lin(SimpleFnDataset):
    '''
    The homoscedastic function used in BLL 1-d regression experiments.
    x in [-0.5, 1.5], but only provide data in [0, 1] for training.
    
    Based on the implementation presented in:
    
    https://ieeexplore.ieee.org/document/10305157
    
    https://github.com/4flixt/2023_Paper_BLL_LML/blob/main/bll/tools.py
    
    '''
    def __init__(self, num_samples: int, dim_input=1, 
                omega=4*np.pi, noise_level=0.1, noise_level_slope = 1.0, 
                seed=None, gpu_id=None):
        
        self.omega = omega
        self.noise_level_slope = noise_level_slope
        
        super().__init__(num_samples, dim_input, dim_output=1, noise_level=noise_level, 
                            seed=seed, gpu_id=gpu_id, x_min=-0.5, x_max=1.5)

        self.name = 'ToyFn_xSin_Noise_Lin'
    
    def func_mean(self, x: np.ndarray) -> np.ndarray:
        
        y_mean = np.ones((x.shape[0], self.dim_output))
        y_mean = y_mean * x[:,0:1]**2
        
        for i in range(self.dim_input):
            y_mean = y_mean * np.sin(self.omega * x[:,i:i+1])

        return y_mean

    def func_noise(self, x: np.ndarray) -> np.ndarray:
        
        g = np.zeros((x.shape[0], self.dim_output)) + 1.0
        
        for i in range(self.dim_input):
            g = g + self.noise_level_slope * x[:,i:i+1]

        noise_level = self.noise_level * g.clip(1.0, 100.0)

        return noise_level


class ToyFn_MeanEst_Simple():
    '''
    A nonlinear system for the mean value estimation in hierarchical systems.
    
    E_{x,y}[f(x,y)] v.s. E_{x}[E_{y}[f(x,y)]]
    
    Parameters
    ----------
    dim_input : int
        Dimension of the input (design variables).
        
    dim_output : int
        Dimension of the output (quantity of interest).
        
    dim_noise : int
        Dimension of the noise (unobservable variables).
        
    dim_vi_sub : int
        Dimension of the intermediate variables as the sub-system input.
        
    dim_vi_remain : int
        Dimension of the remaining intermediate variables.
        
    seed : int | None
        Random seed.

    x_min, x_max : ndarray [dim_input]
        The range of the input data.
        
    e_min, e_max : ndarray [dim_noise]
        The range of the noise data.
        
    Attributes
    ----------    
    X : ndarray [num_samples, dim_input]
        Input data.
        
    Y : ndarray [num_samples, dim_output]
        Output data.
        
    E : ndarray [num_samples, dim_noise]
        Noise data.
        
    VS : ndarray [num_samples, dim_vi_sub]
        Intermediate variables as the sub-system input.
        
    VR : ndarray [num_samples, dim_vi_remain]
        Remaining intermediate (unobservable) variables.
    '''
    def __init__(self, seed=None, 
                i_vs=0, i_vr=0,
                i_function=0, ratio_e=1.0,
                x_min=0.0, x_max=1.0, e_min=0.0, e_max=1.0):
        
        self.name = 'ToyFn_VarEst_Simple'
        
        self.i_vs = i_vs
        self.i_vr = i_vr
        self.i_function = i_function
        self.ratio_e = ratio_e
        
        self.dim_input = 2
        self.dim_output = 1
        self.dim_noise = 1
        self.dim_vi_sub = 1
        self.dim_vi_remain = 1
        self.x_min = x_min
        self.x_max = x_max
        self.e_min = e_min
        self.e_max = e_max
        
        if seed is not None:
            np.random.seed(seed)

    def func_vs(self, x: np.ndarray) -> np.ndarray:
        '''
        The function of the intermediate variables as the sub-system input.
        '''
        
        vs = np.zeros((x.shape[0], self.dim_vi_sub))
        
        if self.i_vs == 0:
            vs[:,0] = np.sin(x[:,0]*np.pi*2) + x[:,1]**2
            
        elif self.i_vs == 1:
            vs[:,0] = x[:,0]*x[:,1] + x[:,1]
            
        elif self.i_vs == 2:
            vs[:,0] = 1E3 * np.sin(x[:,0]*np.pi*2) + 200*x[:,1]**2
            
        else:
            raise Exception('The sub-system index is not supported.')
        
        return vs
    
    def func_vr(self, x: np.ndarray, e: np.ndarray) -> np.ndarray:
        '''
        The function of the remaining intermediate (unobservable) variables.
        '''
        
        vr = np.zeros((x.shape[0], self.dim_vi_remain))
        
        if self.i_vr == 0:
            vr[:,0] = np.cos(x[:,0]*np.pi*2) + x[:,0]
            
        elif self.i_vr == 1:
            vr[:,0] = x[:,0]*x[:,1] + 0.1*(x[:,0]+x[:,1]) + e[:,0]
            
        elif self.i_vr == 2:
            vr[:,0] = 33*np.sin(x[:,0]*np.pi*2) + 100*x[:,0]*x[:,1] + e[:,0]
            
        else:
            raise Exception('The remaining system index is not supported.')
        
        return vr
    
    def func_y_of_vi(self, vs: np.ndarray, vr: np.ndarray, e: np.ndarray) -> np.ndarray:
        '''
        The function of the output.
        '''
        
        if self.i_function == 0:
            #* f adheres to the principle of additivity 
            y = vs + vr**2 + self.ratio_e*e

        elif self.i_function == 1:
            #* f is (x+y)^2 + e
            y = (vs+vr)**2 + self.ratio_e*e
            
        elif self.i_function == 2:
            #* general form
            y = vs*vr + vs + np.sin(vs+vr) + self.ratio_e*e
            
        elif self.i_function == 3:

            y = 66*np.sin(vs) + 35*np.cos(vr) + self.ratio_e*e
            
        else:
            raise Exception('The function index is not supported.')
        
        return y

    def get_data(self, num_samples : int) -> Tuple[np.ndarray, np.ndarray]:
        '''
        Get the data for the simple function.
        
        Parameters
        ----------
        num_samples : int
            Number of samples in the dataset.
            
        Returns
        -------
        X : ndarray [num_samples, dim_input]
            Input data.
            
        Y : ndarray [num_samples, dim_output]
            Output data.
        '''
        self.X = np.random.uniform(self.x_min, self.x_max, (num_samples, self.dim_input))
        
        loc = 0.5*(self.e_min + self.e_max)
        scale = 0.25*(self.e_max - self.e_min)
        
        self.E = np.random.normal(loc, scale, (num_samples, self.dim_noise))
        
        self.VS = self.func_vs(self.X)
        self.VR = self.func_vr(self.X, self.E)
        
        self.Y = self.func_y_of_vi(self.VS, self.VR, self.E)
        
        return self.X, self.Y


class ToyFn_ParamLinHS():
    '''
    A nonlinear parameterized linear system for the variance estimation in hierarchical systems.

    Parameters
    ----------
    i_v : int
        The index of the intermediate variable function.
        
    i_y : int
        The index of the output function.
    
    ratio_e : float
        The ratio of the noise level.
    
    dim_input : int
        Dimension of the input (design variables).
        
    dim_output : int
        Dimension of the output (quantity of interest).
        
    dim_noise : int
        Dimension of the noise (unobservable variables).
        
    dim_v : int
        Dimension of the intermediate variables.
        
    seed : int | None
        Random seed.

    x_min, x_max : ndarray [dim_input]
        The range of the input data.
        
    e_min, e_max : ndarray [dim_noise]
        The range of the noise data.
        
    Attributes
    ----------    
    X : ndarray [num_samples, dim_input]
        Input data.
        
    Y : ndarray [num_samples, dim_output]
        Output data.
        
    E : ndarray [num_samples, dim_noise]
        Noise data.
        
    V : ndarray [num_samples, dim_v]
        Intermediate variables as the sub-system input.
        
    B : ndarray [num_samples, dim_output, dim_v]
        The coefficient matrix of the linear system.

    '''
    def __init__(self, i_v=0, i_y=0, ratio_e=1.0,
                dim_input=2, dim_output=1, dim_noise=1, dim_v=2, seed=None,
                x_min=0.0, x_max=1.0, e_min=0.0, e_max=1.0):

        self.name = 'ToyFn_ParamLinHS'
        
        self.i_v = i_v
        self.i_y = i_y
        self.ratio_e = ratio_e

        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_noise = dim_noise
        self.dim_v = dim_v
        self.x_min = x_min
        self.x_max = x_max
        self.e_min = e_min
        self.e_max = e_max
        
        if seed is not None:
            np.random.seed(seed)
        
    def func_v(self, x: np.ndarray, e: np.ndarray) -> np.ndarray:
        
        v = np.zeros((x.shape[0], self.dim_v))
        
        if self.i_v == 0:
            #* v = v(x)
            v[:,0] = x[:,0] + x[:,1] + 0.1
            v[:,1] = x[:,0]*x[:,1] + x[:,1]**2
            
        elif self.i_v == 1:
            #* v = v(x)
            v[:,0] = np.sin(x[:,0]*np.pi*2) + x[:,1]**2
            v[:,1] = x[:,0]*x[:,1] + x[:,1]
            
        elif self.i_v == 2:
            #* v = v(x)
            v[:,0] = 0.1*x[:,0]**2 + 0.5*x[:,1] + 1.0
            v[:,1] = 0.1*(x[:,0]*x[:,1] + x[:,1]**2)
            
        elif self.i_v == 3:
            #* v = v(x, e)
            v[:,0] = x[:,0] + x[:,1]*self.ratio_e*e[:,0] + 0.1
            v[:,1] = x[:,0]*x[:,1] + x[:,1]**2 + self.ratio_e*e[:,0]
            
        else:
            raise Exception('The sub-system index is not supported.')
        
        return v
    
    def matrix_b(self, x: np.ndarray, e: np.ndarray) -> np.ndarray:
        
        mB = np.zeros((x.shape[0], self.dim_output, self.dim_v))
        
        if self.i_y == 0:
            #* dB/de = 0 ===> V_a = 0.0
            mB[:,0,0] = x[:,0] + 1.0
            mB[:,0,1] = x[:,0]*x[:,1]**2
            
        elif self.i_y == 1:
            
            mB[:,0,0] = x[:,0]*x[:,1] + x[:,1]*self.ratio_e*e[:,0] + np.sin(np.pi*(x[:,0]+x[:,1])) + self.ratio_e*e[:,0]
            mB[:,0,1] = x[:,0] + x[:,1]**2 + 5*self.ratio_e*e[:,0]
            
        elif self.i_y == 2:
            
            mB[:,0,0] = 10*x[:,0]*self.ratio_e*e[:,0] + x[:,1] + 0.1
            mB[:,0,1] = x[:,0]*x[:,1] + np.cos(np.pi*(x[:,0]**2 + x[:,1])) + self.ratio_e*e[:,0]
            
        elif self.i_y == 3:
        
            mB[:,0,0] = 29*x[:,0]*x[:,1] + 600*x[:,1]*self.ratio_e*e[:,0] + 30*np.sin(np.pi*(x[:,0]+x[:,1])) + 100*self.ratio_e*e[:,0]
            mB[:,0,1] = 600**x[:,1] + 1E3*np.cos(np.pi*(x[:,0]**2 + x[:,1])) + 50*self.ratio_e*e[:,0]
            
        else:
            raise Exception('The output function index is not supported.')
        
        return mB
    
    def func_y_of_v(self, x: np.ndarray, v: np.ndarray, e: np.ndarray) -> np.ndarray:
        '''
        Get the output data.
        '''
        b = self.matrix_b(x, e)
        y = np.matmul(b, v[..., None])[:,:,0]
        
        return y

    def func_y(self, x: np.ndarray, e: np.ndarray) -> np.ndarray:
        '''
        Get the output data.
        '''
        v = self.func_v(x, e)
        b = self.matrix_b(x, e)
        y = np.matmul(b, v[..., None])[:,:,0]
        
        return y

    def get_data(self, num_samples : int) -> Tuple[np.ndarray, np.ndarray]:
        '''
        Get the data for the simple function.
        
        Parameters
        ----------
        num_samples : int
            Number of samples in the dataset.
            
        Returns
        -------
        X : ndarray [num_samples, dim_input]
            Input data.
            
        Y : ndarray [num_samples, dim_output]
            Output data.
        '''
        self.X = np.random.uniform(self.x_min, self.x_max, (num_samples, self.dim_input))
        
        loc = 0.5*(self.e_min + self.e_max)
        scale = 0.25*(self.e_max - self.e_min)
        
        self.E = np.random.normal(loc, scale, (num_samples, self.dim_noise))
        
        self.V = self.func_v(self.X, self.E)
        self.B = self.matrix_b(self.X, self.E)
        
        self.Y = np.matmul(self.B, self.V[..., None])[:,:,0]
        
        return self.X, self.Y

    def cal_total_variance(self, x: np.ndarray, e: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        '''
        Calculate the total variance in the overall system.
        
        - V_t = Var_{x~p(x), e~p(e)}[y] = V_a + Var_{x~p(x)}[E_{e~p(e)}[y|x,e]]
        - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
        
        Parameters
        ----------
        x : ndarray [num_samples, dim_input]
            Input data.
        
        e : ndarray [num_samples, dim_noise]
            Noise data.
            
        Returns
        -------
        V_t : ndarray [dim_output]
            Total variance.
            
        y : ndarray [num_samples, dim_output]
            Output data.
        '''
        y = self.func_y(x, e)
        
        return np.var(y, axis=0), y

    def cal_average_aleatoric_uncertainty(self, x: np.ndarray, e: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        '''
        Calculate the average aleatoric uncertainty in the overall system.
        
        - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
        
        Parameters
        ----------
        x : ndarray [num_samples, dim_input]
            Input data.
        
        e : ndarray [num_samples, dim_noise]
            Noise data.
            
        Returns
        -------
        V_a : ndarray [dim_output]
            Average aleatoric uncertainty.
            
        y : ndarray [num_samples, dim_output]
            Output data.
        '''
        y = self.func_y(x, e)
        
        num_samples = x.shape[0]

        # dataset: N samples of (x, y)
        data = {}
        for i in range(self.dim_input):
            data['x'+str(i)] = x[:,i]
            
        for i in range(self.dim_output):
            data['y'+str(i)] = y[:,i]
        
        name_x = ['x'+str(i) for i in range(self.dim_input)]

        # Convert to a DataFrame for convenience
        df = pd.DataFrame(data)

        # Define the number of neighbors for grouping
        n_neighbors = max(5, int(0.01*num_samples))  
        
        # Group by similar x using clustering (e.g., Nearest Neighbors)
        nn = NearestNeighbors(n_neighbors=n_neighbors).fit(df[name_x])
        _, indices = nn.kneighbors(df[name_x])

        # Calculate the variance of y 
        V_a = np.zeros(self.dim_output)
        
        for i in range(self.dim_output):

            # Compute the variance of y for neighbors of each sample
            variances = []
            # ii = 0
            for idx_group in indices:
                group_y = df.iloc[idx_group]['y'+str(i)].values
                variances.append(np.var(group_y))
                # print(ii, np.var(group_y)); ii += 1

            # Compute the expected variance
            V_a[i] = np.mean(variances)
        
        return V_a, y

    def cal_sub_system_average_variance(self, x: np.ndarray, e: np.ndarray) \
                    -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        '''
        Calculate the average variance in the sub-system.
        
        - V_s = Var_{x~p(x), v~p(v), e~p(e)}[y|x,v,e]
        
        Parameters
        ----------
        x : ndarray [num_samples, dim_input]
            Input data.
        
        e : ndarray [num_samples, dim_noise]
            Noise data.
            
        Returns
        -------
        V_s : ndarray [dim_output]
            Average variance in the sub-system.
            
        V_t : ndarray [dim_output]
            Total variance in the overall system.
            
        y_marginal : ndarray [num_samples, dim_output]
            Output data calculated from randomly sampled x, v, e from their marginal distributions.
            
        y : ndarray [num_samples, dim_output]
            Output data.
        '''
        v = self.func_v(x, e)
        y = self.func_y_of_v(x, v, e)
        
        num_samples = x.shape[0]

        # recalculate y with random sampled x, v, e
        y_marginal = np.zeros((num_samples, self.dim_output))
        ii = np.random.randint(0, num_samples-1, num_samples)
        jj = np.random.randint(0, num_samples-1, num_samples)

        for i in range(num_samples):
            
            y_marginal[i:i+1,:] = self.func_y_of_v(
                                x[i:i+1,:], 
                                v[ii[i]:ii[i]+1,:], 
                                e[jj[i]:jj[i]+1,:])
        
        V_s = np.var(y_marginal, axis=0)
        V_t = np.var(y, axis=0)
        
        return V_s, V_t, y_marginal, y


if __name__ == '__main__':
    
    num_samples = 128
    seed = 2
    scatter_size = 20
    
    print('Plotting the toy functions...')
    
    plt.figure(figsize=(15, 10))

    # ToyFn_DUE1D
    dataset = ToyFn_DUE1D(num_samples, seed=seed)
    plt.subplot(2, 3, 1)
    plt.scatter(dataset.X, dataset.Y, s=scatter_size)
    plt.title(dataset.name)
    plt.axis([-1, 1, -2, 1])
    dataset.plot_ground_truth_1d(x_min=-1, x_max=1)
    
    # ToyFn_Lin_Noise_Lin
    dataset = ToyFn_Lin_Noise_Lin(num_samples, seed=seed)
    plt.subplot(2, 3, 2)
    plt.scatter(dataset.X, dataset.Y, s=scatter_size)
    plt.title(dataset.name)
    plt.axis([-0.1, 1.1, -0.1, 1.1])
    dataset.plot_ground_truth_1d()
    
    # ToyFn_Lin_Noise_Sin
    dataset = ToyFn_Lin_Noise_Sin(num_samples, seed=seed)
    plt.subplot(2, 3, 3)
    plt.scatter(dataset.X, dataset.Y, s=scatter_size)
    plt.title(dataset.name)
    plt.axis([-0.1, 1.1, -0.1, 1.1])
    dataset.plot_ground_truth_1d()
    
    # ToyFn_Sin_Noise_Lin
    dataset = ToyFn_Sin_Noise_Lin(num_samples, seed=seed)
    plt.subplot(2, 3, 4)
    plt.scatter(dataset.X, dataset.Y, s=scatter_size)
    plt.title(dataset.name)
    plt.axis([-0.1, 1.1, -0.6, 0.6])
    dataset.plot_ground_truth_1d()
    
    # ToyFn_Sin_Noise_Sin
    dataset = ToyFn_Sin_Noise_Sin(num_samples, seed=seed)
    plt.subplot(2, 3, 5)
    plt.scatter(dataset.X, dataset.Y, s=scatter_size)
    plt.title(dataset.name)
    plt.axis([-0.1, 1.1, -0.6, 0.6])
    dataset.plot_ground_truth_1d()
    
    plt.show()
    plt.close()
    
    