import numpy as np
import pandas as pd
from sklearn.preprocessing import SplineTransformer, StandardScaler, PolynomialFeatures
from sklearn.model_selection import KFold
from sklearn.linear_model import LassoCV, Lasso
from group_lasso import GroupLasso
import matplotlib.pyplot as plt
import xgboost as xgb
import shap
import warnings
import torch
from itertools import combinations
warnings.filterwarnings("ignore")

class AdditiveInteractionSelector:
    """
    Fit additive models with candidate main effects and interactions
    using spline basis expansions + group sparsity (group lasso).
    """

    def __init__(self, n_splines=10, spline_degree=3, include_intercept=False,
                 interaction_splines=10, random_state=0):
        self.n_splines = n_splines
        self.spline_degree = spline_degree
        self.include_intercept = include_intercept
        self.interaction_splines = interaction_splines
        self.random_state = random_state

        # Internal storage
        self.groups = []
        self.group_names = []
        self.scaler = None
        self.model = None
        self.group_norms_ = None
        self.design_matrix_ = None

    # -----------------------------
    # Basis construction utilities
    # -----------------------------
    def _build_univariate_basis(self, x):
        """Build spline basis for one variable."""
        x = np.asarray(x).reshape(-1, 1)
        sp = SplineTransformer(
            degree=self.spline_degree,
            n_knots=self.n_splines,
            include_bias=self.include_intercept
        )
        return sp.fit_transform(x)

    def _build_bivariate_basis(self, x1, x2):
        """Build tensor product spline basis for interaction."""
        B1 = self._build_univariate_basis(x1)
        B2 = self._build_univariate_basis(x2)
        # Tensor product
        return np.einsum("ij,ik->ijk", B1, B2).reshape(len(x1), -1)

    def _build_design(self, X_df, interactions=None):
        """Construct design matrix with groups for univariates and interactions."""
        blocks, self.groups, self.group_names = [], [], []
        col_idx = 0

        # Main effects
        for col in X_df.columns:
            B = self._build_univariate_basis(X_df[col].values)
            blocks.append(B)
            self.groups.append(list(range(col_idx, col_idx + B.shape[1])))
            self.group_names.append((col,))
            col_idx += B.shape[1]

        # Interactions
        if interactions:
            for a, b in interactions:
                Bt = self._build_bivariate_basis(X_df[a].values, X_df[b].values)
                blocks.append(Bt)
                self.groups.append(list(range(col_idx, col_idx + Bt.shape[1])))
                self.group_names.append((a, b))
                col_idx += Bt.shape[1]

        self.design_matrix_ = np.hstack(blocks)
        return self.design_matrix_

    # -----------------------------
    # Fitting
    # -----------------------------
    def fit(self, X_df, y, interactions=None, cv=5, HAS_GROUP_LASSO=True):
        """
        Fit model with group lasso (preferred) or fallback to plain Lasso.
        """
        X = self._build_design(X_df, interactions)
        self.scaler = StandardScaler()
        Xs = self.scaler.fit_transform(X)

        if HAS_GROUP_LASSO:
            # Build group vector
            col_to_group = np.zeros(X.shape[1], dtype=int)
            for gid, idxs in enumerate(self.groups):
                col_to_group[idxs] = gid

            # Cross-validate group lasso penalty
            lambdas = np.logspace(-3, 1, 10)
            best_score, best_model = -np.inf, None
            kf = KFold(n_splits=cv, shuffle=True, random_state=self.random_state)

            for lam in lambdas:
                scores = []
                for tr, va in kf.split(Xs):
                    gl = GroupLasso(
                        groups=col_to_group,
                        group_reg=lam, l1_reg=0.0,
                        scale_reg="group_size",
                        supress_warning=True,
                        n_iter=2000, tol=1e-3
                    )
                    gl.fit(Xs[tr], y[tr])
                    scores.append(gl.score(Xs[va], y[va]))
                if np.mean(scores) > best_score:
                    best_score = np.mean(scores)
                    best_model = GroupLasso(
                        groups=col_to_group,
                        group_reg=lam, l1_reg=0.0,
                        scale_reg="group_size",
                        supress_warning=True,
                        n_iter=2000, tol=1e-3
                    )
                    best_model.fit(Xs, y)

            self.model = best_model
            coefs = self.model.coef_.ravel()

        else:
            # Fallback to plain Lasso
            lasso = LassoCV(cv=cv).fit(Xs, y)
            self.model = lasso
            coefs = lasso.coef_

        # Compute group norms
        self.group_norms_ = [
            np.linalg.norm(coefs[idxs], ord=2) for idxs in self.groups
        ]
        return self

    # -----------------------------
    # Reporting
    # -----------------------------
    def get_group_importance(self):
        """Return DataFrame of group names and their norms."""
        return pd.DataFrame({
            "group": self.group_names,
            "norm": self.group_norms_
        }).sort_values("norm", ascending=False).reset_index(drop=True)

    def get_important_groups(self, threshold=0.1):
        """Return groups with norms above threshold."""
        selected = []
        for tup, val in zip(self.group_names, self.group_norms_):
            if val > threshold:
                indices = [int(s[1:]) - 1 for s in tup]  # convert "x1" → 0
                selected.append(indices)
        return selected
    
    def summary(self):
        """Print ranked group importance."""
        df = self.get_group_importance()
        print("Group importance (higher = more important):")
        print(df)


def extract_active_features(X: torch.tensor, active_idx: list[int]) -> pd.DataFrame:
    """
    Extract active features from X based on active indices.
    
    Parameters
    ----------
    X : np.ndarray
        Data matrix of shape (n_samples, n_features)
    active_idx : list[int]
        Indices of active features
    
    Returns
    -------
    pd.DataFrame
        DataFrame with columns named x1, x2, ..., for active features
    """
    data = {f"x{idx+1}": X[:, idx] for idx in active_idx}
    return pd.DataFrame(data)



