import os
from datetime import datetime
from zoneinfo import ZoneInfo
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

def adapt_features(X, fitted_models, task="binary"):
    try:
        return np.asarray(X, dtype=np.float32)
    except Exception:
        if task == "binary":
            bp = np.stack([m.predict_proba(X) for m in fitted_models.values()], axis=1).astype(np.float32)  
        else:
            bp = np.stack([np.asarray(m.predict(X)).ravel() for m in fitted_models.values()], axis=1).astype(np.float32)  
        return bp.reshape(bp.shape[0], -1)
        
        
def log_run(
    results,
    *,
    csv_path: str = "results_log.csv",
    methods=None,
    all_metrics=None,
    method_times: dict | None = None,
    timezone: str = "America/New_York",
):

    import os
    import numpy as np
    import pandas as pd
    from datetime import datetime
    try:
        from zoneinfo import ZoneInfo
        tz = ZoneInfo(timezone)
        now_str = datetime.now(tz).isoformat(timespec="seconds")
    except Exception:
        tz = None
        now_str = datetime.now().isoformat(timespec="seconds")

    # Defaults
    if methods is None:
        methods = ["Best-single","Uniform Avg.","Freq Avg.","BMA","MoE","DLA","SMC","BHS","IABMA"]
    if all_metrics is None:
        all_metrics = ["Accuracy","ECE","R2","RMSE"]
    time_cols = ["moe_time","dla_time","smc_time","bhs_time","iabma_time"]

    import inspect
    caller_g = {}
    try:
        caller_g = inspect.currentframe().f_back.f_globals or {}
    except Exception:
        pass

    EXPERIMENT = caller_g.get("EXPERIMENT", "unknown")
    UCI_DATASET = caller_g.get("UCI_DATASET")
    TASK = caller_g.get("TASK", None)

    try:
        exp_label = f"uci-{UCI_DATASET}" if (EXPERIMENT == "UCI" and UCI_DATASET) else str(EXPERIMENT)
    except Exception:
        exp_label = "unknown"

    task_str = (str(TASK).lower() if TASK is not None else None)

    row = {"experiment": exp_label, "timestamp": now_str, "task": task_str}

    def _to_float(x):
        try:
            return float(x)
        except Exception:
            return np.nan

    for c in time_cols:
        row[c] = np.nan

    if caller_g:
        for c in time_cols:
            if c in caller_g:
                row[c] = _to_float(caller_g[c])

    if isinstance(method_times, dict):
        mapping = {
            "moe_time":   method_times.get("MoE",   None),
            "dla_time":   method_times.get("DLA",   None),
            "smc_time":   method_times.get("SMC",   None),
            "bhs_time":   method_times.get("BHS",   None),
            "iabma_time": method_times.get("IABMA", None),
        }
        for k, v in mapping.items():
            if v is not None:
                row[k] = _to_float(v)

    for m in methods:
        vals = (results.get(m, {}) or {})
        for met in all_metrics:
            row[f"{m}__{met}"] = vals.get(met, np.nan)

    fixed_cols = (
        ["experiment", "timestamp", "task"]
        + time_cols
        + [f"{m}__{met}" for m in methods for met in all_metrics]
    )

    new_row_df = pd.DataFrame([row])

    if os.path.exists(csv_path):
        old = pd.read_csv(csv_path)
        all_cols = list(dict.fromkeys(list(old.columns) + fixed_cols))
        old = old.reindex(columns=all_cols)
        new_row_df = new_row_df.reindex(columns=all_cols, fill_value=np.nan)
        out = pd.concat([old, new_row_df], ignore_index=True)
    else:
        new_row_df = new_row_df.reindex(columns=fixed_cols, fill_value=np.nan)
        out = new_row_df

    out.to_csv(csv_path, index=False)

    print(f"Appended run to {csv_path}. Schema fixed to: {len(fixed_cols)} columns.")
     
     
