import sys

import numpy as np
import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def get_data(config):
	dcfg = config.data
	print()
	if config.dataset == 'dissim':
		return get_dissim(config.exp_num, dcfg.train_data, dcfg.test_data, dcfg.val_rate, dcfg.max_samples)
	elif config.dataset == 'cisim':
		return get_cisim(config.exp_num, dcfg.train_data, dcfg.test_data, dcfg.val_rate)
	elif config.dataset == 'ihdp100':
		return get_ihdp100(config.exp_num, dcfg.train_data, dcfg.test_data, dcfg.val_rate)
	elif config.dataset == 'ihdp1000':
		return get_ihdp1000(config.exp_num, dcfg.train_data, dcfg.test_data, dcfg.val_rate)
	elif config.dataset == 'jobs':
		return get_jobs(config.exp_num, dcfg.train_data, dcfg.test_data, dcfg.val_rate)
	elif config.dataset == 'toy1':
		return get_toy(n_samples=10000, case=1, noise='gaussian')
	elif config.dataset == 'toy2':
		return get_toy(n_samples=10000, case=2, noise='gaussian')
	elif config.dataset == 'toy3':
		return get_toy(n_samples=10000, case=3, noise='gaussian')
	elif config.dataset == 'toy4':
		return get_toy(n_samples=10000, case=4, noise='gaussian')
	elif config.dataset == 'toy5':
		return get_toy(n_samples=10000, case=5, noise='gaussian')
	elif config.dataset == 'toy6':
		return get_toy(n_samples=10000, case=6, noise='gaussian')
	elif config.dataset == 'toy7':
		return get_toy(n_samples=10000, case=7, noise='gaussian')
	elif config.dataset == 'toy8':
		return get_toy(n_samples=10000, case=8, noise='gaussian')
	elif config.dataset == 'toy9':
		return get_toy(n_samples=10000, case=9, noise='gaussian')


def get_ihdp100(exp_num, train_data, test_data, val_rate):
	data_in = dict(np.load(train_data))
	example = {}
	I = np.random.permutation(range(0, len(data_in['x'])))
	n_valid = int(len(data_in['x']) * val_rate)
	n_train = len(data_in['x']) - n_valid
	I_train = I[:n_train]
	I_valid = I[n_train:]

	example['x_train'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_train]
	example['t_train'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_train]
	example['y_f_train'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_train]
	example['y_cf_train'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1][I_train]
	example['mu0_train'] = torch.from_numpy(data_in['mu0']).float()[:, exp_num:exp_num + 1][I_train]
	example['mu1_train'] = torch.from_numpy(data_in['mu1']).float()[:, exp_num:exp_num + 1][I_train]

	example['x_valid'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_valid]
	example['t_valid'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_valid]
	example['y_f_valid'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_valid]
	example['y_cf_valid'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1][I_valid]
	example['mu0_valid'] = torch.from_numpy(data_in['mu0']).float()[:, exp_num:exp_num + 1][I_valid]
	example['mu1_valid'] = torch.from_numpy(data_in['mu1']).float()[:, exp_num:exp_num + 1][I_valid]

	data_in = dict(np.load(test_data))
	example['x_test'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num]
	example['t_test'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1]
	example['y_f_test'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1]
	example['y_cf_test'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1]
	example['mu0_test'] = torch.from_numpy(data_in['mu0']).float()[:, exp_num:exp_num + 1]
	example['mu1_test'] = torch.from_numpy(data_in['mu1']).float()[:, exp_num:exp_num + 1]


	return example

