import numpy as np
import pandas as pd
import random
from sklearn.metrics import mean_squared_error as mse
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression, LinearRegression
from xgboost import XGBRegressor
from causalml.dataset.regression import synthetic_data
from learner import *
import matplotlib.pyplot as plt


def _sigmoid(z):
    return 1 / (1 + np.exp(-z))

def _logit(p):
    return np.log(p / (1-p))

def _sign(z):
    tmp = np.ones(len(z))
    for _ in range(len(z)):
        if z[_] > 0:
            tmp[_] = 1
        elif z[_] < 0:
            tmp[_] = -1
    return tmp

def data_gen(n = 2000, alpha = 0.1):
    x = np.random.uniform(low = -1, high = 1, size = n)
    pi = 0.5 + 0.4 * _sign(x)
    a = np.random.binomial(n=1, p=pi)

    s = np.random.normal(size = (n, 2))
    rho = 0.3 * np.ones(n)
    r = np.random.binomial(n=1, p=rho, size = n)

    bs = np.array((0.1, -0.1)) @ s.T
    bx1 = (x + 2) ** 2 / 2
    bx2 = x / 2 + 0.875
    bx3 = -5 * (x - 0.2) ** 2 + 1.075
    bx4 = x + 0.125
    bx = np.zeros(n)
    for _ in range(n):
        if x[_] < -0.5:
            bx[_] = bx1[_]
        elif -0.5 <= x[_] < 0:
            bx[_] = bx2[_]
        elif 0 <= x[_] < 0.5:
            bx[_] = bx3[_]
        else:
            bx[_] = bx4[_]

    ite0 = np.ones(n)

    mu1 = bs + bx + ite0
    mu0 = bs + bx
    mu = mu1 * a + mu0 * (1-a)

    nu1 = bx + ite0
    nu0 = bx

    y = mu + np.random.normal(scale = 0.2 - 0.1 * np.cos(2 * np.pi * x))

    err = np.random.normal(loc = n ** (-alpha), scale = n ** (-alpha), size = (n, 2))
    pi_hat = _sigmoid(_logit(pi) + err[:,0])
    pi_hat = np.maximum(pi_hat, 0.01)
    pi_hat = np.minimum(pi_hat, 0.99)
    rho_hat = _sigmoid(_logit(rho) + err[:,1])
    rho_hat = np.maximum(rho_hat, 0.01)
    rho_hat = np.minimum(rho_hat, 0.99)

    err0 = np.random.normal(loc = 1, scale = 1, size = (n, 4))
    mu1_hat = mu1 + err0[:,0]
    mu0_hat = mu0 + err0[:,1]
    nu1_hat = nu1 + err0[:,0]
    nu0_hat = nu0 + err0[:,1]

    x = x.reshape(-1,1)
    return x, a, s, y, r, ite0, pi_hat, rho_hat, mu1_hat, mu0_hat, nu1_hat, nu0_hat, pi, rho, mu1, mu0, nu1, nu0

