import random 
import numpy as np
from scipy.special import expit
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
import copy
import matplotlib.pyplot as plt
import pickle 

from global_functions import *


def sigmoid(x):
	return 1 / (1 + np.exp(-x))

def pi_star_X1(C1, C2, C3): 
	prob_X1 = sigmoid(1*(C1+C2+C3)-2)
	return prob_X1

def pi_X1(C1, C2, C3): 
	prob_X1 = sigmoid(0.5*(C1+C2+C3)-1)
	return prob_X1

def pi_star_X2(X1, W1, C1, C2, C3): 
	prob_X2 = sigmoid(0.5*(C1+C2+C3) + 2*(2*X1-1) - 0.5*W1+1)
	return prob_X2

def pi_X2(X1, W1, C1, C2,C3): 
	prob_X2 = sigmoid(1*(C1+C2+C3) + 1*(2*X1-1) + 0.5*W1-1)
	return prob_X2

def pi_star_X3(X1, W1, X2, W2, C1, C2, C3): 
	prob_X3 = sigmoid(0.25*(C1+C2+C3) + 1*(2*X1-1) - 0.25*W1+1 + 1*(2*X2-1) - 0.25*W2+1)
	return prob_X3

def pi_X3(X1, W1, X2, W2, C1, C2,C3): 
	prob_X3 = sigmoid(0.5*(C1+C2+C3) + 2*(2*X1-1) + 0.25*W1-1 + 2*(2*X2-1) + 0.25*W2-1)
	return prob_X3