def get_ihdp1000(exp_num, train_data, test_data, val_rate):
	data_in = dict(np.load(train_data))
	example = {}
	I = np.random.permutation(range(0, len(data_in['x'])))
	n_valid = int(len(data_in['x']) * val_rate)
	n_train = len(data_in['x']) - n_valid
	I_train = I[:n_train]
	I_valid = I[n_train:]

	example['x_train'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_train]
	example['t_train'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_train]
	example['y_f_train'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_train]
	example['y_cf_train'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1][I_train]
	example['mu0_train'] = torch.from_numpy(data_in['mu0']).float()[:, exp_num:exp_num + 1][I_train]
	example['mu1_train'] = torch.from_numpy(data_in['mu1']).float()[:, exp_num:exp_num + 1][I_train]

	example['x_valid'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_valid]
	example['t_valid'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_valid]
	example['y_f_valid'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_valid]
	example['y_cf_valid'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1][I_valid]
	example['mu0_valid'] = torch.from_numpy(data_in['mu0']).float()[:, exp_num:exp_num + 1][I_valid]
	example['mu1_valid'] = torch.from_numpy(data_in['mu1']).float()[:, exp_num:exp_num + 1][I_valid]

	data_in = dict(np.load(test_data))
	example['x_test'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num]
	example['t_test'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1]
	example['y_f_test'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1]
	example['y_cf_test'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1]
	example['mu0_test'] = torch.from_numpy(data_in['mu0']).float()[:, exp_num:exp_num + 1]
	example['mu1_test'] = torch.from_numpy(data_in['mu1']).float()[:, exp_num:exp_num + 1]

	return example


def get_jobs(exp_num, train_data, test_data, val_rate):
	data_in = dict(np.load(train_data))
	example = {}
	print(data_in['x'].shape, ' >><<<<<<<< jobs shape')
	# we only have 1 split for jobs dataset
	# so we randomly shuffle the training data and report the average results
	rng = np.random.RandomState(exp_num)
	I = rng.permutation(range(0, len(data_in['x'])))
	exp_num = 0
	n_valid = int(len(data_in['x']) * val_rate)
	n_train = len(data_in['x']) - n_valid
	I_train = I[:n_train]
	I_valid = I[n_train:]

	example['x_train'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_train]
	example['t_train'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_train]
	example['e_train'] = torch.from_numpy(data_in['e']).float()[:, exp_num:exp_num + 1][I_train]
	example['y_f_train'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_train]

	example['x_valid'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_valid]
	example['t_valid'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_valid]
	example['e_valid'] = torch.from_numpy(data_in['e']).float()[:, exp_num:exp_num + 1][I_valid]
	example['y_f_valid'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_valid]

	data_in = dict(np.load(test_data))
	example['x_test'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num]
	example['t_test'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1]
	example['e_test'] = torch.from_numpy(data_in['e']).float()[:, exp_num:exp_num + 1]
	example['y_f_test'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1]

	print('training: %d samples, valid: %d samples, test: %d samples' % (len(example['y_f_train']),
																		 len(example['y_f_valid']),
																		 len(example['y_f_test']))),
	return example


def get_dissim(exp_num, train_data, test_data, val_rate, max_samples):
	data_in = dict(np.load(train_data))
	example = {}
	I = np.random.permutation(range(0, len(data_in['x'])))
	n_valid = int(len(data_in['x']) * val_rate)
	n_train = len(data_in['x']) - n_valid
	I_train = I[:n_train]
	I_valid = I[n_train:]

	example['x_train'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_train][:max_samples]
	example['t_train'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_train][:max_samples]
	example['y_f_train'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_train][:max_samples]
	example['y_cf_train'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1][I_train][:max_samples]

	example['x_valid'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_valid]
	example['t_valid'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_valid]
	example['y_f_valid'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_valid]

	data_in = dict(np.load(test_data))
	example['x_test'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num]
	example['t_test'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1]
	example['y_f_test'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1]
	example['y_cf_test'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1]

	print(example['x_train'].size(), example['t_train'].size())
	return example


def get_cisim(exp_num, train_data, test_data, val_rate):
	data_in = dict(np.load(train_data))
	example = {}
	I = np.random.permutation(range(0, len(data_in['x'])))
	n_valid = int(len(data_in['x']) * val_rate)
	n_train = len(data_in['x']) - n_valid
	I_train = I[:n_train]
	I_valid = I[n_train:]

	example['x_train'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_train]
	example['t_train'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_train]
	example['y_f_train'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_train]
	example['y_cf_train'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1][I_train]

	example['x_valid'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num][I_valid]
	example['t_valid'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1][I_valid]
	example['y_f_valid'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1][I_valid]

	data_in = dict(np.load(test_data))
	example['x_test'] = torch.from_numpy(data_in['x']).float()[:, :, exp_num]
	example['t_test'] = torch.from_numpy(data_in['t']).float()[:, exp_num:exp_num + 1]
	example['y_f_test'] = torch.from_numpy(data_in['yf']).float()[:, exp_num:exp_num + 1]
	example['y_cf_test'] = torch.from_numpy(data_in['ycf']).float()[:, exp_num:exp_num + 1]
	return example



