import matplotlib.pyplot as plt
from example.generate_data import generate_demonstration_data
from lasso_selector import SelectiveBorrowingLasso
from influence_selector import Influenceselector



if __name__ == '__main__':

    x_rct, y_rct, x_ext, y_ext = generate_demonstration_data()
    top_k = 100

    # influence selector
    selector_if = Influenceselector(x_ext, y_ext, x_rct, y_rct)
    selector_if.fit()
    if_idx = selector_if.select_samples(top_k=top_k)

    sbf = SelectiveBorrowingLasso(threshold=1e-4)
    sbf.fit(x_rct, y_rct, x_ext, y_ext)

    # Results
    print(f"Optimal parameters: lambda_N={sbf.best_params_['alpha']:.4f}, nu={sbf.best_params_['nu']}")
    print(f"Selected {len(sbf.selected_indices_)}/{len(x_ext)} comparable subjects")

    x_selected, y_selected = sbf.select_samples(x_ext, y_ext, top_k=top_k)
    theta_0, theta_1 = sbf.return_coef()

    # 可视化结果
    plt.figure(figsize=(8, 6))
    # plt.subplot(121)
    plt.scatter(x_ext, y_ext, c='gray', alpha=0.7, label='External Controls')
    plt.scatter(x_rct, y_rct, c='blue', marker='x', label='RCT')
    plt.scatter(x_ext[if_idx], y_ext[if_idx], c='red', label='Influence Selected')
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel("x", fontsize=18, fontweight='bold')
    plt.ylabel("y", fontsize=18, fontweight='bold')
    plt.legend(
        loc='best',
        fontsize=16,
        framealpha=0.5,
        facecolor='white',
        edgecolor='gray'
    )
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.savefig("influence.png")

    plt.figure(figsize=(8, 6))
    # plt.subplot(122)
    plt.scatter(x_ext, y_ext, c='gray', alpha=0.7, label='External Controls')
    plt.scatter(x_rct, y_rct, c='blue', marker='x', label='RCT')
    plt.scatter(x_selected, y_selected, c='green', label='Lasso Selected')
    plt.plot(x_rct, theta_0[0] + theta_0[1] * x_rct,color='blue')
    plt.plot(x_ext, theta_1[0] + theta_1[1] * x_ext, color='gray')
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel("x", fontsize=18, fontweight='bold')
    plt.ylabel("y", fontsize=18, fontweight='bold')
    plt.legend(
        loc='best',
        fontsize=16,
        framealpha=0.5,
        facecolor='white',
        edgecolor='gray'
    )
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.savefig("lasso.png")
    plt.show()