def generate_samples(n0, n1, n2, n3, seednum):
	np.random.seed(seednum)
	random.seed(123)
	
	def fU(n):
		U_X1X2 = np.random.normal(size = n)
		U_X2X3 = np.random.normal(size = n)
		U_W1X1 = np.random.normal(size = n)
		U_W1X2 = np.random.normal(size = n)
		U_W2X2 = np.random.normal(size = n)
		U_W2X3 = np.random.normal(size = n)
		U_C1C2 = np.random.normal(size = n)
		U_C2C3 = np.random.normal(size = n)
		U_X3Y = np.random.normal(size = n)

		return U_X1X2, U_X2X3, U_W1X1, U_W1X2, U_W2X2, U_W2X3, U_C1C2, U_C2C3, U_X3Y

	def fC1(n, U_C1C2, S):
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n)) + 0.1*S
		return np.random.normal(loc=0, scale=0.5, size=(n)) + U_C1C2 + S_noise

	def fC2(n, U_C1C2, U_C2C3, S):
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n)) + 0.1*S
		return np.random.normal(loc=0, scale=0.5, size=(n)) + U_C1C2 + U_C2C3 + S_noise

	def fC3(n, U_C2C3, S):
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n)) + 0.1*S
		return np.random.normal(loc=0, scale=0.5, size=(n)) + U_C2C3 + S_noise

	def fX1(C1, C2, C3, S):
		if S == 0:
			pi_val = pi_star_X1(C1, C2, C3)
		else:
			pi_val = pi_X1(C1, C2, C3)
		X1 = (np.random.rand(len(pi_val)) < pi_val).astype(int)
		return X1 

	def fW1(n, C1, C2, C3, X1, U_W1X1, U_W1X2, S): 
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n))
		W1 = sigmoid(0.5*(C1+C2+C3)-1 + 3*X1 + 0.5*(U_W1X1 + U_W1X2) + S_noise )
		return W1

	def fX2(X1, W1, C1, C2, C3, S):
		if S == 0: 
			pi_val = pi_star_X2(X1, W1, C1, C2, C3)
			# X2 = np.round(pi_star_X2(X1, W, C1, C2))
		else:
			pi_val = pi_X2(X1, W1, C1, C2, C3)
			# X2 = np.round(pi_X2(X1, W, C1, C2))
		X2 = (np.random.rand(len(pi_val)) < pi_val).astype(int)
		return X2 

	def fW2(n, C1, C2, C3, X1, X2, W1, U_W2X2, U_W2X3, S): 
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n))
		W2 = sigmoid(0.5*(C1+C2+C3)-1 + 3*(X1 + X2) + 0.5*(U_W2X2 + U_W2X3) + S_noise )
		return W2

	def fX3(X1, W1, X2, W2, C1, C2, C3, S):
		if S == 0: 
			pi_val = pi_star_X3(X1, W1, X2, W2, C1, C2, C3)
			# X2 = np.round(pi_star_X2(X1, W, C1, C2))
		else:
			pi_val = pi_X3(X1, W1, X2, W2, C1, C2, C3)
			# X2 = np.round(pi_X2(X1, W, C1, C2))
		X3 = (np.random.rand(len(pi_val)) < pi_val).astype(int)
		return X3

	def fY(n, X1, W1, X2, W2, X3, C1, C2, C3, U_X3Y, S):
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n)) 
		Y = sigmoid( 0.5*(C1 + C2 + C3) + (2*(X1 + X2 + X3)-1) - 0.5*(W1 + W2) + 0.1*U_X3Y + S_noise)
		return Y 

	# data S=0 # Target 
	U_X1X2, U_X2X3, U_W1X1, U_W1X2, U_W2X2, U_W2X3, U_C1C2, U_C2C3, U_X3Y = fU(n0)
	C1  = fC1(n0, U_C1C2, S = 0)
	C2 = fC2(n0, U_C1C2, U_C2C3, S=0)
	C3 = fC3(n0, U_C2C3, S = 0)
	X1 = fX1(C1, C2, C3, S = 0)
	W1 = fW1(n0, C1, C2, C3, X1, U_W1X1, U_W1X2, S = 0)
	X2 = fX2(X1, W1, C1, C2, C3, S = 0)
	W2 = fW2(n0, C1, C2, C3, X1, X2, W1, U_W2X2, U_W2X3, S=0)
	X3 = fX3(X1, W1, X2, W2, C1, C2, C3, S=0)
	Y = fY(n0, X1, W1, X2, W2, X3, C1, C2, C3, U_X3Y, S = 0)
	data_S0 = pd.DataFrame(np.column_stack((C1, C2, C3, X1, W1, X2, W2, X3, Y)), columns=['C1', 'C2', 'C3', 'X1', 'W1', 'X2', 'W2', 'X3', 'Y'])


	# data S=1 # Source 1 
	U_X1X2, U_X2X3, U_W1X1, U_W1X2, U_W2X2, U_W2X3, U_C1C2, U_C2C3, U_X3Y = fU(n1)
	C1  = fC1(n1, U_C1C2, S = 1)
	C2 = fC2(n1, U_C1C2, U_C2C3, S=1)
	C3 = fC3(n1, U_C2C3, S = 1)
	X1 = fX1(C1, C2, C3, S = 1)
	W1 = fW1(n1, C1, C2, C3, X1, U_W1X1, U_W1X2, S = 0)
	X2 = fX2(X1, W1, C1, C2, C3, S = 1)
	W2 = fW2(n1, C1, C2, C3, X1, X2, W1, U_W2X2, U_W2X3, S=1)
	X3 = fX3(X1, W1, X2, W2, C1, C2, C3, S=1)
	Y = fY(n1, X1, W1, X2, W2, X3, C1, C2, C3, U_X3Y, S = 1)
	data_S1 = pd.DataFrame(np.column_stack((C1, C2, C3, X1, W1, X2, W2, X3, Y)), columns=['C1', 'C2', 'C3', 'X1', 'W1', 'X2', 'W2', 'X3', 'Y'])


	# data S=2 # Source 2 
	U_X1X2, U_X2X3, U_W1X1, U_W1X2, U_W2X2, U_W2X3, U_C1C2, U_C2C3, U_X3Y = fU(n2)
	C1  = fC1(n2, U_C1C2, S = 2)
	C2 = fC2(n2, U_C1C2, U_C2C3, S=2)
	C3 = fC3(n2, U_C2C3, S = 2)
	X1 = fX1(C1, C2, C3, S = 2)
	W1 = fW1(n2, C1, C2, C3, X1, U_W1X1, U_W1X2, S = 2)
	X2 = fX2(X1, W1, C1, C2, C3, S = 2)
	W2 = fW2(n2, C1, C2, C3, X1, X2, W1, U_W2X2, U_W2X3, S=0)
	X3 = fX3(X1, W1, X2, W2, C1, C2, C3, S = 2)
	Y = fY(n2, X1, W1, X2, W2, X3, C1, C2, C3, U_X3Y, S = 2)
	data_S2 = pd.DataFrame(np.column_stack((C1, C2, C3, X1, W1, X2, W2, X3, Y)), columns=['C1', 'C2', 'C3', 'X1', 'W1', 'X2', 'W2', 'X3', 'Y'])

	# data S=3 # Source 3
	U_X1X2, U_X2X3, U_W1X1, U_W1X2, U_W2X2, U_W2X3, U_C1C2, U_C2C3, U_X3Y = fU(n3)
	C1  = fC1(n3, U_C1C2, S = 3)
	C2 = fC2(n3, U_C1C2, U_C2C3, S = 3)
	C3 = fC3(n3, U_C2C3, S = 3)
	X1 = fX1(C1, C2, C3, S = 3)
	W1 = fW1(n3, C1, C2, C3, X1, U_W1X1, U_W1X2, S = 3)
	X2 = fX2(X1, W1, C1, C2, C3, S = 3)
	W2 = fW2(n3, C1, C2, C3, X1, X2, W1, U_W2X2, U_W2X3, S=3)
	X3 = fX3(X1, W1, X2, W2, C1, C2, C3, S = 3)
	Y = fY(n3, X1, W1, X2, W2, X3, C1, C2, C3, U_X3Y, S = 0)
	data_S3 = pd.DataFrame(np.column_stack((C1, C2, C3, X1, W1, X2, W2, X3, Y)), columns=['C1', 'C2', 'C3', 'X1', 'W1', 'X2', 'W2', 'X3', 'Y'])

	return data_S0, data_S1, data_S2, data_S3

