"""
Code modified from https://github.com/xunzheng/notears/blob/master/notears/utils.py
"""
import copy

import numpy as np
from scipy.special import expit as sigmoid
import igraph as ig
import random

from utils.utils import is_dag


def simulate_data(dag, samplenum, interv_targets, selection_parents,
                  noise_type='gauss', data_type='linear'):
    '''

    Returns: an ordered list of 1+len(interv_targets) numpy arrays,
             each of shape (samplenum, nodenum)

    '''
    assert data_type in {'linear', 'mlp'}
    nodenum = len(dag)

    selection_minmax_thresholds = []
    for _ in range(len(selection_parents)):
        min_percentile = np.random.uniform(low=10, high=70)
        max_percentile = np.random.uniform(low=min_percentile+20, high=min(min_percentile+50, 100)) # select 20% to 50% data
        selection_minmax_thresholds.append((min_percentile, max_percentile))

    if data_type == 'linear':
        B = simulate_linear_parameter(dag)
        noise_scale = np.sqrt(np.random.uniform(low=1, high=4, size=(nodenum,)))
        interv_targets = [()] + interv_targets    # Add observational data as first domain
        all_data = []

        initial_samplenum = 10 ** 5  # Use a very large sample size before selection
        while True:  # ensure that there will not be too few samples after selection
            X_before_selection, E_before_selection = simulate_linear_sem(B, initial_samplenum, noise_type, noise_scale)
            selection_mask = get_selection_mask(X_before_selection, selection_parents, selection_minmax_thresholds)
            num_of_selected_samples = np.sum(selection_mask)
            if num_of_selected_samples >= samplenum * len(interv_targets):
                E_selected = E_before_selection[selection_mask]
                break
            else:
                initial_samplenum *= 2

        print(f"Intervention target: {interv_targets}")
        for interv_id, interv_target in enumerate(interv_targets):
            X = regenerate_intervened_linear_sem(B, E_selected[samplenum*interv_id:samplenum*(interv_id+1)], interv_target)
            print(interv_id, X.shape)
            all_data.append(X)

    elif data_type == 'mlp':
        mlp_params = simulate_mlp_parameter(dag)
        noise_scale = np.sqrt(np.random.uniform(low=1, high=4, size=(nodenum,)))
        interv_targets = [()] + interv_targets    # Add observational data as first domain
        all_data = []

        initial_samplenum = 10 ** 5
        while True:  # ensure that there will not be too few samples after selection
            X_before_selection, E_before_selection = \
                simulate_mlp_sem(dag, mlp_params, initial_samplenum, noise_type, noise_scale)
            selection_mask = get_selection_mask(X_before_selection, selection_parents, selection_minmax_thresholds)
            num_of_selected_samples = np.sum(selection_mask)
            if num_of_selected_samples >= samplenum * len(interv_targets):
                E_selected = E_before_selection[selection_mask]
                break
            else:
                initial_samplenum *= 2

        for interv_id, interv_target in enumerate(interv_targets):
            X = regenerate_intervened_mlp_sem(dag, mlp_params, E_selected[samplenum*interv_id:samplenum*(interv_id+1)], interv_target)
            all_data.append(X)

    else:
        raise ValueError("Unknown data type.")
    return all_data



def get_selection_mask(X_star, selection_parents, selection_minmax_thresholds):
    n = len(X_star)
    selection_mask = np.ones((n)).astype(bool)
    for selection, (min_percentile, max_percentile) in zip(selection_parents, selection_minmax_thresholds):
        parents_sum = X_star[:, selection].sum(axis=1)
        random_noise_for_selection = np.random.normal(loc=0, scale=0.2*np.std(parents_sum), size=(n,))
        jittered_parents_sum = parents_sum + random_noise_for_selection # to prevent from a sharp selection boundary
        interval_min = np.percentile(jittered_parents_sum, min_percentile)
        interval_max = np.percentile(jittered_parents_sum, max_percentile)
        selection_mask = np.logical_and(selection_mask, jittered_parents_sum <= interval_max)
        selection_mask = np.logical_and(selection_mask, jittered_parents_sum >= interval_min)
    return selection_mask


