import random
import numpy as np
import sdeint
import igraph as ig
from scipy.integrate import odeint
from scipy.special import expit as sigmoid


def make_var_stationary(beta, radius=0.97):
    """Rescale coefficients of VAR model to make stable."""
    p = beta.shape[0]
    lag = beta.shape[1] // p
    bottom = np.hstack((np.eye(p * (lag - 1)), np.zeros((p * (lag - 1), p))))
    beta_tilde = np.vstack((beta, bottom))
    eigvals = np.linalg.eigvals(beta_tilde)
    max_eig = max(np.abs(eigvals))
    nonstationary = max_eig > radius
    if nonstationary:
        return make_var_stationary(0.95 * beta, radius)
    else:
        return beta

def simulate_var(B, n, p, lags, ins, noise_scale=None):
    noise_scale = noise_scale if noise_scale else 1
    if ins:
        W = np.random.uniform(low=0.5, high=2.0, size=B[:, p * lags:].shape)
        W[np.random.rand(*W.shape) < 0.5] *= -1
        beta = B[:, p * lags:] * W
        beta = make_var_stationary(beta.T)
        # Generate data.
        contemp_dag = ig.Graph.Adjacency(B[-p:, -p:].tolist())
        contemp_causal_order = contemp_dag.topological_sorting()
        burn_in = 100
        errors = np.random.normal(scale=noise_scale, size=(p, n + burn_in))
        X = np.zeros((p, n + burn_in))
        X[:, :lags] = errors[:, :lags]
        for t in range(lags, n + burn_in):
            for j in contemp_causal_order:
                X[j, t] = np.dot(beta[j, :], X[:, (t - lags): t+1].flatten(order="F"))
                X[j, t] += errors[j, t]
    else:
        W = np.random.uniform(low=0.5, high=2.0, size=B[:p * lags, p * lags:].shape)
        W[np.random.rand(*W.shape) < 0.5] *= -1
        beta = B[:p * lags, p * lags:] * W
        beta = make_var_stationary(beta.T)
        burn_in = 100
        errors = np.random.normal(scale=noise_scale, size=(p, n + burn_in))
        X = np.zeros((p, n + burn_in))
        X[:, :lags] = errors[:, :lags]
        for t in range(lags, n + burn_in):
            X[:, t] = np.dot(beta, X[:, (t - lags): t].flatten(order="F"))
            X[:, t] += errors[:, t]
    return X.T[burn_in:]

def simulate_dy_var(B, b, n, p, lags, ins, noise_scale=None):
    noise_scale = noise_scale if noise_scale else 1
    if ins:
        W = np.random.uniform(low=0.5, high=2.0, size=B[:, p * lags:].shape)
        W[np.random.rand(*W.shape) < 0.5] *= -1
        beta = B[:, p * lags:] * W
        beta = make_var_stationary(beta.T)
        # Generate data.
        contemp_dag = ig.Graph.Adjacency(B[-p:, -p:].tolist())
        contemp_causal_order = contemp_dag.topological_sorting()
        burn_in = 100
        errors = np.random.normal(scale=noise_scale, size=(b, p, n + burn_in))
        X = np.zeros((b, p, n + burn_in))
        X[:, :, :lags] = errors[:, :, :lags]
        a = np.random.rand(*beta.shape)
        betas = []
        for t in range(lags, n + burn_in):
            if t > burn_in:
                beta_t = np.where(a > 0.5, beta*np.cos(np.pi / n * (t-burn_in)), beta*np.sin(np.pi / n * (t-burn_in)))
                beta_t[abs(beta_t) < 0.2] = 0
                # beta_t = np.where(a > 0.5, beta - beta * (t - burn_in) / n, beta * (t - burn_in) / n)
                betas.append(beta_t.T)
            else:
                beta_t = beta
            for j in contemp_causal_order:
                for i in range(b):
                    X[i, j, t] = np.dot(beta_t[j, :], X[i, :, (t - lags): t + 1].flatten(order="F"))
                    X[i, j, t] += errors[i, j, t]
    else:
        W = np.random.uniform(low=0.5, high=2.0, size=B[:p * lags, p * lags:].shape)
        W[np.random.rand(*W.shape) < 0.5] *= -1
        beta = B[:p * lags, p * lags:] * W
        beta = make_var_stationary(beta.T)
        burn_in = 100
        errors = np.random.normal(scale=noise_scale, size=(b, p, n + burn_in))
        X = np.zeros((b, p, n + burn_in))
        X[:, :, :lags] = errors[:, :, :lags]
        a = np.random.rand(*beta.shape)
        betas = []
        for t in range(lags, n + burn_in):
            if t > burn_in:
                beta_t = np.where(a > 0.5, beta*np.cos(np.pi / n * (t-burn_in)), beta*np.sin(np.pi / n * (t-burn_in)))
                beta_t[abs(beta_t) < 0.2] = 0
                # beta_t = np.where(a > 0.5, beta - beta * (t - burn_in) / n, beta * (t - burn_in) / n)
                betas.append(beta_t.T)
            else:
                beta_t = beta
            for i in range(b):
                X[i, :, t] = np.dot(beta_t, X[i, :, (t - lags): t].flatten(order="F"))
                X[i, :, t] += errors[i, :, t]
    return X.transpose((0, 2, 1))[:, burn_in:, :], betas

