import pandas as pd
import numpy as np
import sklearn
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, StandardScaler, FunctionTransformer
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, brier_score_loss, log_loss
from scipy import sparse
from packaging.version import Version
from time import perf_counter
from packaging import version
import xgboost as xgb


def stratified_split(
    X: pd.DataFrame,
    y: pd.Series,
    *,
    test_size=0.3,
    cat_cols=("ProductCD","card4"),
    amt_col="TransactionAmt",
    n_amt_bins=4,
    n_miss_bins=3,
    min_count=2,         
):
    X = X.copy()
    y = y.astype(int)

    # Missingness bin (per-row NA rate)
    miss_rate = X.isna().mean(axis=1)
    miss_bins = pd.qcut(miss_rate, q=n_miss_bins, duplicates="drop").astype(str)

    amt_bins = pd.qcut(X[amt_col], q=n_amt_bins, duplicates="drop").astype(str)

    # Key categoricals (NA as token)
    for c in cat_cols:
        if c in X.columns:
            X[c] = (
                X[c].astype("category")
                    .cat.add_categories(["<NA>"])
                    .fillna("<NA>")
                    .astype(str)
            )
        else:
            X[c] = "<NA>"

    strata = X[list(cat_cols)].astype(str).agg("|".join, axis=1) + "|" + miss_bins + "|" + amt_bins
    strat_y = (strata + "||y=" + y.astype(str)).astype(str)

    counts = strat_y.value_counts()
    rare_keys = counts[counts < min_count].index
    if len(rare_keys) > 0:
        y_suffix = strat_y.str.extract(r"\|\|y=(.*)$", expand=False)  # keep class in the rare bucket
        rare_mask = strat_y.isin(rare_keys)
        strat_y = strat_y.where(~rare_mask, "_RARE_||y=" + y_suffix)

    splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_size)
    (tr_idx, te_idx), = splitter.split(np.zeros(len(y)), strat_y)
    return tr_idx, te_idx

def downsample_preserving_characteristics(
    X: pd.DataFrame,
    y: pd.Series,
    *,
    target_size: int | None = None,
    frac: float | None = None,       # e.g., 0.2 for 20%
    cat_cols=("ProductCD", "card4"),
    amt_col="TransactionAmt",
    n_amt_bins=4,
    n_miss_bins=3,
    min_per_group=1,
):

    X = X.copy()
    y = y.astype(int)
    N = len(X)

    if (target_size is None) == (frac is None):
        raise ValueError("Specify exactly one of target_size or frac.")
    if frac is not None:
        target_size = int(round(N * float(frac)))
        target_size = max(2, min(target_size, N))

    miss_rate = X.isna().mean(axis=1)
    miss_bins = pd.qcut(miss_rate, q=n_miss_bins, duplicates="drop").astype(str)

    if amt_col not in X.columns:
        raise ValueError(f"{amt_col} not in X")
    amt_bins = pd.qcut(X[amt_col], q=n_amt_bins, duplicates="drop").astype(str)

    for c in cat_cols:
        if c in X.columns:
            X[c] = (X[c].astype("category")
                          .cat.add_categories(["<NA>"])
                          .fillna("<NA>")
                          .astype(str))
        else:
            X[c] = "<NA>"

    strata = X[list(cat_cols)].astype(str).agg("|".join, axis=1) + "|" + miss_bins + "|" + amt_bins
    strat_y = (strata + "||y=" + y.astype(str)).astype(str)

    counts = strat_y.value_counts()
    props  = counts / counts.sum()
    ideal  = props * target_size

    alloc = np.floor(ideal).astype(int)
    remainder = target_size - int(alloc.sum())
    if remainder > 0:
        order = np.argsort(-(ideal - alloc))
        keys = counts.index.to_numpy()
        for k in keys[order][:remainder]:
            alloc[k] += 1

    alloc = np.minimum(alloc, counts)
    current = int(alloc.sum())

    if current < target_size:
        deficit = target_size - current
        capacity = (counts - alloc)
        for k in capacity.sort_values(ascending=False).index:
            if deficit <= 0: break
            take = min(capacity[k], deficit)
            if take > 0:
                alloc[k] += take
                deficit -= take

    if min_per_group > 0:
        for k in alloc.index:
            if counts[k] >= min_per_group:
                alloc[k] = max(alloc[k], min_per_group)
        excess = int(alloc.sum()) - target_size
        if excess > 0:
            give = (alloc - min_per_group).clip(lower=0)
            if give.sum() > 0:
                frac_give = (give / give.sum()).fillna(0.0)
                cut = np.floor(frac_give * excess).astype(int)
                alloc -= cut
                residue = int(alloc.sum()) - target_size
                if residue > 0:
                    for k in give.sort_values(ascending=False).index:
                        if residue == 0: break
                        if alloc[k] > min_per_group:
                            alloc[k] -= 1
                            residue -= 1

    sel_idx = []
    for key, idxs in strat_y.groupby(strat_y).groups.items():
        idxs = pd.Index(idxs)
        k = int(alloc.get(key, 0))
        if k <= 0:
            continue
        if len(idxs) <= k:
            sel_idx.append(idxs)
        else:
            take = np.random.choice(idxs.to_numpy(), size=k, replace=False)
            sel_idx.append(pd.Index(take))

    if not sel_idx:
        return pd.Index([])

    sel_idx = pd.Index(np.concatenate([s.to_numpy() for s in sel_idx]))
    if len(sel_idx) > target_size:
        drop = len(sel_idx) - target_size
        keep_mask = np.ones(len(sel_idx), dtype=bool)
        keep_mask[np.random.choice(len(sel_idx), size=drop, replace=False)] = False
        sel_idx = sel_idx[keep_mask]

    return sel_idx