def regenerate_intervened_linear_sem(B, E, interv_target):
    B = np.copy(B)
    d = len(B)
    if interv_target is not None and len(interv_target) > 0:
        B_bin = (B != 0).astype(int)
        B_temp = simulate_linear_parameter(B_bin)
        # Change new causal mechanism/function
        B[:, interv_target] = B_temp[:, interv_target]
    G = ig.Graph.Weighted_Adjacency(B.tolist())
    ordered_vertices = G.topological_sorting()
    X = np.zeros_like(E)
    for j in ordered_vertices:
        parents = G.neighbors(j, mode=ig.IN)
        X[:, j] = X[:, parents] @ B[parents, j] + E[:, j]
        if interv_target is not None and j in interv_target:
            # As suggested by Haoyue, add some small Gaussian noise (to
            # 1. reflect intervention randomness, and 2. so that root variables can also be intervened)
            mean = np.random.uniform(low=1, high=2) * np.random.choice([-1, 1])
            scale = np.sqrt(np.random.uniform(low=0.2, high=2))
            X[:, j] += np.random.normal(loc=mean, scale=scale, size=(len(X),))
    return X


def regenerate_intervened_mlp_sem(dag, mlp_params, E, interv_target):
    mlp_params = copy.deepcopy(mlp_params)
    d = len(dag)
    if interv_target is not None and len(interv_target) > 0:
        mlp_params_temp = simulate_mlp_parameter(dag)
        for j in interv_target:
            # Change new causal mechanism/function
            mlp_params[j] = mlp_params_temp[j]
        mlp_params_interv = mlp_params
    else:
        mlp_params_interv = mlp_params    # No intervention
    G = ig.Graph.Weighted_Adjacency(dag.tolist())
    ordered_vertices = G.topological_sorting()
    X = np.zeros_like(E)
    for j in ordered_vertices:
        parents = G.neighbors(j, mode=ig.IN)
        pa_size = len(parents)
        X[:, j] = E[:, j]
        if pa_size > 0:
            X[:, j] += sigmoid(X[:, parents] @ mlp_params_interv[j]['W1']) @ mlp_params_interv[j]['W2']
        if interv_target is not None and j in interv_target:
            # As suggested by Haoyue, add some small Gaussian noise
            mean = np.random.uniform(low=1, high=2) * np.random.choice([-1, 1])
            scale = np.sqrt(np.random.uniform(low=0.2, high=2))
            X[:, j] += np.random.normal(loc=mean, scale=scale, size=(len(X),))
    return X


def simulate_mlp_parameter(dag):
    """Simulate SEM parameters for a DAG.

    Args:
        B (np.ndarray): [d, d] binary adj matrix of DAG
        w_ranges (tuple): disjoint weight ranges

    Returns:
        W (np.ndarray): [d, d] weighted adj matrix of DAG
    """
    G = ig.Graph.Adjacency(dag.tolist())
    ordered_vertices = G.topological_sorting()
    d = len(dag)
    assert len(ordered_vertices) == d
    mlp_params = dict()
    for j in ordered_vertices:
        parents = G.neighbors(j, mode=ig.IN)
        pa_size = len(parents)
        if pa_size > 0:
            hidden = 100
            W1 = np.random.uniform(low=0.5, high=2.0, size=[pa_size, hidden])
            W1[np.random.rand(*W1.shape) < 0.5] *= -1
            W2 = np.random.uniform(low=0.5, high=2.0, size=hidden)
            W2[np.random.rand(hidden) < 0.5] *= -1
            mlp_params[j] = {'W1': W1, 'W2': W2}
        else:
            mlp_params[j] = None
    return mlp_params


