import numpy as np
from numpy.random import normal, binomial 
import pandas as pd


def linear_data_generator(pz, A, B, n):
    '''
    generate a linear Gaussian dataset with n data points (each one contains x, z, t, y) in a 
    dataframe; pz defins a Gaussian distribution and A and B are the linear coefficients and shift for the proxy generating process, i.e, x = Az + B. t is generated by a bernoulli distribution t ~ Bern(sigmoid(z)) and y is generated by y = zt + eps, where eps ~ N(0, 1)
    '''
    assert len(A) == len(B)
    assert len(pz) == 2

    # Generate z from a Gaussian distribution
    z = normal(pz[0], pz[1], size=n)

    # Generate x using the linear relationship x = Az + B
    A = np.array(A)
    B = np.array(B)
    x = np.outer(z, A) + B  # Each row in x corresponds to x = Az + B for each z

    # Generate t using a Bernoulli distribution t ~ Bern(sigmoid(z))
    t_prob = sigmoid(z)
    t = binomial(1, t_prob)

    # Generate y using the relationship y = zt + eps, where eps ~ N(0, 1)
    eps = np.random.normal(0, 1, size=n)
    y = z * t + eps

    # Create a DataFrame
    data = pd.DataFrame(np.hstack((z.reshape(-1, 1), x,  t.reshape(-1, 1), y.reshape(-1, 1))), 
                         columns=['z'] + [f'x{i}' for i in range(len(A))] + [ 't', 'y'])
    
    return data

def sigmoid(x):
    return 1 / (1 + np.exp(-x))