import torch
import torch.nn.functional as F

import copy

##################################################################################################################
### MECHANISMS
##################################################################################################################

class Mechanism:

    def __init__(self):
        pass

    def __call__(self, x, u):
        """ Causal mechanism.

        Arguments:
        ----------
        x : torch.tensor, shape (n_samples, n_parents)
            Node's parents' values.
        u : torch.tensor, shape (n_samples, 1)
            Node's noise values.
        """
        raise NotImplementedError


class LinearMechanism(Mechanism):

    def __init__(self, weight_parents=torch.ones((1, 2)), weight_noise=torch.ones((1, 1)), bias=torch.zeros(1)):
        """ Linear causal mechanism, implementing the function:

                f(x, u) = <weight_x, x> + weight_u*u + bias

        Arguments:
        ----------
        weight_x : torch.tensor, shape (1, n_parents)
            Linear combination weight for node's parents' values.
        weight_u : torch.tensor, shape (1, 1)
            Linear combination weight for node's noise values.
        bias : torch.tensor, shape ([])
            Linear combination bias term.
        """

        self.weight_parents = weight_parents
        self.weight_noise = weight_noise
        
        self.bias = bias

    def __call__(self, x, u):
        
        out = F.linear(x, self.weight_parents, bias=None) + F.linear(u, self.weight_noise, bias=self.bias)
        return out

class QuadraticMechanism(Mechanism):

    def __init__(self, weight_parents=torch.ones((1, 2)), weight_noise=torch.ones((1, 1)), bias=torch.zeros(1)):
        """ Linear causal mechanism, implementing the function:

                f(x, u) = <weight_x, x**2> + weight_u*u + bias

        Arguments:
        ----------
        weight_x : torch.tensor, shape (1, n_parents)
            Linear combination weight for node's parents' values.
        weight_u : torch.tensor, shape (1, 1)
            Linear combination weight for node's noise values.
        bias : torch.tensor, shape ([])
            Linear combination bias term.
        """

        self.weight_parents = weight_parents
        self.weight_noise = weight_noise
        
        self.bias = bias

    def __call__(self, x, u):
        
        out = F.linear(x**2, self.weight_parents, bias=None) + F.linear(u, self.weight_noise, bias=self.bias)
        return out

##################################################################################################################
### STRUCTURAL CAUSAL MODELS
##################################################################################################################

class SCM:

    def __init__(self, dag, mechanisms, noise_distrs):

        self.dag = dag
        self.noise_distrs = noise_distrs
        self.mechanisms = mechanisms
        self.n_nodes = len(self.dag.nodes)
    
    def sample(self, n_samples):

        
        x = torch.empty((n_samples, self.n_nodes)) # parents variables
        u = torch.empty((n_samples, self.n_nodes)) # noise variables
        
        for i_distr, distr in enumerate(self.noise_distrs):
            u[:, [i_distr]] = distr.sample(sample_shape=(n_samples, 1))

        for node, mechanism in zip(self.dag.nodes, self.mechanisms):
            x[:, [node]] = mechanism(x[:, self.dag.get_parents(node)], u[:, [node]])

        return x

    def intervene(self, index, parents_new, mechanism_new, noise_distr_new=None):
        """ Return new SCM, after intervention on one node.
        """

        # Sorted DAG
        dag = self.dag.rewire(index, parents_new)
        # Mechanisms
        mechanisms = copy.deepcopy(self.mechanisms)
        mechanisms[index] = mechanism_new
        # Noise distributions
        noise_distrs = copy.deepcopy(self.noise_distrs)
        if noise_distr_new is not None:
            noise_distrs[index] = noise_distr_new

        return SCM(dag, mechanisms, noise_distrs)

class LinearGaussianSCM(SCM):
    """ Linear Gaussian SCM with constant noise variances.
    """

    def __init__(self, dag):

        noise_distrs = []
        mechanisms = []
        for node in dag.nodes:
            noise_distrs.append(torch.distributions.normal.Normal(0., 1.))
            
            machanism = 
            
        
            

        
        