import torch
import pandas as pd
from utils import dataframe_to_dict, MLP1, MLP2, Exponential_regression
from estimator import aipw_borrow
from lasso_selector import SelectiveBorrowingLasso
from influence_selector import InfluenceSelector, InfluenceSelector_mlp, InfluenceSelector_exp
from sklearn.linear_model import LogisticRegression, LinearRegression
from prettytable import PrettyTable
from sklearn.preprocessing import MinMaxScaler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def pred(dataset, control_n, top_k, batch_size, lr, mu):

    bias_full_list, se_full_list, bv_full_list = [], [], []
    bias_lasso_list, se_lasso_list, bv_lasso_list = [], [], []
    bias_if_list, se_if_list, bv_if_list = [], [], []

    print("dataset = {}".format(dataset))
    print("control_n={}".format(control_n))

    if dataset == "NSW":
        rct_treated_file = "dataset/NSW/nswre74_treated.txt"
        rct_control_file = "dataset/NSW/nswre74_control.txt"
        ext_file = "dataset/NSW/PSID_control.txt"

        # 读取RCT数据
        data_rct_treated = pd.read_csv(rct_treated_file, sep='\s+', header=None)
        data_rct_control = pd.read_csv(rct_control_file, sep='\s+', header=None)
        data_rct_treated.columns = ["A", "age", "education", "Black", "Hispanic", "married", "nodegree", "RE74", "RE75", "Y"]
        data_rct_treated["R"] = 1
        data_rct_control.columns = ["A", "age", "education", "Black", "Hispanic", "married", "nodegree", "RE74", "RE75", "Y"]
        data_rct_control["R"] = 1

        # 读取外部数据
        data_ext = pd.read_csv(ext_file, sep='\s+', header=None)
        data_ext.columns = ["A", "age", "education", "Black", "Hispanic", "married", "nodegree", "RE74", "RE75", "Y"]
        data_ext["R"] = 0

        numeric_cols = ["age", "education", "RE74", "RE75", "Y"]
        # 标准化
        scaler_rct_treat = MinMaxScaler()
        data_rct_treated[numeric_cols] = scaler_rct_treat.fit_transform(data_rct_treated[numeric_cols])
        scaler_rct_control = MinMaxScaler()
        data_rct_control[numeric_cols] = scaler_rct_control.fit_transform(data_rct_control[numeric_cols])
        scaler_ext = MinMaxScaler()
        data_ext[numeric_cols] = scaler_ext.fit_transform(data_ext[numeric_cols])

        X = ["age", "education", "Black", "Hispanic", "married", "nodegree", "RE74", "RE75"]

        X_ext = data_ext[X].values
        y_ext = data_ext["Y"].values

    elif dataset == "linear":
        rct_file = "dataset/simulated_data/rct_data_linear_" + str(mu) + ".csv"
        ext_file = "dataset/simulated_data/rwe_data_linear_" + str(mu) + ".csv"
        X = ["X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8"]

        data_1 = pd.read_csv(rct_file)
        data_rct_treated = data_1[data_1['A'] == 1].copy()
        data_rct_treated["R"] = 1
        data_rct_control = data_1[data_1['A'] == 0].copy()
        data_rct_control["R"] = 1

        data_ext = pd.read_csv(ext_file)
        data_ext["R"] = 0

        X_ext = data_ext[X].values
        y_ext = data_ext["Y"].values

    elif dataset == "exp":
        rct_file = "dataset/simulated_data/rct_data_exp_" + str(mu) + ".csv"
        ext_file = "dataset/simulated_data/rwe_data_exp_" + str(mu) + ".csv"
        X = ["X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8"]

        data_1 = pd.read_csv(rct_file)
        data_rct_treated = data_1[data_1['A'] == 1].copy()
        data_rct_treated["R"] = 1
        data_rct_control = data_1[data_1['A'] == 0].copy()
        data_rct_control["R"] = 1

        data_ext = pd.read_csv(ext_file)
        data_ext["R"] = 0

        X_ext = data_ext[X].values
        y_ext = data_ext["Y"].values

    data_ext_lasso = data_ext.copy()
    data_ext_if = data_ext.copy()

    data_rct_control["S"] = 1
    data_rct_treated["S"] = 1

    selected = data_rct_control.sample(n=control_n, random_state=42)

    data_rct = pd.concat([data_rct_treated, selected], ignore_index=True)
    full_data = pd.concat([data_rct_treated, selected, data_ext], ignore_index=True)

    rct_dict = dataframe_to_dict(data_rct, X)
    full_data_dict = dataframe_to_dict(full_data, X)
    selected_control_dict = dataframe_to_dict(selected, X)

    # # lasso选择器
    selector_lasso = SelectiveBorrowingLasso(dataset=dataset, batch_size=batch_size, lr=lr, num_epoch=10000, threshold=1e-6)
    selector_lasso.fit(selected_control_dict["X"], selected_control_dict["Y"], X_ext, y_ext)
    lasso_score = selector_lasso.return_b_tilde()

    # # Results
    print(f"Optimal parameters: lambda_N={selector_lasso.best_params_['alpha']:.4f}, nu={selector_lasso.best_params_['nu']}")
    print(f"Selected {len(selector_lasso.selected_indices_)}/{len(X_ext)} comparable subjects")

    ps_model = LogisticRegression(C=1, solver='lbfgs', max_iter=1000).fit(rct_dict["X"], rct_dict["A"])
    if dataset == "NSW":
        mu1_model = MLP2(rct_dict["X"].shape[1], 16, 1).to(device)
        mu1_model.fit(rct_dict["X"][rct_dict["A"] == 1], rct_dict["Y"][rct_dict["A"] == 1], batch_size=batch_size)
        mu0_model = MLP1(rct_dict["X"].shape[1], 16, 1).to(device)
        mu0_model.fit(rct_dict["X"][rct_dict["A"] == 0], rct_dict["Y"][rct_dict["A"] == 0], batch_size=32)

        selector_if = InfluenceSelector_mlp(dataset=dataset, threshold=0.1, x_test=X_ext, y_test=y_ext,
                                            x_train=selected_control_dict["X"], y_train=selected_control_dict["Y"],
                                            batch_size=batch_size, lr=lr, num_epoch=10000)
        selector_if.fit()
        influence_score = selector_if.return_influence()
    elif dataset == "linear":
        mu1_model = LinearRegression()
        mu1_model.fit(rct_dict["X"][rct_dict["A"] == 1], rct_dict["Y"][rct_dict["A"] == 1])
        mu0_model = LinearRegression()
        mu0_model.fit(rct_dict["X"][rct_dict["A"] == 0], rct_dict["Y"][rct_dict["A"] == 0])

        selector_if = InfluenceSelector(dataset=dataset, threshold=0.1, x_test=X_ext, y_test=y_ext,
                                        x_train=selected_control_dict["X"], y_train=selected_control_dict["Y"])
        selector_if.fit()
        influence_score = selector_if.return_influence()

    elif dataset == "exp":
        mu1_model = Exponential_regression()
        mu1_model.fit(rct_dict["X"][rct_dict["A"] == 1], rct_dict["Y"][rct_dict["A"] == 1])
        mu0_model = Exponential_regression()
        mu0_model.fit(rct_dict["X"][rct_dict["A"] == 0], rct_dict["Y"][rct_dict["A"] == 0])

        selector_if = InfluenceSelector_exp(dataset=dataset, threshold=0.1, x_test=X_ext, y_test=y_ext,
                                        x_train=selected_control_dict["X"], y_train=selected_control_dict["Y"])
        selector_if.fit()
        influence_score = selector_if.return_influence()

    tau_full, tau_exp, bias_full, se_rct, bv_full, se_full, CI_full = aipw_borrow(full_data_dict, ps_model, mu0_model ,mu1_model, batch_size, dataset)

    if dataset == "NSW":
        dir_markdown = "results/markdowntable/" + dataset + "/" + "control_n_" + str(control_n)
        if not os.path.exists(dir_markdown):
            os.makedirs(dir_markdown)
    elif dataset == "linear":
        dir_markdown = "results/markdowntable/linear/" + dataset + "_mu_" + str(mu) + "/" + "control_n_" + str(control_n)
        if not os.path.exists(dir_markdown):
            os.makedirs(dir_markdown)
    elif dataset == "exp":
        dir_markdown = "results/markdowntable/exp/" + dataset + "_mu_" + str(mu) + "/" + "control_n_" + str(control_n)
        if not os.path.exists(dir_markdown):
            os.makedirs(dir_markdown)
    for i in range(len(top_k)):
        print("top_k: ", top_k[i])
        lasso_idx = selector_lasso.select_samples(top_k=top_k[i])
        data_ext_lasso["S"] = 0
        data_ext_lasso.loc[lasso_idx, "S"] = 1

        if_idx = selector_if.select_samples(top_k=top_k[i])
        data_ext_if["S"] = 0
        data_ext_if.loc[if_idx, "S"] = 1

        lasso_data = pd.concat([data_rct_treated, selected, data_ext_lasso[data_ext_lasso["S"] == 1]],ignore_index=True)
        if_data = pd.concat([data_rct_treated, selected, data_ext_if[data_ext_if["S"] == 1]], ignore_index=True)

        lasso_data_dict = dataframe_to_dict(lasso_data,X)
        if_data_dict = dataframe_to_dict(if_data,X)


        valid_metric_header = ['estimator', 'tau_borrow', 'tau_exp', 'bias', 'se_rct', "bv", "se", 'CI']
        table = PrettyTable(valid_metric_header)

        table.add_row(["full", tau_full, tau_exp, bias_full, se_rct, bv_full, se_full, CI_full])
        bias_full_list.append(bias_full)
        se_full_list.append(se_full)
        bv_full_list.append(bv_full)

        tau_lasso, tau_exp, bias_lasso, se_rct, bv_lasso, se_lasso, CI_lasso = aipw_borrow(lasso_data_dict, ps_model, mu0_model, mu1_model, batch_size, dataset)
        table.add_row(["lasso", tau_lasso, tau_exp, bias_lasso, se_rct, bv_lasso, se_lasso, CI_lasso])
        bias_lasso_list.append(bias_lasso)
        se_lasso_list.append(se_lasso)
        bv_lasso_list.append(bv_lasso)

        tau_if, tau_exp, bias_if, se_rct, bv_if, se_if, CI_if = aipw_borrow(if_data_dict, ps_model, mu0_model, mu1_model, batch_size, dataset)
        table.add_row(["if", tau_if, tau_exp, bias_if, se_rct, bv_if, se_if, CI_if])
        bias_if_list.append(bias_if)
        se_if_list.append(se_if)
        bv_if_list.append(bv_if)

        with open(dir_markdown + "/" + "valid_markdowntable" + "_top_k_" + str(top_k[i]) + ".txt", 'w') as fp:
            fp.write(table.get_string())

    return lasso_score, influence_score, bias_full_list, se_full_list, bv_full_list, bias_lasso_list, se_lasso_list, bv_lasso_list, bias_if_list, se_if_list, bv_if_list