def evaluate_DML(data_S0, data_S1, data_S2, data_S3, seednum, L=2, add_noise_TF = False):
	def evaluate_check_mu(col_feature_mu, mu_model, data, Xname, Sval):
		''' 
		Compute sum_{x2} mu2(C1,C2,X1, W, x2) * pi_star_X2(X1, W, C1, C2) 
		'''
		data_x0 = data.copy()
		data_x0[Xname] = 0
		matrix_data_x0 = xgb.DMatrix(data_x0[col_feature_mu])
		eval_mu_x0 = add_noise( mu_model.predict(matrix_data_x0), add_noise_TF )
		################
		data_x1 = data.copy()
		data_x1[Xname] = 1
		matrix_data_x1 = xgb.DMatrix(data_x1[col_feature_mu])
		eval_mu_x1 = add_noise( mu_model.predict(matrix_data_x1), add_noise_TF )
		################
		if Sval == 3: # Evaluate check_mu2
			pi_star_val = np.array( pi_star_X3(data['X1'], data['W1'], data['X2'], data['W2'], data['C1'], data['C2'], data['C3']) )
		################
		if Sval == 2: # Evaluate check_mu2
			pi_star_val = np.array( pi_star_X2(data['X1'], data['W1'], data['C1'], data['C2'], data['C3']) )
		################
		if Sval == 1: # Evaluate check_mu2
			pi_star_val = np.array( pi_star_X1(data['C1'], data['C2'], data['C3']) )
		################
		check_mu = (eval_mu_x1 * pi_star_val) + (eval_mu_x0 * (1-pi_star_val))
		return check_mu

	np.random.seed(seednum)
	random.seed(123)

	results_OM = []
	results_PW = []
	results_DML = []
	kf = KFold(n_splits=L, shuffle=True)

	for train_index, test_index in kf.split(data_S1):
		'''
		Estimate OM 
		'''
		# Split the samples data_S1 and data_S2
		data_S1_train, data_S1_test = data_S1.iloc[train_index], data_S1.iloc[test_index]
		data_S2_train, data_S2_test = data_S2.iloc[train_index], data_S2.iloc[test_index]
		data_S3_train, data_S3_test = data_S3.iloc[train_index], data_S3.iloc[test_index]

		# Learn mu3_model := mu3(C1,C2,C3,X1,W1,X2,W2,X3) := E_{P3_pi3}[Y | C1,C2,C3,X1,W1,X2,W2,X3] by regressing Y onto {C1,C2,C3,X1,W1,X2,W2,X3} using S=3
		col_feature_mu3 = ['C1','C2','C3', 'X1','W1','X2','W2','X3']
		col_label_mu3 = ['Y']
		mu3_model = learn_mu(data_S3, col_feature_mu3, col_label_mu3, mu_params) # Train the model with data_S3

		# Compute \sum_{x3} mu3(C1,C2,C3,X1,W1,X2,W2,x3) * pi_star_X3(X1, W1, X2, W2, C1, C2, C3) using S2  
		check_mu3_S2 = evaluate_check_mu(col_feature_mu3, mu3_model, data_S2, 'X3', Sval=3)

		# Learn mu2_model := mu2(C1,C2,C3,X1,W1,X2) := E_{P2_pi2}[check_mu3 | C1,C2,C3,X1,W1] by regressing check_mu3 onto {C1,C2,C3,X1,W1} using S=2
		data_S2_copy = data_S2.copy()
		data_S2_copy['check_mu3'] = check_mu3_S2
		col_feature_mu2 = ['C1','C2','C3', 'X1','W1','X2']
		col_label_mu2 = ['check_mu3']
		mu2_model = learn_mu(data_S2_copy, col_feature_mu2, col_label_mu2, mu_params)

		# Compute \sum_{x2} mu2(C1,C2,C3,X1,W1,x2) * pi_star_X2(X1, W1, C1, C2, C3) evaluated from data_S1
		check_mu2_S1 = evaluate_check_mu(col_feature_mu2, mu2_model, data_S1, 'X2', Sval=2)

		# Learn mu1_model := mu2(C1,C2,C3,X1) := E_{P1_pi1}[check_mu2 | C1,C2,C3,X1] by regressing check_mu2 onto {C1,C2,C3,X1} using S=1
		data_S1_copy = data_S1.copy()
		data_S1_copy['check_mu2'] = check_mu2_S1
		col_feature_mu1 = ['C1','C2','C3', 'X1']
		col_label_mu1 = ['check_mu2']
		mu1_model = learn_mu(data_S1_copy, col_feature_mu1, col_label_mu1, mu_params)

		# Compute \sum_{x1} mu1(C1,C2,C3,x1) * pi_star_X1(C1, C2, C3) evaluated from data_S1
		check_mu1_S0 = evaluate_check_mu(col_feature_mu1, mu1_model, data_S0, 'X1', Sval=1)

		# OM
		result_OM = np.clip( np.mean(check_mu1_S0), 0, 1)
		results_OM.append(result_OM)

		'''
		Estimate PW 
		'''
		# Compute omega_3(C1,C2,C3,X1,W1,X2,W2) = (P2(C1,C2,C3,X1,W1,X2,W2) / P3(C1,C2,C3,X1,W1,X2,W2)
		model_ratio_P2_over_P3_C1C2C3X1W1X2W2 = estimate_odds_ratio(data_S2_train, data_S3_train, ['C1','C2','C3','X1','W1','X2','W2'], len(data_S3_train), lambda_params)
		matrix_ratio_P2_over_P3_C1C2C3X1W1X2W2 = xgb.DMatrix(data_S3_test[['C1','C2','C3','X1','W1','X2','W2']])
		pred_ratio_P2_over_P3_C1C2C3X1W1X2W2 = model_ratio_P2_over_P3_C1C2C3X1W1X2W2.predict(matrix_ratio_P2_over_P3_C1C2C3X1W1X2W2) # P(S=1 | C1,C2)
		### Compute P2(S=0|C1C2C3X1W1X2W2)/P(S=1|C1C2C3X1W1X2W2)
		ratio_P2_over_P3_C1C2C3X1W1X2W2 = (1-pred_ratio_P2_over_P3_C1C2C3X1W1X2W2)/(pred_ratio_P2_over_P3_C1C2C3X1W1X2W2)

		## Estmate (pi_star_X3((X1, W1, X2, W2, C1, C2,C3)) / pi_X3((X1, W1, X2, W2, C1, C2,C3)))
		pi_star_X3_over_pi_X3 = np.array( pi_star_X3(data_S3_test['X1'], data_S3_test['W1'], data_S3_test['X2'], data_S3_test['W2'], data_S3_test['C1'], data_S3_test['C2'], data_S3_test['C3']) / pi_X3(data_S3_test['X1'], data_S3_test['W1'], data_S3_test['X2'], data_S3_test['W2'], data_S3_test['C1'], data_S3_test['C2'], data_S3_test['C3']) )
		omega_3 = np.clip( (ratio_P2_over_P3_C1C2C3X1W1X2W2 * pi_star_X3_over_pi_X3), 0, 2)

		# Compute omega_2(C1,C2,C3,X1,W1) = (P1(C1,C2,C3,X1,W1) / P2(C1,C2,C3,X1,W1)
		model_ratio_P1_over_P2_C1C2C3X1W1 = estimate_odds_ratio(data_S1_train, data_S2_train, ['C1','C2','C3','X1','W1'], len(data_S2_train), lambda_params)
		matrix_ratio_P1_over_P2_C1C2C3X1W1_at_S3 = xgb.DMatrix(data_S3_test[['C1','C2','C3','X1','W1']])
		pred_ratio_P1_over_P2_C1C2C3X1W1_at_S3 = model_ratio_P1_over_P2_C1C2C3X1W1.predict(matrix_ratio_P1_over_P2_C1C2C3X1W1_at_S3) # P(S=1 | C1,C2)
		### Compute P(S=0|C1C2C3X1W1)/P(S=1|C1C2C3X1W1)
		ratio_P1_over_P2_C1C2C3X1W1_at_S3 = (1-pred_ratio_P1_over_P2_C1C2C3X1W1_at_S3)/(pred_ratio_P1_over_P2_C1C2C3X1W1_at_S3)

		## Estmate (pi_star_X2((X1, W1, C1, C2,C3)) / pi_X2((X1, W1, C1, C2,C3)))
		pi_star_X2_over_pi_X2_at_S3 = np.array( pi_star_X2(data_S3_test['X1'], data_S3_test['W1'], data_S3_test['C1'], data_S3_test['C2'], data_S3_test['C3']) / pi_X2(data_S3_test['X1'], data_S3_test['W1'], data_S3_test['C1'], data_S3_test['C2'], data_S3_test['C3']) )
		omega_2_at_S3 = np.clip( (ratio_P1_over_P2_C1C2C3X1W1_at_S3 * pi_star_X2_over_pi_X2_at_S3), 0, 2)

		# Compute omega_1(C1,C2,C3,X1,W1) = (P0(C1,C2,C3) / P1(C1,C2,C3)
		model_ratio_P0_over_P1_C1C2C3 = estimate_odds_ratio(data_S0, data_S1_train, ['C1','C2','C3'], len(data_S1_train), lambda_params)
		matrix_ratio_P0_over_P1_C1C2C3_at_S3 = xgb.DMatrix(data_S3_test[['C1','C2','C3']])
		pred_ratio_P0_over_P1_C1C2C3_at_S3 = model_ratio_P0_over_P1_C1C2C3.predict(matrix_ratio_P0_over_P1_C1C2C3_at_S3) # P(S=1 | C1,C2)
		### Compute P(S=0|C1C2C3)/P(S=1|C1C2C3)
		ratio_P0_over_P1_C1C2C3_at_S3 = (1-pred_ratio_P0_over_P1_C1C2C3_at_S3)/(pred_ratio_P0_over_P1_C1C2C3_at_S3)

		## Estmate (pi_star_X1((C1, C2,C3)) / pi_X1((C1, C2,C3)))
		pi_star_X1_over_pi_X1_at_S3 = np.array( pi_star_X1(data_S3_test['C1'], data_S3_test['C2'], data_S3_test['C3']) / pi_X1(data_S3_test['C1'], data_S3_test['C2'], data_S3_test['C3']) )
		omega_1_at_S3 = np.clip( (ratio_P0_over_P1_C1C2C3_at_S3 * pi_star_X1_over_pi_X1_at_S3), 0, 2)

		# omega_123
		omega_123 = add_noise( omega_3 * omega_2_at_S3 * omega_1_at_S3, add_noise_TF )
		result_PW = np.clip( np.mean(data_S3_test['Y'] * omega_123), 0, 1) 
		results_PW.append(result_PW)

		'''
		Estimate DML:

		result_PW + result_OM - mean_S3[ omega_123 * mu3_S3 ] + mean_S2[omega_12_S2 * (check_mu3_S2 - mu2_S2)] + mean_S1[omega_1_S1 * (check_mu2_S1 - mu1_S1)] 

		'''
		# mu3_S3, mu2_S2, mu1_S1
		mu3_S3 = add_noise( mu3_model.predict(xgb.DMatrix(data_S3_test[col_feature_mu3])), add_noise_TF )
		mu2_S2 = add_noise( mu2_model.predict(xgb.DMatrix(data_S2_test[col_feature_mu2])), add_noise_TF )
		mu1_S1 = add_noise( mu1_model.predict(xgb.DMatrix(data_S1_test[col_feature_mu1])), add_noise_TF )
		check_mu3_S2_test = evaluate_check_mu(col_feature_mu3, mu3_model, data_S2_test, 'X3', Sval=3)
		check_mu2_S1_test = evaluate_check_mu(col_feature_mu2, mu2_model, data_S1_test, 'X2', Sval=2)

		# omega_12_S2
		model_ratio_P0_over_P2_C1C2C3 = estimate_odds_ratio(data_S0, data_S2_train, ['C1','C2','C3'], len(data_S2_train),lambda_params)
		matrix_ratio_P0_over_P2_C1C2C3 = xgb.DMatrix(data_S2_test[['C1','C2','C3']])
		pred_ratio_P0_over_P2_C1C2C3 = model_ratio_P0_over_P2_C1C2C3.predict(matrix_ratio_P0_over_P2_C1C2C3) # P(S=1 | C1,C2)
		### Compute P(S=0|C)/P(S=1|C)
		ratio_P0_over_P2_C1C2C3 = (1-pred_ratio_P0_over_P2_C1C2C3)/(pred_ratio_P0_over_P2_C1C2C3)

		## Estimate P1(W,C1,C2,X1)/P2(W,C1,C2,X1)
		### Model for P(S | W,C1,C2,X1) whree S=0 means W,C1,C2,X1 from P1, and S=1 means W,C1,C2,X1 from P2 
		model_ratio_P1_over_P2_C1C2C3X1W1 = estimate_odds_ratio(data_S1_train, data_S2_train, ['C1','C2','C3','X1','W1'], len(data_S2_train),lambda_params)
		matrix_ratio_P1_over_P2_C1C2C3X1W1 = xgb.DMatrix(data_S2_test[['C1','C2','C3','X1','W1']])
		pred_ratio_P1_over_P2_C1C2C3X1W1 = model_ratio_P1_over_P2_C1C2C3X1W1.predict(matrix_ratio_P1_over_P2_C1C2C3X1W1)
		### Compute P(S=0 | W,C1,C2,X1)/P(S=1| W,C1,C2,X1)
		ratio_P1_over_P2_C1C2C3X1W1 = (1-pred_ratio_P1_over_P2_C1C2C3X1W1)/(pred_ratio_P1_over_P2_C1C2C3X1W1)

		## Estimate P2(C1,C2,X1)/P1(C1,C2,X1)
		### Model for P(S | C1,C2,X1) whree S=0 means C1,C2,X1 from P2, and S=1 means C1,C2,X1 from P1 
		model_ratio_P2_over_P1_C1C2C3X1 = estimate_odds_ratio(data_S2_train, data_S1_train, ['C1','C2','C3','X1'], len(data_S2_train),lambda_params)
		matrix_ratio_P2_over_P1_C1C2C3X1 = xgb.DMatrix(data_S2_test[['C1','C2','C3','X1']])
		pred_ratio_P2_over_P1_C1C2C3X1 = model_ratio_P2_over_P1_C1C2C3X1.predict(matrix_ratio_P2_over_P1_C1C2C3X1)
		ratio_P2_over_P1_C1C2C3X1 = (1-pred_ratio_P2_over_P1_C1C2C3X1)/(pred_ratio_P2_over_P1_C1C2C3X1)

		## Estmate (pi_star_X1(C1,C2) / pi_X1(C1,C2))
		pi_star_X1_over_pi_X1_S2 = np.array( pi_star_X1(data_S2_test['C1'], data_S2_test['C2'], data_S2_test['C3']) / pi_X1(data_S2_test['C1'], data_S2_test['C2'], data_S2_test['C3']) )

		## Estmate (pi_star_X2(X1,W,C1,C2) / pi_X2(X1,W,C1,C2))
		pi_star_X2_over_pi_X2_S2 = np.array( pi_star_X2(data_S2_test['X1'], data_S2_test['W1'], data_S2_test['C1'], data_S2_test['C2'], data_S2_test['C3']) / pi_X2(data_S2_test['X1'], data_S2_test['W1'], data_S2_test['C1'], data_S2_test['C2'], data_S2_test['C3']) )

		# omega_12_S2
		omega_12_S2 = add_noise( ratio_P0_over_P2_C1C2C3 * (ratio_P1_over_P2_C1C2C3X1W1 * ratio_P2_over_P1_C1C2C3X1) * pi_star_X1_over_pi_X1_S2 * pi_star_X2_over_pi_X2_S2, add_noise_TF)

		# Compute omega1 := {P0(C) / P1(C) } * {pi_star_X1(C) / pi_1(C)}
		## Compute {P0(C) / P1(C) }
		model_ratio_P0_over_P1_C1C2C3 = estimate_odds_ratio(data_S0, data_S1_train, ['C1','C2','C3'], len(data_S1_train), lambda_params)
		matrix_ratio_P0_over_P1_C1C2C3 = xgb.DMatrix(data_S1_test[['C1','C2','C3']])
		pred_ratio_P0_over_P1_C1C2C3 = model_ratio_P0_over_P1_C1C2C3.predict(matrix_ratio_P0_over_P1_C1C2C3)
		### Compute P(S=0 | C1,C2)/P(S=1| C1,C2)
		ratio_P0_over_P1_C1C2C3 = (1-pred_ratio_P0_over_P1_C1C2C3)/(pred_ratio_P0_over_P1_C1C2C3)

		## Compute {pi_star_X1(C) / pi_1(C)}
		pi_star_X1_over_pi_X1 = np.array( pi_star_X1(data_S1_test['C1'], data_S1_test['C2'], data_S1_test['C3']) / pi_X1(data_S1_test['C1'], data_S1_test['C2'], data_S1_test['C3']) )		

		# Omega1
		omega_1_S1 = add_noise( ratio_P0_over_P1_C1C2C3 * pi_star_X1_over_pi_X1, add_noise_TF)

		result_DML = np.clip( (result_OM + result_PW) - np.mean(omega_123 * mu3_S3) + np.mean(omega_12_S2 * (check_mu3_S2_test - mu2_S2)) + np.mean(omega_1_S1 * (check_mu2_S1_test - mu1_S1)) , 0, 1)
		results_DML.append(result_DML)

		
	return np.mean(results_OM), np.mean(results_PW), np.mean(results_DML)
	# return np.mean(results_OM), np.mean(results_PW), np.mean(results_DML)

