import numpy as np
import matplotlib.pyplot as plt


def generate_demonstration_data(n_rct=100, n_ext=200, seed=42):
    np.random.seed(seed)

    # RCT数据 (主体趋势)
    x_rct = np.linspace(0, 2, n_rct)
    y_rct = 2 * x_rct + np.random.normal(0, 0.2, n_rct)

    # 普通样本  # 高影响力区域样本  # 异常值样本
    x_ext = np.concatenate([
        np.random.uniform(0.0, 1.5, 100),
        np.random.uniform(1.5, 2.0, 95),
        np.array([1.7,1.8,1.9,2.0,2.0])
    ])
    y_ext = np.concatenate([
        -1.0 + 2.5 * x_ext[:100] + np.random.normal(0, 0.5, 100),  # 2.5
        -1.0 + 2.5 * x_ext[100:n_ext - 5] + np.random.normal(0, 0.5, 95),  # 2.5
        0.5 * np.ones(5)
    ])

    return x_rct, y_rct, x_ext, y_ext

if __name__ == '__main__':
    x_rct, y_rct, x_ext, y_ext = generate_demonstration_data()

    plt.figure(figsize=(10, 6))
    plt.scatter(x_ext, y_ext, c='gray', alpha=0.5, label='External Data')
    plt.scatter(x_rct, y_rct, c='blue', marker='x', label='RCT Data')
    plt.scatter(x_ext[100:195], y_ext[100:195], c='green', label='High-Influence Points')
    plt.scatter(x_ext[-5:], y_ext[-5:], c='black', label='Outliers')
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.legend()
    plt.title("Simulated data")
    plt.show()