def simulate_ins_dy_var(B, b, n, p, lags, noise_scale=None):
    noise_scale = noise_scale if noise_scale else 1
    W = np.random.uniform(low=0.5, high=2.0, size=B[:, p * lags:].shape)
    W[np.random.rand(*W.shape) < 0.5] *= -1
    beta = B[:, p * lags:] * W
    beta = make_var_stationary(beta.T)
    # Generate data.
    contemp_dag = ig.Graph.Adjacency(B[-p:, -p:].tolist())
    contemp_causal_order = contemp_dag.topological_sorting()
    burn_in = 100
    errors = np.random.normal(scale=noise_scale, size=(b, p, n + burn_in))
    X = np.zeros((b, p, n + burn_in))
    X[:, :, :lags] = errors[:, :, :lags]
    a = np.random.rand(*beta[:, p * lags:].shape)
    betas = []
    for t in range(lags, n + burn_in):
        if t > burn_in:
            beta_ins = np.where(a > 0.5, beta[:, p * lags:]*np.cos(np.pi / n * (t-burn_in)), beta[:, p * lags:]*np.sin(np.pi / n * (t-burn_in)))
            beta_t = np.concatenate((beta[:, :p * lags], beta_ins), axis=1)
            beta_t[abs(beta_t) < 0.2] = 0
            betas.append(beta_t.T)
        else:
            beta_t = beta
        for j in contemp_causal_order:
            for i in range(b):
                X[i, j, t] = np.dot(beta_t[j, :], X[i, :, (t - lags): t + 1].flatten(order="F"))
                X[i, j, t] += errors[i, j, t]
    return X.transpose((0, 2, 1))[:, burn_in:, :], betas

def simulate_dag(p, e, graph_type, lags, ins, es):

    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)

    d_total = p * (lags + 1)
    B_time = np.zeros((d_total, d_total))
    if ins:
        if graph_type == 'ER':
            # Erdos-Renyi
            G_und = ig.Graph.Erdos_Renyi(n=p, m=p * e)
            B_und = _graph_to_adjmat(G_und)
            B = _random_acyclic_orientation(B_und)
        elif graph_type == 'SF' or graph_type == 'BA':
            # Scale-free, Barabasi-Albert
            G = ig.Graph.Barabasi(n=p, m=int(round(e)), directed=True)
            B = _graph_to_adjmat(G)
        elif graph_type == 'BP':
            # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
            top = int(0.2 * p)
            G = ig.Graph.Random_Bipartite(top, p - top, m=p * e, directed=True, neimode=ig.OUT)
            B = _graph_to_adjmat(G)
        else:
            raise ValueError('unknown graph type')
        B_perm = _random_permutation(B)
        if lags == 0:
            assert ig.Graph.Adjacency(B_perm.tolist()).is_dag()
            return B_perm
        B_time[-p:, -p:] = B_perm
    for lag in range(lags, 0, -1):
        for from_node in range(0, p):
            from_node_index = p * (lags - lag) + from_node
            for to_node_index in range(-p, 0, 1):
                random_number = np.random.uniform(low=0.0, high=1.0)
                threshold = 1.0 / p * es[-lag]
                if random_number <= threshold:
                    B_time[from_node_index, to_node_index] = 1
    assert ig.Graph.Adjacency(B_time.tolist()).is_dag()
    return B_time

