from __future__ import annotations

import argparse
import math
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, matthews_corrcoef
from sklearn.model_selection import train_test_split

import os
import sys

@dataclass
class ResampleConfig:
    label_col: str
    sensitive_col: str
    privileged_def: str = "age_threshold"
    privileged_bool_col: Optional[str] = None
    age_threshold: Optional[float] = 50.0
    unfavourable_value: int | str | bool = 1
    alpha: float = 0.0
    beta: float = 0.0
    gamma: float = 0.0
    random_state: Optional[int] = 42


def _check_abg(a: float, b: float, g: float):
    for name, v in [("alpha", a), ("beta", b), ("gamma", g)]:
        if not (0.0 <= float(v) <= 1.0):
            raise ValueError(f"{name} must be in [0,1], got {v}")


def _make_privileged_series(df: pd.DataFrame, cfg: ResampleConfig):
    if cfg.privileged_def == "bool_col":
        if not cfg.privileged_bool_col:
            raise ValueError("Set privileged_bool_col when privileged_def='bool_col'.")
        return df[cfg.privileged_bool_col].astype(bool)
    elif cfg.privileged_def == "age_threshold":
        if cfg.age_threshold is None:
            raise ValueError("age_threshold must be set for 'age_threshold' mode.")
        return df[cfg.sensitive_col] < cfg.age_threshold
    else:
        raise ValueError("privileged_def must be 'age_threshold' or 'bool_col'.")


def _partition_masks(df: pd.DataFrame, cfg: ResampleConfig, priv: pd.Series):
    unfav = (df[cfg.label_col] == cfg.unfavourable_value)
    fav = ~unfav
    up = ~priv
    return {
        "p_fav": priv & fav,
        "p_unfav": priv & unfav,
        "up_fav": up & fav,
        "up_unfav": up & unfav,
    }


def _counts(masks: Dict[str, pd.Series]):
    return {k: int(v.sum()) for k, v in masks.items()}


def _safe_div(a: int | float, b: int | float):
    return float(a) / float(b) if b else 0.0


@dataclass
class TargetRatios:
    Pp: float
    Fp: float
    Ap: float
    r_p_fav: float
    r_p_unfav: float
    r_up_fav: float
    r_up_unfav: float


def _compute_target_ratios(df: pd.DataFrame, cfg: ResampleConfig, masks: Dict[str, pd.Series]):
    n = len(df)
    c = _counts(masks)
    p = c["p_fav"] + c["p_unfav"]
    up = c["up_fav"] + c["up_unfav"]
    f = c["p_fav"] + c["up_fav"]

    P = _safe_div(p, n)
    F = _safe_div(f, n)
    rate_p = _safe_div(c["p_fav"], max(p, 1))
    rate_up = _safe_div(c["up_fav"], max(up, 1))
    A = rate_p / (rate_up if rate_up > 0 else 1.0)

    a, b, g = cfg.alpha, cfg.beta, cfg.gamma
    _check_abg(a, b, g)

    P_prime = P * (1 - a) + 0.5 * a
    F_prime = F * (1 - b) + 0.5 * b
    A_prime = A * (1 - g) + 1.0 * g

    denom = 1.0 + A_prime * P_prime - P_prime
    denom = max(denom, 1e-12)
    F_up_prime = float(np.clip(F_prime / denom, 0.0, 1.0))
    F_p_prime = float(np.clip(A_prime * F_up_prime, 0.0, 1.0))

    r_p_fav = P_prime * F_p_prime
    r_p_unfav = P_prime * (1.0 - F_p_prime)
    r_up_fav = (1.0 - P_prime) * F_up_prime
    r_up_unfav = (1.0 - P_prime) * (1.0 - F_up_prime)

    s = r_p_fav + r_p_unfav + r_up_fav + r_up_unfav
    if not np.isclose(s, 1.0):
        r_p_fav /= s; r_p_unfav /= s; r_up_fav /= s; r_up_unfav /= s

    return TargetRatios(P_prime, F_prime, A_prime, r_p_fav, r_p_unfav, r_up_fav, r_up_unfav)


def _max_feasible_total(c_avail: Dict[str, int], ratios: TargetRatios):
    bounds = []
    for key, ratio in [
        ("p_fav", ratios.r_p_fav),
        ("p_unfav", ratios.r_p_unfav),
        ("up_fav", ratios.r_up_fav),
        ("up_unfav", ratios.r_up_unfav),
    ]:
        if ratio > 0:
            bounds.append(int(math.floor(c_avail[key] / ratio)))
    return max(0, min(bounds)) if bounds else 0


