import numpy as np
from sklearn.linear_model import LinearRegression
from torch import autograd

from utils import MLP1, MLP2, Exponential_regression
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class InfluenceSelector:
    def __init__(self, dataset, threshold, x_test, y_test, x_train, y_train):
        self.x_test = x_test
        self.y_test = y_test
        self.x_train = x_train
        self.y_train = y_train
        self.threshold = threshold
        self.dataset = dataset

    def fit(self):

        if self.x_train.ndim == 1:
            X_train = self.x_train.reshape(-1, 1)
            X_test = self.x_test.reshape(-1, 1)
            y_train = self.y_train
            y_test = self.y_test
        else:
            X_train = self.x_train
            X_test = self.x_test
            y_train = self.y_train
            y_test = self.y_test

        # 训练初始模型
        model = LinearRegression(fit_intercept=True).fit(X_train, y_train)
        theta = np.concatenate([[model.intercept_], model.coef_])

        # 计算Hessian矩阵
        H = self._compute_hessian(X_train)
        H_inv = np.linalg.inv(H)

        self.influences = []
        for i in range(X_test.shape[0]):
            x_i, y_i = X_test[i], y_test[i]
            grad_i = self._compute_gradient(x_i, y_i, theta)

            total_influence = 0
            for j in range(len(X_train)):
                grad_train = self._compute_gradient(X_train[j], y_train[j], theta)
                influence_ij = - grad_train @ H_inv @ grad_i  # 影响函数值  -grad_test @
                total_influence += np.abs(influence_ij)
            self.influences.append(total_influence)

        self.influences = np.array(self.influences)
        self.selected_indices_ = np.where(np.abs(self.influences) < self.threshold)[0]


    def _plot_influence_scores(self, true_biases=None):
        """Visualize the selected subjects"""
        plt.figure(figsize=(8, 6))

        sorted_idx = np.argsort(self.influences)
        plt.scatter(range(len(self.influences)), self.influences[sorted_idx],
                    c=np.isin(sorted_idx, self.selected_indices_),
                    cmap='coolwarm', alpha=0.7)
        plt.axhline(self.threshold, color='k', linestyle='--')
        plt.axhline(-self.threshold, color='k', linestyle='--')
        plt.title("Estimated influence with Selection")
        plt.xlabel("Subject Index (sorted)")
        plt.ylabel("Estimated influence")

        plt.tight_layout()

        plt.savefig("results/figure/" + self.dataset + "/" + "influence_" + self.model_fit + ".png")

    def _compute_gradient(self, x, y, theta):

        x = np.array(x)
        X_with_intercept = np.insert(x,0,1)
        y = np.array(y).reshape(-1, 1)
        theta = np.array(theta).reshape(-1, 1)
        predictions = X_with_intercept @ theta
        errors = predictions - y
        gradients = errors * X_with_intercept
        return gradients.squeeze()

    def _compute_hessian(self, x):
        m = x.shape[0]
        X_with_intercept = np.column_stack([np.ones(m), x])
        hessian = (1/m) * X_with_intercept.T @ X_with_intercept
        return hessian

    def select_samples(self, top_k=150):
        selected_indices = np.argsort(self.influences)[::-1][-top_k:] #[::-1]
        return selected_indices

    def return_influence(self):
        sorted_idx = np.argsort(self.influences)
        return self.influences[sorted_idx]