def plot_confidence_error_bars(methods, y_true, n_bins=10):

    y_arr = y_true.to_numpy() if hasattr(y_true, "to_numpy") else np.asarray(y_true)

    # palette + hatch (your spec)
    method_color = {
        "DLA":   "#1f77b4",  # blue
        "IABMA": "#00441b",  # deep green (hatched)
        "MoE":   "#ff7f0e",  # orange
        "SMC":   "#9467bd",  # purple
        "BHS":   "#d62728",  # red
    }
    preferred = ["MoE", "DLA", "SMC", "BHS", "IABMA"]
    meth_order = [m for m in preferred if m in methods] or list(methods.keys())

    # confidence bins in [0, 0.5]
    bins = np.linspace(0, 0.5, n_bins + 1)

    # keep only bins with lower edge >= 0.25
    keep_mask = bins[:-1] >= 0.25
    bins_lo = bins[:-1][keep_mask]
    bins_hi = bins[1:][keep_mask]
    n_keep = len(bins_lo)
    x = np.arange(n_keep)

    # compute per-method per-bin error (mean) and SD + counts
    bin_mean_map = {}
    bin_sd_map = {}
    bin_cnt_map = {}

    for name in meth_order:
        P = np.asarray(methods[name])
        p1 = P[:, 1]
        conf = np.abs(p1 - 0.5)
        err = ((p1 >= 0.5).astype(int) != y_arr).astype(float)

        means_all, sds_all, cnts_all = [], [], []
        for lo, hi in zip(bins[:-1], bins[1:]):
            m = (conf >= lo) & (conf < hi)
            c = int(m.sum())
            cnts_all.append(c)
            if c > 0:
                e = err[m]
                means_all.append(float(e.mean()))
                # SD of Bernoulli errors in the bin
                sds_all.append(float(e.std(ddof=0)))
            else:
                means_all.append(np.nan)
                sds_all.append(np.nan)

        means_all = np.array(means_all, dtype=float)
        sds_all   = np.array(sds_all,   dtype=float)
        cnts_all  = np.array(cnts_all,  dtype=int)

        bin_mean_map[name] = means_all[keep_mask]
        bin_sd_map[name]   = sds_all[keep_mask]
        bin_cnt_map[name]  = cnts_all[keep_mask]

    # plotting
    k = len(meth_order)
    width = 0.85 / max(1, k)  # total cluster width ~0.85
    fig, ax = plt.subplots(figsize=(10, 5.2))

    error_kw = dict(ecolor="black", capsize=3, elinewidth=0.8)

    for j, name in enumerate(meth_order):
        means = bin_mean_map[name]    # may contain np.nan
        sds   = bin_sd_map[name]
        cnts  = bin_cnt_map[name]
        pos   = x - 0.5*(k-1)*width + j*width

        # draw bars, skipping empty bins (nan)
        for xi, mu, sd, c in zip(pos, means, sds, cnts):
            if not np.isfinite(mu):
                continue
            # Only show error bar if SD is finite and there are at least 2 samples
            yerr = (sd / np.sqrt(c)) if (np.isfinite(sd) and c >= 2) else None

            ax.bar(
                xi, mu, width=width,
                color=method_color.get(name, "#7f7f7f"),
                edgecolor="white", linewidth=0.5,
                hatch=("//" if name == "IABMA" else None),
                label=name if j == 0 and xi == pos[0] else None,
                alpha=1.0 if name == "IABMA" else 0.75,
                yerr=yerr, error_kw=error_kw
            )

    ax.yaxis.set_major_locator(MultipleLocator(0.1))
    ax.yaxis.set_minor_locator(MultipleLocator(0.05))
    ax.grid(True, axis="y", which="major", alpha=0.3)
    ax.grid(True, axis="y", which="minor", alpha=0.3)

    ax.set_xticks(x)
    ax.set_xticklabels([f"{lo:.2f}-{hi:.2f}" for lo, hi in zip(bins_lo, bins_hi)], rotation=0)
    ax.set_xlabel("Confidence |p - 0.5| (bins ≥ 0.25)")
    ax.set_ylabel("Average error rate in bin")
    ax.set_ylim(0, 0.8)
    ax.grid(True, axis="y", alpha=0.25)

    handles, labels = [], []
    for name in meth_order:
        h = ax.bar(0, 0, color=method_color.get(name, "#7f7f7f"),
                   edgecolor="white", linewidth=0.5,
                   hatch=("//" if name == "IABMA" else None))
        handles.append(h)
        labels.append(name)

    fig.legend(handles, labels,
               loc="lower center", ncol=min(5, k), frameon=True, bbox_to_anchor=(0.5, -0.02))
    plt.tight_layout(rect=(0, 0.06, 1, 1))  # leave bottom space for legend
    return fig
    plt.show()


def compute_ece(y_true, p_hat, n_bins=10, eps=1e-12):
    y_true = np.asarray(y_true).ravel().astype(int)
    p_hat = np.asarray(p_hat).ravel()
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    inds = np.digitize(p_hat, bins) - 1  
    ece = 0.0
    for b in range(n_bins):
        mask = inds == b
        if not np.any(mask):
            continue
        conf = p_hat[mask].mean()
        acc = y_true[mask].mean()
        w = mask.mean()
        ece += w * abs(acc - conf)
    return float(ece)
    
    
def downsample_majority(x, y, ratio=1.0):
      idx_pos = y[y == 1].index
      idx_neg = y[y == 0].index

      n_neg = int(len(idx_pos) * ratio)
      neg_keep = np.random.choice(idx_neg, size=n_neg, replace=False)

      idx_keep = np.concatenate([idx_pos, neg_keep])
      return x.loc[idx_keep], y.loc[idx_keep]

  
  
def zscore_numeric_df(X_train: pd.DataFrame, X_test: pd.DataFrame):

    Xtr = X_train.copy()
    Xte = X_test.copy()

    num_cols = Xtr.select_dtypes(include=["number", "bool"]).columns
    if len(num_cols) == 0:
        return Xtr, Xte  

    scaler = StandardScaler()
    Xtr[num_cols] = scaler.fit_transform(Xtr[num_cols])
    Xte[num_cols] = scaler.transform(Xte[num_cols])
    return Xtr, Xte