#
import numpy as np
import scipy




def generate_data_exp(n, d, z_d_dim, amount_of_missingness, missing_value=-1):
    # GENERATE DATA

    assert 0 < n
    assert 0 < z_d_dim <= d
    assert 0 < amount_of_missingness <= .5   # Note that this is approximate, sampling is still random

    # X
    X = _generate_covariates(d, n)  # Contains negatives. We will use negative
                                    # values to detirmine some of the data, only
                                    # to remove them later, resulting in basic
                                    # non-linearities (identifiability in the DAG).
    # OUTCOMES
    Y0, Y1, CATE = _generate_outcomes(X)



    # DOWN
    Z_down = _Z_down(amount_of_missingness, X, z_d_dim)
    

    # TREATMENTS
    W = _treatments(Z_down, X, z_d_dim)
    # OBSERVED Y
    Y = _generate_observed_outcomes(Y0, Y1, W)

    # UP
    Z_up = _Z_up(amount_of_missingness, X, z_d_dim, W)

    # COMPLETE DATA
    X_ = _complete_covariates(X, z_d_dim, Z_up, Z_down, missing_value)
    
    return X, X_, Y0, Y1, Y, CATE, W, Z_up, Z_down


def _generate_covariates(d, n):
    assert 0 < d
    assert 0 < n

    # COVARIATES
    #X = np.random.rand(n, drandom.multivariate_normal)         # Fully observed X
    A = np.random.rand(d,d)
    cov = np.dot(A, A.transpose())

    X = np.random.multivariate_normal(np.zeros(d), cov, size=n)
    X /= (X.max() - X.min())

    return X

def _generate_outcomes(X):
    theta_y0 = np.random.rand(X.shape[1])
    theta_y1 = np.random.rand(X.shape[1])
    
    

    Y0 = np.sum(X * theta_y0, 1)
    Y1 = np.sum(X * theta_y1, 1)

    CATE = Y1 - Y0
    
    return Y0, Y1, CATE

def _generate_observed_outcomes(Y0, Y1, W):
    return np.array([Y0[i] if w == 0 else Y1[i] for i, w in enumerate(W)]) + np.random.randn(W.shape[0])*.1



def _Z_down(amount_of_missingness, X, z_d_dim):
    highest_border = X[:,:z_d_dim].argsort(axis=1)[:,-int(np.max((int(np.round(amount_of_missingness * z_d_dim)), 1)))]
    Z_down = list(x >= x[highest_border[i]] for i, x in enumerate(X[:,:z_d_dim]))
    Z_down = np.array(Z_down).astype(int)
    return np.abs(Z_down-1)       # 0 = missing, 1 = present

def _treatments(Z_down, X, z_d_dim):
    W = []
    for z_d in Z_down:
        if 0 == z_d[-1]:
            w = 0
        elif 0 in z_d[:int(np.floor(z_d_dim/2))]:
            w = 1
        else:
            w = np.random.binomial(1, .5)
        W.append(w)
    return np.array(W)

def _Z_up(amount_of_missingness, X, z_d_dim, W):
    d = X.shape[1]
    dim_count = np.round(amount_of_missingness * (d - z_d_dim) * 2)
    dim_count = np.max((dim_count, 1))
    dim_count = np.min((dim_count, int((d - z_d_dim) / 2)))
    dim_count = int(dim_count)
    
    theta_z_in_0 = np.random.normal(
        loc=scipy.stats.norm.ppf(1 - amount_of_missingness),        # We translate the binary treatment
        size=dim_count, scale=.5)                                   # to a more complex interaction. Recall
    theta_z_in_1 = np.random.normal(                                # that each arrow needs to be identifiable.
        loc=scipy.stats.norm.ppf(1 - amount_of_missingness),        # Solely relying on the binary information 
        size=dim_count, scale=.5)                                   # makes this very hard, if not impossible.

    theta_z_in_0 = np.full(dim_count, scipy.stats.norm.ppf(1 - amount_of_missingness))
    theta_z_in_1 = np.full(dim_count, scipy.stats.norm.ppf(1 - amount_of_missingness))

    n = X.shape[0]

    Z_up = np.zeros((n, d - z_d_dim))
    for i, z in enumerate(Z_up):
        x = X[i, z_d_dim:z_d_dim+dim_count]         # Again, relying on negatives in X will help
                                                    # make each dependency identifiable.
        if W[i]:
            Z_up[i, -dim_count:] = (x - X[:, z_d_dim:z_d_dim+dim_count].mean(axis=0)) > (theta_z_in_1 * x.std(axis=0)) # if W[i] else x > (-1 * theta_z_in)
        else:
            Z_up[i, :dim_count] = (x - X[:, z_d_dim:z_d_dim+dim_count].mean(axis=0)) > (theta_z_in_0 * x.std(axis=0)) # WAS :dim_count instead of -dim_count: -> trying out MAR throughout

    Z_up = np.abs(Z_up-1)

    return Z_up

def _complete_covariates(X, z_d_dim, Z_up, Z_down, missing_value):

    X = np.abs(X)   # the non-linearity for identifiability of our DAG

    # X_tilde
    X_ = X.copy()
    X_[:,z_d_dim:][Z_up==0] = missing_value
    X_[:,:z_d_dim][Z_down==0] = missing_value

    return X_
