import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPRegressor
import train

def I_fun(x):
    res = []
    for i in range(len(x)):
        if x[i] > 0:
            res.append(1)
        else:
            res.append(0)

    res = np.reshape(res,(x.shape[0],1))
    return res



def pred(datasets,trials,niter,batch_size,T,missing_ratio,sub_mumu,sub_sigma,lr,pref_idx):
    Reward_ours_all = np.zeros((trials, 10))
    epsilons = np.zeros((trials, 1))

    for j in range(1, trials + 1):
        print("datasets={},T={},batch_size={},missing_ratio={},sub_mumu={},sub_sigma={},pref_idx={},trials={}:".format(datasets,T, batch_size,missing_ratio, sub_mumu,
                                                                                             sub_sigma, pref_idx, j))
        # print("The", j, "th experiment:")
        if datasets == 'news':
            train_data = np.load('data/jobs_DW_bin.new.10.train.npz')
            x_tr = train_data['x'][:, :, 0]  # 2570*17
            x_tr = (x_tr - np.mean(x_tr, axis=0)) / np.std(x_tr, axis=0)
            N = x_tr.shape[0]
            out_treat = np.loadtxt(
                'data/NEWS/test/' + str(sub_mumu) + '_var' + str(sub_sigma) + '_beta_2_var1_T' + str(
                    T) + '_' + str(j) +'.txt', delimiter=',')

        elif datasets == 'ihdp':
            if j % 10 == 0:
                x_tr = np.loadtxt('data/IHDP/csv/ihdp_npci_' + str(10) + '.csv', delimiter=',')
            else:
                x_tr = np.loadtxt('data/IHDP/csv/ihdp_npci_' + str(j % 10) + '.csv', delimiter=',')
            # matrix = x_tr[:, 5:]  # covariate
            N = x_tr.shape[0]
            x_tr = (x_tr - np.mean(x_tr, axis=0)) / np.std(x_tr, axis=0)
            out_treat = np.loadtxt(
                'data/IHDP/test/' + str(sub_mumu) + '_var' + str(sub_sigma) + '_beta_3_var1_T' + str(
                    T) + '_' + str(j) + '.txt', delimiter=',')

        a_tr = out_treat[:, 0]  # treatment
        s_tr = np.reshape(out_treat[:, 1], (N, 1))  # short-term effects
        y_tr = np.reshape(out_treat[:, 2], (N, 1))  # long-term effects
        R_tr = np.reshape(out_treat[:, 3], (N, 1))  # missing indicator: 0 missing; 1 not missing
        s0_tr = np.reshape(out_treat[:, 4], (N, 1))  # below are groundtruth
        s1_tr = np.reshape(out_treat[:, 5], (N, 1))  #
        y0_tr = np.reshape(out_treat[:, 6], (N, 1))  #
        y1_tr = np.reshape(out_treat[:, 7], (N, 1))  #

        # missing_index = np.random.choice(N, round(N * missing_ratio))
        missing_mecha = np.mean(x_tr + s_tr, axis=1)
        missing_index = missing_mecha.argsort()[-round(N * missing_ratio):]

        y_tr[missing_index] = np.nan
        R_tr[missing_index] = 0

        # construct propensity score
        clf = LogisticRegression(C=1, solver='lbfgs', max_iter=1000)
        clf.fit(x_tr, a_tr)  # pr(A|X)
        p_tr = clf.predict_proba(x_tr)[:, 1]  # size: N
        p_tr = np.clip(p_tr, 0.1, 0.9)

        # construct est_r_1
        clf = LogisticRegression(C=1, solver='lbfgs', max_iter=1000)
        clf.fit(np.hstack((x_tr, s_tr, a_tr.reshape(-1, 1))), np.squeeze(R_tr))
        est_r_0 = clf.predict_proba(np.hstack((x_tr, s_tr, np.zeros([N, 1]))))[:, 1]  # size: N
        est_r_0 = np.clip(est_r_0, 0.1, 0.9)
        est_r_1 = clf.predict_proba(np.hstack((x_tr, s_tr, np.ones([N, 1]))))[:, 1]  # size: N
        est_r_1 = np.clip(est_r_1, 0.1, 0.9)

        # construct mu0, mu1
        clf = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000,
                           random_state=42)  # logisticregression as well
        clf.fit(x_tr[a_tr == 0], s_tr[a_tr == 0].ravel())
        mu0_or = clf.predict(x_tr)
        clf = MLPRegressor(hidden_layer_sizes=(100,), max_iter=1000, random_state=42)
        clf.fit(x_tr[a_tr == 1], s_tr[a_tr == 1].ravel())
        mu1_or = clf.predict(x_tr)

        # construct bar_mu0, bar_mu1
        bar_index0, bar_index1 = [], []
        for ii in range(N):
            if R_tr[ii] == 1 and a_tr[ii] == 0:
                bar_index0.append(ii)
            if R_tr[ii] == 1 and a_tr[ii] == 1:
                bar_index1.append(ii)
        clf = MLPRegressor(hidden_layer_sizes=(100,), max_iter=100000, random_state=42)
        clf.fit(x_tr[bar_index0], y_tr[bar_index0].ravel())
        bar_mu0_or = clf.predict(x_tr)
        clf = MLPRegressor(hidden_layer_sizes=(100,), max_iter=100000, random_state=42)
        clf.fit(x_tr[bar_index1], y_tr[bar_index1].ravel())
        bar_mu1_or = clf.predict(x_tr)

        # construct tilde_mu0, tilde_mu1
        xs_tr = np.concatenate((x_tr, s_tr), axis=1)
        clf = MLPRegressor(hidden_layer_sizes=(100,), max_iter=100000, random_state=42)
        clf.fit(xs_tr[bar_index0], y_tr[bar_index0].ravel())
        tilde_mu0_or = clf.predict(xs_tr)
        clf = MLPRegressor(hidden_layer_sizes=(100,), max_iter=100000, random_state=42)
        clf.fit(xs_tr[bar_index1], y_tr[bar_index1].ravel())
        tilde_mu1_or = clf.predict(xs_tr)

        OR_ours = train.OR_model_YS(input_size=x_tr.shape[1])
        weights_ours = OR_ours.fit(x_tr, a_tr, s_tr, y_tr, p_tr, R_tr, est_r_1, est_r_0, mu0_or, mu1_or, bar_mu0_or,
                                   bar_mu1_or, tilde_mu0_or, tilde_mu1_or
                                   , niter, npref, pref_idx, batch_size=batch_size,lr=lr,lamb=1e-1)
        weights_ours = weights_ours.cpu()

        re_mu0_or = np.reshape(mu0_or,(mu0_or.shape[0],1))
        ###
        b = weights_ours.numpy()
        epsilon = -np.mean(I_fun((y1_tr-y0_tr)+b[0]/b[1]*(s1_tr-s0_tr))*(s1_tr-s0_tr) + re_mu0_or)
        epsilons[j-1,0] = epsilon
        epsiloni = epsilons.mean(axis=0)
        ###

        pred_ours = OR_ours.predict(x_tr)  # actions with estimated policy

        OR_linear = train.OR_model_linear(input_size=x_tr.shape[1])
        weights_w = OR_linear.fit_w(x_tr, a_tr, s_tr, y_tr, p_tr, R_tr, est_r_1, est_r_0, mu0_or, mu1_or,
                                    bar_mu0_or, bar_mu1_or, tilde_mu0_or, tilde_mu1_or
                                    , niter, npref, pref_idx, weights_ours,batch_size=batch_size,lr=lr,lamb=1e-1)
        weights_w = weights_w.cpu()
        pred_w = OR_linear.predict(x_tr)  # actions with estimated policy


        # evaluate according to the reward defined (overall)
        welfare_s = s1_tr - s0_tr
        welfare_m = 0.5*(s1_tr - s0_tr)+0.5*(y1_tr - y0_tr)
        welfare_y = y1_tr - y0_tr
        ### ParetoMTL
        # short and long-term reward
        Reward_ours_all[j - 1, 0] = np.sum(
            pred_ours.reshape(-1, 1) * s1_tr + (1 - pred_ours.reshape(-1, 1)) * s0_tr)
        Reward_ours_all[j - 1, 1] = np.sum(
            pred_ours.reshape(-1, 1) * y1_tr + (1 - pred_ours.reshape(-1, 1)) * y0_tr)
        # short and long-term Delta_W
        Reward_ours_all[j - 1, 2] = sum(welfare_s[pred_ours == 1])
        Reward_ours_all[j - 1, 3] = sum(welfare_y[pred_ours == 1])
        Reward_ours_all[j - 1, 4] = sum(welfare_m[pred_ours == 1])

        ### linear
        # short and long-term reward
        Reward_ours_all[j - 1, 5] = np.sum(pred_w.reshape(-1, 1) * s1_tr + (1 - pred_w.reshape(-1, 1)) * s0_tr)
        Reward_ours_all[j - 1, 6] = np.sum(pred_w.reshape(-1, 1) * y1_tr + (1 - pred_w.reshape(-1, 1)) * y0_tr)

        # short and long-term Delta_w
        Reward_ours_all[j - 1, 7] = sum(welfare_s[pred_w == 1])
        Reward_ours_all[j - 1, 8] = sum(welfare_y[pred_w == 1])
        Reward_ours_all[j - 1, 9] = sum(welfare_m[pred_w == 1])

        rewardi = Reward_ours_all.mean(axis=0)
        variance = Reward_ours_all.std(axis=0)

        list = [datasets, trials, niter, batch_size, T, missing_ratio, sub_mumu, sub_sigma, pref_idx, rewardi[0],variance[0], rewardi[1], variance[1], \
                rewardi[2], variance[2], rewardi[3], variance[3], rewardi[4], variance[4], rewardi[5], variance[5],
                rewardi[6], variance[6], rewardi[7], variance[7],rewardi[8], variance[8],rewardi[9], variance[9],epsiloni[0]]


    return list



