
import numpy as np
import pandas as pd
from typing import Tuple
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.datasets import fetch_openml
from packaging import version

import sklearn
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler, FunctionTransformer, KBinsDiscretizer
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    roc_auc_score, brier_score_loss,
    mean_squared_error, mean_absolute_error, r2_score
)
from sklearn.base import BaseEstimator, ClassifierMixin

# --- classifiers
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (RandomForestClassifier,
                              ExtraTreesClassifier,
                              GradientBoostingClassifier,
                              HistGradientBoostingClassifier)

to_dense = FunctionTransformer(
    lambda X: X.toarray() if hasattr(X, "toarray") else X,
    accept_sparse=True
)

# --- regressors
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.neural_network import MLPRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import (RandomForestRegressor,
                              ExtraTreesRegressor,
                              GradientBoostingRegressor,
                              HistGradientBoostingRegressor)

# --- Helpers
def build_gate_features(fitted_models: dict, X: pd.DataFrame,
                        prefer: tuple[str, ...] = (
                            "KNN (B5)", "LinearSVM (B5)", "Ridge(B2)",
                            "RF(B3)", "ExtraTrees(B2)"
                        )) -> np.ndarray:
    pre = None  

    for m in fitted_models.values():
        if hasattr(m, "named_steps") and "preproc" in m.named_steps:
            pre = m.named_steps["preproc"]   # <-- was: pipe = m
            break

    if pre is None:
        for m in fitted_models.values():
            if hasattr(m, "named_steps"):
                for _, step in m.named_steps.items():
                    if hasattr(step, "transform"):
                        pre = step
                        break
            if pre is not None:
                break

    if pre is None:
        raise KeyError("No transform-capable preprocessor found in fitted_models.")

    Xr = pre.transform(X)
    if hasattr(Xr, "toarray"):
        Xr = Xr.toarray()
    return np.asarray(Xr, dtype=np.float32)


class MarginSigmoidWrapper(BaseEstimator, ClassifierMixin):

    def __init__(self, base_estimator=None):
        self.base_estimator = base_estimator or LinearSVC()

    def fit(self, X, y):
        self.base_estimator.fit(X, y)
        self.classes_ = getattr(self.base_estimator, "classes_", np.array([0, 1]))
        if hasattr(self.base_estimator, "n_features_in_"):
            self.n_features_in_ = self.base_estimator.n_features_in_

        try:
            m = self.base_estimator.decision_function(X).ravel()
        except Exception:
            m = self.base_estimator._predict_proba_lr(X)[:, 1] * 8 - 4

        import scipy.optimize as opt
        yb = np.asarray(y).ravel().astype(int)

        def nll(ab):
            a, b = ab
            p = 1 / (1 + np.exp(-(a * m + b)))
            eps = 1e-9
            return -np.mean(yb * np.log(p + eps) + (1 - yb) * np.log(1 - p + eps))

        ab0 = np.array([1.0, 0.0])
        self.a_, self.b_ = opt.minimize(nll, ab0, method="BFGS").x
        return self

    def predict_proba(self, X):
        m = self.base_estimator.decision_function(X).ravel()
        p1 = 1 / (1 + np.exp(-(self.a_ * m + self.b_)))
        p1 = np.clip(p1, 1e-6, 1 - 1e-6)
        return np.c_[1 - p1, p1]

    def predict(self, X):
        return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)


def _safe_drop_sparse_missing(df: pd.DataFrame, max_missing=0.4) -> pd.DataFrame:
    miss = df.isna().mean()
    keep = miss[miss <= float(max_missing)].index
    return df.loc[:, keep]