def _rounded_counts(N: int, ratios: TargetRatios):
    raw = {
        "p_fav": ratios.r_p_fav * N,
        "p_unfav": ratios.r_p_unfav * N,
        "up_fav": ratios.r_up_fav * N,
        "up_unfav": ratios.r_up_unfav * N,
    }
    floored = {k: int(math.floor(v)) for k, v in raw.items()}
    rem = N - sum(floored.values())
    if rem > 0:
        fracs = sorted(((k, raw[k] - floored[k]) for k in raw), key=lambda x: x[1], reverse=True)
        for i in range(rem):
            floored[fracs[i % 4][0]] += 1
    return floored


def resample_dataset(df: pd.DataFrame, cfg: ResampleConfig):
    priv = _make_privileged_series(df, cfg)
    masks = _partition_masks(df, cfg, priv)
    c = _counts(masks)
    ratios = _compute_target_ratios(df, cfg, masks)

    N_prime = _max_feasible_total(c, ratios)
    if N_prime == 0:
        raise RuntimeError("Infeasible sampling (N'=0). Try smaller alpha/beta or adjust gamma.")
    target_counts = _rounded_counts(N_prime, ratios)

    rng = np.random.RandomState(cfg.random_state)
    idx_blocks = []
    for k, m in masks.items():
        n_take = target_counts[k]
        if n_take <= 0:
            continue
        pool = df.index[m]
        n_take = min(n_take, len(pool))
        idx_blocks.append(rng.choice(pool, size=n_take, replace=False))
    idx = np.concatenate(idx_blocks)
    Dp = df.loc[idx].sample(frac=1.0, random_state=cfg.random_state).reset_index(drop=True)

    # Achieved summary
    priv_p = _make_privileged_series(Dp, cfg)
    m_p = _partition_masks(Dp, cfg, priv_p)
    c_p = _counts(m_p)
    n = len(Dp)
    p_total = c_p["p_fav"] + c_p["p_unfav"]
    f_total = c_p["p_fav"] + c_p["up_fav"]
    Pp = _safe_div(p_total, n)
    Fp = _safe_div(f_total, n)
    rate_p = _safe_div(c_p["p_fav"], max(p_total, 1))
    rate_up = _safe_div(c_p["up_fav"], max(c_p["up_fav"] + c_p["up_unfav"], 1))
    Ap = rate_p / (rate_up if rate_up > 0 else 1.0)

    return Dp, {"N_prime": float(n), "P_prime": Pp, "F_prime": Fp, "A_prime": Ap}


# Metrics

def disparate_impact_on_predictions(y_pred: np.ndarray, privileged_mask: np.ndarray, unfavourable_value):
    pred_unfav = (y_pred == unfavourable_value)
    unpriv_mask = ~privileged_mask

    num_up = pred_unfav[unpriv_mask].sum()
    den_up = unpriv_mask.sum()
    num_p = pred_unfav[privileged_mask].sum()
    den_p = privileged_mask.sum()

    p_up = (num_up / den_up) if den_up else 0.0
    p_p = (num_p / den_p) if den_p else 0.0
    return float("nan") if p_p == 0 else (p_up / p_p)


def build_features(df: pd.DataFrame, label_col: str):
    X = df.drop(columns=[label_col])
    y = df[label_col]
    X = pd.get_dummies(X, drop_first=False)
    return X, y


# Grid search