def load_fraud_data(
    transaction_csv_path: str,
    identity_csv_path: str,
    test_size: float = 0.2,
    return_df: bool = True,
    *,                          
    sample_frac: float | None = None,     # e.g., 0.2 = keep 20% of rows
    sample_size: int | None = None,       
    cat_cols=("ProductCD","card4"),
    amt_col="TransactionAmt",
    n_amt_bins=4,
    n_miss_bins=3,
    min_count=2,
):
    """
    IEEE-CIS Fraud Detection dataset.
    Optionally downsample to a smaller dataset while preserving characteristics,
    then do a stratified no-shift train/test split.
    """
    # Load 
    df_transaction = pd.read_csv(transaction_csv_path, index_col="TransactionID")
    df_identity    = pd.read_csv(identity_csv_path, index_col="TransactionID")

    # Merge and prep
    df = df_transaction.join(df_identity, how="left")
    df = df[~df["isFraud"].isna()]
    y = df["isFraud"].astype(int)
    X = df.drop(columns=["isFraud"])

    # Drop columns with >50% missing
    missing_frac = X.isnull().mean()
    X = X.loc[:, missing_frac < 0.5]

    if (sample_frac is not None) or (sample_size is not None):
        idx_small = downsample_preserving_characteristics(
            X, y,
            frac=sample_frac,
            target_size=sample_size,
            cat_cols=cat_cols,
            amt_col=amt_col,
            n_amt_bins=n_amt_bins,
            n_miss_bins=n_miss_bins,
            min_per_group=1,
        )
        if len(idx_small) == 0:
            raise ValueError("Downsampling produced an empty selection; try reducing bin counts or increasing sample size.")
        X = X.loc[idx_small]
        y = y.loc[idx_small]

    tr_idx, te_idx = stratified_split(
        X, y,
        test_size=test_size,
        cat_cols=cat_cols,
        amt_col=amt_col,
        n_amt_bins=n_amt_bins,
        n_miss_bins=n_miss_bins,
        min_count=min_count
    )

    x_train, y_train = X.iloc[tr_idx], y.iloc[tr_idx]
    x_test,  y_test  = X.iloc[te_idx], y.iloc[te_idx]

    if return_df:
        return x_train, y_train, x_test, y_test
    else:
        return (
            x_train.to_numpy(),
            y_train.to_numpy(),
            x_test.to_numpy(),
            y_test.to_numpy(),
        )
        
