import pandas as pd
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn.model_selection import KFold
from sklearn.linear_model import LinearRegression, Lasso
from tqdm import tqdm
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from sklearn.covariance import LedoitWolf

class SelectiveBorrowingLasso:
    def __init__(self, lambda_N_candidates=np.logspace(-3, 1, 20),
                 nu_candidates=[1,2], n_folds=5, threshold=0.1):
        """
        Complete selective borrowing framework with CV tuning

        Parameters:
        - lambda_N_candidates: array of lambda_N values to try
        - nu_candidates: list of nu values [1, 2]
        - n_folds: number of cross-validation folds
        - threshold: cutoff for selecting comparable subjects
        """
        self.lambda_N_candidates = lambda_N_candidates
        self.nu_candidates = nu_candidates
        self.n_folds = n_folds
        self.n_jobs = -1
        self.threshold = threshold
        self.best_params_ = None
        self.cv_results_ = None
        self.b_hat = None
        self.selected_indices_ = None

    def fit_sigma_b(self, X_internal, Y_internal, X_external, Y_external):
        """Fit μ₀ and μ₀,ε models"""
        self.mu0_model = LinearRegression().fit(X_internal, Y_internal)
        self.mu0e_model = LinearRegression().fit(X_external, Y_external)

        self.b_hat = self.mu0e_model.predict(X_external) - self.mu0_model.predict(X_external)

        cov_estimator = LedoitWolf().fit(self.b_hat.reshape(-1, 1))
        self.Sigma_inv = np.linalg.inv(cov_estimator.covariance_)

        return self


    def parallel_cv_wrapper(self, X, b_hat, Sigma_inv, alpha, nu, train_idx, val_idx):
        """单次交叉验证的计算单元"""
        X_train, X_val = X[train_idx], X[val_idx]
        b_train, b_val = b_hat[train_idx], b_hat[val_idx]
        weights = 1 / (np.abs(b_train) ** nu)
        weights = np.clip(weights, 1e-6, 1e6)

        # 带权重的Lasso求解
        lasso = Lasso(alpha=alpha, max_iter=10000)
        lasso.fit(X_train, b_train, sample_weight=weights)
        pred = lasso.predict(X_val)

        # 计算加权MSE
        resid = (b_val - pred).reshape(-1, 1)
        return Sigma_inv * resid.T @ resid / len(b_val)

    def adaptive_lasso_parallel_cv(self, X, b_hat, Sigma_inv, nu_values=[1, 1.5, 2],
                                   alpha_range=np.logspace(-4, 2, 20), n_folds=5, n_jobs=-1):
        """并行化交叉验证主函数"""
        kf = KFold(n_splits=n_folds)
        results = []

        # 生成所有参数组合
        param_grid = [(nu, alpha) for nu in nu_values for alpha in alpha_range]

        def evaluate_params(nu, alpha):
            fold_mses = []
            for train_idx, val_idx in kf.split(X):
                mse = self.parallel_cv_wrapper(X, b_hat, Sigma_inv, alpha, nu, train_idx, val_idx)
                fold_mses.append(mse.item())
            return {'nu': nu, 'alpha': alpha, 'mse': np.mean(fold_mses)}

        # 并行计算
        results = Parallel(n_jobs=n_jobs)(
            delayed(evaluate_params)(nu, alpha) for nu, alpha in tqdm(param_grid)
        )

        return pd.DataFrame(results)

    def fit(self, X_internal, Y_internal, X_external, Y_external):
        """Complete fitting procedure with optimal parameters"""
        # 数据标准化
        if X_internal.ndim == 1:
            X_internal = X_internal.reshape(-1, 1)
            X_external = X_external.reshape(-1, 1)
        else:
            X_internal = StandardScaler().fit(X_internal).transform(X_internal)
            X_external = StandardScaler().fit(X_external).transform(X_external)

        # Final fit with all data
        self.fit_sigma_b(X_internal, Y_internal, X_external, Y_external)

        cv_results = self.adaptive_lasso_parallel_cv(
            X_external, self.b_hat, self.Sigma_inv,
            nu_values=[1, 1.5, 2],
            alpha_range=np.logspace(-4, 2, 50),
            n_folds=5,
            n_jobs=-1  # 使用所有CPU核心
        )
        # 选择最优参数
        best_params = cv_results.loc[cv_results['mse'].idxmin()]
        self.best_params_ = {'nu': best_params['nu'], 'alpha': best_params['alpha']}

        # 4. 用最优参数训练最终模型
        # ======================
        final_weights = 1 / (np.abs(self.b_hat) ** self.best_params_["nu"])
        final_lasso = Lasso(alpha=self.best_params_["alpha"], max_iter=10000)
        final_lasso.fit(X_external, self.b_hat, sample_weight=final_weights)

        # 5. 计算最终的b_hat
        # ======================
        self.b_tilde = final_lasso.predict(X_external)
        # Select comparable subjects
        self.selected_indices_ = np.where(np.abs(self.b_tilde) < self.threshold)[0]
        return self

    def return_coef(self):
        theta_0 = np.array([self.mu0_model.intercept_, self.mu0_model.coef_[0]])
        theta_1 = np.array([self.mu0e_model.intercept_, self.mu0e_model.coef_[0]])
        return theta_0, theta_1

    def plot_selection_results(self, true_biases=None):
        """Visualize the selected subjects"""
        plt.figure(figsize=(8, 6))
        sorted_idx = np.argsort(self.b_tilde)
        plt.scatter(range(len(self.b_tilde)), self.b_tilde[sorted_idx],
                    c=np.isin(sorted_idx, self.selected_indices_),
                    cmap='coolwarm', alpha=0.7)
        plt.axhline(self.threshold, color='k', linestyle='--')
        plt.axhline(-self.threshold, color='k', linestyle='--')
        plt.title("Estimated Biases with Selection")
        plt.xlabel("Subject Index (sorted)")
        plt.ylabel("Estimated Bias")

        plt.tight_layout()
        plt.savefig("bias.png")


    def get_selected_subjects(self, X_external, Y_external=None):
        """Return selected comparable subjects"""
        if Y_external is not None:
            return X_external[self.selected_indices_], Y_external[self.selected_indices_]
        return X_external[self.selected_indices_]

    def select_samples(self, X_external, Y_external, top_k=100):
        selected_indices = np.argsort(np.abs(self.b_tilde))[::-1][-top_k:]
        return X_external[selected_indices], Y_external[selected_indices]