def _clean_uci_frame(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == object:
            df[c] = (
                df[c]
                .astype(str)
                .str.strip()
                .replace({"?": np.nan, "NA": np.nan, "NaN": np.nan, "nan": np.nan, "None": np.nan, "unknown": np.nan})
            )
    return df


def _stratified_reg_split(X, y, test_size=0.2, n_bins=10):
    y = pd.Series(y)
    q = pd.qcut(y, q=n_bins, duplicates="drop")
    return train_test_split(X, y, test_size=test_size, stratify=q)


def _balance_train_classes(X_tr: pd.DataFrame, y_tr: pd.Series):
    df = X_tr.copy()
    df["__y__"] = y_tr.values
    counts = df["__y__"].value_counts()
    if counts.nunique() == 1:
        # already balanced
        df_bal = df
    else:
        n_min = counts.min()
        parts = []
        for cls, n in counts.items():
            take = df[df["__y__"] == cls].sample(n=n_min, replace=False)
            parts.append(take)
        df_bal = pd.concat(parts, axis=0).sample(frac=1.0)  # shuffle
    y_bal = df_bal.pop("__y__")
    return df_bal.reset_index(drop=True), y_bal.reset_index(drop=True)


def load_uci_data(
    name: str,
    *,
    test_size: float = 0.2,
    drop_missing_cols: float = 0.4,
    return_df: bool = True,
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:

    name = str(name).lower().strip()

    registry = {
        # Classification 
        "spambase":          (44,                   "class",               "classification"),
        "bank-marketing":    (1461,                 "y",                   "classification"),
        "credit-g":          (31,                   "class",               "classification"),
        # Regression 
        "bike-sharing":      (42712,                "cnt",                 "regression"),
        "california-housing":(42165,                "MedHouseVal",         "regression"),  
        "communities":       (46286,                "ViolentCrimesPerPop", "regression"),

    }
    if name not in registry:
        raise ValueError(f"Unknown dataset '{name}'. Choose one of: {list(registry.keys())}")

    ds_id_or_name, target_name, task_hint = registry[name]

    if isinstance(ds_id_or_name, int):
        openml_obj = fetch_openml(data_id=ds_id_or_name, as_frame=True)
    else:
        openml_obj = fetch_openml(name=ds_id_or_name, version=1, as_frame=True)

    df = openml_obj.frame.copy()

    if target_name not in df.columns and getattr(openml_obj, "target_names", None):
        target_name = openml_obj.target_names[0]
    if target_name not in df.columns:
        raise RuntimeError(
            f"Target column '{target_name}' not found. "
            f"Available columns: {list(df.columns)[:10]}... (total {len(df.columns)})"
        )

    y_raw = df[target_name]
    X = df.drop(columns=[target_name])

    X = _clean_uci_frame(X)
    X = _safe_drop_sparse_missing(X, max_missing=drop_missing_cols)

    def _infer_task_from_y(y_series: pd.Series) -> str:
        if pd.api.types.is_numeric_dtype(y_series):
            if pd.api.types.is_integer_dtype(y_series) and y_series.nunique() <= 10:
                return "classification"
            return "regression"
        return "classification"

    task = _infer_task_from_y(y_raw)

 
    if task_hint == "regression" and task == "classification":
        pass
    elif task_hint == "classification" and task == "regression":
        pass
    else:
        task = task_hint  

    if task == "classification":
        if not pd.api.types.is_integer_dtype(y_raw) and not pd.api.types.is_bool_dtype(y_raw):
            y = y_raw.astype("category").cat.codes.astype(int)
        else:
            y = y_raw.astype(int)

        x_tr, x_te, y_tr, y_te = train_test_split(
            X, y, test_size=test_size, stratify=y
        )

        x_tr, y_tr = _balance_train_classes(x_tr, y_tr)

    else:  # regression
        try:
            y = pd.to_numeric(y_raw, errors="raise")
        except Exception:
            raise ValueError(
                f"Target column '{target_name}' for dataset '{name}' is non-numeric "
                f"(example value: {y_raw.iloc[0]!r}). This is a classification target."
            )
        mask = y.notna()
        X, y = X.loc[mask], y.loc[mask]

        # Balanced coverage across target range
        x_tr, x_te, y_tr, y_te = _stratified_reg_split(
            X, y, test_size=test_size, n_bins=12
        )

    # tidy indices
    x_tr = x_tr.reset_index(drop=True)
    x_te = x_te.reset_index(drop=True)
    y_tr = y_tr.reset_index(drop=True)
    y_te = y_te.reset_index(drop=True)

    if return_df:
        return x_tr, y_tr, x_te, y_te
    else:
        return x_tr.to_numpy(), y_tr.to_numpy(), x_te.to_numpy(), y_te.to_numpy()


def _ohe_kwargs():
    kw = dict(handle_unknown="ignore")
    if version.parse(sklearn.__version__) >= version.parse("1.2"):
        kw["sparse_output"] = True
    else:
        kw["sparse"] = True
    return kw


def _split_cols(X: pd.DataFrame):
    cat_cols = X.select_dtypes(include=["object", "category", "bool"]).columns.tolist()
    num_cols = X.columns.difference(cat_cols).tolist()
    return num_cols, cat_cols


def _as_py_str_list(cols):
    if hasattr(cols, "tolist"):
        cols = cols.tolist()
    elif not isinstance(cols, (list, tuple)):
        cols = [cols]
    # coerce each item to a real Python str
    return [str(c) for c in cols]
    
    
def make_preproc(num_cols, cat_cols, *, for_tree=False):
    num_cols = _as_py_str_list(num_cols)
    cat_cols = _as_py_str_list(cat_cols)
    ohe_kw = _ohe_kwargs()
    if for_tree:
        return ColumnTransformer(
            [
                ("num", "passthrough", num_cols),
                ("cat", Pipeline([
                    ("imp", SimpleImputer(strategy="most_frequent")),
                    ("ohe", OneHotEncoder(min_frequency=10, **ohe_kw)),
                ]), cat_cols),
            ],
            remainder="drop",
        )
    else:
        return ColumnTransformer(
            [
                ("num", Pipeline([
                    ("imp", SimpleImputer(strategy="median")),
                    ("scaler", StandardScaler()),
                ]), num_cols),
                ("cat", Pipeline([
                    ("imp", SimpleImputer(strategy="most_frequent")),
                    ("ohe", OneHotEncoder(min_frequency=10, **ohe_kw)),
                ]), cat_cols),
            ],
            remainder="drop",
        )

def build_feature_bundles(X: pd.DataFrame, y: pd.Series | np.ndarray | None = None, *, task: str | None = None,):
  
    num_all, cat_all = _split_cols(X)

    miss_rate = X.isna().mean()
    var_num = X[num_all].var().fillna(0.0) if num_all else pd.Series(dtype=float)
    card_cat = X[cat_all].nunique().astype(int) if cat_all else pd.Series(dtype=int)

    num_low_miss = [c for c in num_all if miss_rate.get(c, 0.0) <= miss_rate.quantile(0.2)]
    num_high_var = list(var_num.sort_values(ascending=False).head(max(1, len(num_all))).index) if len(num_all) else [] 
    if len(num_high_var)>3:
        num_high_var = num_high_var[0:3]
    cat_high_card = list(card_cat.sort_values(ascending=False).head(max(1, len(cat_all))).index) if len(cat_all) else []
    if len(cat_high_card)>3:
        cat_high_card = cat_high_card[0:3]
    cat_low_card = [c for c in cat_all if c not in set(cat_high_card)][0:len(cat_all) ]
    if len(cat_low_card)>5:
        cat_low_card = cat_low_card[0:5]
    
    num_high_corr: list[str] = []
    if y is not None and len(num_all) > 0:
        y_s = pd.Series(y)
        if not pd.api.types.is_numeric_dtype(y_s):
            y_s = y_s.astype("category").cat.codes

        Xn = X[num_all].copy()
        Xn = Xn.fillna(Xn.median(numeric_only=True))

        corrs = Xn.apply(lambda col: col.corr(y_s), axis=0)
        corrs = corrs.replace([np.inf, -np.inf], np.nan).fillna(0.0)
        k = min(3, len(num_all))
        num_high_corr = list(corrs.abs().sort_values(ascending=False).head(k).index)

    bundles = {
        "B1_num_high_corr": (num_high_corr, []),
        "B2_num_high_var": (num_high_var, []),
        "B3_cat_high_card": ([], cat_high_card),
        "B4_cat_low_card": ([], cat_low_card),
        "B5_all_cat": ([], cat_all) ,
        "B6_num_all": (num_all, []),
        "B7_mix": (num_high_corr, cat_high_card),
    }

    def safe_bundle(num, cat):
        if not num and not cat:
            if num_all:
                num = [str(np.random.choice(num_all))]
            elif cat_all:
                cat = [str(np.random.choice(cat_all))]
        return _as_py_str_list(num), _as_py_str_list(cat)

    return {k: safe_bundle(*v) for k, v in bundles.items()}


def train_models(
    X_train: pd.DataFrame,
    y_train: pd.Series,
    *,
    task: str | None = None,           # "classification" | "regression" | None = auto
):
    """
    Fit a diverse set of weaker-but-different models, each on a different feature bundle.
    Returns {name: fitted Pipeline}.
    """
    if task is None:
        if (pd.api.types.is_bool_dtype(y_train) or
            (pd.api.types.is_integer_dtype(y_train) and y_train.nunique() <= 10) or
            (pd.api.types.is_categorical_dtype(y_train) and y_train.nunique() <= 10)):
            task = "classification"
        else:
            task = "regression" if pd.api.types.is_numeric_dtype(y_train) else "classification"

    bundles = build_feature_bundles(X_train, y_train)
    models = {}

    if task == "classification":

        num, cat = bundles["B2_num_high_var"]  
        models["NB (B2)"] = Pipeline([
            ("preproc", ColumnTransformer([
                ("num", KBinsDiscretizer(n_bins=10, encode="onehot", strategy="quantile"), num),
                ("cat", OneHotEncoder(min_frequency=10, **_ohe_kwargs()), cat),
            ], remainder="drop")),
            ("clf", MultinomialNB(alpha=1.0, fit_prior=True))
        ])
                
        # KNN
        num, cat = bundles["B5_all_cat"]
        models["KNN (B5)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("clf", KNeighborsClassifier(n_neighbors=1, weights="uniform", n_jobs=-1))
        ])
        

        # Random Forest 
        num, cat = bundles["B3_cat_high_card"]
        models["RF (B3)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("clf", RandomForestClassifier(
                n_estimators=60, max_depth=6, min_samples_leaf=5,
                n_jobs=-1 
            ))
        ])

        # ExtraTrees 
        num, cat = bundles["B4_cat_low_card"]
        models["ExtraTrees (B4)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("clf", ExtraTreesClassifier(
                n_estimators=60, max_depth=6, min_samples_leaf=5,
                n_jobs=-1
            ))
        ])
        
        # SVM 
        num, cat = bundles["B5_all_cat"]
        models["LinearSVM (B5)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("clf", MarginSigmoidWrapper(
                base_estimator=LinearSVC(
                    C=0.5, loss="hinge", tol=1e-2, max_iter=10000
                )
            ))
        ])

    else:  
        # regression
        
        # Ridge 
        num, cat = bundles["B6_num_all"]
        models["Ridge(B6)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("reg", Ridge(alpha=0.05))
        ])

        # Lasso 
        num, cat = bundles["B6_num_all"]
        models["Lasso(B6)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("reg", Lasso(alpha=0.05, max_iter=5000))
        ])

        # KNN 
        num, cat = bundles["B7_mix"]
        models["KNN(B7)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("reg", KNeighborsRegressor(n_neighbors=3, weights="distance", n_jobs=-1))
        ])


        # Random Forest 
        num, cat = bundles["B7_mix"]
        models["RF(B7)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("reg", RandomForestRegressor(
                n_estimators=50, max_depth=None, min_samples_leaf=2,
                n_jobs=-1))
        ])

        # ExtraTrees 
        num, cat = bundles["B7_mix"]
        models["ExtraTrees(B7)"] = Pipeline([
            ("preproc", make_preproc(num, cat, for_tree=False)),
            ("reg", ExtraTreesRegressor(
                n_estimators=50, max_depth=None, min_samples_leaf=2,
                n_jobs=-1))
        ])

    fitted = {}
    for name, pipe in models.items():
        print(f"Fitting {name}...")
        pipe.fit(X_train, y_train)
        fitted[name] = pipe

    return fitted


