'''

Samplers used to generate data in the experiments
Created by: Ezinne Nwankwo 
Modified from: https://github.com/wenshuoguo/robust-causal-code-public/tree/main

'''


import numpy as np
import random
from scipy.special import expit
import math
import pandas as pd
from sklearn.linear_model import LogisticRegression 

def gen_KS_samples(num_samples,l=1,q=2,p=4, seed=9876, norm=False):
  '''
  simulate a dataset according to two outcome model:
  binary Z, continuous Y
  
  Xi1 = N(M,I4)
  
  '''
  np.random.seed(seed=seed)
  
  U_mean = np.zeros((q,))
  U_cov = np.identity(q)
  U_samples = np.random.multivariate_normal(U_mean, U_cov, size=num_samples)

  M = [0.2,3,0.3,1]
  X_mean = np.array(M).reshape(4,)
  X_cov = np.identity(4)
  X_samples = np.random.multivariate_normal(X_mean, X_cov, size=num_samples)
  
  T_samples = []
  Z_samples = []
  Y_samples = []
  if norm == False:
    i = 0
    for u in U_samples:
      # covariates
      x_1 = X_samples[i,0]
      x_2 = X_samples[i,1]
      x_3 = X_samples[i,2]
      x_4 = X_samples[i,3]
      # treatments
      p = min(1,math.exp(-x_1 -2*x_2 -0.25*x_3 -0.1*x_4))
      t = np.random.binomial(1, p)
      T_samples.append(t)
      # latent outcomes
      z = 10+27.4*x_1 +13.72*x_2 + 13.7*x_3 + 13.7*x_4 + np.random.normal(0,1)
      Z_samples.append(z)
      # observed outcomes 
      if t == 1: 
        a = np.array([1.3, 0.11])
        y = z*a + u
        Y_samples.append(y)
      else: 
        b = np.array([2.3, 3.11])
        y = z*b + u
        Y_samples.append(y)
      i+=1
    # Generate treatment indicator using logistic regression
    logreg = LogisticRegression(solver='liblinear')

    # Create balanced treatment groups
    balanced = False
    while not balanced:
        # Generate treatment assignment probabilities
        p_treatment = logreg.fit(X_samples, T_samples).predict_proba(X_samples)[:, 1]
        
        # Assign treatment based on probabilities
        treatment = np.random.binomial(n=1, p=p_treatment, size=num_samples)
        
        # Check if treatment groups are balanced
        if np.abs(np.mean(treatment) - 0.5) < 0.05:  # Within 5% of balance
            balanced = True
    T_samples = treatment
  if norm == True:
    i = 0
    for u in U_samples:
      # covariates
      x_1 = X_samples[i,0]
      x_2 = X_samples[i,1]
      x_3 = X_samples[i,2]
      x_4 = X_samples[i,3]
      # treatments
      p = min(1,math.exp(-x_1 -2*x_2 -0.25*x_3 -0.1*x_4))
      t = np.random.binomial(1, p)
      T_samples.append(t)
      # latent outcomes
      z = 210+27.4*x_1 +13.72*x_2 + 13.7*x_3 + 13.7*x_4 + np.random.normal(0,1)
      Z_samples.append(z)
      # observed outcomes 
      if t == 1: 
        a = np.array([1.3, 0.11])
        y = z*a + u
        Y_samples.append(y)
      else: 
        b = np.array([12.3, 13.11])
        y = z*b + u
        Y_samples.append(y)
      i+=1
    # Generate treatment indicator using logistic regression
    logreg = LogisticRegression(solver='liblinear')

    # Create balanced treatment groups
    balanced = False
    while not balanced:
        # Generate treatment assignment probabilities
        p_treatment = logreg.fit(X_samples, T_samples).predict_proba(X_samples)[:, 1]
        
        # Assign treatment based on probabilities
        treatment = np.random.binomial(n=1, p=p_treatment, size=num_samples)
        
        # Check if treatment groups are balanced
        if np.abs(np.mean(treatment) - 0.5) < 0.05:  # Within 5% of balance
            balanced = True
    T_samples = treatment
    
    Z_samples = np.array(Z_samples)    
    Y_samples = np.array(Y_samples)
    X_samples = np.array(X_samples)
    Z_samples = (Z_samples - Z_samples.mean(axis=0)) / Z_samples.std(axis=0)
    Y_samples = (Y_samples - Y_samples.mean(axis=0)) / Y_samples.std(axis=0)
    X_samples = (X_samples - X_samples.mean(axis=0)) / X_samples.std(axis=0)

  return(np.array(X_samples), np.array(Y_samples).reshape(num_samples,2), np.array(Z_samples), np.array(T_samples))
  