from .base import Model
from typing import Optional
import pandas as pd
import cvxpy as cp
import numpy as np

from .base import Model
from typing import Optional
import pandas as pd
import cvxpy as cp
import numpy as np

class ConstrainedModel(Model):
    def __init__(
        self,
        constraint_groups: Optional[dict] = None,
        group_total: float = 1.0,
        nonneg: bool = False,
        sum_to_one: bool = False,
        verbose: bool = False,
        soft_group_sum: bool = True,
        group_penalty_strength: float = 100.0,  # Large enough to approximately enforce
    ):
        super().__init__(verbose=verbose)
        self.constraint_groups = constraint_groups
        self.group_total = group_total
        self.nonneg = nonneg
        self.sum_to_one = sum_to_one
        self.soft_group_sum = soft_group_sum
        self.group_penalty_strength = group_penalty_strength

    def _convert_index_groups_to_numeric(self, A: pd.DataFrame, drop_empty=True) -> dict[str, list[int]]:
        # Constraint_groups might contain information for the whole survey (but we need only information on the relevant questions)
        row_index_map = {label: i for i, label in enumerate(A.index)}
        result = {}
        for group_name, members in self.constraint_groups.items():
            idx_list = [row_index_map[label] for label in members if label in row_index_map]
            if not idx_list and not drop_empty:
                raise ValueError(f"Group '{group_name}' has no members in A.index")
            if idx_list:
                result[group_name] = idx_list
        return result

    def _build_constraints(self, X: np.ndarray, A: pd.DataFrame, beta: cp.Variable) -> tuple[list[cp.Constraint], cp.Expression]:
        """
        Returns:
            constraints (list): Hard constraints (e.g., nonnegativity, sum-to-one).
            penalty_term (cp.Expression): Group soft penalty term to be added to loss.
        """
        constraints = []
        penalty_term = 0

        if self.constraint_groups:
            G_idx = self._convert_index_groups_to_numeric(A)

            for group_name, g in G_idx.items():
                group_sum_expr = cp.sum(X[g] @ beta)
                if self.soft_group_sum:
                    penalty_term += cp.square(group_sum_expr - self.group_total)
                else:
                    constraints.append(group_sum_expr == self.group_total)

        if self.nonneg:
            constraints.append(beta >= 0)
        if self.sum_to_one:
            constraints.append(cp.sum(beta) == 1)

        return constraints, penalty_term * self.group_penalty_strength
    
    def compute_group_constraint_violations(self, A: pd.DataFrame = None) -> dict[str, float]:
        """
        Computes group constraint violations: |sum(A[g] @ beta) - group_total| per group.

        Returns:
            Dict of group name -> absolute constraint violation.
        """
        if A is None:
            A = self.A_
        if self.constraint_groups is None:
            raise ValueError("constraint_groups must be defined.")
        if not hasattr(self, "beta_"):
            raise ValueError("Model must be fit before computing constraint violations.")

        X = self.safe_df_to_numpy(A)
        beta = self.beta_

        violations = {}
        G_idx = self._convert_index_groups_to_numeric(A)  # drop_empty=True by default

        for group_name, row_indices in G_idx.items():
            group_sum = np.sum(X[row_indices] @ beta)
            violations[group_name] = abs(group_sum - self.group_total)

        return violations
    
    @property
    def mean_group_constraint_violation(self):
        v = self.compute_group_constraint_violations()
        return np.mean(list(v.values()))

class ConstrainedModelLegacy(Model):
    def __init__(self,
                 constraint_groups: Optional[dict] = None,
                 group_total: float = 1.0,
                 nonneg: bool = False,
                 sum_to_one: bool = False,
                 verbose: bool = False):
        super().__init__(verbose=verbose)
        self.constraint_groups = constraint_groups
        self.group_total = group_total
        self.nonneg = nonneg
        self.sum_to_one = sum_to_one

    def _convert_index_groups_to_numeric(self, A: pd.DataFrame, drop_empty=True) -> dict[str, list[int]]:
        # Constraint_groups might contain information for the whole survey (but we need only information on the relevant questions)
        row_index_map = {label: i for i, label in enumerate(A.index)}
        result = {}

        for group_name, members in self.constraint_groups.items():
            idx_list = [row_index_map[label] for label in members if label in row_index_map]
            if not idx_list and not drop_empty:
                raise ValueError(f"Group '{group_name}' has no members in A.index")
            if idx_list:
                result[group_name] = idx_list

        return result

    def _build_constraints(self, X: np.ndarray, A: pd.DataFrame, beta: cp.Variable) -> list[cp.Constraint]:
        constraints = []
        if self.constraint_groups:
            G_idx = self._convert_index_groups_to_numeric(A)
            for g in G_idx.values():
                constraints.append(cp.sum(X[g] @ beta) == self.group_total)
        if self.nonneg:
            constraints.append(beta >= 0)
        if self.sum_to_one:
            constraints.append(cp.sum(beta) == 1)
        return constraints