import numpy as np
import matplotlib.pyplot as plt

# ==========================================
# 1.
# ==========================================
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times", "DejaVu Serif"]


# ==========================================
# 2. 数据生成类 (针对可视化简化版)
# ==========================================
class KillerDatasetViz:
    def __init__(self, n_samples=300):
        self.dim = 1
        self.n_samples = n_samples
        self.b_true = 1.0

        np.random.seed(42)

        # === true weight ===
        self.w_true = np.array([2.0])

        # ---------------------------------------
        # A. Normal Data, 70%
        # ---------------------------------------
        n_center = int(0.7 * n_samples)
        self.x_norm = np.random.uniform(-2, 2, (n_center, self.dim))
        self.y_norm = self.x_norm @ self.w_true + self.b_true + np.random.normal(0, 1.0, n_center)

        # ---------------------------------------
        # B. High Leverage, 10%
        # ---------------------------------------
        # samples extend far along the direction of $w$, reaching a distance of approximately 10
        n_lev = int(0.1 * n_samples)
        direction = self.w_true / np.linalg.norm(self.w_true)
        center_lev = direction * 10.0
        self.x_lev = center_lev + np.random.normal(0, 0.2, (n_lev, self.dim))
        self.y_lev = self.x_lev @ self.w_true + self.b_true + np.random.normal(0, 0.1, n_lev)

        # ---------------------------------------
        # C. Proximal Outliers
        # ---------------------------------------
        # located near the center, but with an inverted slope and a large positive bias ($+10.0$).
        n_out = int(0.2 * n_samples)
        self.x_out = np.random.normal(0, 1.0, (n_out, self.dim))

        w_outlier = -1.0 * self.w_true  # 权重反转
        self.y_out = self.x_out @ w_outlier + self.b_true + 10.0 + np.random.normal(0, 0.5, n_out)

    def plot_dataset(self):
        fig, ax = plt.subplots(figsize=(10, 7), dpi=150)
        # Normal
        ax.scatter(self.x_norm, self.y_norm, c='blue', alpha=0.3, s=40,
                   edgecolors='k', linewidth=0.5, label='Normal (70%)')

        # Leverage
        ax.scatter(self.x_lev, self.y_lev, c='green', alpha=0.3, s=60, marker='^',
                   edgecolors='k', linewidth=0.5, label='High Leverage (10%)')

        # Outliers
        ax.scatter(self.x_out, self.y_out, c='red', alpha=0.3, s=50, marker='X',
                   edgecolors='k', linewidth=0.5, label='Proximal Outliers (20%)')

        x_min = min(self.x_norm.min(), self.x_out.min()) - 1
        x_max = self.x_lev.max() + 1
        x_line = np.linspace(x_min, x_max, 100).reshape(-1, 1)
        y_line = x_line @ self.w_true + self.b_true

        ax.plot(x_line, y_line, color='black', linestyle='--', linewidth=2.5,
                label='Truth Model')

        ax.set_xlabel(r"Feature $X$", fontsize=24)
        ax.set_ylabel(r"Label $Y$", fontsize=24)
        ax.tick_params(axis='both', labelsize=14)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=12, loc='lower right')
        for spine in ax.spines.values():
            spine.set_linewidth(1.8)
        ax.legend(fontsize=20)

        plt.tight_layout()
        plt.savefig('data_2d.pdf', format='pdf', bbox_inches='tight')
        plt.show()


# ==========================================
# 3.
# ==========================================
if __name__ == "__main__":
    dataset = KillerDatasetViz(n_samples=1000)
    dataset.plot_dataset()