def compute_ece(y_true, p_hat, n_bins=15, eps=1e-12):
    y_true = np.asarray(y_true).ravel().astype(int)
    p_hat = np.asarray(p_hat).ravel()
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    inds = np.digitize(p_hat, bins) - 1  # 0..n_bins-1
    ece = 0.0
    for b in range(n_bins):
        mask = inds == b
        if not np.any(mask):
            continue
        conf = p_hat[mask].mean()
        acc = y_true[mask].mean()
        w = mask.mean()
        ece += w * abs(acc - conf)
    return float(ece)


def train_models(x_train, y_train):
    
    neg = int((y_train == 0).sum())
    pos = int((y_train == 1).sum())
    xgb_pos_w = max(1.0, neg / max(1, pos))

    # --- Column splits ---
    categorical_cols = x_train.select_dtypes(include=["object", "category"]).columns.tolist()
    numeric_cols     = x_train.columns.difference(categorical_cols).tolist()

    # --- Version-safe OHE kwargs ---
    ohe_params = dict(handle_unknown="ignore")
    if version.parse(sklearn.__version__) >= version.parse("1.2"):
        ohe_params["sparse_output"] = True
    else:
        ohe_params["sparse"] = True
    if version.parse(sklearn.__version__) >= version.parse("1.1"):
        ohe_params["min_frequency"] = 10

    # --- Preprocessing ---
    num_dense = Pipeline([
        ("imp", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler())
    ])
    cat_dense = Pipeline([
        ("imp", SimpleImputer(strategy="most_frequent")),
        ("ohe", OneHotEncoder(**ohe_params))
    ])
    preproc_dense_sparse_out = ColumnTransformer(
        [("num", num_dense, numeric_cols),
         ("cat", cat_dense, categorical_cols)]
    )
    preproc_tree = ColumnTransformer(
        [("num", "passthrough", numeric_cols),
         ("cat", OneHotEncoder(**ohe_params), categorical_cols)]
    )
    to_dense = FunctionTransformer(
        lambda X: X.toarray() if hasattr(X, "toarray") else np.asarray(X),
        accept_sparse=True
    )

    # --- Models ---
    models = {
    # Logistic: keep L1 for sparsity/diversity, but remove class_weight for fairness
    "Logistic": Pipeline([
        ("preproc", preproc_dense_sparse_out),
        ("clf", LogisticRegression(
            solver="saga",
            penalty="l1",
            C=0.05,
            # class_weight="balanced",   # <- remove this
            max_iter=4000, tol=1e-3, n_jobs=-1
        ))
    ]),

    # MLP: slightly wider + more L2; unchanged otherwise
    "MLP": Pipeline([
        ("preproc", preproc_dense_sparse_out),
        ("densify", to_dense),
        ("clf", MLPClassifier(
            hidden_layer_sizes=(384, 192),
            activation="relu",
            alpha=3e-3,                  # a touch lighter than 5e-3
            batch_size=512,
            learning_rate="adaptive",
            learning_rate_init=1e-3,
            early_stopping=True,
            validation_fraction=0.12,
            n_iter_no_change=12,
            max_iter=300,
            tol=1e-4
        ))
    ]),

    # XGB: DART but gentler dropout + add scale_pos_weight for fairness
    "XGB": Pipeline([
        ("preproc", preproc_tree),
        ("clf", xgb.XGBClassifier(
            booster="gbtree",
            tree_method="hist",
            max_bin=256, 
            n_estimators=300,
            max_depth=5,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.7,
            scale_pos_weight=xgb_pos_w,
            missing=np.nan,
            n_jobs=-1,
            eval_metric="logloss",
            reg_lambda=1.0
        ))
    ]),

    # HGB: allow a bit more depth and reduce L2
    "HGB": Pipeline([
        ("preproc", preproc_tree),
        ("densify", to_dense),
        ("clf", HistGradientBoostingClassifier(
            max_iter=350,
            max_depth=4,                
            learning_rate=0.07,
            l2_regularization=0.5,      
            early_stopping=True
        ))
    ])
}

    # --- Fit all models ---
    fitted = {}
    for name, pipe in models.items():
        print(f"--- Fitting {name} ---")
        pipe.fit(x_train, y_train)
        fitted[name] = pipe

    return fitted