if __name__ == '__main__':
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import os

    datasets = ["linear"]
    batch_size = 32
    lr = 0.0001

    for dataset in datasets:
        if dataset == "linear":
            mu_ = [0.1,0.2,0.3,0.4,0.5]
            for mu in mu_:
                influence_list = []
                lasso_list = []
                control_list = [70,80,90,100]  # 70,80
                for control_n in control_list:
                    top_k = [10, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800]
                    dir_ = dataset + "_mu_" + str(mu) + "/" + "control_n_" + str(control_n)
                    path = "results/figure/linear/" + dir_
                    if not os.path.exists(path):
                        os.makedirs(path)

                    lasso_score, influence_score, bias_full_list, se_full_list, bv_full_list, bias_lasso_list, se_lasso_list, bv_lasso_list, bias_if_list, se_if_list, bv_if_list = pred(
                        dataset=dataset, control_n=control_n,
                        top_k=top_k, batch_size=batch_size, lr=lr, mu=mu)
                    lasso_list = list(lasso_score)
                    influence_list = list(influence_score)

                    # 创建画布
                    plt.figure(figsize=(8, 6))
                    plt.plot(top_k, bias_full_list, marker='s', linestyle='--', color='red',label=r'$\hat{\tau}_{full}$', linewidth=2)
                    plt.plot(top_k, bias_lasso_list, marker='^', linestyle=':', color='green',label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                    plt.plot(top_k, bias_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$',linewidth=2)

                    # 添加标题、坐标轴标签和图例
                    plt.xticks(fontsize=18)
                    plt.yticks(fontsize=18)
                    plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                    plt.ylabel("Bias", fontsize=18, fontweight='bold')
                    plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white', edgecolor='gray')

                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.tight_layout()
                    plt.savefig(path + "/" + "bias_tau" + ".png")
                    plt.close()

                    # 创建画布
                    plt.figure(figsize=(8, 6))
                    plt.plot(top_k, se_full_list, marker='s', linestyle='--', color='red', label=r'$\hat{\tau}_{full}$',linewidth=2)
                    plt.plot(top_k, se_lasso_list, marker='^', linestyle=':', color='green',label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                    plt.plot(top_k, se_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$',linewidth=2)

                    # 添加标题、坐标轴标签和图例
                    plt.xticks(fontsize=18)
                    plt.yticks(fontsize=18)
                    plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                    plt.ylabel("Standard Error", fontsize=18, fontweight='bold')
                    plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white', edgecolor='gray')

                    # 显示网格并调整布局
                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.tight_layout()
                    plt.savefig(path + "/" + "se_tau" + ".png")
                    plt.close()

                    # 创建画布
                    plt.figure(figsize=(8, 6))
                    plt.plot(top_k, bv_full_list, marker='s', linestyle='--', color='red', label=r'$\hat{\tau}_{full}$',linewidth=2)
                    plt.plot(top_k, bv_lasso_list, marker='^', linestyle=':', color='green',label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                    plt.plot(top_k, bv_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$',linewidth=2)

                    plt.xticks(fontsize=18)
                    plt.yticks(fontsize=18)
                    plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                    plt.ylabel("MSE", fontsize=18, fontweight='bold')
                    plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white',edgecolor='gray')

                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.tight_layout()
                    # plt.savefig("results/figure/" + dataset + "_se_mlp.png")
                    plt.savefig(path + "/" + "mse_tau" + ".png")
                    plt.close()


        elif dataset == "exp":
            mu_ = [0.1,0.2,0.3,0.4,0.5]
            for mu in mu_:
                influence_list = []
                lasso_list = []
                control_list = [70,80,90,100]
                for control_n in control_list:
                    top_k = [10, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800]
                    dir_ = dataset + "_mu_" + str(mu) + "/" + "control_n_" + str(control_n)
                    path = "results/figure/exp/" + dir_
                    if not os.path.exists(path):
                        os.makedirs(path)

                    lasso_score, influence_score, bias_full_list, se_full_list, bv_full_list, bias_lasso_list, se_lasso_list, bv_lasso_list, bias_if_list, se_if_list, bv_if_list = pred(
                        dataset=dataset, control_n=control_n,
                        top_k=top_k, batch_size=batch_size, lr=lr, mu=mu)
                    lasso_list = list(lasso_score)
                    influence_list = list(influence_score)

                    # 创建画布
                    plt.figure(figsize=(8, 6))
                    plt.plot(top_k, bias_full_list, marker='s', linestyle='--', color='red',label=r'$\hat{\tau}_{full}$', linewidth=2)
                    plt.plot(top_k, bias_lasso_list, marker='^', linestyle=':', color='green',label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                    plt.plot(top_k, bias_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$',linewidth=2)

                    # 添加标题、坐标轴标签和图例
                    plt.xticks(fontsize=18)
                    plt.yticks(fontsize=18)
                    plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                    plt.ylabel("Bias", fontsize=18, fontweight='bold')
                    plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white', edgecolor='gray')

                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.tight_layout()
                    plt.savefig(path + "/" + "bias_tau" + ".png")
                    plt.close()

                    # 创建画布
                    plt.figure(figsize=(8, 6))
                    plt.plot(top_k, se_full_list, marker='s', linestyle='--', color='red', label=r'$\hat{\tau}_{full}$',linewidth=2)
                    plt.plot(top_k, se_lasso_list, marker='^', linestyle=':', color='green',label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                    plt.plot(top_k, se_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$',linewidth=2)

                    # 添加标题、坐标轴标签和图例
                    plt.xticks(fontsize=18)
                    plt.yticks(fontsize=18)
                    plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                    plt.ylabel("Standard Error", fontsize=18, fontweight='bold')
                    plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white', edgecolor='gray')

                    # 显示网格并调整布局
                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.tight_layout()
                    plt.savefig(path + "/" + "se_tau" + ".png")
                    plt.close()

                    # 创建画布
                    plt.figure(figsize=(8, 6))
                    plt.plot(top_k, bv_full_list, marker='s', linestyle='--', color='red', label=r'$\hat{\tau}_{full}$',linewidth=2)
                    plt.plot(top_k, bv_lasso_list, marker='^', linestyle=':', color='green',label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                    plt.plot(top_k, bv_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$',linewidth=2)

                    plt.xticks(fontsize=18)
                    plt.yticks(fontsize=18)
                    plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                    plt.ylabel("MSE", fontsize=18, fontweight='bold')
                    plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white',edgecolor='gray')

                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.tight_layout()
                    # plt.savefig("results/figure/" + dataset + "_se_mlp.png")
                    plt.savefig(path + "/" + "mse_tau" + ".png")
                    plt.close()


        elif dataset == "NSW":
            mu = 0.0
            influence_list = []
            lasso_list = []
            control_n_list = [70,75,80,85]
            for control_n in control_n_list:
                top_k = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100, 120, 128]  # , 140, 160, 180, 200, 220, 240, 250
                dir_ = dataset + "/" + "control_n_" + str(control_n)
                path = "results/figure/" + dir_
                if not os.path.exists(path):
                    os.makedirs(path)

                lasso_score, influence_score, bias_full_list, se_full_list, bv_full_list, bias_lasso_list, se_lasso_list, bv_lasso_list, bias_if_list, se_if_list, bv_if_list = pred(
                    dataset=dataset, control_n=control_n, top_k=top_k, batch_size = batch_size, lr = lr, mu=mu)
                lasso_list = list(lasso_score)
                influence_list = list(influence_score)

                # 创建画布
                plt.figure(figsize=(8, 6))
                plt.plot(top_k, bias_full_list, marker='s', linestyle='--', color='red', label=r'$\hat{\tau}_{full}$', linewidth=2)
                plt.plot(top_k, bias_lasso_list, marker='^', linestyle=':', color='green', label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                plt.plot(top_k, bias_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$', linewidth=2)

                plt.xticks(fontsize=18)
                plt.yticks(fontsize=18)
                plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                plt.ylabel("Bias", fontsize=18, fontweight='bold')
                plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white', edgecolor='gray')

                # 显示网格并调整布局
                plt.grid(True, linestyle='--', alpha=0.6)
                plt.tight_layout()
                plt.savefig(path + "/" + "bias_tau" + ".png")
                plt.close()

                # 创建画布
                plt.figure(figsize=(8, 6))
                plt.plot(top_k, se_full_list, marker='s', linestyle='--', color='red', label=r'$\hat{\tau}_{full}$', linewidth=2)
                plt.plot(top_k, se_lasso_list, marker='^', linestyle=':', color='green', label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                plt.plot(top_k, se_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$', linewidth=2)

                plt.xticks(fontsize=18)
                plt.yticks(fontsize=18)
                plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                plt.ylabel("Standard Error", fontsize=18, fontweight='bold')
                plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white', edgecolor='gray')

                plt.grid(True, linestyle='--', alpha=0.6)
                plt.tight_layout()
                plt.savefig(path + "/" + "se_tau" + ".png")
                plt.close()

                # 创建画布
                plt.figure(figsize=(8, 6))
                plt.plot(top_k, bv_full_list, marker='s', linestyle='--', color='red', label=r'$\hat{\tau}_{full}$', linewidth=2)
                plt.plot(top_k, bv_lasso_list, marker='^', linestyle=':', color='green', label=r'$\hat{\tau}_{lasso}$', linewidth=2)
                plt.plot(top_k, bv_if_list, marker='D', linestyle='-.', color='purple', label=r'$\hat{\tau}_{if}$', linewidth=2)

                plt.xticks(fontsize=18)
                plt.yticks(fontsize=18)
                plt.xlabel("Top_k", fontsize=18, fontweight='bold')
                plt.ylabel("MSE", fontsize=18, fontweight='bold')
                plt.legend(loc='best', fontsize=18, framealpha=0.5, facecolor='white', edgecolor='gray')

                plt.grid(True, linestyle='--', alpha=0.6)
                plt.tight_layout()
                plt.savefig(path + "/" + "mse_tau" + ".png")
                plt.close()

