import numpy as np


def kci_paper_type_2(n, d=1, is_type2=False, k=1):
    x = np.random.normal(0, 1, (n, 1))
    y = np.random.normal(0, 1, (n, 1))
    if k == 1:
        z = np.random.normal(0, 1, (n, 5))
    if k == 3:
        norm_params = np.array([[0, 1],
                                [2, 1],
                                [-2, 1]])
        pi = np.array([0.4, 0.3, 0.3])
        mixture_idx = np.random.choice(range(3), size=5*n, p=pi)
        z = np.fromiter((np.random.normal(*(norm_params[i])) for i in mixture_idx),
                        dtype=np.float64).reshape(-1, 5)
        z = (z - np.mean(z, axis=0)) / np.std(z, axis=0)
    if k == 10:
        norm_params = np.array([[-4, 1], [-3, 1], [-2, 1],
                                [-1, 1], [0, 1], [1, 1],
                                [2, 1], [3, 1], [4, 1], [5, 1]])
        pi = np.array([.08, .08, .11, .13, .1, .09, .11, .1, .12, .08])
        mixture_idx = np.random.choice(range(10), size=5*n, p=pi)
        z = np.fromiter((np.random.normal(*(norm_params[i])) for i in mixture_idx),
                        dtype=np.float64).reshape(-1, 5)
        z = (z - np.mean(z, axis=0)) / np.std(z, axis=0)
    z1 = z[:, 0].reshape(-1, 1)
    zz1 = .7 * ((z1**3) / 5 + z1/2)
    x = zz1 + np.tanh(x)
    x = x + (x**3) / 3 + np.tanh(x/3) / 2
    zz2 = (z1**3 / 4 + z1) / 3
    y = y + zz2
    y = y + np.tanh(y/3)
    if d == 1:
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(y)) / np.std(y)
        if is_type2 is True:
            ff = 0.15 * np.random.normal(size=(n, 1))
            x += ff
            y += ff
        return x, y, z[:, 0].reshape(-1, 1)

    x = np.random.normal(0, 1, (n, 1))
    y = np.random.normal(0, 1, (n, 1))
    z2 = z[:, 1].reshape(-1, 1)
    zz1_2 = zz1 / 2 + z2
    zz1_2 = zz1_2 / 2 + .7 * np.tanh(zz1_2)
    x = zz1_2 + np.tanh(x)
    x = x + (x**3) / 3 + np.tanh(x / 3) / 2
    zz2_2 = zz2 / 2 + z2
    zz2_2 = zz2_2 / 2 + .7 * np.tanh(zz2_2)
    y = y + zz2_2
    y = y + np.tanh(y / 3)
    if d == 2:
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(y)) / np.std(y)
        z = (z - np.mean(z, axis=0)) / np.std(z, axis=0)
        if is_type2 is True:
            ff = 0.5 * np.random.normal(size=(n, 1))
            x += ff
            y += ff
        return x, y, z[:, :2]

    x = np.random.normal(0, 1, (n, 1))
    y = np.random.normal(0, 1, (n, 1))
    z3 = z[:, 2].reshape(-1, 1)
    zz1_3 = zz1_2*2/3 + z3*5/6
    zz1_3 = zz1_3/2 + .7 * np.tanh(zz1_3)
    x = zz1_3 + np.tanh(x)
    x = x + (x**3)/3 + np.tanh(x/3)/2
    zz2_3 = zz2_2*2/3 + z3*5/6
    zz2_3 = zz2_3/2 + .7 * np.tanh(zz2_3)
    y = y + zz2_3
    y = y + np.tanh(y/3)
    if d == 3:
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(y)) / np.std(y)
        z = (z - np.mean(z, axis=0)) / np.std(z, axis=0)
        if is_type2 is True:
            ff = 0.5 * np.random.normal(size=(n, 1))
            x += ff
            y += ff
        return x, y, z[:, :3]

    x = np.random.normal(0, 1, (n, 1))
    y = np.random.normal(0, 1, (n, 1))
    z4 = z[:, 3].reshape(-1, 1)
    zz1_4 = zz1_3 * 2/3 + z4 * 5/6
    zz1_4 = zz1_4 / 2 + 0.7 * np.tanh(zz1_4)
    x = zz1_4 + np.tanh(x)
    x = x + (x**3) / 3 + np.tanh(x/3) / 2
    zz2_4 = zz2_3 * 2/3 + z4 * 5/6
    zz2_4 = zz2_4 / 2 + 0.7 * np.tanh(zz2_4)
    y = y + zz2_4
    y = y + np.tanh(y/3)
    if d == 4:
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(y)) / np.std(y)
        z = (z - np.mean(z, axis=0)) / np.std(z, axis=0)
        if is_type2 is True:
            ff = 0.5 * np.random.normal(size=(n, 1))
            x += ff
            y += ff
        return x, y, z[:, :4]

    x = np.random.normal(0, 1, (n, 1))
    y = np.random.normal(0, 1, (n, 1))
    z5 = z[:, 4].reshape(-1, 1)
    zz1_5 = zz1_4 * 2/3 + z5 * 5/6
    zz1_5 = zz1_5 / 2 + 0.7 * np.tanh(zz1_5)
    x = zz1_5 + np.tanh(x)
    x = x + (x**3) / 3 + np.tanh(x/3) / 2
    zz2_5 = zz2_4 * 2/3 + z5 * 5/6
    zz2_5 = zz2_5 / 2 + 0.7 * np.tanh(zz2_5)
    y = y + zz2_5
    y = y + np.tanh(y/3)
    if d == 5:
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(y)) / np.std(y)
        z = (z - np.mean(z, axis=0)) / np.std(z, axis=0)
        if is_type2 is True:
            ff = 0.5 * np.random.normal(size=(n, 1))
            x += ff
            y += ff
        return x, y, z


