
import torch
import numpy as np
from torch.nn.functional import sigmoid

def sample_20250111(Z,X,U,generator):
    if len(Z.shape)==2:
        Z = Z.unsqueeze(1)
    dimX = X.shape[-1]
    U0_constant = U['U0_constant']
    U0_x = U['U0_x']
    U0_z = U['U0_z']
    U1 = U['U1']
    first_term = U0_constant + (U0_x * X).sum(2) / np.sqrt(dimX) + (U0_z * Z).sum(2)

    second_term = torch.einsum('ik,ijk,ijk->ij', U1, Z, X[:,:,:2]) 
    mean = torch.sigmoid( (first_term + second_term) )
    
    Y = torch.bernoulli(mean, generator=generator)
    return mean, Y

def generate_U_20250111(D,dimX,g, zero_out=False, zero_out_last=False, U0_const_sd = 1):
    U0_sd = 0.25
    U1_sd = 0.25

    U0_constant = torch.normal( torch.zeros(D,1), torch.ones(D,1)*U0_const_sd, generator=g)
    U0_x = torch.normal( torch.ones(D,1,dimX), torch.ones(D,1,dimX)*U0_sd, generator=g)
    U0_z = torch.normal( torch.ones(D,1,2), torch.ones(D,1,2)*U0_sd, generator=g)
    U1 = torch.normal( torch.ones(D,2), torch.ones(D,2)*U1_sd, generator=g)
    if zero_out:
        # new thing
        U0_x[:,:,0] = 0 
    if zero_out_last:
        U0_x[:,:,-1] = 0
    U = {
            'U0_constant': U0_constant,
            'U0_x': U0_x,
            'U0_z': U0_z,
            'U1': U1
    }
    return U

def binary_linear_dgp_20250111(D,N,dimX,ave_U,one_X_per_col,generator,
                               num_gen=100,
                               fewer_U=False,
                               fewer_Z=False,  
                               no_cross=False,
                               even_fewer_U=False,
                               Z=None,
                               X_sigmoid_transform=None,
                               X_product_sign=False,
                               X_product_sign_more=False,
                               X_product_sign_more_last=False,
                               X_discont=False,
                               X=None,
                               U=None,
                               zero_out=False,
                               zero_out_last=False,
                               U0_const_sd=1,
                              ):
    '''
    Z: shape (D,1)
    X: shape (D,N,dimX)
    Y: shape (D,N)
    click_rate: shape (D,N)
    
    as discussed
    '''
    g = generator
    if Z is not None:
        assert not fewer_Z
        assert Z.shape[0] == D
        assert Z.shape[-1] == 2
        if len(Z.shape) == 2:
            Z = Z.unsqueeze(1)
        elif len(Z.shape) == 3:
            assert Z.shape[1] == 1
        else:
            raise ValueError('Inputted Z must have dimension 2 or 3')
    else:
        Z = torch.normal( torch.zeros((D,1,2)), generator=g ) 
    if fewer_Z:
        Z[:,:,1] = Z[:,:,0]
        
    if one_X_per_col:
        X = torch.normal( torch.zeros((1,N,dimX)), generator=g ).repeat( (D,1,1) )  
    else:
        X = torch.normal( torch.zeros((D,N,dimX)), generator=g )  
    
    if U is None:
        U = generate_U_20250111(D,dimX,g,zero_out,zero_out_last,U0_const_sd)
    
    mean, Y = sample_20250111(Z,X,U,g)
    res = {}
    if fewer_Z:
        res['Z'] = Z[:,0,:1]
    else:
        res['Z'] = Z[:,0,:]
    assert not (X_sigmoid_transform and X_product_sign)
    assert not (X_sigmoid_transform and X_product_sign_more)
    if X_sigmoid_transform is None:
        res['X'] = X
    else:
        res['X'] = 2*sigmoid(X_sigmoid_transform*X)-1
        res['X_gen'] = X
        
       
    assert not ((X_product_sign or X_product_sign_more) and X_product_sign_more_last)
    if X_product_sign:
        X_train = X
        X_train[:,:,0]=X_train[:,:,0]*X_train[:,:,1].sign()
        res['X'] = X_train
        res['X_gen'] = X
    if X_product_sign_more:
        X_train = X
        for i in range(1,X.shape[-1]):
            X_train[:,:,i]=X_train[:,:,i]*X_train[:,:,0].sign()
        res['X'] = X_train
        res['X_gen'] = X
    if X_product_sign_more_last:
        X_train = X
        for i in range(0,X.shape[-1]-1):
            X_train[:,:,i]=X_train[:,:,i]*X_train[:,:,X.shape[-1]-1].sign()
        if X_discont:
            X_train[X_train>0]=X_train[X_train>0]*0.2+0.5
            X_train[X_train<0]=X_train[X_train<0]*0.2-0.5
        res['X'] = X_train
        res['X_gen'] = X
            
    res['U'] = U
    res['Y'] = Y
    res['click_rate'] = mean

    if ave_U:
        means = []
        for _ in range(num_gen):
            U = generate_U_20250111(D,dimX,g,zero_out,zero_out_last, U0_const_sd)
            mean, Y = sample_20250111(Z,X,U,g)

            means.append(mean.unsqueeze(0))
        res['outcome_mean_ave_U'] = torch.cat(means,0).mean(0)     
        res['click_rate_ave_U'] = torch.cat(means,0).mean(0)      

    return res

CONTEXT_DGPs = {
    '0111_logistic': lambda D,N,dimX, ave_U, one_X_per_col, g: binary_linear_dgp_20250111(D, N, dimX, ave_U, one_X_per_col, g), 
    '0111_logistic_withZ_zero_prodsign_last': lambda D,N,dimX, ave_U, one_X_per_col, g, Z: binary_linear_dgp_20250111(D, N, dimX, ave_U, one_X_per_col, g, Z=Z, zero_out_last=True, X_product_sign_more_last=True),
}
