from sklearn.model_selection import GroupKFold
from .cross_validation import ModelCV
import numpy as np
import pandas as pd
from typing import Union
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

if not logger.hasHandlers():
    handler = logging.StreamHandler()
    formatter = logging.Formatter('[%(levelname)s] %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)


class ConstrainedModelCV(ModelCV):
    """
    Cross-validation for constrained models where constraints are tied to observation groups.
    `constraint_groups` should be a dictionary: {group_name: [obs_index_str1, obs_index_str2, ...]}.
    """

    def _resolve_observation_groups(self, A: pd.DataFrame, constraint_groups: dict) -> np.ndarray:
        """
        Converts a dictionary of observation groupings to a numpy array of group labels.

        Args:
            A: Feature matrix with index of observation names.
            constraint_groups: dict of group name -> list of observation index names (str).

        Returns:
            A numpy array of group labels aligned with A.index.
        """
        obs_to_group = {}
        for group, members in constraint_groups.items():
            for obs_name in members:
                obs_to_group[obs_name] = group

        try:
            return np.array([obs_to_group[str(idx)] for idx in A.index])
        except KeyError as e:
            raise ValueError(f"Observation '{e.args[0]}' in A.index not found in any constraint group.")

    def fit(self, A: pd.DataFrame, y: Union[pd.Series, np.ndarray]):
        self.A_ = A
        self.y_ = y
        constraint_groups = self.model_kwargs.get("constraint_groups", None)
        if constraint_groups is None:
            return super().fit(A, y)

        group_labels = self._resolve_observation_groups(A, constraint_groups)
        if len(group_labels) != len(A):
            raise ValueError("Resolved group labels do not match number of rows in A.")

        y = y.values if isinstance(y, pd.Series) else y
        group_kfold = GroupKFold(n_splits=self.k)
        total_configs = len(self.alphas) * (len(self.l1_ratios) if self.l1_ratios is not None else 1)
        config_count = 0

        for alpha in self.alphas:
            ratios = self.l1_ratios if self.l1_ratios is not None else [None]
            for l1_ratio in ratios:
                config_count += 1
                if self.verbose:
                    logger.info(f"Evaluating config {config_count}/{total_configs}: alpha={alpha}, l1_ratio={l1_ratio}")
                errors = []

                for train_idx, val_idx in group_kfold.split(A, y, groups=group_labels):
                    A_train, A_val = A.iloc[train_idx], A.iloc[val_idx]
                    y_train, y_val = y[train_idx], y[val_idx]

                    
                    train_obs_names = set(A.index[train_idx])
                    filtered_groups = {
                        group: obs_list
                        for group, obs_list in self.model_kwargs["constraint_groups"].items()
                        if set(obs_list).issubset(train_obs_names)
                    }
                    assert sum(len(v) for v in filtered_groups.values()) == len(train_idx), (
                        f"Mismatch between number of training observations and total observations "
                        f"in constraint_groups: got {sum(len(v) for v in filtered_groups.values())} in groups, "
                        f"expected {len(train_idx)} in training fold."
                    ) 

                    kwargs = self.model_kwargs.copy()
                    kwargs.update({
                        "alpha": alpha,
                        "constraint_groups": filtered_groups
                    })

                    if l1_ratio is not None:
                        kwargs.update({"l1_ratio": l1_ratio})

                    model = self.model_cls(**kwargs)
                    model.fit(A_train, y_train)
                    errors.append(model.score(A_val, y_val))

                avg_error = np.mean(errors)
                std_error = np.std(errors)
                self.cv_mean_mse_[(alpha, l1_ratio)] = avg_error
                self.cv_std_mse_[(alpha, l1_ratio)] = std_error

                if self.verbose:
                    logger.info(f"Mean MSE: {avg_error:.4f} ± {std_error:.4f}")

                if avg_error < self.best_score_:
                    self.best_score_ = avg_error
                    self.best_params_ = (alpha, l1_ratio)

        # Prepare diagnostics for plotting
        if self.l1_ratios is None:
            mean_mse_arr = np.array([self.cv_mean_mse_[(a, None)] for a in self.alphas])
            std_mse_arr = np.array([self.cv_std_mse_[(a, None)] for a in self.alphas])
        else:
            mean_mse_arr = np.array([
                [self.cv_mean_mse_[(a, l1)] for a in self.alphas] for l1 in self.l1_ratios
            ])
            std_mse_arr = np.array([
                [self.cv_std_mse_[(a, l1)] for a in self.alphas] for l1 in self.l1_ratios
            ])

        self.diagnostics_ = {
            "alphas": self.alphas,
            "l1_ratios": self.l1_ratios,
            "mean_mse": mean_mse_arr,
            "std_mse": std_mse_arr,
            "best_alpha": self.best_params_[0],
            "best_l1_ratio": self.best_params_[1],
        }

        return self