def simulate_mlp_sem(dag, mlp_params, n, noise_type='gauss', noise_scale=None):
    """Simulate samples from nonlinear SEM.

    Args:
        B (np.ndarray): [d, d] binary adj matrix of DAG
        n (int): num of samples
        data_type (str): mlp, mim, gp, gp-add
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones

    Returns:
        X (np.ndarray): [n, d] sample matrix
    """
    def _simulate_single_equation(X, mlp_param, scale):
        """X: [n, num of parents], x: [n]"""
        if noise_type == 'gauss':
            z = np.random.normal(scale=scale, size=n)
        elif noise_type == 'exp':
            z = np.random.exponential(scale=scale, size=n)
        elif noise_type == 'gumbel':
            z = np.random.gumbel(scale=scale, size=n)
        elif noise_type == 'uniform':
            z = np.random.uniform(low=-scale, high=scale, size=n)
        else:
            raise ValueError('unknown noise type')
        pa_size = X.shape[1]
        if pa_size == 0:
            return z, z
        else:
            x = sigmoid(X @ mlp_param['W1']) @ mlp_param['W2'] + z
            return x, z
    d = len(dag)
    if noise_scale is None:
        scale_vec = np.ones(d)
    elif np.isscalar(noise_scale):
        scale_vec = noise_scale * np.ones(d)
    else:
        if len(noise_scale) != d:
            raise ValueError('noise scale must be a scalar or has length d')
        scale_vec = noise_scale
    X = np.zeros([n, d])
    E = np.zeros([n, d])
    G = ig.Graph.Adjacency(dag.tolist())
    ordered_vertices = G.topological_sorting()
    assert len(ordered_vertices) == d
    for j in ordered_vertices:
        parents = G.neighbors(j, mode=ig.IN)
        X[:, j], E[:, j] = _simulate_single_equation(X[:, parents], mlp_params[j], scale_vec[j])
    return X, E


def simulate_dag(d, s0, graph_type):
    """Simulate random DAG with some expected number of edges.

    Args:
        d (int): num of nodes
        s0 (int): expected num of edges
        graph_type (str): ER, SF, BP

    Returns:
        B (np.ndarray): [d, d] binary adj matrix of DAG
    """
    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)

    if graph_type == 'ER':
        # Erdos-Renyi
        G_und = ig.Graph.Erdos_Renyi(n=d, m=s0)
        B_und = _graph_to_adjmat(G_und)
        B = _random_acyclic_orientation(B_und)
    elif graph_type == 'SF':
        # Scale-free, Barabasi-Albert
        G = ig.Graph.Barabasi(n=d, m=int(round(s0 / d)), directed=True)
        B = _graph_to_adjmat(G)
    elif graph_type == 'BP':
        # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
        top = int(0.2 * d)
        G = ig.Graph.Random_Bipartite(top, d - top, m=s0, directed=True, neimode=ig.OUT)
        B = _graph_to_adjmat(G)
    else:
        raise ValueError('unknown graph type')
    B_perm = _random_permutation(B)
    assert ig.Graph.Adjacency(B_perm.tolist()).is_dag()
    return B_perm


def simulate_linear_parameter(B, w_ranges=((-2.0, -0.5), (0.5, 2.0))):
    """Simulate SEM parameters for a DAG.

    Args:
        B (np.ndarray): [d, d] binary adj matrix of DAG
        w_ranges (tuple): disjoint weight ranges

    Returns:
        W (np.ndarray): [d, d] weighted adj matrix of DAG
    """
    W = np.zeros(B.shape)
    S = np.random.randint(len(w_ranges), size=B.shape)  # which range
    for i, (low, high) in enumerate(w_ranges):
        U = np.random.uniform(low=low, high=high, size=B.shape)
        W += B * (S == i) * U
    return W