def mixing_function(x, t, e, case=1):
	if case == 1:
		yf = x + t + e
	elif case == 2:
		print('Generating case 2')
		yf = np.sin(x + 2 * np.pi * t) + e
	elif case == 3:
		yf = (np.exp(t - x + 0.5)) * e
	elif case == 4:
		yf = np.exp(np.sin(x + np.pi * t) + e)
	elif case == 5:
		yf = np.exp(-5 * t + x) + np.exp(t - x + 0.5) * e
	elif case == 6:
		yf = np.exp(np.sin(x + np.pi * t + e**2) + np.cos(e))
	elif case == 7:
		yf = np.exp(np.sin(x + np.pi * t + np.cos(e)) + np.cos(e))
	elif case == 8:
		yf = np.exp(np.sin(x + np.pi * t) + e**2)
	elif case == 9:
		yf = np.exp(np.sin(x + np.pi * t) + np.cos(e))
	return yf


@torch.no_grad()
def get_toy(n_samples=10000, case=1, noise='gaussian'):
	example = {}
	start = 0
	end = 1
	if case == 6:
		start = -3
		end = 3
	x = np.random.uniform(0, 1., n_samples)
	t = np.random.uniform(start, end, n_samples)
	if noise == 'gaussian':
		e = np.random.normal(0, 1, n_samples)
	else:
		e = np.random.uniform(0, 1, n_samples)

	example['x_train'] = torch.from_numpy(x).float().view(-1, 1)
	example['t_train'] = torch.from_numpy(t).float().view(-1, 1)
	example['e_train'] = torch.from_numpy(e).float().view(-1, 1)
	example['x_test'] = torch.from_numpy(np.array([0.5]).reshape(1, 1)).float()
	example['t_test'] = torch.from_numpy(np.array([0.5]).reshape(1, 1)).float()
	example['e_test'] = torch.from_numpy(np.array([0.5]).reshape(1, 1)).float()

	example['x_val'] = example['x_test'].repeat(100, 1)
	example['t_val'] = torch.linspace(start, end, 100).view(100, 1)
	example['e_val'] = example['e_test'].repeat(100, 1)

	xx = torch.cat([example['x_train'], example['x_test'], example['x_val']], 0)
	tt = torch.cat([example['t_train'], example['t_test'], example['t_val']], 0)
	ee = torch.cat([example['e_train'], example['e_test'], example['e_val']], 0)
	yy = mixing_function(xx, tt, ee, case)

	example['y_f_train'] = (yy[:len(example['x_train'])]).float().view(-1, 1)
	example['y_f_test'] = (yy[len(example['x_train'])]).float().view(1, 1)
	example['y_f_val'] = (yy[len(example['x_train']) + 1:]).float().view(-1, 1)
	assert len(example['y_f_train']) == len(example['x_train'])
	assert len(example['y_f_test']) == 1
	assert len(example['y_f_val']) == 100
	return example


def generate_dissim(max_split=10):
	df = pd.read_excel('datasets/simulated_CI_discrete.xlsx')
	arr = df.to_numpy()
	train = []
	test = []

	for split in range(max_split):
		all_ids = np.random.permutation(1000)+1
		train_ids = all_ids[:700]
		test_ids = all_ids[700:]
		test_split = []
		train_split = []
		for id in test_ids:
			rows = arr[arr[:,0]==id]
			use = np.random.choice(2)
			if use == 0:
				new_row = list(rows[0][1:])
				new_row += [rows[1][-1]]
			else:
				new_row = list(rows[1][1:])
				new_row += [rows[0][-1]]
			test_split.append(new_row)

		for id in train_ids:
			rows = arr[arr[:,0]==id]
			use = np.random.choice(2)
			if use == 0:
				new_row = list(rows[0][1:])
				new_row += [rows[1][-1]]
			else:
				new_row = list(rows[1][1:])
				new_row += [rows[0][-1]]
			train_split.append(new_row)

		train_split = np.stack(train_split, 0)
		test_split = np.stack(test_split, 0)
		train.append(train_split)
		test.append(test_split)

	train = np.stack(train, -1)
	test = np.stack(test, -1)
	print(train.shape, test.shape)
	train_dict = {'x': train[:,:1,:], 't': train[:,1,:], 'yf': train[:,2,:], 'ycf': train[:,3,:]}
	np.savez('datasets/dissim.train.npz', **train_dict)

	test_dict = {'x': test[:,:1,:], 't': test[:,1,:], 'yf': test[:,2,:], 'ycf': test[:,3,:]}
	np.savez('datasets/dissim.test.npz', **test_dict)


