from .constrained_model import ConstrainedModel
import cvxpy as cp
import numpy as np
import pandas as pd

class ConstrainedLasso(ConstrainedModel):
    def __init__(self, alpha: float = 1.0, zero_threshold: float = 1e-5, **kwargs):
        """
        Constrained Lasso regression model with optional group constraints and nonnegativity.

        Args:
            alpha: L1 regularization strength.
            **kwargs: Additional arguments passed to ConstrainedModel, such as:
                - constraint_groups: dict of group labels to row indices
                - group_total: float (default=1.0)
                - nonneg: bool
                - sum_to_one: bool
                - verbose: bool
        """
        super().__init__(**kwargs)
        self.alpha = alpha
        self.zero_threshold = zero_threshold

    def fit(self, A: pd.DataFrame, y: np.ndarray):
        X = self.safe_df_to_numpy(A)
        n, p = X.shape
        beta = cp.Variable(p)

        loss = cp.sum_squares(X @ beta - y) / (2 * n)
        penalty_l1 = self.alpha * cp.norm1(beta)

        constraints, penalty_soft = self._build_constraints(X, A, beta)

        objective = cp.Minimize(loss + penalty_l1 + penalty_soft)
        problem = cp.Problem(objective, constraints)
        problem.solve(verbose=self.verbose)

        self.beta_raw_ = beta.value  # original unmodified coefficients
        self.beta_ = beta.value.copy()
        self.beta_[np.abs(self.beta_) < self.zero_threshold] = 0.0
        self.status_ = problem.status
        self.objective_value_ = problem.value
        self.feature_names_ = A.columns.tolist()

        self.X_ = X  # safe_df_to_numpy(A)
        self.A_ = A  # original df with index info
        self.y_ = y
        self.n_obs_, self.n_features_ = X.shape

        self.log_fit_summary()
        return self
    
    def count_nonzero_weights(self) -> int:
        return int(np.sum(self.beta_ != 0))

    def get_support(self) -> np.ndarray:
        return self.beta_ != 0