if __name__ == '__main__':
    random.seed(42)

    n = 1000
    alpha_list = np.arange(0.06, 0.5 + 0.001, 0.05)
    n_splits = 3
    n_exper = 200

    n_alpha = len(alpha_list)
    mean_t = np.zeros(n_alpha)
    mean_x = np.zeros(n_alpha)
    mean_dr = np.zeros(n_alpha)
    mean_SA = np.zeros(n_alpha)
    mean_ora = np.zeros(n_alpha)
    mean_ora0 = np.zeros(n_alpha)

    for j in range(n_alpha):
        mse_t = np.zeros(n_exper)
        mse_x = np.zeros(n_exper)
        mse_dr = np.zeros(n_exper)
        mse_SA = np.zeros(n_exper)
        mse_ora = np.zeros(n_exper)
        mse_ora0 = np.zeros(n_exper)

        for i in range(n_exper):
            if i % 50 == 0: print('alpha:', j+1, f'/{n_alpha}; n_exper:', i + 1, f'/{n_exper}')
            x, a, s, y, r, ite0, pi_hat, rho_hat, mu1_hat, mu0_hat, nu1_hat, nu0_hat, pi, rho, mu1, mu0, nu1, nu0 = data_gen(
                n = n, alpha = alpha_list[j])
            kf = KFold(n_splits=n_splits)

            ite_t = np.zeros(n)
            ite_x = np.zeros(n)
            ite_dr = np.zeros(n)
            ite_SA = np.zeros(n)
            ite_ora = np.zeros(n)
            ite_ora0 = np.zeros(n)

            for train_index, test_index in kf.split(x):
                ite_t[test_index] = nu1_hat[test_index]-nu0_hat[test_index]

                ## x-learner
                index1r = train_index[(r[train_index] == 1) & (a[train_index] == 1)]
                y1_r = y[index1r]
                d1_r = y1_r - nu0_hat[index1r]
                ite_x1 = LinearRegression()
                ite_x1.fit(x[index1r], d1_r)

                index0r = train_index[(r[train_index] == 1) & (a[train_index] == 0)]
                y0_r = y[index0r]
                d0_r = nu1_hat[index0r] - y0_r
                ite_x0 = LinearRegression()
                ite_x0.fit(x[index0r], d0_r)
                ite_x[test_index] = pi_hat[test_index]*ite_x1.predict(x[test_index])+(1-pi_hat[test_index])*ite_x0.predict(x[test_index])

                ## dr-learner
                train_index_r = train_index[r[train_index]==1]
                y_r = y[train_index_r]
                a_r = a[train_index_r]
                pi_hat_r = pi_hat[train_index_r]
                mu_1_r = nu1_hat[train_index_r]
                mu_0_r = nu0_hat[train_index_r]
                pseudo_dr_r = ((a_r-pi_hat_r)/(pi_hat_r*(1-pi_hat_r)))*(y_r-a_r*mu_1_r-(1-a_r)*mu_0_r)+ mu_1_r-mu_0_r
                tau_dr = LinearRegression()
                tau_dr.fit(x[train_index_r], pseudo_dr_r)
                ite_dr[test_index] = tau_dr.predict(x[test_index])


                # surrogate
                y_s = a[train_index] * mu1_hat[train_index] + (1-a[train_index]) * mu0_hat[train_index]
                mu_1_s = nu1_hat[train_index]
                mu_0_s = nu0_hat[train_index]

                ## SA-learner
                res_SA = r[train_index] / rho_hat[train_index] * (y[train_index] - y_s)
                + y_s - a[train_index] * mu_1_s - (1-a[train_index]) * mu_0_s
                pseudo_SA = (a[train_index] - pi_hat[train_index]) / pi_hat[train_index] / (1-pi_hat[train_index]) * res_SA + mu_1_s - mu_0_s
                tau_SA = LinearRegression()
                tau_SA.fit(x[train_index], pseudo_SA)
                ite_SA[test_index] = tau_SA.predict(x[test_index])


            mse_t[i] = mse(ite_t, ite0)
            mse_x[i] = mse(ite_x, ite0)
            mse_dr[i] = mse(ite_dr, ite0)
            mse_SA[i] = mse(ite_SA, ite0)

        mean_t[j] = mse_t.mean()
        mean_x[j] = mse_x.mean()
        mean_dr[j] = mse_dr.mean()
        mean_SA[j] = mse_SA.mean()


    plt.plot(alpha_list, mean_t, label = 'T-learner')
    plt.plot(alpha_list, mean_x, label = 'X-learner')
    plt.plot(alpha_list, mean_dr, label = 'DR-learner')
    plt.plot(alpha_list, mean_SA, label = 'SA-learner')
    plt.legend(loc='upper right')
    plt.xlabel(r'($\hat{\rho}$, $\hat{\pi}$) convergence rate ($\alpha$ in RMSE $n^{-\alpha}$)')
    plt.ylabel('PEHE')
    plt.savefig("dr.png")
    plt.show()