# --- Evaluation

def _predict_proba_safe(model, X):
    """
    Robustly get probabilities for a model or Pipeline.
    - If Pipeline, apply all transforms, then query the final estimator.
    - Try predict_proba; else decision_function -> sigmoid/softmax;
      else fabricate from hard predictions.
    """
    import numpy as np
    import pandas as pd
    from sklearn.pipeline import Pipeline

    est = model
    Xt = X

    try:
        if isinstance(model, Pipeline):
            steps = list(model.steps)
            if steps:
                for name, step in steps[:-1]:
                    if hasattr(step, "transform"):
                        Xt = step.transform(Xt)
                est = steps[-1][1]  # final estimator
    except Exception:
        est, Xt = model, X

    try:
        proba_fn = getattr(est, "predict_proba", None)
        if callable(proba_fn):
            P = proba_fn(Xt)
            P = np.asarray(P)
            if P.ndim == 1:
                P = np.c_[1.0 - P, P]
            return P.astype(np.float32)
    except Exception:
        pass

    try:
        df_fn = getattr(est, "decision_function", None)
        if callable(df_fn):
            df = df_fn(Xt)
            df = np.asarray(df)
            if df.ndim == 1:  
                sig = 1.0 / (1.0 + np.exp(-df))
                P = np.c_[1.0 - sig, sig]
            else:             
                e = np.exp(df - df.max(axis=1, keepdims=True))
                P = e / e.sum(axis=1, keepdims=True)
            return P.astype(np.float32)
    except Exception:
        pass

    try:
        yhat = est.predict(Xt)
    except Exception:
        yhat = model.predict(X)
    yhat = np.asarray(yhat).ravel()
    classes = np.unique(yhat)
    if len(classes) == 2:
        m = {classes[0]: 0, classes[1]: 1}
        p1 = pd.Series(yhat).map(m).to_numpy(dtype=float)
        P = np.c_[1.0 - p1, p1]
    else:
        cls2idx = {c: i for i, c in enumerate(classes)}
        idx = np.vectorize(cls2idx.get)(yhat)
        P = np.zeros((len(yhat), len(classes)), dtype=np.float32)
        P[np.arange(len(yhat)), idx] = 1.0
    return P.astype(np.float32)