class InfluenceSelector_exp:
    def __init__(self, dataset, threshold, x_test, y_test, x_train, y_train):
        self.x_test = x_test
        self.y_test = y_test
        self.x_train = x_train
        self.y_train = y_train
        self.threshold = threshold
        self.dataset = dataset

    def fit(self):

        if self.x_train.ndim == 1:
            X_train = self.x_train.reshape(-1, 1)
            X_test = self.x_test.reshape(-1, 1)
            y_train = self.y_train
            y_test = self.y_test
        else:
            X_train = self.x_train
            X_test = self.x_test
            y_train = self.y_train
            y_test = self.y_test

        model = Exponential_regression().fit(X_train, y_train)

        theta = model.params

        H = self._compute_hessian(X_train, y_train, theta)
        # H_inv = pinv(H)
        H_inv = np.linalg.inv(H)

        self.influences = []
        for i in range(X_test.shape[0]):
            x_i, y_i = X_test[i], y_test[i]
            grad_i = self._compute_gradient(x_i, y_i, theta)

            total_influence = 0
            for j in range(len(X_train)):
                grad_train = self._compute_gradient(X_train[j], y_train[j], theta)
                influence_ij = - grad_train @ H_inv @ grad_i  # 影响函数值  -grad_test @
                total_influence += np.abs(influence_ij)
            self.influences.append(total_influence)

        self.influences = np.array(self.influences)
        self.selected_indices_ = np.where(np.abs(self.influences) < self.threshold)[0]


    def _plot_influence_scores(self, true_biases=None):
        """Visualize the selected subjects"""
        plt.figure(figsize=(8, 6))

        sorted_idx = np.argsort(self.influences)
        plt.scatter(range(len(self.influences)), self.influences[sorted_idx],
                    c=np.isin(sorted_idx, self.selected_indices_),
                    cmap='coolwarm', alpha=0.7)
        plt.axhline(self.threshold, color='k', linestyle='--')
        plt.axhline(-self.threshold, color='k', linestyle='--')
        plt.title("Estimated influence with Selection")
        plt.xlabel("Subject Index (sorted)")
        plt.ylabel("Estimated influence")

        plt.tight_layout()

        plt.savefig("results/figure/" + self.dataset + "/" + "influence_" + self.model_fit + ".png")

    def _compute_gradient(self, X, y, theta):

        """计算梯度"""
        a = theta[0]
        b = theta[1:]
        exp_term = np.exp(np.dot(X, b))
        error = y - a * exp_term

        grad_a = -np.sum(error * exp_term)
        grad_b = -np.dot(X.T, error * a * exp_term)

        return np.concatenate(([grad_a], grad_b))

    def _compute_hessian(self, X, y, theta):
        """计算Hessian矩阵"""
        a = theta[0]
        b = theta[1:]
        n_features = X.shape[1]
        n_params = n_features + 1
        hessian = np.zeros((n_params, n_params))

        exp_term = np.exp(np.dot(X, b))
        error = y - a * exp_term

        hessian[0, 0] = np.sum(exp_term**2)

        for k in range(n_features):
            hessian[0, k+1] = np.sum(exp_term**2 * X[:, k]) - np.sum(error * exp_term * X[:, k])
            hessian[k+1, 0] = hessian[0, k+1]

        for k in range(n_features):
            for l in range(n_features):
                term1 = np.sum(a**2 * exp_term**2 * X[:, k] * X[:, l])
                term2 = np.sum(error * a * exp_term * X[:, k] * X[:, l])
                hessian[k+1, l+1] = term1 - term2

        return hessian

    def select_samples(self, top_k=150):
        selected_indices = np.argsort(self.influences)[::-1][-top_k:] #[::-1]
        return selected_indices

    def return_influence(self):
        sorted_idx = np.argsort(self.influences)
        return self.influences[sorted_idx]