if __name__ == '__main__':

    from joblib import Parallel, delayed

    trials = [50]
    niter = 300
    npref = 10

    datasets = ['news','ihdp']
    missing_set = [0.2]
    lr=[0.01]

    para = []
    for data in range(len(datasets)):
        dataset = datasets[data]
        if dataset == 'ihdp':
            allT = [4]
            mumu = [0]
            sigma = [3]
            batch = [128]
        elif dataset == 'news':
            allT = [4]
            mumu = [0]
            sigma = [1]
            batch = [512]
        for tr in range(len(trials)):
            trial = trials[tr]
            for b in range(len(batch)):
                batch_size = batch[b]
                for tt in range(len(allT)):
                    T = allT[tt]
                    for m in range(len(missing_set)):
                        missing_ratio = missing_set[m] #
                        for mm in range(len(mumu)):
                            sub_mumu = mumu[mm]
                            for ss in range(len(sigma)):
                                sub_sigma = sigma[ss]
                                for learn_ratio in range(len(lr)):
                                    sub_lr = lr[learn_ratio]
                                    for pref_idx in range(npref):
                                        paralist = [dataset,trial,niter,batch_size,T,missing_ratio,sub_mumu,sub_sigma,sub_lr,pref_idx]
                                        para.append(paralist)





    result = Parallel(n_jobs=-1)(delayed(pred)(*arglist) for arglist in para)
    print(result)

print("Done!")