def generate_cisim_ori(max_split=10):
	df = pd.read_excel('datasets/simulated_CI2.xlsx')
	arr = df.to_numpy()
	train = []
	test = []

	for split in range(max_split):
		all_ids = np.random.permutation(1000)+1
		train_ids = all_ids[:700]
		test_ids = all_ids[700:]
		test_split = []
		train_split = []
		for id in test_ids:
			rows = arr[arr[:,0]==id]
			use = np.random.choice(len(rows))
			if use == 10:
				use += 1
			new_row = list(rows[use][1:])
			new_row += [rows[len(rows)-1-use][-1]]
			test_split.append(new_row)

		for id in train_ids:
			rows = arr[arr[:,0]==id]
			use = np.random.choice(len(rows))
			if use == 10:
				use += 1
			new_row = list(rows[use][1:])
			new_row += [rows[len(rows)-1-use][-1]]
			train_split.append(new_row)

		train_split = np.stack(train_split, 0)
		test_split = np.stack(test_split, 0)
		print(train_split, test_split)

		train.append(train_split)
		test.append(test_split)

	train = np.stack(train, -1)
	test = np.stack(test, -1)
	print(train.shape, test.shape)
	train_dict = {'x': train[:,:1,:], 't': train[:,1,:], 'yf': train[:,2,:], 'ycf': train[:,3,:]}
	np.savez('datasets/cisim.train.npz', **train_dict)

	test_dict = {'x': test[:,:1,:], 't': test[:,1,:], 'yf': test[:,2,:], 'ycf': test[:,3,:]}
	np.savez('datasets/cisim.test.npz', **test_dict)



import matplotlib.patheffects as path_effects
def add_median_labels(ax, fmt='.3f'):
    lines = ax.get_lines()
    boxes = [c for c in ax.get_children() if type(c).__name__ == 'PathPatch']
    lines_per_box = int(len(lines) / len(boxes))
    for median in lines[4:len(lines):lines_per_box]:
        x, y = (data.mean() for data in median.get_data())
        # choose value depending on horizontal or vertical plot orientation
        value = x if (median.get_xdata()[1] - median.get_xdata()[0]) == 0 else y
        text = ax.text(x, y, f'{value:{fmt}}', ha='center', va='center',
                       fontweight='bold', color='white', fontdict={'size':22})
        # create median-colored border around white text for contrast
        text.set_path_effects([
            path_effects.Stroke(linewidth=3, foreground=median.get_color()),
            path_effects.Normal(),
        ])

def read_data():
	outs = []
	for seed in range(10):
		fn = 'checkpoints/toy5_seed%d/tau.txt' % (seed+1)
		out = np.loadtxt(fn)
		outs.append(out.reshape(-1).tolist()[-1])
	print(outs)

