import numpy as np
import pandas as pd
import random
from sklearn.metrics import mean_squared_error as mse
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression, LinearRegression
from xgboost import XGBRegressor
from causalml.dataset.regression import synthetic_data
from learner import *
import matplotlib.pyplot as plt


def link_r(s):
    return 1 / (1 + np.exp(-s))

def get_synthetic(n, mode):
    y, x, a, tau, b, e = synthetic_data(mode=mode, n=n, p=5, sigma=1.0)
    ps = 2
    s1 = np.random.normal(loc = 1, scale = 1, size = (n, ps))
    s0 = np.random.normal(loc = -1, scale = 1, size = (n, ps))
    s = np.zeros((n,ps))
    for i in range(n):
        if a[i] == 1:
            s[i] = s1[i]
            y[i] -= sum(s[i])
        else:
            s[i] = s0[i]
    pr = np.zeros(n)
    for i in range(n):
        pr[i] = link_r(sum(s[i]) / ps + 1)
    r = (np.random.rand(n) < pr).astype(int)
    return x, a, s, y, r, tau-2, pr, e

if __name__ == '__main__':
    n_list = [1000, 2000, 3000]
    res = []
    n_exper = 200
    n_splits = 3

    for n in n_list:
        mse_s = np.zeros(n_exper)
        mse_t = np.zeros(n_exper)
        mse_x = np.zeros(n_exper)
        mse_dr = np.zeros(n_exper)
        mse_Ss = np.zeros(n_exper)
        mse_St = np.zeros(n_exper)
        mse_Sx = np.zeros(n_exper)
        mse_Sdr = np.zeros(n_exper)
        mse_SA = np.zeros(n_exper)

        kf = KFold(n_splits=n_splits)
        for i in range(n_exper):
            x, a, s, y, r, ite0, rho0, p0 = get_synthetic(n = n, mode = 1)
            X = np.column_stack((a, s, x))

            tau_s = np.zeros(n)
            tau_t = np.zeros(n)
            tau_x = np.zeros(n)
            tau_dr = np.zeros(n)
            tau_Ss = np.zeros(n)
            tau_St = np.zeros(n)
            tau_Sx = np.zeros(n)
            tau_Sdr = np.zeros(n)
            tau_SA = np.zeros(n)
            for train_index, test_index in kf.split(x):
                train_index_r = train_index[r[train_index] == 1]
                s_learner = BaseSRegressor()
                s_learner.fit(x[train_index_r], a[train_index_r], y[train_index_r])
                tau_s[test_index] = s_learner.predict(x[test_index])

                t_learner = BaseTRegressor()
                t_learner.fit(x[train_index_r], a[train_index_r], y[train_index_r])
                tau_t[test_index] = t_learner.predict(x[test_index])

                x_learner = BaseXRegressor()
                x_learner.fit(x[train_index_r], a[train_index_r], y[train_index_r])
                tau_x[test_index] = x_learner.predict(x[test_index])

                dr_learner = BaseDRRegressor()
                dr_learner.fit(x[train_index_r], a[train_index_r], y[train_index_r])
                tau_dr[test_index] = dr_learner.predict(x[test_index])

                sr = XGBRegressor()
                sr.fit(X[train_index_r], y[train_index_r])
                mu = sr.predict(X[train_index])

                Ss_learner = BaseSRegressor()
                Ss_learner.fit(x[train_index], a[train_index], mu)
                tau_Ss[test_index] = Ss_learner.predict(x[test_index])

                St_learner = BaseTRegressor()
                St_learner.fit(x[train_index], a[train_index], mu)
                tau_St[test_index] = St_learner.predict(x[test_index])

                Sx_learner = BaseXRegressor()
                Sx_learner.fit(x[train_index], a[train_index], mu)
                tau_Sx[test_index] = Sx_learner.predict(x[test_index])

                Sdr_learner = BaseDRRegressor()
                Sdr_learner.fit(x[train_index], a[train_index], mu)
                tau_Sdr[test_index] = Sdr_learner.predict(x[test_index])

                SA_learner = SurrogateRegressor()
                SA_learner.fit(x[train_index], a[train_index], s[train_index], y[train_index], r[train_index])
                tau_SA[test_index] = SA_learner.predict(x[test_index])

            mse_s[i] = mse(tau_s, ite0)
            mse_t[i] = mse(tau_t, ite0)
            mse_x[i] = mse(tau_x, ite0)
            mse_dr[i] = mse(tau_dr, ite0)
            mse_Ss[i] = mse(tau_Ss, ite0)
            mse_St[i] = mse(tau_St, ite0)
            mse_Sx[i] = mse(tau_Sx, ite0)
            mse_Sdr[i] = mse(tau_Sdr, ite0)
            mse_SA[i] = mse(tau_SA, ite0)

        out = {
            'Meta-learner': ['S-learner', 'T-learner', 'X-learner', 'DR-learner',
                             'S-learner (S)', 'T-learner (S)', 'X-learner (S)', 'DR-learner (S)', 'SA-learner'],
            'AMSE': [mse_s.mean(), mse_t.mean(), mse_x.mean(), mse_dr.mean(),
                    mse_Ss.mean(), mse_St.mean(), mse_Sx.mean(), mse_Sdr.mean(), mse_SA.mean()],
            'SE': [mse_s.std(), mse_t.std(), mse_x.std(), mse_dr.std(),
                    mse_Ss.std(), mse_St.std(), mse_Sx.std(), mse_Sdr.std(), mse_SA.std()]
        }
        res.append(pd.DataFrame(out))

    mean0 = res[0].iloc[:,1]
    mean1 = res[1].iloc[:,1]
    mean2 = res[2].iloc[:,1]

    tmp = np.arange(9)  # x locations
    width = 0.1  # bar width

    fig, ax = plt.subplots()
    bars1 = ax.bar(tmp - width, mean0, width, label="n = 1000")
    bars2 = ax.bar(tmp, mean1, width, label="n = 2000")
    bars3 = ax.bar(tmp + width, mean1, width, label="n = 3000")
    ax.set_ylabel("PEHE")
    ax.set_xticks(tmp)
    ax.set_xticklabels(['S', 'T', 'X', 'DR', 'S1' , 'T1', 'X1', 'DR1', 'SA'])
    ax.legend()
    plt.savefig('pehe.png')
    plt.show()