def performance(truth, est_OM, est_PW, est_DML):
	table_data = {
		'Truth': truth,
		'OM': est_OM,
		'PW': est_PW, 
		'DML': est_DML
	}

	error_data = {
		'OM': np.abs(truth - est_OM),
		'PW': np.abs(truth - est_PW),
		'DML': np.abs(truth - est_DML)
	}

	return table_data, error_data


if __name__ == "__main__":
	experiment_seed = 190602
	n_list = [2500, 5000, 10000, 20000]
	rounds_simulations = 100
	seednum_idx = 1 
	L = 2
	add_noise_TF = True 
	n0 = 1000000

	mu_params = {
		'booster': 'gbtree',
		'eta': 0.3,
		'gamma': 0,
		'max_depth': 10,
		'min_child_weight': 1,
		'subsample': 0.75,
		'colsample_bytree': 0.75,
		'lambda': 0.0,
		'alpha': 0.0,
		'objective': 'reg:squarederror',
		'eval_metric': 'rmse',
		'n_jobs': 4  # Assuming you have 4 cores
	}
	lambda_params = {
		'booster': 'gbtree',
		'eta': 0.05,
		'gamma': 0,
		'max_depth': 20,
		'min_child_weight': 1,
		'subsample': 0.75,
		'colsample_bytree': 0.75,
		'objective': 'binary:logistic',  # Change as per your objective
		'eval_metric': 'logloss',  # Change as per your needs
		'reg_lambda': 0.0,
		'reg_alpha': 0.0,
		'nthread': 4
	}

	simulation_params = {
		'experiment_seed': experiment_seed,
		'n_list': n_list,
		'n0': n0,
		'rounds_simulations': rounds_simulations,
		'L': L,
		'add_noise_TF': add_noise_TF,
		'mu_params': mu_params, 
		'lambda_params': mu_params
	}


	seednum_list = np.random.randint(1000000, size=rounds_simulations)
	avg_acc = {"OM":[], "PW":[], "DML": []}
	ci_acc = {"OM":[], "PW":[], "DML": []}

	for n in n_list:
		n1 = n2 = n3 = n 
		avg_acc_at_n = {"OM":[], "PW":[], "DML": []}
		for seednum in seednum_list:
			data_S0, data_S1, data_S2, data_S3 = generate_samples(n0, n1, n2, n3, seednum)
			
			truth = np.mean(data_S0['Y'])
			est_OM, est_PW, est_DML = evaluate_DML(data_S0, data_S1, data_S2, data_S3, seednum, L, add_noise_TF)

			table_data, error_data = performance(truth, est_OM, est_PW, est_DML)
			avg_acc_at_n["OM"].append( error_data["OM"] )
			avg_acc_at_n["PW"].append( error_data["PW"] )
			avg_acc_at_n["DML"].append( error_data["DML"] )

			print(("%.3f%% completed") % (seednum_idx / (len(seednum_list) * len(n_list)) * 100))
			seednum_idx += 1 

		for method in ['OM', 'PW', 'DML']:
			mean, margin_of_error = mean_confidence_interval(avg_acc_at_n[method])
			avg_acc[method].append(mean)
			ci_acc[method].append(margin_of_error)

	print(table_data)
	print(error_data)

	location_file = "experiments/pkl/"
	location_fig = "experiments/plot/"
	param_filename = "param_seed" + str(experiment_seed) + "_r" + str(rounds_simulations) + "_noise" + str(add_noise_TF) + "_mgTR"
	result_filename = "result_seed" + str(experiment_seed) + "_r" + str(rounds_simulations) + "_noise" + str(add_noise_TF) + "_mgTR"
	extension = ".pkl"
	image_name = "plot_seed" + str(experiment_seed) + "_r" + str(rounds_simulations) + "_noise" + str(add_noise_TF) + "_mgTR"

	mean_OM, err_OM = avg_acc['OM'], ci_acc['OM']
	mean_PW, err_PW = avg_acc['PW'], ci_acc['PW']
	mean_DML, err_DML = avg_acc['DML'], ci_acc['DML']

	# Plotting with confidence intervals
	plt.figure(figsize=(12, 10))  # 10 inches wide, 8 inches tall
	plt.errorbar(n_list, mean_OM, yerr=err_OM, label='OM', marker='o', capsize=10, markersize=16, linewidth=5, elinewidth=3)
	plt.errorbar(n_list, mean_PW, yerr=err_PW, label='PW', marker='o', capsize=10, markersize=16, linewidth=5, elinewidth=3)
	plt.errorbar(n_list, mean_DML, yerr=err_DML, label='DML', marker='o', capsize=10, markersize=16, linewidth=5, elinewidth=3)
	plt.xticks(ticks=n_list, labels=n_list, size=35)
	# plt.ylabel("Error", fontsize=35)
	plt.yticks(size=45)
	plt.legend(prop={'size': 30})
	plt.grid(False)
	plt.savefig(location_fig + image_name + ".pdf")
	plt.show()

	









