import os
import numpy as np
import pandas as pd

class IHDP(object):
    def __init__(
            self,
            path,
            trial,
            center=False,
            exclude_population=False
    ):
        self.trial = trial
        train_dataset = np.load(
            os.path.join(path, 'ihdp_npci_1-1000.train.npz')
            # os.path.join(path, 'ihdp_npci_1-100.train.npz')
        )
        test_dataset = np.load(
            os.path.join(path, 'ihdp_npci_1-1000.test.npz')
            # os.path.join(path, 'ihdp_npci_1-100.train.npz')
        )
        self.train_data = get_trial(
            dataset=train_dataset,
            trial=trial,
            training=True,
            exclude_population=exclude_population
        )
        # self.y_mean = self.train_data['y'].mean(dtype='float32')
        # self.y_std = self.train_data['y'].std(dtype='float32')
        self.test_data = get_trial(
            dataset=test_dataset,
            trial=trial,
            training=False,
            exclude_population=exclude_population
        )
        # self.dim_x_cont = self.train_data['x_cont'].shape[-1]
        # self.dim_x_bin = self.train_data['x_bin'].shape[-1]
        # self.dim_x = self.dim_x_cont + self.dim_x_bin

    def get_training_data(self):
        x, y, t = self.preprocess(self.train_data)
        mu0, mu1 = self.get_mu(test_set = False)
        examples_per_treatment = t.sum(0)
        return x, y, t, mu0, mu1, examples_per_treatment

    def get_test_data(self, test_set=True):
        _data = self.test_data if test_set else self.train_data
        x, y, t = self.preprocess(_data)
        examples_per_treatment = t.sum(0)
        mu1 = _data['mu1'].astype('float32')
        mu0 = _data['mu0'].astype('float32')
        #cate = mu1 - mu0
        return x, y, t, mu0, mu1, examples_per_treatment

    def get_subpop(self, test_set=True):
        _data = self.test_data if test_set else self.train_data
        return _data['ind_subpop']

    def get_t(self, test_set=True):
        _data = self.test_data if test_set else self.train_data
        return _data['t']

    def preprocess(self, dataset):
        x = np.hstack([dataset['x_cont'], dataset['x_bin']])
        #y = (dataset['y'].astype('float32') - self.y_mean) / self.y_std
        y = dataset['y_factual'].astype('float32')
        t = dataset['t'].astype('float32')
        return x, y, t
    def get_mu(self, test_set=True):
        _data = self.test_data if test_set else self.train_data
        return _data['mu0'], _data['mu1']

def get_arm_idx(t):
    return np.where(t==0), np.where(t==1)

def get_trial(
        dataset,
        trial,
        training=True,
        exclude_population=False
):
    bin_feats = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
    cont_feats = [i for i in range(25) if i not in bin_feats]
    ind_subpop = dataset['x'][:, bin_feats[2], trial].astype('bool')
    x = dataset['x'][:, :, trial]
    if exclude_population:
        x = np.delete(x, bin_feats[2], axis=-1)
        bin_feats.pop(2)
        if training:
            idx_included = np.where(ind_subpop)[0]
        else:
            idx_included = np.arange(dataset['x'].shape[0], dtype='int32')
    else:
        idx_included = np.arange(dataset['x'].shape[0], dtype='int32')
    x_bin = dataset['x'][:, bin_feats, trial][idx_included]
    # x_bin[:, 7] -= 1.
    t = dataset['t'][:, trial]
    # t_in = np.zeros((len(t), 2), 'float32')
    # t_in[:, 0] = 1 - t
    # t_in[:, 1] = t
    trial_data = {
        'x_bin': x_bin.astype('float32'),
        'x_cont': dataset['x'][:, cont_feats, trial][idx_included].astype('float32'),
        'y_factual': dataset['yf'][:, trial][idx_included],
        'treatment': t[idx_included],
        #'t': t_in[idx_included],
        'y_cfactual': dataset['ycf'][:, trial][idx_included],
        'mu0': dataset['mu0'][:, trial][idx_included],
        'mu1': dataset['mu1'][:, trial][idx_included],
        # 'ate': dataset['ate'],
        # 'yadd': dataset['yadd'],
        # 'ymul': dataset['ymul'],
        # 'ind_subpop': ind_subpop[idx_included]
    }

    return trial_data

def process_data(path='', trial=4):
    data = IHDP(path=path, trial = trial)

    x_combined = np.hstack([data.train_data['x_cont'], data.train_data['x_bin']])

    # Create column names for the x covariates
    # Assuming `x_bin` has 19 columns and `x_cont` has 6 columns
    x_bin_dim = data.train_data['x_bin'].shape[1]
    x_cont_dim = data.train_data['x_cont'].shape[1]
    x_columns = [f"Z{i+1}" for i in range(x_bin_dim + x_cont_dim)]

    # Create a DataFrame for the combined x covariates
    df_x = pd.DataFrame(x_combined, columns=x_columns)

    # Now, extract the remaining arrays (y, t, etc.) and add them to the DataFrame
    df_other = pd.DataFrame({key: data.train_data[key] for key in data.train_data if key not in ['x_cont', 'x_bin']})

    # Concatenate the x DataFrame with the other columns
    df_train = pd.concat([df_x, df_other], axis=1)

    x_combined = np.hstack([data.test_data['x_cont'], data.test_data['x_bin']])

    # Create column names for the x covariates
    # Assuming `x_bin` has 19 columns and `x_cont` has 6 columns
    x_bin_dim = data.test_data['x_bin'].shape[1]
    x_cont_dim = data.test_data['x_cont'].shape[1]
    x_columns = [f"Z{i+1}" for i in range(x_bin_dim + x_cont_dim)]

    # Create a DataFrame for the combined x covariates
    df_x = pd.DataFrame(x_combined, columns=x_columns)

    # Now, extract the remaining arrays (y, t, etc.) and add them to the DataFrame
    df_other = pd.DataFrame({key: data.test_data[key] for key in data.test_data if key not in ['x_cont', 'x_bin']})

    # Concatenate the x DataFrame with the other columns
    df_test = pd.concat([df_x, df_other], axis=1)

    df = pd.concat([df_train, df_test],axis=0)
    return df
    


