import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
ohenc_W5 = OneHotEncoder(sparse=False, drop=None)

DENOMINATOR_Y = 50
C = (1 + 3.5**2) + (1 + 24**2)


def _return_true_response_from_X(
        A_intervention,
        X1_1_intervention,
        X1_2_intervention,
        X1_3_intervention):
    # inverse transformation for calcurating true response value Y
    W1_intervention = 2*np.log(X1_1_intervention.flatten())
    W2_intervention = (X1_2_intervention.flatten() - 10)*(
      1+X1_1_intervention.flatten()**2)
    W3_intervention = (
      25*(X1_3_intervention.flatten() - 0.6)
      / X1_1_intervention.flatten())

    # calcurate true response value Y
    true_response = (
      - 0.15*A_intervention.flatten()**2
      + A_intervention.flatten()*(W1_intervention**2 + W2_intervention**2)
      - 15
      + (W1_intervention + 3)**2
      + 2*(W2_intervention - 25)**2
      + W3_intervention - C)/DENOMINATOR_Y
    test_data_size = len(true_response)
    true_response = true_response.reshape(
        test_data_size, 1)

    return true_response


def gen_train(train_data_size: int):
    sample_size = train_data_size
    DENOMINATOR_Y = 50
    myu_vec = [-0.5, 1, 0, 1]
    sigma_mat = np.repeat(0.0, 4*4).reshape(4, 4)
    np.fill_diagonal(sigma_mat, 1)
    W_mat = np.random.multivariate_normal(myu_vec, sigma_mat, sample_size)
    W1 = W_mat[:, 0]
    W2 = W_mat[:, 1]
    W3 = W_mat[:, 2]
    W4 = W_mat[:, 3]
    W5 = np.random.choice(a=3, size=sample_size, p=(0.7, 0.15, 0.15))
    W5_mat = W5.reshape(sample_size, 1)
    ohenc_W5.fit(W5_mat)
    myuA = (5*np.abs(W1) + 6*np.abs(W2) + np.abs(W4)
            + 1*np.abs(W5 == 1) + 5*np.abs(W5 == 2))
    A = np.random.noncentral_chisquare(3, myuA)

    Y = (((-0.15*A**2 + A*(W1**2+W2**2) - 15
          + (W1 + 3)**2 + 2*(W2 - 25)**2 + W3)
          - C)/DENOMINATOR_Y + np.random.normal(0, 1, sample_size))
    X1_1 = np.exp(W1/2)
    X1_2 = W2/(1 + np.exp(W1)) + 10
    X1_3 = (W1*W3)/25 + 0.6
    X4 = (W4 - 1)**2
    X5 = ohenc_W5.transform(W5_mat)

    X1_1 = np.reshape(X1_1.astype(np.float32),
                      (sample_size, 1))
    X1_2 = np.reshape(X1_2.astype(np.float32),
                      (sample_size, 1))
    X1_3 = np.reshape(X1_3.astype(np.float32),
                      (sample_size, 1))
    X4 = np.reshape(X4.astype(np.float32),
                    (sample_size, 1))
    X5 = np.reshape(X5.astype(np.float32),
                    (sample_size, X5.shape[1]))
    A = np.reshape(A.astype(np.float32),
                   (sample_size, 1))
    Y = np.reshape(Y.astype(np.float32),
                   (sample_size, 1))
    orig_explanatories_train = [A, X1_1, X1_2, X1_3, X4, X5]

    explanatories_train = np.concatenate(orig_explanatories_train, axis=1)
    response_train = Y

    return [explanatories_train,
            response_train]


