# synthetic/linear.py
from .noise_scale import init_noise_dist
from .utils import sample_recursive_scm
from .abstract import MechanismModel


class LinearAdditive(MechanismModel):
    """
    Linear mechanism with additive noise
    """
    def __init__(self,
                 param,
                 bias,
                 noise,
                 noise_scale_constant=None,
                 noise_scale=None,
                 noise_scale_heteroscedastic=None,
                 n_interv_vars=0,
                 interv_dist=None,
                 bias_with_ancestor=None):

        self.param = param
        self.bias = bias
        self.bias_with_ancestor = bias_with_ancestor
        self.noise = noise
        self.noise_scale = noise_scale
        self.noise_scale_constant = noise_scale_constant
        self.noise_scale_heteroscedastic = noise_scale_heteroscedastic
        self.n_interv_vars = n_interv_vars
        self.interv_dist = interv_dist

    def __call__(self, rng, g, n_observations_obs, n_observations_int, seed=None):
        assert self.interv_dist is not None or self.n_interv_vars == 0

        # construct mechanism for each node
        n_vars = g.shape[-1]
        f = []
        for j in range(n_vars):
            # sample parameters
            w = self.param(rng, shape=(n_vars,))
            b = self.bias(rng, shape=(1,))

            # bind parameters to mechanism function
            f.append(lambda x, is_parent, z, theta=w, bias=b: (x @ (theta * is_parent)) + bias + z)

        # construct noise distribution for each node
        nse = []
        for j in range(n_vars):
            nse.append(init_noise_dist(rng=rng,
                                     seed=seed,
                                     dim=int(g[:, j].sum().item()),
                                     dist=self.noise,
                                     noise_scale_constant=self.noise_scale_constant,
                                     noise_scale=self.noise_scale,
                                     noise_scale_heteroscedastic=self.noise_scale_heteroscedastic))

        # sample recursively over g given functionals and endogenous noise distribution
        data = sample_recursive_scm(
            rng=rng,
            n_observations_obs=n_observations_obs,
            n_observations_int=n_observations_int,
            g=g,
            f=f,
            nse=nse,
            interv_dist=self.interv_dist,
            n_interv_vars=self.n_interv_vars,
            bias_with_ancestor=self.bias_with_ancestor,
        )
        
        return data