import time
import numpy as np
import scipy
from scipy.optimize import fsolve


def evaluate_comprehensive(z, oracle):
    """
    Evaluate metrics to match the paper and debias/large-scale:

    - acc, bal_acc: accuracy and balanced accuracy (0.5*(TPR+TNR)).
    - DPD (Demographic Parity Difference, Feldman et al. 2015):
      DPD = |P(Ŷ=1|Z=0) - P(Ŷ=1|Z=1)|
    - EOD (Equalized Odds Difference, Hardt et al. 2016):
      EOD = |TPR0 - TPR1| + |FPR0 - FPR1|,  TPRz=P(Ŷ=1|Y=1,Z=z), FPRz=P(Ŷ=1|Y=0,Z=z)

    Labels and protected in {-1,+1} are mapped to {0,1} (Z: unpriv=0, priv=1).
    """
    x_vec = np.asarray(z[: oracle.dim]).reshape(-1)
    scores = (oracle.A @ x_vec).reshape(-1)

    # Convert to {0,1} so " > 0.5 " thresholding matches large-scale metrics code.
    y_pred01 = (scores >= 0.0).astype(np.float64).reshape(-1, 1)

    y_true = np.asarray(oracle.b).reshape(-1)
    y_true01 = (y_true > 0).astype(np.float64).reshape(-1, 1)

    protected = np.asarray(oracle.c).reshape(-1)
    # Map {-1,+1} -> {0,1} if needed; otherwise keep as-is
    if np.any(protected < 0):
        protected01 = (protected > 0).astype(np.float64).reshape(-1, 1)
    else:
        protected01 = protected.astype(np.float64).reshape(-1, 1)

    def _safe_positive_rate(labels01, prot01, group_value):
        mask = prot01.reshape(-1) == group_value
        if mask.sum() == 0:
            return np.nan
        return np.mean(labels01.reshape(-1)[mask] > 0.5)

    def _group_tpr_fpr(y_t01, y_p01, prot01, group_value):
        mask = prot01.reshape(-1) == group_value
        if mask.sum() == 0:
            return np.nan, np.nan
        y_t = y_t01.reshape(-1)[mask]
        y_p = y_p01.reshape(-1)[mask]
        tp = np.mean((y_p > 0.5) & (y_t > 0.5))
        fn = np.mean((y_p <= 0.5) & (y_t > 0.5))
        fp = np.mean((y_p > 0.5) & (y_t <= 0.5))
        tn = np.mean((y_p <= 0.5) & (y_t <= 0.5))
        tpr = tp / (tp + fn + 1e-8)
        fpr = fp / (fp + tn + 1e-8)
        return tpr, fpr

    # accuracy
    acc = float(np.mean((y_pred01 > 0.5) == (y_true01 > 0.5)))

    # balanced accuracy
    y_t = y_true01.reshape(-1)
    y_p = y_pred01.reshape(-1)
    tp = np.sum((y_p > 0.5) & (y_t > 0.5))
    fn = np.sum((y_p <= 0.5) & (y_t > 0.5))
    fp = np.sum((y_p > 0.5) & (y_t <= 0.5))
    tn = np.sum((y_p <= 0.5) & (y_t <= 0.5))
    tpr = tp / (tp + fn + 1e-8)
    tnr = tn / (tn + fp + 1e-8)
    bal_acc = float(0.5 * (tpr + tnr))

    # DPD = |P(Ŷ=1|Z=0) - P(Ŷ=1|Z=1)|
    pr_0 = _safe_positive_rate(y_pred01, protected01, 0.0)
    pr_1 = _safe_positive_rate(y_pred01, protected01, 1.0)
    dpd = float(np.abs(pr_0 - pr_1)) if not (np.isnan(pr_0) or np.isnan(pr_1)) else float("nan")

    # EOD = |TPR0 - TPR1| + |FPR0 - FPR1|
    tpr_0, fpr_0 = _group_tpr_fpr(y_true01, y_pred01, protected01, 0.0)
    tpr_1, fpr_1 = _group_tpr_fpr(y_true01, y_pred01, protected01, 1.0)
    eod = float(np.abs(tpr_0 - tpr_1) + np.abs(fpr_0 - fpr_1))

    pred_loss, adv_loss = oracle.compute_losses(z)

    return {
        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "dpd": dpd,
        "eod": eod,
        "pred_loss": float(pred_loss),
        "adv_loss": float(adv_loss),
    }



def _log_eval(name, rec, cfg):
    step = rec.get("iter", None)
    t = float(rec.get("time", float("nan")))
    acc = float(rec.get("accuracy", float("nan")))
    bal_acc = float(rec.get("balanced_accuracy", float("nan")))
    dpd = float(rec.get("dpd", float("nan")))
    eod = float(rec.get("eod", float("nan")))

    if step is None:
        print(f"{name} Eval: time={t:.2f}s, acc={acc:.3f}, bal_acc={bal_acc:.3f}, DPD={dpd:.3f}, EOD={eod:.3f}")
    else:
        print(f"{name} Eval: step={int(step)}, time={t:.2f}s, acc={acc:.3f}, bal_acc={bal_acc:.3f}, DPD={dpd:.3f}, EOD={eod:.3f}")