def compute_cdf_by_monte_carlo(case=1):
	taus = []
	tau1 = [0.6722978353500366, 0.6762803792953491, 0.6960402131080627, 0.659024178981781, 0.6978422403335571, 0.6740140914916992, 0.7006696462631226, 0.678325891494751, 0.7019245028495789, 0.7118008732795715]
	tau2 = [0.6976876258850098, 0.6842520236968994, 0.7172662615776062, 0.7056241631507874, 0.6760654449462891, 0.665776789188385, 0.6967891454696655, 0.6928293108940125, 0.7068043351173401, 0.7077009081840515]
	tau3 = [0.6998239159584045, 0.6959832310676575, 0.6954031586647034, 0.6757087707519531, 0.7139696478843689, 0.6975005865097046, 0.6956238746643066, 0.6790811419487, 0.6665158867835999, 0.7024075984954834]
	tau4 = [0.6985689401626587, 0.6920350790023804, 0.6582285165786743, 0.6793773174285889, 0.7091084718704224, 0.6971206665039062, 0.6779647469520569, 0.6949737071990967, 0.7160433530807495, 0.6349159479141235]
	tau5 = [0.6827809810638428, 0.6938003897666931, 0.6883903741836548, 0.6758599877357483, 0.7111649513244629, 0.7110852003097534, 0.6991679072380066, 0.6735031008720398, 0.6846495866775513, 0.708888828754425]
	all_taus = [tau1, tau2, tau3, tau4, tau5]
	for seed in np.arange(1, 11):
		np.random.seed(seed)
		dataset = get_toy(n_samples=10000, case=case, noise='gaussian')
		x_f = dataset['x_train'].numpy()
		t_f = dataset['t_train'].numpy()
		y_f = dataset['y_f_train'].numpy()
		y_f_test = dataset['y_f_test'].item()
		satis_ys = []
		small_ys = []
		for i in range(len(x_f)):
			if np.abs(x_f[i]-0.5)<=0.02 and np.abs(t_f[i]-0.5)<=0.02:
				satis_ys.append(y_f[i])
				if y_f[i]<=y_f_test:
					small_ys.append(y_f[i])
		if len(satis_ys)>0:
			taus.append(len(small_ys)/len(satis_ys))
	print(taus)
	import pandas as pd
	sns.set(style="darkgrid")
	df = pd.DataFrame()
	df['Monte Carlo'] = taus
	df['Ours'] = all_taus[case-1]
	my_pal = {"Monte Carlo": "g", "Ours": "m"}
	ax = sns.boxplot(df, palette=my_pal)
	add_median_labels(ax)
	plt.xticks(size=22)
	plt.ylabel('Estimated Quantile', fontdict={'size':22})
	plt.title('Toy%d' %case, fontdict={'size':22})
	plt.grid()
	#plt.show()
	plt.savefig('images/toy%d_boxplot.pdf' % case, bbox_inches='tight')


def generate_cisim():
	import numpy as np
	import pandas as pd

	# X->Y, Z->Y

	all_train_xs = []
	all_train_ts = []
	all_train_yf = []
	all_train_ycf = []

	all_test_xs = []
	all_test_ts = []
	all_test_yf = []
	all_test_ycf = []

	for seed in range(10):
		np.random.seed(seed)

		train_xs = []
		train_ts = []
		train_yf = []
		train_ycf = []

		for i in range(1, 7000 + 1):

			# generate Z: age
			Z_tmp = np.random.randint(10, 51)  # age between 10 and 51

			# generate X: continuous treatment
			X_tmp = np.random.choice(np.arange(0, 2.1, 0.1))
			if X_tmp == 1.0:
				X_tmp = X_tmp + 0.1
			X_test = 2.0 - X_tmp

			# generate U: noise term for each individual
			U_tmp = np.random.randn(1)

			# generate Y
			#Y_tmp = 2*(X_tmp ** 2) +  np.log((X_tmp+0.1) * 0.5) - (Z_tmp ** (1 / 3)) + 5 * U_tmp
			#Y_cf_tmp = 2*(X_test ** 2) + np.log((X_test+0.1) * 0.5) - (Z_tmp ** (1 / 3)) + 5 * U_tmp
			Y_tmp = X_tmp + U_tmp + (Z_tmp/100)
			Y_cf_tmp = X_test + U_tmp + (Z_tmp/100)


			train_xs.append(Z_tmp)
			train_ts.append(X_tmp)
			train_yf.append(Y_tmp)
			train_ycf.append(Y_cf_tmp)

		all_train_xs.append(np.array(train_xs).reshape(-1,1))
		all_train_ts.append(np.array(train_ts))
		all_train_yf.append(np.array(train_yf).reshape(-1))
		all_train_ycf.append(np.array(train_ycf).reshape(-1))

		print(train_xs[:10], train_ts[:10], train_yf[:10], train_ycf[:10])

		test_xs = []
		test_ts = []
		test_yf = []
		test_ycf = []

		for i in range(1, 3000+1):

			# generate Z: age
			Z_tmp = np.random.randint(10, 51)  # age between 10 and 51

			# generate X: continuous treatment
			X_tmp = np.random.choice(np.arange(0, 2.1, 0.1))
			if X_tmp == 1.0:
				X_tmp = X_tmp + 0.1
			X_test = 2.0 - X_tmp

			# generate U: noise term for each individual
			U_tmp = np.random.randn(1)

			# generate Y
			#Y_tmp = 2*(X_tmp ** 2) + np.log((X_tmp+0.1) * 0.5) - (Z_tmp ** (1 / 3)) + 5 * U_tmp
			#Y_cf_tmp = 2*(X_test ** 2) + np.log((X_test+0.1) * 0.5) - (Z_tmp ** (1 / 3)) + 5 * U_tmp
			Y_tmp = X_tmp + U_tmp + (Z_tmp/100)
			Y_cf_tmp = X_test + U_tmp + (Z_tmp/100)

			test_xs.append(Z_tmp)
			test_ts.append(X_tmp)
			test_yf.append(Y_tmp)
			test_ycf.append(Y_cf_tmp)

		all_test_xs.append(np.array(test_xs).reshape(-1,1))
		all_test_ts.append(np.array(test_ts))
		all_test_yf.append(np.array(test_yf).reshape(-1))
		all_test_ycf.append(np.array(test_ycf).reshape(-1))


	all_train_xs = np.stack(all_train_xs, -1)
	all_train_ts = np.stack(all_train_ts, -1)
	all_train_yf = np.stack(all_train_yf, -1)
	all_train_ycf = np.stack(all_train_ycf, -1)

	all_test_xs = np.stack(all_test_xs, -1)
	all_test_ts = np.stack(all_test_ts, -1)
	all_test_yf = np.stack(all_test_yf, -1)
	all_test_ycf = np.stack(all_test_ycf, -1)




	print(all_train_xs.shape, all_train_ts.shape, all_train_yf.shape, all_train_ycf.shape)
	print(all_test_xs.shape, all_test_ts.shape, all_test_yf.shape, all_test_ycf.shape)

	train_dict = {'x': all_train_xs, 't': all_train_ts, 'yf': all_train_yf, 'ycf': all_train_ycf}
	np.savez('datasets/cisim.train.npz', **train_dict)

	test_dict = {'x': all_test_xs, 't': all_test_ts, 'yf': all_test_yf, 'ycf': all_test_ycf}
	np.savez('datasets/cisim.test.npz', **test_dict)