class InfluenceSelector_mlp:
    def __init__(self, dataset, threshold, x_test, y_test, x_train, y_train, batch_size, lr, num_epoch, top_k=150):
        self.x_test = x_test
        self.y_test = y_test
        self.x_train = x_train
        self.y_train = y_train
        self.top_k = top_k
        self.model = MLP1(self.x_train.shape[1], 16, 1).to(device)
        self.criterion = nn.MSELoss()
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.lr = lr
        self.threshold = threshold
        self.dataset = dataset


    def fit(self):

        if self.x_train.ndim == 1:
            X_train = torch.Tensor(self.x_train.reshape(-1, 1)).to(device)
            X_test = torch.Tensor(self.x_train.reshape(-1, 1)).to(device)
            y_train = torch.Tensor(self.y_train).to(device)
            y_test = torch.Tensor(self.y_test).to(device)
        else:
            X_train = torch.Tensor(self.x_train).to(device)
            X_test = torch.Tensor(self.x_test).to(device)
            y_train = torch.Tensor(self.y_train).to(device)
            y_test = torch.Tensor(self.y_test).to(device)

        self.model.fit(X_train, y_train, batch_size= 32, num_epoch=self.num_epoch, lr=self.lr)
        self.model.eval()
        gradients = autograd.grad(self.criterion(self.model(X_train), y_train), self.model.parameters(), create_graph=True)
        hessian_matrix = torch.zeros((sum(p.numel() for p in self.model.parameters()), sum(p.numel() for p in self.model.parameters()))).to(device)

        for i, gradient in enumerate(gradients):
            gradient_vector = gradient.contiguous().view(-1)
            grad_outputs = torch.ones_like(gradient_vector)
            hessian_vector = torch.autograd.grad(gradient_vector, self.model.parameters(), grad_outputs=grad_outputs, retain_graph=True)
            hessian_matrix[i] = torch.cat([hv.contiguous().view(-1) for hv in hessian_vector])
        hessian_inv = torch.linalg.pinv(hessian_matrix + 1e-6)

        self.influences = []
        for i in range(len(X_test)):
            x_i, y_i = X_test[i], y_test[i]
            grad_i = self._compute_gradient(x_i, y_i)

            total_influence = 0
            for j in range(len(X_train)):
                grad_train = self._compute_gradient(X_train[j], y_train[j])
                influence_ij = - grad_train @ hessian_inv @ grad_i
                total_influence += np.abs(influence_ij.detach().cpu().numpy())

            self.influences.append(total_influence)

        self.influences = np.array(self.influences)
        self.selected_indices_ = np.where(np.abs(self.influences) < self.threshold)[0]

    def _compute_diag_hessian(self, X_train, y_train):
        train_data = TensorDataset(X_train, y_train)
        dataloader = DataLoader(train_data, batch_size=self.batch_size)
        self.model.eval()
        params = list(self.model.parameters())
        hessian_diag = [torch.zeros_like(p) for p in params]

        for inputs, targets in dataloader:
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            grads = torch.autograd.grad(loss, params, create_graph=True)
            for i, grad in enumerate(grads):
                grad2 = torch.autograd.grad(grad.sum(), params[i], retain_graph=True)
                hessian_diag[i] += grad2[0].data

        hessian_inv_diag = [1 / (h/len(dataloader) + 1e-6) for h in hessian_diag]
        diag_elements = torch.cat([h.flatten() for h in hessian_inv_diag])
        diag_matrix = torch.diag(diag_elements)
        return diag_matrix

    def _plot_influence_scores(self, true_biases=None):
        """Visualize the selected subjects"""
        plt.figure(figsize=(8, 6))
        sorted_idx = np.argsort(self.influences)
        plt.scatter(range(len(self.influences)), self.influences[sorted_idx],
                    c=np.isin(sorted_idx, self.selected_indices_),
                    cmap='coolwarm', alpha=0.7)
        plt.axhline(self.threshold, color='k', linestyle='--')
        plt.axhline(-self.threshold, color='k', linestyle='--')
        plt.title("Estimated influence with Selection")
        plt.xlabel("Subject Index (sorted)")
        plt.ylabel("Estimated influence")

        plt.tight_layout()

        plt.savefig("results/figure/" + self.dataset + "/" + "influence_" + self.model_fit + ".png")

    def _compute_gradient(self, x, y):
        """返回损失的梯度（展平后的向量）"""
        self.model.zero_grad()
        loss = self.criterion(self.model(x), y)  # 单样本
        grads = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
        flatten_gards = torch.cat([g.view(-1) for g in grads])
        return flatten_gards

    def select_samples(self, top_k=150):
        selected_indices = np.argsort(self.influences)[::-1][-top_k:]
        return selected_indices

    def return_influence(self):
        sorted_idx = np.argsort(self.influences)
        return self.influences[sorted_idx]