def evaluate_models(fitted: dict, x_test, y_test, n_bins: int = 15) -> pd.DataFrame:
    """
    Evaluate fitted models on test set. Returns DataFrame with Accuracy and ECE.
    """
    rows = []
    for name, pipe in fitted.items():
        proba = pipe.predict_proba(x_test)[:, 1]
        pred  = (proba >= 0.5).astype(int)

        acc = accuracy_score(y_test, pred)
        ece = compute_ece(y_test, proba, n_bins=n_bins)

        rows.append({"model": name, "Accuracy": acc, "ECE": ece})

    return pd.DataFrame(rows).sort_values("Accuracy", ascending=False).reset_index(drop=True)

def format_results(results: pd.DataFrame) -> pd.DataFrame:
    table = results[["model", "Accuracy", "ECE"]].copy()
    table = table.set_index("model")
    return table
    
def plot_confidence_error_bars(methods, y_true, n_bins=10):

    y_arr = y_true.to_numpy() if hasattr(y_true, "to_numpy") else np.asarray(y_true)
    
    method_color = {
      "DLA":   "#1f77b4",  # blue
      "IABMA": "#00441b",  # deep green 
      "MoE":   "#ff7f0e",  # orange
      "SMC":   "#9467bd",  # purple
      "BHS":   "#d62728",  # red
    }
    preferred = ["MoE", "DLA", "SMC", "BHS", "IABMA"]
    meth_order = [m for m in preferred if m in methods] or list(methods.keys())
    
    bins = np.linspace(0, 0.5, n_bins + 1)
    
    keep_mask = bins[:-1] >= 0.25
    bins_lo = bins[:-1][keep_mask]
    bins_hi = bins[1:][keep_mask]
    n_keep = len(bins_lo)
    x = np.arange(n_keep)
    
    bin_err_map = {}
    bin_cnt_map = {}
    for name in meth_order:
      P = np.asarray(methods[name])
      p1 = P[:, 1]
      conf = np.abs(p1 - 0.5)
      err = ((p1 >= 0.5).astype(int) != y_arr).astype(float)
    
      errs_all, cnts_all = [], []
      for lo, hi in zip(bins[:-1], bins[1:]):
          m = (conf >= lo) & (conf < hi)
          cnts_all.append(int(m.sum()))
          errs_all.append(float(err[m].mean()) if m.any() else np.nan)
    
      errs_all = np.array(errs_all, dtype=float)
      cnts_all = np.array(cnts_all, dtype=int)
    
      bin_err_map[name] = errs_all[keep_mask]
      bin_cnt_map[name] = cnts_all[keep_mask]
    
    k = len(meth_order)
    width = 0.85 / max(1, k)  
    fig, ax = plt.subplots(figsize=(10, 5.2))
    
    for j, name in enumerate(meth_order):
      heights = bin_err_map[name]    
      cnts    = bin_cnt_map[name]
      pos = x - 0.5*(k-1)*width + j*width
    
      for xi, hi, c in zip(pos, heights, cnts):
          if not np.isfinite(hi):  
              continue
          ax.bar(
              xi, hi, width=width,
              color=method_color.get(name, "#7f7f7f"),
              edgecolor="white", linewidth=0.5,
              hatch=("//" if name == "IABMA" else None),
              label=name if j == 0 and xi == pos[0] else None,
              alpha=1.0 if name == "IABMA" else 0.75
          )
    
    ax.set_xticks(x)
    ax.set_xticklabels([f"{lo:.2f}-{hi:.2f}" for lo, hi in zip(bins_lo, bins_hi)], rotation=0)
    ax.set_xlabel("Confidence |p - 0.5|")
    ax.set_ylabel("Average error rate in bin")
    ax.set_ylim(0, 0.8)
    ax.grid(True, axis="y", alpha=0.25)
    
    handles, labels = [], []
    for name in meth_order:
      h = ax.bar(0, 0, color=method_color.get(name, "#7f7f7f"),
                edgecolor="white", linewidth=0.5,
                hatch=("//" if name == "IABMA" else None))
      handles.append(h)
      labels.append(name)
    
    fig.legend(handles, labels, 
            loc="lower center", ncol=min(5, k), frameon=True, bbox_to_anchor=(0.5, -0.02))
    plt.tight_layout(rect=(0, 0.06, 1, 1))  
    plt.show()
    return fig