def type2_x_cause_y(n, k=1):
    if k == 1:
        z = np.random.normal(0, 1, (n, 1))
        x = np.random.normal(0, 1, (n, 1))
        y = np.random.normal(0, 1, (n, 1))
    if k == 3:
        norm_params = np.array([[0, 1],
                                [2, 1],
                                [-2, 1]])
        pi = np.array([0.4, 0.3, 0.3])
        mixture_idx = np.random.choice(range(3), size=5*n, p=pi)
        xyz = np.fromiter((np.random.normal(*(norm_params[i])) for i in mixture_idx),
                          dtype=np.float64).reshape(-1, 5)
        x = xyz[:, 0].reshape(-1, 1)
        y = xyz[:, 1].reshape(-1, 1)
        z = xyz[:, 2].reshape(-1, 1)
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(x)) / np.std(x)
        z = (z - np.mean(x)) / np.std(x)

    if k == 10:
        norm_params = np.array([[-4, 1], [-3, 1], [-2, 1],
                                [-1, 1], [0, 1], [1, 1],
                                [2, 1], [3, 1], [4, 1], [5, 1]])
        pi = np.array([.08, .08, .11, .13, .1, .09, .11, .1, .12, .08])
        mixture_idx = np.random.choice(range(10), size=5*n, p=pi)
        xyz = np.fromiter((np.random.normal(*(norm_params[i])) for i in mixture_idx),
                          dtype=np.float64).reshape(-1, 5)
        x = xyz[:, 0].reshape(-1, 1)
        y = xyz[:, 1].reshape(-1, 1)
        z = xyz[:, 2].reshape(-1, 1)
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(x)) / np.std(x)
        z = (z - np.mean(x)) / np.std(x)

    zz1 = .7 * ((z**3) / 5 + z/2)
    x = zz1 + np.tanh(x)
    xx1 = x/2 + .7 * np.tanh(x)
    x = x + (x**3) / 3 + np.tanh(x/3) / 2
    zz2 = (z**3 / 4 + z) / 3
    y = y + zz2
    y = y + np.tanh(y/3) + 0.6 * xx1
    x = (x - np.mean(x)) / np.std(x)
    y = (y - np.mean(y)) / np.std(y)

    return x, y, z

def dgp_for_rcit_comparison(n=1200, dZ=30, n_comp=10, seed=0):
    rng = np.random.default_rng(seed)

    pis = rng.dirichlet(np.ones(n_comp))
    n_each = rng.multinomial(n, pis)
    mus = rng.normal(0, 3, size=(n_comp, dZ))
    Sigmas = [np.diag(rng.uniform(0.5, 2.0, size=dZ)) for _ in range(n_comp)]

    Z_parts = [
        rng.multivariate_normal(mus[k], Sigmas[k], size=n_each[k])
        for k in range(n_comp)
    ]
    Z = np.vstack(Z_parts)
    rng.shuffle(Z)

    Wf1 = rng.normal(size=(dZ, 40))
    Wf2 = rng.normal(size=(40, 1))
    Wg1 = rng.normal(size=(dZ, 50))
    Wg2 = rng.normal(size=(50, 1))

    def f(z):
        h = np.tanh(z @ Wf1)
        return (h @ Wf2).ravel()

    def g(z):
        h = np.sin(z @ Wg1)  
        return (h @ Wg2).ravel()

    eps_x = rng.normal(0, 0.1, size=n)
    eps_y = rng.normal(0, 0.1, size=n)

    X = f(Z) + eps_x
    Y = g(Z) + eps_y

    X = (X - np.mean(X)) / np.std(X)
    Y = (Y - np.mean(Y)) / np.std(Y)
    Z = (Z - np.mean(Z)) / np.std(Z)

    return X.reshape(-1, 1), Y.reshape(-1, 1), Z


def simulate_nongaussian(
    n,
    z_dim=1,
    x_dim=1,
    y_dim=1,
    z_sampler=None,
    fx=None,
    fy=None,
    noise_x=0.5,
    noise_y=0.5,
    seed=None
):
    rng = np.random.default_rng(seed)

    def default_z_sampler(n, d, rng):
        u = rng.uniform(low=-np.pi, high=np.pi, size=(n, d))
        return np.sin(u) + 0.3 * np.sin(3 * u) + 0.1 * (u ** 2)

    if z_sampler is None:
        z_sampler = default_z_sampler

    Z = z_sampler(n, z_dim, rng)

    def default_fx(Z, rng):
        return np.tanh(Z @ rng.normal(size=(Z.shape[1], x_dim)))

    def default_fy(Z, rng):
        return (Z ** 3) @ rng.normal(size=(Z.shape[1], y_dim))

    if fx is None:
        fx = default_fx
    if fy is None:
        fy = default_fy

    X_det = fx(Z, rng)
    Y_det = fy(Z, rng)

    X = X_det + rng.normal(scale=noise_x, size=(n, x_dim))
    Y = Y_det + rng.normal(scale=noise_y, size=(n, y_dim))

    X = (X - np.mean(X)) / np.std(X)
    Y = (Y - np.mean(Y)) / np.std(Y)
    Z = (Z - np.mean(Z)) / np.std(Z)
    
    return X, Y, Z