def create_confounder(n_samples, case, confounder):
	e = np.random.normal(0, 1, n_samples)
	if confounder == 1:
		c = np.random.uniform(0, 1, n_samples)
		x = 2 * c - np.random.uniform(0, 1, n_samples)
		t = -1 * c + 2 * np.random.uniform(0, 1, n_samples)
		y = mixing_function(x, t, e, case)
	elif confounder == 2:
		c = np.random.uniform(0, 1, n_samples)
		x = c + np.random.uniform(0, 1, n_samples)
		t = np.random.uniform(0, 1, n_samples)
		y = mixing_function(x, t, e, case) + c
	elif confounder == 3:
		c = np.random.uniform(0, 1, n_samples)
		t = c + np.random.uniform(0, 1, n_samples)
		x = np.random.uniform(0, 1, n_samples)
		y = mixing_function(x, t, e, case) + c
	return x, t, e, y, c


def get_confounder(n_samples=10000, case=1, confounder=3, noise='gaussian'):
	example = {}
	x, t, e, yf, c = create_confounder(n_samples, case, confounder)

	example['x_train'] = torch.from_numpy(x).float().view(-1, 1)
	example['t_train'] = torch.from_numpy(t).float().view(-1, 1)
	example['yf_train'] = torch.from_numpy(yf).float().view(-1, 1)

	example['x_test'] = np.array([0.5]).reshape(1, 1)
	example['t_test'] = np.array([0.5]).reshape(1, 1)
	example['e_test'] = np.array([0.5]).reshape(1, 1)
	conf = 0.2
	if confounder == 2:
		example['x_test'] += conf
	elif confounder == 3:
		example['t_test'] += conf
	example['y_f_test'] = mixing_function(example['x_test'], example['t_test'], example['e_test'], case)
	if confounder >= 2:
		example['y_f_test'] += conf
	example['x_test'] = torch.from_numpy(example['x_test']).float()
	example['t_test'] = torch.from_numpy(example['t_test']).float()
	example['e_test'] = torch.from_numpy(example['e_test']).float()
	example['y_f_test'] = torch.from_numpy(example['y_f_test']).float()

	return example


if __name__ == '__main__':
	#read_data()
	#compute_cdf_by_monte_carlo()
	generate_cisim_ori()


