# synthetic/noise_scale.py
import numpy as np
from .abstract import NoiseModel
from .utils import draw_rff_params


class SimpleNoise(NoiseModel):
    """
    Simple noise with constant scale (homogeneous noise)
    """
    def __init__(self, dist, scale):
        self.dist = dist
        self.scale = scale

    def __call__(self, rng, x, is_parent):
        n_obs = x.shape[0]
        
        if self.scale is None:                   
            return self.dist(rng, shape=(n_obs,))
        else:
            return self.scale * self.dist(rng, shape=(n_obs,))


class HeteroscedasticRFFNoise(NoiseModel):
    """
    Heterogenous noise based on a random Fourier feature scaling function h(.)
    """
    def __init__(self, dist, *, rng, d, length_scale, output_scale, n_rff=100):
        self.dist = dist
        self.param = draw_rff_params(rng=rng, d=d, length_scale=length_scale,
                                     output_scale=output_scale, n_rff=n_rff)
        self.n_rff = n_rff

    def __call__(self, rng, x, is_parent):
        # compute rff function f(x) = w'phi(x)
        x_parents = x[:, np.where(is_parent)[0]]
        phi = np.cos(np.einsum('db,...d->...b', self.param["omega"], x_parents) + self.param["b"])
        f_x = np.sqrt(2.0) * self.param["c"] * np.einsum('b,...b->...', self.param["w"], phi) / np.sqrt(self.n_rff)

        # sample noise with scale^2 = log(1+exp(f(x))
        scale = np.sqrt(np.log(1.0 + np.exp(f_x)))
        return scale * self.dist(rng, shape=scale.shape)


def init_noise_dist(*, rng, seed, dim, dist, noise_scale_constant, noise_scale, noise_scale_heteroscedastic):
    """
    Initialize noise distribution
    """
    if noise_scale_constant is True:
        return SimpleNoise(dist, None)  
    else:
        if noise_scale is not None:
            assert noise_scale_heteroscedastic is None
            scale = noise_scale(rng)
            return SimpleNoise(dist, scale)
        elif noise_scale_heteroscedastic is not None:
            assert noise_scale is None
            assert "rff" in noise_scale_heteroscedastic
            return HeteroscedasticRFFNoise(dist, rng=rng, d=int(dim),
                                           length_scale=noise_scale_heteroscedastic["length_scale"],
                                           output_scale=noise_scale_heteroscedastic["output_scale"],
                                           n_rff=100)
        else:
            raise KeyError("neither `noise_scale_constant`, `noise_scale` nor `noise_scale_heteroscedastic` are given")


def init_noise_dist_torch(*, rng, seed, dim, dist, noise_scale_constant, noise_scale, noise_scale_heteroscedastic):
    """
    Initialize noise distribution for torch
    """
    if noise_scale_constant is True:
        return SimpleNoise(dist, None)  
    else:
        if noise_scale is not None:
            assert noise_scale_heteroscedastic is None
            scale = noise_scale(seed)
            return SimpleNoise(dist, scale)
        elif noise_scale_heteroscedastic is not None:
            assert noise_scale is None
            assert "rff" in noise_scale_heteroscedastic
            return HeteroscedasticRFFNoise(dist, rng=rng, d=int(dim),
                                           length_scale=noise_scale_heteroscedastic["length_scale"],
                                           output_scale=noise_scale_heteroscedastic["output_scale"],
                                           n_rff=100)
        else:
            raise KeyError("neither `noise_scale_constant`, `noise_scale` nor `noise_scale_heteroscedastic` are given")