def simulate_linear_sem(W, n, noise_type='gauss', noise_scale=None):
    """Simulate samples from linear SEM with specified type of noise.

    For uniform, noise z ~ uniform(-a, a), where a = noise_scale.

    Args:
        W (np.ndarray): [d, d] weighted adj matrix of DAG
        n (int): num of samples, n=inf mimics population risk
        noise_type (str): gauss, exp, gumbel, uniform, logistic, poisson
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones

    Returns:
        X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
    """
    def _simulate_single_equation(X, w, scale):
        """X: [n, num of parents], w: [num of parents], x: [n]"""
        if noise_type == 'gauss':
            z = np.random.normal(scale=scale, size=n)
        elif noise_type == 'exp':
            z = np.random.exponential(scale=scale, size=n)
        elif noise_type == 'gumbel':
            z = np.random.gumbel(scale=scale, size=n)
        elif noise_type == 'uniform':
            z = np.random.uniform(low=-scale, high=scale, size=n)
        else:
            raise ValueError('unknown noise type')
        x = X @ w + z
        return x, z

    d = W.shape[0]
    if noise_scale is None:
        scale_vec = np.ones(d)
    elif np.isscalar(noise_scale):
        scale_vec = noise_scale * np.ones(d)
    else:
        if len(noise_scale) != d:
            raise ValueError('noise scale must be a scalar or has length d')
        scale_vec = noise_scale
    if not is_dag(W):
        raise ValueError('W must be a DAG')
    # if np.isinf(n):  # population risk for linear gauss SEM
    #     if noise_type == 'gauss':
    #         # make 1/d X'X = true cov
    #         X = np.sqrt(d) * np.diag(scale_vec) @ np.linalg.inv(np.eye(d) - W)
    #         return X
    #     else:
    #         raise ValueError('population risk not available')
    # empirical risk
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    ordered_vertices = G.topological_sorting()
    assert len(ordered_vertices) == d
    X = np.zeros([n, d])
    E = np.zeros([n, d])
    for j in ordered_vertices:
        parents = G.neighbors(j, mode=ig.IN)
        X[:, j], E[:, j] = _simulate_single_equation(X[:, parents], W[parents, j], scale_vec[j])
    return X, E


def simulate_nonlinear_sem(B, n, data_type, noise_scale=None):
    """Simulate samples from nonlinear SEM.

    Args:
        B (np.ndarray): [d, d] binary adj matrix of DAG
        n (int): num of samples
        data_type (str): mlp, mim, gp, gp-add
        noise_scale (np.ndarray): scale parameter of additive noise, default all ones

    Returns:
        X (np.ndarray): [n, d] sample matrix
    """
    def _simulate_single_equation(X, scale):
        """X: [n, num of parents], x: [n]"""
        z = np.random.normal(scale=scale, size=n)
        pa_size = X.shape[1]
        if pa_size == 0:
            return z
        if data_type == 'mlp':
            hidden = 100
            W1 = np.random.uniform(low=0.5, high=2.0, size=[pa_size, hidden])
            W1[np.random.rand(*W1.shape) < 0.5] *= -1
            W2 = np.random.uniform(low=0.5, high=2.0, size=hidden)
            W2[np.random.rand(hidden) < 0.5] *= -1
            x = sigmoid(X @ W1) @ W2 + z
        elif data_type == 'mim':
            w1 = np.random.uniform(low=0.5, high=2.0, size=pa_size)
            w1[np.random.rand(pa_size) < 0.5] *= -1
            w2 = np.random.uniform(low=0.5, high=2.0, size=pa_size)
            w2[np.random.rand(pa_size) < 0.5] *= -1
            w3 = np.random.uniform(low=0.5, high=2.0, size=pa_size)
            w3[np.random.rand(pa_size) < 0.5] *= -1
            x = np.tanh(X @ w1) + np.cos(X @ w2) + np.sin(X @ w3) + z
        elif data_type == 'gp':
            from sklearn.gaussian_process import GaussianProcessRegressor
            gp = GaussianProcessRegressor()
            x = gp.sample_y(X, random_state=None).flatten() + z
        elif data_type == 'gp-add':
            from sklearn.gaussian_process import GaussianProcessRegressor
            gp = GaussianProcessRegressor()
            x = sum([gp.sample_y(X[:, i, None], random_state=None).flatten()
                     for i in range(X.shape[1])]) + z
        else:
            raise ValueError('unknown sem type')
        return x

    d = B.shape[0]
    scale_vec = noise_scale if noise_scale else np.ones(d)
    X = np.zeros([n, d])
    G = ig.Graph.Adjacency(B.tolist())
    ordered_vertices = G.topological_sorting()
    assert len(ordered_vertices) == d
    for j in ordered_vertices:
        parents = G.neighbors(j, mode=ig.IN)
        X[:, j] = _simulate_single_equation(X[:, parents], scale_vec[j])
    return X