from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.preprocessing import StandardScaler
from skopt import BayesSearchCV
from skopt.space import Real
from sklearn.base import BaseEstimator
from sklearn.neural_network import MLPRegressor
import numpy as np
from sklearn.model_selection import KFold
from sklearn.linear_model import LinearRegression, Lasso
from scipy.optimize import minimize
from tqdm import tqdm
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from sklearn.covariance import LedoitWolf
import pandas as pd
import torch
from utils import MLP1, MLP2, Exponential_regression

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SelectiveBorrowingLasso:
    def __init__(self, dataset, batch_size=64, lr=0.005, num_epoch=10000, threshold=0.1, lambda_N_candidates=np.logspace(-3, 1, 20),
                 nu_candidates=[1,2], n_folds=5):
        """
        Complete selective borrowing framework with CV tuning

        Parameters:
        - lambda_N_candidates: array of lambda_N values to try
        - nu_candidates: list of nu values [1, 2]
        - n_folds: number of cross-validation folds
        - threshold: cutoff for selecting comparable subjects
        """
        self.lambda_N_candidates = lambda_N_candidates
        self.nu_candidates = nu_candidates
        self.n_folds = n_folds
        self.n_jobs = -1
        self.threshold = threshold
        self.best_params_ = None
        self.cv_results_ = None
        self.b_hat = None
        self.selected_indices_ = None
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.lr = lr

    def fit_sigma_b(self, X_rct, y_rct, X_ext, y_ext):
        """Fit μ₀ and μ₀,ε models"""
        if X_rct.ndim == 1:
            X_rct_scaled = X_rct.reshape(-1,1)
            X_ext_scaled = X_ext.reshape(-1,1)
        else:
            X_rct_scaled = X_rct
            X_ext_scaled = X_ext


        # 训练预测模型
        if self.dataset == "linear":
            self.model_rct_ = LinearRegression().fit(X_rct_scaled, y_rct)
            self.model_ext_ = LinearRegression().fit(X_ext_scaled, y_ext)

        elif self.dataset == "NSW":
            self.model_rct_ = MLP1(X_rct_scaled.shape[1], 16, 1).to(device)
            self.model_rct_.fit(X_rct_scaled, y_rct, batch_size=32, num_epoch=self.num_epoch, lr=self.lr)
            self.model_ext_ = MLP2(X_ext_scaled.shape[1], 16, 1).to(device)
            self.model_ext_.fit(X_ext_scaled, y_ext, batch_size=self.batch_size, num_epoch=self.num_epoch,lr=self.lr)

        elif self.dataset == "exp":
            self.model_rct_ = Exponential_regression().fit(X_rct_scaled, y_rct)
            self.model_ext_ = Exponential_regression().fit(X_ext_scaled, y_ext)

        self.b_hat = np.abs(self.model_ext_.predict(X_ext_scaled) - self.model_rct_.predict(X_ext_scaled))
        cov_estimator = LedoitWolf().fit(self.b_hat.reshape(-1, 1))
        self.Sigma_inv = np.linalg.inv(cov_estimator.covariance_)

        return self

    def parallel_cv_wrapper(self, X, b_hat, Sigma_inv, alpha, nu, train_idx, val_idx):
        """单次交叉验证的计算单元"""
        X_train, X_val = X[train_idx], X[val_idx]
        b_train, b_val = b_hat[train_idx], b_hat[val_idx]
        weights = 1 / (np.abs(b_train) ** nu)
        weights = np.clip(weights, 1e-6, 1e6)

        lasso = Lasso(alpha=alpha, max_iter=10000)
        lasso.fit(X_train, b_train, sample_weight=weights)
        pred = lasso.predict(X_val)

        resid = (b_val - pred).reshape(-1, 1)
        return Sigma_inv * resid.T @ resid / len(b_val)

    def adaptive_lasso_parallel_cv(self, X, b_hat, Sigma_inv, nu_values=[1, 1.5, 2],alpha_range=np.logspace(-4, 2, 20), n_folds=5, n_jobs=-1):
        """并行化交叉验证主函数"""
        kf = KFold(n_splits=n_folds)
        results = []

        param_grid = [(nu, alpha) for nu in nu_values for alpha in alpha_range]

        def evaluate_params(nu, alpha):
            fold_mses = []
            for train_idx, val_idx in kf.split(X):
                mse = self.parallel_cv_wrapper(X, b_hat, Sigma_inv, alpha, nu, train_idx, val_idx)
                fold_mses.append(mse.item())
            return {'nu': nu, 'alpha': alpha, 'mse': np.mean(fold_mses)}

        # 并行计算
        results = Parallel(n_jobs=n_jobs)(
            delayed(evaluate_params)(nu, alpha) for nu, alpha in tqdm(param_grid)
        )

        return pd.DataFrame(results)

    def fit(self, X_internal, Y_internal, X_external, Y_external):
        """Complete fitting procedure with optimal parameters"""
        # 数据标准化
        if X_internal.ndim == 1:
            X_internal = X_internal.reshape(-1, 1)
            X_external = X_external.reshape(-1, 1)
        else:
            X_internal = X_internal
            X_external = X_external

        self.fit_sigma_b(X_internal, Y_internal, X_external, Y_external)

        cv_results = self.adaptive_lasso_parallel_cv(
            X_external, self.b_hat, self.Sigma_inv,
            nu_values=[1, 1.5, 2],
            alpha_range=np.logspace(-4, 2, 50),
            n_folds=5,
            n_jobs=-1
        )

        best_params = cv_results.loc[cv_results['mse'].idxmin()]
        self.best_params_ = {'nu': best_params['nu'], 'alpha': best_params['alpha']}

        final_weights = 1 / (np.abs(self.b_hat) ** self.best_params_["nu"])
        final_lasso = Lasso(alpha=self.best_params_["alpha"], max_iter=10000)
        final_lasso.fit(X_external, self.b_hat, sample_weight=final_weights)

        self.b_tilde = final_lasso.predict(X_external)
        # Select comparable subjects
        self.selected_indices_ = np.where(np.abs(self.b_tilde) < self.threshold)[0]
        return self

    def return_coef(self):
        theta_0 = np.array([self.mu0_model.intercept_, self.mu0_model.coef_[0]])
        theta_1 = np.array([self.mu0e_model.intercept_, self.mu0e_model.coef_[0]])
        return theta_0, theta_1

    def plot_selection_results(self, generata_pattern, true_biases=None):
        """Visualize the selected subjects"""
        plt.figure(figsize=(8, 6))
        sorted_idx = np.argsort(self.b_tilde)
        plt.scatter(range(len(self.b_tilde)), self.b_tilde[sorted_idx],
                    c=np.isin(sorted_idx, self.selected_indices_),
                    cmap='coolwarm', alpha=0.7)
        plt.axhline(self.threshold, color='k', linestyle='--')
        plt.axhline(-self.threshold, color='k', linestyle='--')
        plt.title("Estimated Biases with Selection")
        plt.xlabel("Subject Index (sorted)")
        plt.ylabel("Estimated Bias")

        plt.tight_layout()
        plt.savefig("results/figure/" + self.dataset + "_" + generata_pattern + "/" + "bias_" + self.model_fit +".png")

    def get_selected_subjects(self, X_external, Y_external=None):
        """Return selected comparable subjects"""
        if Y_external is not None:
            return X_external[self.selected_indices_], Y_external[self.selected_indices_]
        return X_external[self.selected_indices_]

    def select_samples(self, top_k=100):
        selected_indices = np.argsort(np.abs(self.b_tilde))[::-1][-top_k:]
        return selected_indices

    def return_b_tilde(self):
        sorted_idx = np.argsort(self.b_tilde)
        return self.b_tilde[sorted_idx]