def _maybe_eval(oracle, z, ellapse_time, eval_interval, next_eval_time, iter_idx):
    if eval_interval is None or eval_interval <= 0:
        return next_eval_time, None
    if next_eval_time is not None and ellapse_time < next_eval_time:
        return next_eval_time, None
    rec = {"time": ellapse_time, "iter": iter_idx, **evaluate_comprehensive(z, oracle)}
    return next_eval_time + eval_interval, rec


def _fmt_eval(rec, digits=4):
    def _fmt_val(v):
        if isinstance(v, float):
            return float(f"{v:.{digits}g}")
        if isinstance(v, np.ndarray):
            return np.array2string(v, precision=digits, suppress_small=True)
        return v

    return {k: _fmt_val(v) for k, v in rec.items()}


def EG(oracle, z0, cfg):
    print("Running EG...")
    time_lst_EG = []
    gnorm_lst_EG = []
    eval_records = []
    ellapse_time = 0.0
    z = z0.copy()
    i = 0

    eta = cfg.eta
    max_time = cfg.training_time
    eval_interval = getattr(cfg, "eval_interval", None)
    next_eval_time = eval_interval if eval_interval and eval_interval > 0 else None

    # Initial evaluation (always do once)
    init_eval = {"time": ellapse_time, "iter": 0, **evaluate_comprehensive(z, oracle)}
    eval_records.append(init_eval)
    _log_eval("EG", init_eval, cfg)
    # print(f"Initial Evaluation at time {ellapse_time:.2f}s:")
    # print(_fmt_eval(init_eval))
    if next_eval_time is None and eval_interval and eval_interval > 0:
        next_eval_time = eval_interval

    while True:
        start = time.time()
        gz = oracle.GDA_field(z)
        z_half = z - eta * gz
        gz_half = oracle.GDA_field(z_half)
        z = z - eta * gz_half
        end = time.time()
        ellapse_time += end - start

        if ellapse_time > max_time:
            break
        time_lst_EG.append(ellapse_time)

        gnorm = np.linalg.norm(gz_half).item()
        gnorm_lst_EG.append(gnorm)

        next_eval_time, eval_result = _maybe_eval(oracle, z, ellapse_time, eval_interval, next_eval_time, i)
        if eval_result is not None:
            eval_records.append(eval_result)
            _log_eval("EG", eval_result, cfg)

        i = i + 1
    
    # Final evaluation
    next_eval_time, eval_result = _maybe_eval(oracle, z, ellapse_time, eval_interval, next_eval_time, i)
    if eval_result is not None:
        eval_records.append(eval_result)
        _log_eval("EG", eval_result, cfg)

    return time_lst_EG, gnorm_lst_EG, eval_records


def run_LEN(oracle, z0, cfg, m):
    print(f"Running LEN (m={m})...")
    time_lst_LEN = []
    gnorm_lst_LEN = []
    eval_records = []
    ellapse_time = 0.0

    z = z0.copy()
    i = 0
    rho = cfg.rho
    max_time = cfg.training_time
    eval_interval = getattr(cfg, "eval_interval", None)
    next_eval_time = eval_interval if eval_interval and eval_interval > 0 else None

    # Initial evaluation (always do once)
    init_eval = {"time": ellapse_time, "iter": 0, **evaluate_comprehensive(z, oracle)}
    eval_records.append(init_eval)
    _log_eval("LEN" if m == cfg.m else "NPE", init_eval, cfg)
    # print(f"Initial Evaluation at time {ellapse_time:.2f}s:")
    # print(_fmt_eval(init_eval))
    if next_eval_time is None and eval_interval and eval_interval > 0:
        next_eval_time = eval_interval

    gamma = 1 / (m * rho)

    while True:
        start = time.time()
        gz = oracle.GDA_field(z)

        if i % m == 0:
            Hz = oracle.Jacobian_GDA_field(z)
            U, Q = scipy.linalg.schur(Hz, output="complex")

        def func(g_val):
            term = Q @ scipy.linalg.solve_triangular(1 / g_val * np.eye(oracle.d) + U, Q.conj().T @ gz)
            return g_val - 3 / (16 * m * rho * np.linalg.norm(term))

        gamma_sol = fsolve(func, gamma)
        gamma = gamma_sol[0] if isinstance(gamma_sol, (np.ndarray, list)) else gamma_sol

        term = Q @ scipy.linalg.solve_triangular(1 / gamma * np.eye(oracle.d) + U, Q.conj().T @ gz)
        z_half = z - term.real

        gz_half = oracle.GDA_field(z_half)
        z = z - gamma * gz_half

        end = time.time()
        ellapse_time += end - start
        if ellapse_time > max_time:
            break
        time_lst_LEN.append(ellapse_time)

        gnorm = np.linalg.norm(gz_half)
        gnorm_lst_LEN.append(gnorm)

        next_eval_time, eval_result = _maybe_eval(oracle, z, ellapse_time, eval_interval, next_eval_time, i)
        if eval_result is not None:
            eval_records.append(eval_result)
            _log_eval("LEN" if m == cfg.m else "NPE", eval_result, cfg)

        i = i + 1

    # Final evaluation
    next_eval_time, eval_result = _maybe_eval(oracle, z, ellapse_time, eval_interval, next_eval_time, i)
    if eval_result is not None:
        eval_records.append(eval_result)
        _log_eval("LEN" if m == cfg.m else "NPE", eval_result, cfg)

    return time_lst_LEN, gnorm_lst_LEN, eval_records