def grid_search_abg_with_rf(
    df: pd.DataFrame,
    label_col: str,
    sensitive_col: str,
    unfavourable_value=1,
    privileged_def: str = "age_threshold",
    age_threshold: float = 50.0,
    privileged_bool_col: Optional[str] = None,
    alphas: Iterable[float] = np.linspace(0.0,1.0,11),
    betas: Iterable[float] = np.linspace(0.0,1.0,11),
    gammas: Iterable[float] = np.linspace(0.0,1.0,11),
    test_size: float = 0.2,
    random_state: int = 42,
    rf_kwargs: Optional[dict] = None,):

    if rf_kwargs is None:
        rf_kwargs = dict(
            n_estimators=200,
            max_depth=None,
            n_jobs=-1,
            random_state=random_state,
        )

    df_train, df_test = train_test_split(df, test_size=test_size, random_state=random_state, stratify=df[label_col])

    if privileged_def == "age_threshold":
        priv_test = (df_test[sensitive_col] < age_threshold).values
    else:
        priv_test = df_test[privileged_bool_col].astype(bool).values

    X_test_raw, y_test = build_features(df_test, label_col)
    results: List[dict] = []

    alphas = list(alphas); betas = list(betas); gammas = list(gammas)
    total = len(alphas) * len(betas) * len(gammas)

    with tqdm(total=total, desc="Grid search (alpha,beta,gamma)") as pbar:
        for a in alphas:
            for b in betas:
                for g in gammas:
                    cfg = ResampleConfig(
                        label_col=label_col,
                        sensitive_col=sensitive_col,
                        privileged_def=privileged_def,
                        privileged_bool_col=privileged_bool_col,
                        age_threshold=age_threshold,
                        unfavourable_value=unfavourable_value,
                        alpha=a, beta=b, gamma=g,
                        random_state=random_state,
                    )
                    try:
                        train_resampled, achieved = resample_dataset(df_train, cfg)
                    except Exception as e:
                        results.append(dict(alpha=a, beta=b, gamma=g, status=f"error: {e}"))
                        pbar.update(1)
                        continue

                    X_train, y_train = build_features(train_resampled, label_col)

                    X_test = X_test_raw.copy()
                    for c in [c for c in X_train.columns if c not in X_test]:
                        X_test[c] = 0
                    X_test = X_test[X_train.columns]

                    rf = RandomForestClassifier(**rf_kwargs)
                    rf.fit(X_train, y_train)
                    y_pred = rf.predict(X_test)

                    acc = accuracy_score(y_test, y_pred)
                    labels_sorted = [unfavourable_value, *[v for v in pd.unique(y_test) if v != unfavourable_value]]
                    prfs = precision_recall_fscore_support(y_test, y_pred, labels=labels_sorted, zero_division=0)
                    prec_unfav, rec_unfav, f1_unfav = prfs[0][0], prfs[1][0], prfs[2][0]
                    di_pred = disparate_impact_on_predictions(y_pred, priv_test, unfavourable_value)
                    y_true_bin = (y_test.values == unfavourable_value).astype(int)
                    y_pred_bin = (np.array(y_pred) == unfavourable_value).astype(int)
                    mcc_unfav = matthews_corrcoef(y_true_bin, y_pred_bin)

                    results.append({
                        "alpha": a, "beta": b, "gamma": g,
                        "status": "ok",
                        "train_size": int(achieved["N_prime"]),
                        "acc": acc,
                        "prec_unfav": prec_unfav,
                        "rec_unfav": rec_unfav,
                        "f1_unfav": f1_unfav,
                        "di_pred": di_pred,
                        "mcc": mcc_unfav
                    })

                    pbar.update(1)

    return pd.DataFrame(results)


def main():
    parser = argparse.ArgumentParser(description="Resampling (alpha,beta,gamma) + RF grid search on CSV dataset.")
    parser.add_argument("--csv", type=str, required=True, help="Path to dataset CSV.")
    parser.add_argument("--label-col", type=str, required=True, help="Name of label column.")
    parser.add_argument("--age-col", type=str, required=True, help="Name of sensitive age column.")
    parser.add_argument("--unfavourable-value", type=str, default="1", help="Unfavourable label value (default '1').")
    parser.add_argument("--age-threshold", type=float, default=50.0, help="Age >= threshold = unprivileged.")
    parser.add_argument("--random-state", type=int, default=42, help="Random State.")
    parser.add_argument("--out", type=str, default="grid_results.csv")
    parser.add_argument("--size_percent", type=float, default=100.0, help="Percentage of rows to keep from the original dataset.")
    parser.add_argument("--force", type=bool, default=False, help="Should results be overwritten.")
    args = parser.parse_args()

    if os.path.isfile(args.out) and (not args.force):
        # File exists
        print(f"File {args.out} exists")
        print("=" * 50)
        sys.exit(0)

    df = pd.read_csv(args.csv)
    if args.size_percent < 100.0:
       frac = float(args.size_percent) / 100.0
       # Reproducible random downsampling of the full dataset
       df = df.sample(frac=frac, random_state=args.random_state).reset_index(drop=True)
    label_col, age_col = args.label_col, args.age_col
    try:
        unfav_val = int(args.unfavourable_value)
    except ValueError:
        unfav_val = args.unfavourable_value

    results = grid_search_abg_with_rf(
        df=df,
        label_col=label_col,
        sensitive_col=age_col,
        unfavourable_value=unfav_val,
        age_threshold=args.age_threshold,
        random_state=args.random_state,
    )
    results.to_csv(args.out, index=False)
    print(f"Saved results to {args.out}")


if __name__ == "__main__":
    main()

