from typing import Iterable, Tuple
import networkx as nx
import numpy as np
from numba import njit
from scipy.sparse import csr_matrix
from scipy.stats import beta, gamma, bernoulli


def init_params(
    graph,
    n_hidden=0,
    max_corr=1.0,
    min_corr=0.25
):
    """
    This function will use the generated graph to initialize the parameters
    """
    causal_order = np.array(list(nx.topological_sort(graph)))
    nodes = len(causal_order)
    adj_mtx = nx.adjacency_matrix(graph).toarray()
    if isinstance(adj_mtx, csr_matrix):
        adj_mtx = adj_mtx.toarray()
    if n_hidden == 0:
        # linear model: mask with adjacency to depend only on parents
        weights = np.random.uniform(min_corr, max_corr, size=(nodes, nodes))
        weights *= 2 * np.random.binomial(1, 0.5, size=(nodes, nodes)) - 1
        weights *= adj_mtx
    else:
        # non-linear model: just sample several with the right mask to make a MLP
        weights_1 = np.random.normal(size=(nodes, nodes, n_hidden))

        for hid in range(n_hidden):
            weights_1[:, :, hid] *= adj_mtx
        # add more params to put all together
        weights_2 = np.random.normal(size=(nodes, n_hidden))
        weights = (weights_1, weights_2)
    return causal_order, weights


def _linear_sem(
    data: np.ndarray,
    weights: np.ndarray,
    node: int,
    dropout: float,
    noise_level: float,
    dependent_dropout: bool,
    uniform_noise: bool,
    is_obs: bool
):
    n_samples, n_nodes = data.shape
    if dropout == 0:
        obs = np.dot(data, weights[:, node])
    elif dependent_dropout:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, 1))
        obs = np.dot(data * mask, weights[:, node])
    else:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, n_nodes))
        obs = np.dot(data * mask, weights[:, node])
    
    # Add noise
    if is_obs:
        if uniform_noise:
            obs += np.random.uniform(-noise_level * 1.73, noise_level * 1.73, size=(n_samples,))
        else:
            obs += np.random.normal(0, noise_level, size=(n_samples,))
    return obs


def _nonlinear_sem(
    data: np.ndarray,
    weights_1: np.ndarray,
    weights_2: np.ndarray,
    node: int,
    dropout: float,
    noise_level: float,
    dependent_dropout: bool,
    uniform_noise: bool,
    is_obs: bool
):
    n_samples, n_nodes = data.shape
    if dropout == 0:
        temp = np.dot(data, weights_1[:, node])  # shape n_samples times n_hidden
    elif dependent_dropout:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, n_nodes))
        temp = np.einsum('ij,ij,jk->ik', data, mask, weights_1[:, node])  # shape n_samples times n_hidden
    else:
        mask = bernoulli.rvs(1 - dropout, size=(n_samples, n_nodes, weights_1.shape[2]))
        temp = np.einsum('ij,ijk,jk->ik', data, mask, weights_1[:, node])  # shape n_samples times n_hidden

    obs = np.dot(
        np.tanh(temp),
        weights_2[node],
    )
    if is_obs:
        if uniform_noise:
            obs += np.random.uniform(-noise_level * 1.73, noise_level * 1.73, size=(n_samples,))
        else:
            obs += np.random.normal(0, noise_level, size=(n_samples,))

    return obs


def simulate_data_linear(
    n_samples: int,
    weights: np.ndarray,
    causal_order: np.ndarray,
    targets: Iterable[int],
    obs_nodes: Iterable[int],
    hard=True,
    dependent_dropout=True,
    uniform_noise=False,
    noise_level=0.01,
    dropout=0.0
) -> np.ndarray:
    """Simulate data from a linear model
        The functions includes both types of interventions: hard/soft

    Args:
        n_samples (int): Number of samples to generate
        weights (np.ndarray): weight matrix
        causal_order (np.ndarray): causal order of the nodes and factors
        targets (Iterable[int]): intervention targets
        hard (bool, optional): Whether intervention is hard or soft. Defaults to True.
        dependent_dropout (bool, optional): Whether dropout the entire column of . Defaults to True.
        uniform_noise (bool, optional): _description_. Defaults to False.
        noise_level (float, optional): _description_. Defaults to 0.01.
        dropout (float, optional): _description_. Defaults to 0.0.

    Returns:
        np.ndarray: simulated data
    """

    n_nodes = causal_order.shape[0]
    data = np.zeros(shape=(n_samples, n_nodes))
    weights_intv = weights + np.random.uniform(1.0, 2.0, size=weights.shape) * np.random.choice([-1, 1], size=weights.shape)

    for node in causal_order:
        # each node is a function of its parents
        if node in targets:
            if hard:
                data[:, node] = np.random.normal(0, 1, size=(n_samples,))                
            else:
                data[:, node] = _linear_sem(
                    data,
                    weights_intv,
                    node,
                    dropout,
                    noise_level,
                    dependent_dropout,
                    uniform_noise,
                    is_obs=node in obs_nodes
                )
        else:
            data[:, node] = _linear_sem(
                data,
                weights,
                node,
                dropout,
                noise_level,
                dependent_dropout,
                uniform_noise,
                is_obs=node in obs_nodes
            )

    return data


def simulate_data_nn(
    n_samples,
    weights_1,
    weights_2,
    causal_order,
    targets: Iterable[int],
    obs_nodes: Iterable[int],
    hard=True,
    dependent_dropout=True,
    uniform_noise=False,
    noise_level=0.01,
    dropout=0.0,
):
    n_nodes = causal_order.shape[0]
    data = np.zeros(shape=(n_samples, n_nodes))
    weights_1_intv = weights_1 + np.random.uniform(1.0, 2.0, size=weights_1.shape) * np.random.choice([-1, 1], size=weights_1.shape)
    weights_2_intv = weights_2 + np.random.uniform(1.0, 2.0, size=weights_2.shape) * np.random.choice([-1, 1], size=weights_2.shape)

    for node in causal_order:
        # each node is a function of its parents
        if node in targets:
            if hard:
                data[:, node] = np.random.normal(0, 1, size=(n_samples,))                
            else:
                data[:, node] = _nonlinear_sem(
                    data,
                    weights_1_intv,
                    weights_2_intv,
                    node,
                    dropout,
                    noise_level,
                    dependent_dropout,
                    uniform_noise,
                    is_obs=node in obs_nodes
                )
        else:
            data[:, node] = _nonlinear_sem(
                data,
                weights_1,
                weights_2,
                node,
                dropout,
                noise_level,
                dependent_dropout,
                uniform_noise,
                is_obs=node in obs_nodes
            )

    return data
