
from sklearn.model_selection import train_test_split, KFold
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from lime.lime_tabular import LimeTabularExplainer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import pairwise_distances
import numpy as np
import cvxpy as cp
from sklearn.metrics import r2_score, mean_squared_error
from new_select_feature import nonnegative_garrotte_path

from new_select_feature import find_steps_with_exact_k,apply_centered_weights_train



def build_mix_matrix(X, continuous_idx, categorical_idx=None, knots=None, n0=None):

    n = X.shape[0]
    A_parts = []
    cat_var_sizes = []
    for i, idx in enumerate(continuous_idx):
        xi = X[:, idx]
        k = knots[i]
        B = np.zeros((n, n0))
        for j in range(n0):
            left, right = k[j], k[j + 1]

            B[:, j] = np.clip((xi - left) / (right - left), 0, 1)
        A_parts.append(B)
    if categorical_idx:
        for i, idx in enumerate(categorical_idx):
            X_cat = X[:, idx]
            if X_cat.ndim == 1:
                X_cat = X_cat.reshape(-1, 1)
            A_parts.append(X_cat)
            cat_var_sizes.append(X_cat.shape[1])
    A = np.hstack(A_parts)
    return A, cat_var_sizes


def fit_piecewise_local_model_li(p_g,b_global,X_train, y_train, weights_train,continuous_idx, categorical_idx=None,
                                    X_test=None, y_test=None, weights_test=None,
                                    knots=None, n0=5,
                                    lambda1=0.0, lambda_sparse=0.0, kk=None):
    n, n_feature = X_train.shape

    A,cat_var_sizes = build_mix_matrix(X_train, continuous_idx, categorical_idx=categorical_idx, knots=knots, n0=n0)

    n, _ = X_train.shape
    n_cont = len(continuous_idx) if continuous_idx is not None else 0
    n_cat = len(categorical_idx) if categorical_idx is not None else 0

    D = np.zeros((n_cont * n0, n_cont * n0))
    for i in range(n_cont):
        D[i * n0:(i + 1) * n0, i * n0:(i + 1) * n0] = np.tril(np.ones((n0, n0)))
    Df = np.linalg.inv(D)

    C = np.zeros((n_cont * n0 - 1, n_cont * n0))
    for i in range(n_cont * n0 - 1):
        C[i, i] = -1
        C[i, i + 1] = 1
    rows_to_delete = np.arange(n0 - 1, n_cont * n0 - 1, n0)
    C1 = np.delete(C, rows_to_delete, axis=0)

    A_wmean = np.average(A, axis=0, weights=weights_train)  # sum(w*x)/sum(w)
    y_wmean = np.average(y_train, weights=weights_train)

    A = A - A_wmean
    y_train_c = y_train - y_wmean

    n_basis = A.shape[1]
    w = cp.Variable(n_basis)
    d_u = w[:n_cont * n0]
    u = D @ d_u

    if C1.shape[0] == 0 or C1.shape[1] == 0:
        regularization_term1 = 0
    else:
        regularization_term1 = (lambda1 /  n_cont) * cp.sum_squares(C1 @ u)

    objective1 = cp.Minimize(
        cp.sum(cp.multiply(weights_train, cp.square(A @ w -y_train_c))) / n+
        regularization_term1
    )
    cp.Problem(objective1).solve(solver=cp.SCS)

    d_u_value = w.value[:n_cont * n0]

    M_train = np.zeros((n, n_cont + n_cat))
    beta_norm_cont = []
    beta_norm_cat = []

    beta_max = []
    for j in range(n_cont):
        beta_j = d_u_value[j * n0:(j + 1) * n0]
        #norm_1=np.linalg.norm(d_u_value[j * n0:(j + 1) * n0])
        norm_1 = np.linalg.norm(d_u_value[j * n0:(j + 1) * n0],ord=1)
        M_train[:, j] = (A[:, j * n0:(j + 1) * n0] @ beta_j)
        beta_norm_cont.append(norm_1)
        beta_max.append(np.max(abs(M_train[:, j])))

    cat_start = n_cont * n0
    cur = cat_start
    col_ptr = n_cont

    for group in categorical_idx:

        if np.isscalar(group):
            width = 1
            #norm_2 = np.linalg.norm(w.value[cur])
            norm_2 = np.linalg.norm([w.value[cur]],ord=1)
            M_train[:, col_ptr] = w.value[cur]*A[:, cur]
            cur += 1
            col_ptr += 1
            beta_norm_cat.append(norm_2)

        else:
            # group 是多列，比如 [5,6,7]
            width = len(group)
            M_train[:, col_ptr:col_ptr + width] = w.value[cur]*A[:, cur:cur + width]
            norm_2 = np.linalg.norm(w.value[cur:cur + width],ord=1)
            cur += width
            col_ptr += width
            beta_norm_cat.append(norm_2)

    beta_norm = beta_norm_cont + beta_norm_cat
    r2_train1 = r2_score(y_train_c, A @ w.value, sample_weight=weights_train)

    col_mean = np.mean(M_train, axis=0)
    col_stds = np.std(M_train, axis=0)
    col_stds[col_stds == 0] = 1e-8


    ll = sum((np.square(A @ w.value - y_train_c))) / (n - n_basis - 1)

    #beta_init = beta_norm * p_g
    beta_init = np.ones(len(beta_norm))
    d_path, r_path, actives, beta_hat_path, beta_init,remove = nonnegative_garrotte_path(
        p_g,M_train, y=y_train_c, beta_init=beta_init, sample_weight=weights_train, max_iter=5000, tol=1e-12, verbose=False,
    )
    results = {
        "u_value": D @ d_u_value,
        "knots": knots,
        "piecewise_parameters": d_u_value,
        "col_stds": col_stds,
        "w":[w.value,A_wmean,y_wmean],
        "beta_norm": beta_norm,
        "actives": actives,
        "beta_hat_path": beta_hat_path,
        "r2_train1":r2_train1
    }
    return results