def gen_test(test_data_size: int):
    sample_size = test_data_size
    DENOMINATOR_Y = 50
    myu_vec = [-0.5, 1, 0, 1]
    sigma_mat = np.repeat(0.0, 4*4).reshape(4, 4)
    np.fill_diagonal(sigma_mat, 1)
    W_mat = np.random.multivariate_normal(myu_vec, sigma_mat, sample_size)
    W1 = W_mat[:, 0]
    W2 = W_mat[:, 1]
    W3 = W_mat[:, 2]
    W4 = W_mat[:, 3]
    W5 = np.random.choice(a=3, size=sample_size, p=(0.7, 0.15, 0.15))
    W5_mat = W5.reshape(sample_size, 1)
    ohenc_W5.fit(W5_mat)
    myuA = (5*np.abs(W1) + 6*np.abs(W2) + np.abs(W4)
            + 1*np.abs(W5 == 1) + 5*np.abs(W5 == 2))
    A = np.random.noncentral_chisquare(3, myuA)

    Y = (((-0.15*A**2 + A*(W1**2+W2**2) - 15
          + (W1 + 3)**2 + 2*(W2 - 25)**2 + W3)
          - C)/DENOMINATOR_Y + np.random.normal(0, 1, sample_size))
    X1_1 = np.exp(W1/2)
    X1_2 = W2/(1 + np.exp(W1)) + 10
    X1_3 = (W1*W3)/25 + 0.6
    X4 = (W4 - 1)**2
    X5 = ohenc_W5.transform(W5_mat)

    X1_1 = np.reshape(X1_1.astype(np.float32),
                      (sample_size, 1))
    X1_2 = np.reshape(X1_2.astype(np.float32),
                      (sample_size, 1))
    X1_3 = np.reshape(X1_3.astype(np.float32),
                      (sample_size, 1))
    X4 = np.reshape(X4.astype(np.float32),
                    (sample_size, 1))
    X5 = np.reshape(X5.astype(np.float32),
                    (sample_size, X5.shape[1]))
    A = np.reshape(A.astype(np.float32),
                   (sample_size, 1))
    Y = np.reshape(Y.astype(np.float32),
                   (sample_size, 1))
    orig_explanatories_test = [A, X1_1, X1_2, X1_3, X4, X5]

    A_test = orig_explanatories_test[0]
    X1_to_3_test_list = orig_explanatories_test[1:]
    X1_to_3_test = np.concatenate(X1_to_3_test_list, axis=1)

    """
    Generate test data for Experiment1.
    (by shuffling each of A and X1_to_3_train by the index, respectively)
    """
    r_index_A_Expr1 = np.random.choice(
        test_data_size, size=test_data_size, replace=False)
    A_intervention_Expr1 = A_test[r_index_A_Expr1, :]
    r_index_X_all_Expr1 = np.random.choice(
        test_data_size, size=test_data_size, replace=False)
    X1_1_intervention_Expr1 = X1_to_3_test[
        r_index_X_all_Expr1, :][:, [0]]
    X1_2_intervention_Expr1 = X1_to_3_test[
        r_index_X_all_Expr1, :][:, [1]]
    X1_3_intervention_Expr1 = X1_to_3_test[
        r_index_X_all_Expr1, :][:, [2]]
    X2_intervention_Expr1 = X1_to_3_test[
        r_index_X_all_Expr1, :][:, [3]]
    X3_intervention_Expr1 = X1_to_3_test[
        r_index_X_all_Expr1, 4:]
    explanatories_intervention_Expr1 = [
          A_intervention_Expr1,
          X1_1_intervention_Expr1,
          X1_2_intervention_Expr1,
          X1_3_intervention_Expr1,
          X2_intervention_Expr1,
          X3_intervention_Expr1]
    true_response_intervention_Expr1 = _return_true_response_from_X(
      A_intervention_Expr1,
      X1_1_intervention_Expr1,
      X1_2_intervention_Expr1,
      X1_3_intervention_Expr1)

    """
    Generate test data for Experiment2.
    (by shuffling each of A, X1, X2 and X3 by the index, respectively)
    """
    r_index_A_Expr2 = np.random.choice(
        test_data_size, size=test_data_size, replace=False)
    A_intervention_Expr2 = A_test[r_index_A_Expr2, :]
    r_index_X1_Expr2 = np.random.choice(
        test_data_size, size=test_data_size, replace=False)
    X1_1_intervention_Expr2 = X1_to_3_test[
        r_index_X1_Expr2, :][:, [0]]
    X1_2_intervention_Expr2 = X1_to_3_test[
        r_index_X1_Expr2, :][:, [1]]
    X1_3_intervention_Expr2 = X1_to_3_test[
        r_index_X1_Expr2, :][:, [2]]
    r_index_X2_Expr2 = np.random.choice(
        test_data_size, size=test_data_size, replace=False)
    X2_intervention_Expr2 = X1_to_3_test[
        r_index_X2_Expr2, :][:, [3]]
    r_index_X3_Expr2 = np.random.choice(
        test_data_size, size=test_data_size, replace=False)
    X3_intervention_Expr2 = X1_to_3_test[
        r_index_X3_Expr2, 4:]

    explanatories_intervention_Expr2 = [
          A_intervention_Expr2,
          X1_1_intervention_Expr2,
          X1_2_intervention_Expr2,
          X1_3_intervention_Expr2,
          X2_intervention_Expr2,
          X3_intervention_Expr2]

    true_response_intervention_Expr2 = _return_true_response_from_X(
      A_intervention_Expr2,
      X1_1_intervention_Expr2,
      X1_2_intervention_Expr2,
      X1_3_intervention_Expr2)

    return [explanatories_intervention_Expr1,
            true_response_intervention_Expr1,
            explanatories_intervention_Expr2,
            true_response_intervention_Expr2]
