import numpy as np
from typing import Union
from causally.scm.causal_mechanism import PredictionModel
from causally.scm.scm import BaseStructuralCausalModel
from causally.scm.noise import RandomNoiseDistribution, Distribution
from causally.scm.context import SCMContext
from causally.graph.random_graph import GraphGenerator

class NonAdditiveNoiseModel(BaseStructuralCausalModel):
    """Class for data generation from a nonlinear model with non-additive noise.

    Parameters
    ----------
    num_samples: int
        Number of samples in the dataset.
    graph_generator: GraphGenerator
        Random graph generator implementing the ``get_random_graph`` method.
    noise_generator:  Distribution
        Sampler of the noise random variables. It must be an instance of
        a class inheriting from ``causally.scm.noise.Distribution``, implementing
        the ``sample`` method.
    causal_mechanism: PredictionModel
        Object for the generation of the nonlinear causal mechanism.
        It must be an instance of a class inheriting from
        ``causally.scm.causal_mechanism.PredictionModel``, implementing
        the ``predict`` method.
    scm_context: SCMContext, default None
        ``SCMContext`` object specifying the modeling assumptions of the SCM.
        If ``None`` this is equivalent to an ``SCMContext`` object with no
        assumption specified.
    seed: int, default None
        Seed for reproducibility. If ``None``, then the random seed is not set.
    """

    def __init__(
        self,
        num_samples: int,
        graph_generator: GraphGenerator,
        noise_generator: Union[RandomNoiseDistribution, Distribution],
        causal_mechanism: PredictionModel,
        scm_context: SCMContext = None,
        seed: int = None,
    ):
        super().__init__(
            num_samples, graph_generator, noise_generator, scm_context, seed
        )
        self.causal_mechanism = causal_mechanism

    def _sample_mechanism(self, parents: np.array, child_noise: np.array) -> np.array:
        effect = self.causal_mechanism.predict(np.concatenate([parents, child_noise[:, None]], axis=1))
        return effect