from typing import Any, Dict, Iterable, List, Tuple
import networkx as nx
import numpy as np
from numba import njit
from scipy.stats import beta, gamma, bernoulli
from scipy.special import expit
import pdb


def init_params(graph, n_hidden=0, max_corr=1.0, min_corr=0.25):
    """
    This function will use the generated graph to initialize the parameters.
        Graph should include nodes, modules and intervention nodes.
    """
    causal_order = np.array(list(nx.topological_sort(graph)))
    nodes = len(causal_order)
    if n_hidden == 0:
        # linear model: mask with adjacency to depend only on parents
        weights = np.random.uniform(min_corr, max_corr, size=(nodes, nodes))
        weights *= 2 * np.random.binomial(1, 0.5, size=(nodes, nodes)) - 1
        weights *= nx.adjacency_matrix(graph).toarray()
    else:
        # non-linear model: just sample several with the right mask to make a MLP
        weights_1 = np.random.normal(size=(nodes, nodes, n_hidden))

        for hid in range(n_hidden):
            weights_1[:, :, hid] *= nx.adjacency_matrix(graph).toarray()
        # add more params to put all together
        weights_2 = np.random.normal(size=(nodes, n_hidden))
        weights = (weights_1, weights_2)
    return causal_order, weights


def _linear_sem(
    data: np.ndarray,
    weights: np.ndarray,
    node: int,
    dropout: float,
    noise_level: float,
    dependent_dropout: bool,
    uniform_noise: bool,
    is_obs: bool
):
    n_samples, n_nodes = data.shape
    if dropout == 0:
        obs = np.dot(data, weights[:, node])
    elif dependent_dropout:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, 1))
        obs = np.dot(data * mask, weights[:, node])
    else:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, n_nodes))
        obs = np.dot(data * mask, weights[:, node])
    
    # Add noise
    if is_obs:
        if uniform_noise:
            obs += np.random.uniform(-noise_level * 1.73, noise_level * 1.73, size=(n_samples,))
        else:
            obs += np.random.normal(0, noise_level, size=(n_samples,))
    return obs


def _nonlinear_sem(
    data: np.ndarray,
    weights_1: np.ndarray,
    weights_2: np.ndarray,
    node: int,
    dropout: float,
    noise_level: float,
    dependent_dropout: bool,
    uniform_noise: bool,
    is_obs: bool
):
    n_samples, n_nodes = data.shape
    if dropout == 0:
        temp = np.dot(data, weights_1[:, node])  # shape n_samples times n_hidden
    elif dependent_dropout:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, n_nodes))
        temp = np.einsum('ij,ij,jk->ik', data, mask, weights_1[:, node])  # shape n_samples times n_hidden
    else:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, n_nodes, weights_1.shape[2]))
        temp = np.einsum('ij,ijk,jk->ik', data, mask, weights_1[:, node])  # shape n_samples times n_hidden
    obs = np.dot(
        np.tanh(temp),
        weights_2[node],
    )
    if is_obs:
        if uniform_noise:
            obs += np.random.uniform(-noise_level * 1.73, noise_level * 1.73, size=(n_samples,))
        else:
            obs += np.random.normal(0, noise_level, size=(n_samples,))
    return obs


def simulate_data_linear(
    n_samples,
    weights,
    causal_order,
    targets: Iterable[int],
    intv_src,
    intv_nodes,
    obs_nodes,
    hard=False,
    alpha=1.0,
    scale=1.0,
    noise_level=0.01,
    dropout=0.0,
    dependent_dropout=True,
    uniform_noise=False
) -> Tuple[np.ndarray, np.ndarray]:
    """ Simulate data from a linear model

    Args:
        n_samples (int): Number of samples to generate
        weights (Any): weight matrix
        causal_order (Iterable[int]): causal order of the nodes and factors
        targets (Iterable[int]): set of intervened nodes
        intv_src (Iterable[int]): set of activated intervention nodes
        intv_nodes (Iterable[int]): set of intervention nodes
        hard (bool, optional): whether to use hard intervention. Defaults to False.
        alpha (float, optional): Gamma distribution shape. Defaults to 0.1.
        scale (float, optional): Gamma distribution scale. Defaults to 0.1.
        noise_level (float, optional): Gaussian noise std. Defaults to 0.01.
        dropout (float, optional): Dropout rate. Defaults to 0.0.
        dependent_dropout (bool, optional): Whether to drop entire columns. Defaults to True.
        uniform_noise (bool, optional): Whether to use uniform noise. Defaults to False.

    Returns:
        Tuple[np.ndarray, np.ndarray]: data and dose
    """
    n_nodes = causal_order.shape[0]
    data = np.zeros(shape=(n_samples, n_nodes))

    for node in causal_order:
        # each node is a function of its parents
        if node in intv_nodes:  # intervention nodes
            if hard:
                data[:, node] = 1.0 if node in intv_src else 0.0
            else:
                data[:, node] = gamma(alpha, scale=scale).rvs(size=(n_samples,)) if node in intv_src else 0.0
        elif node in targets:  # intervened modules
            if hard:
                data[:, node] = np.random.normal(0, 1, size=(n_samples,))
            else:
                data[:, node] = _linear_sem(
                    data,
                    weights,
                    node,
                    dropout,
                    noise_level,
                    dependent_dropout,
                    uniform_noise,
                    node in obs_nodes
                )
        else:  # unintervened nodes
            data[:, node] = _linear_sem(
                data,
                weights,
                node,
                dropout,
                noise_level,
                dependent_dropout,
                uniform_noise,
                node in obs_nodes
            )
    return data


def simulate_data_nn(
    n_samples,
    weights_1,
    weights_2,
    causal_order,
    targets,
    intv_src,
    intv_nodes,
    obs_nodes,
    hard=False,
    alpha=1.0,
    scale=1.0,
    noise_level=0.01,
    dropout=0.0,
    dependent_dropout=True,
    uniform_noise=False
):
    n_nodes = causal_order.shape[0]
    data = np.zeros(shape=(n_samples, n_nodes))

    for node in causal_order:
        # each node is a function of its parents
        if node in intv_nodes:  # intervention nodes
            if hard:
                data[:, node] = 1.0 if node in intv_src else 0.0
            else:
                data[:, node] = gamma(alpha, scale=scale).rvs(size=(n_samples,)) if node in intv_src else 0.0
        elif node in targets:  # intervened modules
            if hard:
                data[:, node] = np.random.normal(0, 1, size=(n_samples,))
            else:
                data[:, node] = _nonlinear_sem(
                    data,
                    weights_1,
                    weights_2,
                    node,
                    dropout,
                    noise_level,
                    dependent_dropout,
                    uniform_noise,
                    node in obs_nodes
                )
        else:
            data[:, node] = _nonlinear_sem(
                data,
                weights_1,
                weights_2,
                node,
                dropout,
                noise_level,
                dependent_dropout,
                uniform_noise,
                node in obs_nodes
            )
    return data