def spaco(oracle, z0, cfg):
    print("Running SPACO (alternating stochastic updates, penalty)...")
    time_lst = []
    gnorm_lst = []
    eval_records = []
    ellapse_time = 0.0

    z = z0.copy()
    i = 0
    eta = cfg.eta
    max_time = cfg.training_time
    eval_interval = getattr(cfg, "eval_interval", None)
    next_eval_time = eval_interval if eval_interval and eval_interval > 0 else None
    batch_size = getattr(cfg, "batch_size", None) or min(oracle.m, 256)
    penalty_eps = getattr(cfg, "penalty_eps", 0.0)

    rho0 = cfg.penalty

    use_storm = getattr(cfg, "use_storm", True)
    storm_eta0 = getattr(cfg, "storm_eta0", 1.0)
    storm_eta_min = getattr(cfg, "storm_eta_min", 0.1)
    pilot_t = getattr(cfg, "pilot_t", 0.05)
    pilot_s = getattr(cfg, "pilot_s", 0.20)

    storm_buffer_x = None
    prev_z_x = None

    # Initial evaluation
    init_eval = {"time": ellapse_time, "iter": 0, **evaluate_comprehensive(z, oracle)}
    eval_records.append(init_eval)
    _log_eval("SPACO", init_eval, cfg)
    # print(f"Initial Evaluation at time {ellapse_time:.2f}s:")
    # print(_fmt_eval(init_eval))
    if next_eval_time is None and eval_interval and eval_interval > 0:
        next_eval_time = eval_interval

    while True:
        start = time.time()

        idx_batch = oracle._sample_indices(batch_size)

        # Compute params
        storm_eta = max(storm_eta0 * ((i + 1) ** (- pilot_s)), storm_eta_min)
        rho = rho0 * ((i + 1) ** pilot_t)
        eta_x = eta * (i + 1) ** (- 6 * pilot_t - pilot_s)
        eta_y = eta * (i + 1) ** (- pilot_t - pilot_s)

        # Update y (ascent via negative grad sign already embedded) with penalty weight
        gy = oracle.grad_y_stochastic_penalty(z, rho=rho, eps=penalty_eps, batch_size=batch_size, idx=idx_batch)
        z[oracle.dim:] = z[oracle.dim:] - eta_y * gy

        # Update x (descent) with penalty weight + STORM momentum
        gx_curr = oracle.grad_x_stochastic_penalty(z, rho=rho, eps=penalty_eps, batch_size=batch_size, idx=idx_batch)

        if not use_storm or storm_buffer_x is None or prev_z_x is None:
            d_new = gx_curr
        else:
            gx_prev = oracle.grad_x_stochastic_penalty(prev_z_x, rho=rho, eps=penalty_eps, batch_size=batch_size, idx=idx_batch)
            d_new = gx_curr + (1.0 - storm_eta) * (storm_buffer_x - gx_prev)
            
            storm_buffer_x = d_new
            prev_z_x = z.copy()

        z[: oracle.dim] = z[: oracle.dim] - eta_x * d_new

        end = time.time()
        ellapse_time += end - start
        if ellapse_time > max_time:
            break

        time_lst.append(ellapse_time)
        gnorm = np.sqrt(np.linalg.norm(d_new) ** 2 + np.linalg.norm(gy) ** 2).item()
        gnorm_lst.append(gnorm)

        next_eval_time, eval_result = _maybe_eval(oracle, z, ellapse_time, eval_interval, next_eval_time, i)
        if eval_result is not None:
            eval_records.append(eval_result)
            _log_eval("SPACO", eval_result, cfg)

        i += 1

    # Final evaluation
    next_eval_time, eval_result = _maybe_eval(oracle, z, ellapse_time, eval_interval, next_eval_time, i)
    if eval_result is not None:
        eval_records.append(eval_result)
        _log_eval("SPACO", eval_result, cfg)

    return time_lst, gnorm_lst, eval_records