def simulate_nonlinear_sem(B, n, sem_type, p, lags, ins, noise_scale=None):
    def _simulate_single_equation_temporal(X, scale, w1=None, w2=None, w3=None):
        assert X.shape[0] == 1
        if sem_type == 'GLMPoissonDiscrete':
            z = np.random.randint(low=0, high=3, size=1)
        else:
            z = np.random.normal(scale=scale, size=1)
        pa_size = X.shape[1]
        if pa_size == 0:
            return z
        if sem_type == 'AdditiveNoiseModel':
            x = sigmoid(X @ w1) @ w2 + z
        elif sem_type == 'AdditiveIndexModel':
            x = np.tanh(X @ w1) + np.cos(X @ w2) + np.sin(X @ w3) + z
        elif sem_type == 'GLMPoissonDiscrete':
            link_function = np.tanh(X @ w1) + 2
            x = np.random.poisson(link_function, size=1)
            if link_function <= 0:
                raise Exception("Link function must be positive but the link function is equal to: {}".format(link_function))
        else:
            raise ValueError('unknown sem type')
        return x

    scale_vec = noise_scale if noise_scale else 1 * np.ones(p)
    if lags == 0:
        raise Exception("time series data only.")
    else:
        assert B.shape[0] == B.shape[1]
        assert B.shape[0] > p
        assert B.shape[0] % p == 0
        G_all = ig.Graph.Adjacency(B.tolist())
        if ins:
            contemp_dag = ig.Graph.Adjacency(B[-p:, -p:].tolist())
            contemp_causal_order = contemp_dag.topological_sorting()
            assert len(contemp_causal_order) == p
        else:
            contemp_causal_order = [_ for _ in range(p)]
        transient = int(.2 * n)
        data = np.zeros((n + transient, p))
        for t in range(lags):
            for j in range(p):
                parents = []
                data[t, j] = _simulate_single_equation_temporal(data[t, parents].reshape(1, -1), scale_vec[j])
        w_dict = dict()
        for t in range(lags, n + transient):
            for j in contemp_causal_order:
                parents_all = G_all.neighbors(j + p * lags, mode=ig.IN)
                parents_lagged = []
                for parent_index in parents_all:
                    max_lagged_parent_index = p * (lags + 1) - p - 1
                    if parent_index <= max_lagged_parent_index:
                        parents_lagged.append(parent_index)
                if ins:
                    parents_contemp = contemp_dag.neighbors(j, mode=ig.IN)
                else:
                    parents_contemp = []
                data_contemp = data[t, parents_contemp]
                data_lagged = []
                for parent_lagged in parents_lagged:
                    lag = lags - (parent_lagged // p)
                    corresponding_contemp = parent_lagged % p
                    data_current_lag = data[t - lag, corresponding_contemp]
                    data_lagged.append(data_current_lag)
                parents_data = np.array(data_lagged + data_contemp.tolist()).reshape((1, -1))
                w1, w2, w3 = None, None, None
                if w_dict.get(j) is None:
                    pa_size = parents_data.shape[1]
                    if sem_type == 'AdditiveNoiseModel':
                        hidden = 100
                        w1 = np.random.uniform(low=0.5, high=2.0, size=[pa_size, hidden])
                        w1[np.random.rand(*w1.shape) < 0.5] *= -1
                        w2 = np.random.uniform(low=0.5, high=2.0, size=hidden)
                        w2[np.random.rand(hidden) < 0.5] *= -1
                    elif sem_type == 'AdditiveIndexModel' or sem_type == 'GLMPoissonDiscrete':
                        w1 = np.random.uniform(low=0.5, high=2.0, size=pa_size)
                        w1[np.random.rand(pa_size) < 0.5] *= -1
                        w2 = np.random.uniform(low=0.5, high=2.0, size=pa_size)
                        w2[np.random.rand(pa_size) < 0.5] *= -1
                        w3 = np.random.uniform(low=0.5, high=2.0, size=pa_size)
                        w3[np.random.rand(pa_size) < 0.5] *= -1
                    else:
                        raise Exception("sem_type '{}' is not supported.".format(sem_type))
                    w_sub_dict = dict()
                    w_sub_dict["w1"] = w1
                    w_sub_dict["w2"] = w2
                    w_sub_dict["w3"] = w3
                    assert (w1 is not None or w2 is not None or w3 is not None)
                    w_dict[j] = w_sub_dict
                else:
                    w1 = w_dict[j]["w1"]
                    w2 = w_dict[j]["w2"]
                    w3 = w_dict[j]["w3"]
                    assert (w1 is not None or w2 is not None or w3 is not None)
                data[t, j] = _simulate_single_equation_temporal(parents_data, scale_vec[j], w1, w2, w3)
        data = data[transient:, :]
        return data

def lorenz(x, t, F=5):
    """Partial derivatives for Lorenz-96 ODE."""
    p = len(x)
    dxdt = np.zeros(p)
    for i in range(p):
        dxdt[i] = (x[(i + 1) % p] - x[(i - 2) % p]) * x[(i - 1) % p] - x[i] + F

    return dxdt

def simulate_lorenz_96(p, T, sigma=0.05, F=10.0, delta_t=0.01, sd=0.1, burn_in=1000, seed=None):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    def GG(x, t):
        p = len(x)
        return np.diag([sigma] * p)

    # Use scipy to solve ODE.
    x0 = np.random.normal(scale=0.01, size=p)
    t = np.linspace(0, (T + burn_in) * delta_t, T + burn_in)
    # X = odeint(lorenz, x0, t, args=(F,))
    # X += np.random.normal(scale=sd, size=(T + burn_in, p))
    X = sdeint.itoint(lorenz, GG, x0, t)

    # Set up Granger causality ground truth.
    GC = np.zeros((p, p), dtype=int)
    for i in range(p):
        GC[i, i] = 1
        GC[i, (i + 1) % p] = 1
        GC[i, (i - 1) % p] = 1
        GC[i, (i - 2) % p] = 1

    return X[burn_in:], GC

def rossler(x, t, a=0, eps=0.1, b=4, d=2):
    """Partial derivatives for rossler ODE."""
    p = len(x)
    dxdt = np.zeros(p)
    dxdt[0] = a * x[0] - x[1]
    dxdt[p - 2] = x[(p - 3)]
    dxdt[p - 1] = eps + b * x[(p - 1)] * (x[(p - 2)] - d)

    for i in range(1, p - 2):
        dxdt[i] = np.sin(x[(i - 1)]) - np.sin(x[(i + 1)])

    return dxdt

def simulate_rossler(p, T, sigma=0.5, a=0, eps=0.1, b=4, d=2, delta_t=0.05, sd=0.1, burn_in=1000, seed=None):
    if seed is not None:
        np.random.seed(seed)

    def GG(x, t):
        p = len(x)
        return np.diag([sigma] * p)

    # Use scipy to solve ODE.
    x0 = np.random.normal(scale=0.01, size=p)
    t = np.linspace(0, (T + burn_in) * delta_t, T + burn_in)
    # X = odeint(rossler, x0, t, args=(a,eps,b,d,))
    # X += np.random.normal(scale=sd, size=(T + burn_in, p))

    X = sdeint.itoint(rossler, GG, x0, t)

    # Set up Granger causality ground truth.
    GC = np.zeros((p, p), dtype=int)
    GC[0, 0] = 1
    GC[0, 1] = 1
    GC[p - 2, p - 3] = 1
    GC[p - 1, p - 1] = 1
    GC[p - 1, p - 2] = 1
    for i in range(1, p - 2):
        # GC[i, i] = 1
        GC[i, (i + 1)] = 1
        GC[i, (i - 1)] = 1

    return 400 * X[burn_in:], GC

def glycolytic(
    x, t, k1=0.52, K1=100, K2=6, K3=16, K4=100, K5=1.28, K6=12, K=1.8, kappa=13, phi=0.1, q=4, A=4, N=1, J0=2.5
):
    """Partial derivatives for Glycolytic oscillator model.

    source:
    https://www.pnas.org/content/pnas/suppl/2016/03/23/1517384113.DCSupplemental/pnas.1517384113.sapp.pdf

    Args:
    - r (np.array): vector of self-interaction
    - alpha (pxp np.array): matrix of interactions"""
    dxdt = np.zeros(7)

    dxdt[0] = J0 - (K1 * x[0] * x[5]) / (1 + (x[5] / k1) ** q)
    dxdt[1] = (2 * K1 * x[0] * x[5]) / (1 + (x[5] / k1) ** q) - K2 * x[1] * (N - x[4]) - K6 * x[1] * x[4]
    dxdt[2] = K2 * x[1] * (N - x[4]) - K3 * x[2] * (A - x[5])
    dxdt[3] = K3 * x[2] * (A - x[5]) - K4 * x[3] * x[4] - kappa * (x[3] - x[6])
    dxdt[4] = K2 * x[1] * (N - x[4]) - K4 * x[3] * x[4] - K6 * x[1] * x[4]
    dxdt[5] = (-2 * K1 * x[0] * x[5]) / (1 + (x[5] / k1) ** q) + 2 * K3 * x[2] * (A - x[5]) - K5 * x[5]
    dxdt[6] = phi * kappa * (x[3] - x[6]) - K * x[6]

    return dxdt

def simulate_glycolytic(T, sigma=0.5, delta_t=0.001, sd=0.01, burn_in=1000, seed=None, scale=True):
    if seed is not None:
        np.random.seed(seed)

    def GG(x, t):
        p = len(x)
        return np.diag([sigma] * p)

    x0 = np.zeros(7)
    x0[0] = np.random.uniform(0.15, 1.6)
    x0[1] = np.random.uniform(0.19, 2.16)
    x0[2] = np.random.uniform(0.04, 0.2)
    x0[3] = np.random.uniform(0.1, 0.35)
    x0[4] = np.random.uniform(0.08, 0.3)
    x0[5] = np.random.uniform(0.14, 2.67)
    x0[6] = np.random.uniform(0.05, 0.1)

    # Use scipy to solve ODE.
    t = np.linspace(0, (T + burn_in) * delta_t, T + burn_in)
    # X = odeint(glycolytic, x0, t)
    # X += np.random.normal(scale=sd, size=(T + burn_in, 7))

    X = sdeint.itoint(glycolytic, GG, x0, t)

    # Set up ground truth.
    GC = np.zeros((7, 7), dtype=int)
    GC[0, :] = np.array([1, 0, 0, 0, 0, 1, 0])
    GC[1, :] = np.array([1, 1, 0, 0, 1, 1, 0])
    GC[2, :] = np.array([0, 1, 1, 0, 1, 1, 0])
    GC[3, :] = np.array([0, 0, 1, 1, 1, 1, 1])
    GC[4, :] = np.array([0, 1, 0, 0, 1, 1, 0])
    GC[5, :] = np.array([1, 1, 0, 0, 0, 1, 0])
    GC[6, :] = np.array([0, 0, 0, 1, 0, 0, 1])

    if scale:
        X = np.transpose(
            np.array([(X[:, i] - X[:, i].min()) / (X[:, i].max() - X[:, i].min()) for i in range(X.shape[1])])
        )

    return 10 * X[burn_in:], GC