import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler


class Influenceselector:
    def __init__(self, x_test, y_test, x_train, y_train):
        self.x_test = x_test
        self.y_test = y_test
        self.x_train = x_train
        self.y_train = y_train

    def fit(self):

        if self.x_train.ndim == 1:
            X_train = self.x_train.reshape(-1, 1)
            X_test = self.x_test.reshape(-1, 1)
            y_train = self.y_train
            y_test = self.y_test
        else:
            X_train = StandardScaler().fit(self.x_train).transform(self.x_train)
            X_test = StandardScaler().fit(self.x_test).transform(self.x_test)
            y_train = self.y_train
            y_test = self.y_test

        # 训练初始模型
        model = LinearRegression()
        model.fit(X_train, y_train)
        theta = np.array([model.intercept_, model.coef_[0]])  # [beta_0, beta_1]

        # 计算Hessian矩阵
        H = self._compute_hessian(X_train)
        H_inv = np.linalg.inv(H)

        self.influences = []
        for i in range(len(X_test)):
            x_i, y_i = X_test[i], y_test[i]
            grad_i = self._compute_gradient(x_i, y_i, theta)

            # 计算对所有测试点的影响总和
            total_influence = 0
            for j in range(len(X_train)):
                grad_train = self._compute_gradient(X_train[j], y_train[j], theta)
                influence_ij = - grad_train @ H_inv @ grad_i
                total_influence += np.abs(influence_ij)

            self.influences.append(total_influence)

    def _compute_gradient(self, x, y, theta):
        """计算单个样本的梯度"""
        grad = np.zeros(2)
        grad[0] = -2 * (y - theta[0] - theta[1] * x)
        grad[1] = -2 * x * (y - theta[0] - theta[1] * x)
        return grad

    def _compute_hessian(self, x):
        """计算Hessian矩阵"""
        n = len(x)
        sum_x = np.sum(x)
        sum_x2 = np.sum(x ** 2)
        H = 2 * np.array([[n, sum_x], [sum_x, sum_x2]])
        return H

    def select_samples(self, top_k=150):
        selected_indices = np.argsort(self.influences)[::-1][-top_k:]
        return selected_indices