def _compute_ece_local(y_true, p_hat, n_bins=15):
    y_true = np.asarray(y_true).ravel().astype(int)
    p_hat  = np.asarray(p_hat).ravel()
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    inds = np.digitize(p_hat, bins) - 1
    ece = 0.0
    for b in range(n_bins):
        mask = inds == b
        if not np.any(mask):
            continue
        conf = p_hat[mask].mean()
        acc  = y_true[mask].mean()
        w    = mask.mean()
        ece += w * abs(acc - conf)
    return float(ece)


def _infer_task(y):
    if pd.api.types.is_numeric_dtype(y):
        if pd.api.types.is_integer_dtype(y) and pd.Series(y).nunique() <= 10:
            return "classification"
        return "regression"
    return "classification"


def evaluate_models(
    fitted_models: dict,
    X_test,
    y_test,
    *,
    task: str | None = None,
    ece_bins: int = 15
) -> pd.DataFrame:
    
    y = y_test.to_numpy() if hasattr(y_test, "to_numpy") else np.asarray(y_test)
    task = task or _infer_task(y)

    rows = []
    for name, mdl in fitted_models.items():
        try:
            if task == "classification":
                P = _predict_proba_safe(mdl, X_test)
                if P.shape[1] == 2:
                    p1 = P[:, 1]
                    yhat = (p1 >= 0.5).astype(int)
                    acc = accuracy_score(y, yhat)
                    f1  = f1_score(y, yhat, zero_division=0)
                    pre = precision_score(y, yhat, zero_division=0)
                    rec = recall_score(y, yhat, zero_division=0)
                    try:
                        auc = roc_auc_score(y, p1)
                    except Exception:
                        auc = np.nan
                    try:
                        bri = brier_score_loss(y, p1)
                    except Exception:
                        bri = np.nan
                    try:
                        ece = compute_ece(y, p1, n_bins=ece_bins)  # may exist in your notebook
                    except Exception:
                        ece = _compute_ece_local(y, p1, n_bins=ece_bins)
                    rows.append({
                        "model": name,
                        "Accuracy": acc,
                        "F1": f1,
                        "Precision": pre,
                        "Recall": rec,
                        "ROC_AUC": auc,
                        "Brier": bri,
                        "ECE": ece
                    })
                else:
                    yhat = np.argmax(P, axis=1)
                    acc = accuracy_score(y, yhat)
                    f1m = f1_score(y, yhat, average="macro", zero_division=0)
                    rows.append({
                        "model": name,
                        "Accuracy": acc,
                        "F1_macro": f1m
                    })
            else:  # regression
                yhat = np.asarray(mdl.predict(X_test)).ravel()
                rmse = mean_squared_error(y, yhat)
                mae  = mean_absolute_error(y, yhat)
                r2   = r2_score(y, yhat)
                rows.append({
                    "model": name,
                    "RMSE": rmse,
                    "MAE": mae,
                    "R2": r2
                })
        except Exception as e:
            rows.append({"model": name, "error": repr(e)})

    df = pd.DataFrame(rows).set_index("model", drop=True)

    if task == "classification":
        cols = [c for c in ["Accuracy","F1","Precision","Recall","ROC_AUC","Brier","ECE","error"] if c in df.columns]
        df = df[cols]
        df = df.sort_values(by=[c for c in ["Accuracy","ROC_AUC","F1"] if c in df.columns][0], ascending=False)
    else:
        cols = [c for c in ["RMSE","MAE","R2","error"] if c in df.columns]
        df = df[cols]
        sort_key = "RMSE" if "RMSE" in df.columns else ("MAE" if "MAE" in df.columns else None)
        if sort_key:
            df = df.sort_values(by=sort_key, ascending=True)

    return df
