import numpy as np
import pandas as pd

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OrdinalEncoder, LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, accuracy_score, average_precision_score
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC

import copy
import argparse
import os
import sys

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split

import xgboost as xgb

parser=argparse.ArgumentParser()
parser.add_argument("--eps", help="Contamination ratio",
                    type=float, default=0.0)
parser.add_argument("--run", help="Run (for averaging)",
                    type=int, default=0)
parser.add_argument("--comment", help="Comment",
                    type=str, default="")
parser.add_argument("--annp", help="ANN percentage",
                    type=float, default=100.0)
parser.add_argument("--force", help="Force overwriting old run",
                    type=bool, default=False)
parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
filename = f"adultsfixeddefaultv3_conflearn_partial_{size_percent}_eps_{eps}_annp_{annp}_run_{run}{args.comment}.csv"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)

def load_adult(size_percent: float = 100.0):
    df = fetch_openml("adult", version=2, as_frame=True).frame
    df = df.replace("?", np.nan)
    df = df.dropna()

    X = df.drop("class", axis=1)
    y = df["class"]
    if size_percent < 100.0:
        X, _, y, _ = train_test_split(
            X, y, train_size=size_percent / 100.0, stratify=y
        )

    num_cols = X.select_dtypes(include=["int64", "float64"]).columns
    cat_cols = X.select_dtypes(include=["category", "object"]).columns

    num_imputer = SimpleImputer(strategy="mean")
    X[num_cols] = num_imputer.fit_transform(X[num_cols])
    X[num_cols] = StandardScaler().fit_transform(X[num_cols])

    cat_imputer = SimpleImputer(strategy="most_frequent")
    X[cat_cols] = cat_imputer.fit_transform(X[cat_cols])
    X[cat_cols] = OrdinalEncoder().fit_transform(X[cat_cols])

    y = LabelEncoder().fit_transform(y)

    return train_test_split(X.values, y, test_size=0.2, stratify=y)


def add_label_noise(y, noise_rate=0.3):
    rng = np.random.RandomState(None)
    y_noisy = y.copy()
    n = len(y)
    n_noisy = int(noise_rate * n)
    idx = rng.choice(n, n_noisy, replace=False)
    labels = np.unique(y)
    for i in idx:
        choices = labels[labels != y_noisy[i]]
        y_noisy[i] = rng.choice(choices)
    return y_noisy, idx


def _cv_probabilities(X, s, n_splits=3):

    ghat = np.zeros(len(s), dtype=float)
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True)

    for tr, va in cv.split(X, s):
        clf = LogisticRegression(
            solver="lbfgs", max_iter=1000, n_jobs=None
        )
        clf.fit(X[tr], s[tr])
        ghat[va] = clf.predict_proba(X[va])[:, 1]

    return ghat


def _confident_thresholds(g, s):

    s = np.asarray(s)
    g = np.asarray(g, dtype=float)
    LB = g[s == 1].mean() if (s == 1).any() else 0.5
    UB = g[s == 0].mean() if (s == 0).any() else 0.5
    return LB, UB


def _noise_scores(g, s, LB, UB):
    s = np.asarray(s)
    g = np.asarray(g, dtype=float)
    scores = np.where(s == 1, LB - g, g - UB)
    return np.maximum(scores, 0.0)


def prefilter_confident_learning(X_train, y_train, prune_fraction=0.20, cv=3):

    # probabilities on noisy labels s=y_train
    g = _cv_probabilities(X_train, y_train, n_splits=cv)

    # thresholds
    LB, UB = _confident_thresholds(g, y_train)

    # per-example noise scores
    scores = _noise_scores(g, y_train, LB, UB)

    # prune top-k by score
    n = len(y_train)
    k_drop = int(np.floor(prune_fraction * n))
    order_desc = np.argsort(-scores)  # descending by score
    drop_idx = order_desc[:k_drop]
    keep_mask = np.ones(n, dtype=bool)
    keep_mask[drop_idx] = False

    X_pref = X_train[keep_mask]
    y_pref = y_train[keep_mask]

    info = {
        "LB_y=1": float(LB),
        "UB_y=0": float(UB),
        "avg_prob_s1": float(g[y_train == 1].mean()) if (y_train == 1).any() else None,
        "avg_prob_s0": float(g[y_train == 0].mean()) if (y_train == 0).any() else None,
        "pruned": int(k_drop),
        "kept": int(keep_mask.sum()),
    }
    return X_pref, y_pref, info

def train_evaluate(config):
    # instantiate models
    m = config["model_type"]
    if m == "xgboost":
        params = {
            "use_label_encoder": False,
            "eval_metric": "logloss",
        }
        # model_clean = xgb.XGBClassifier(**params)
        model_noisy = xgb.XGBClassifier(**params)

    elif m == "random_forest":
        params = {}
        model_noisy = RandomForestClassifier(**params)

    elif m == "svm":
        params = {}
        model_noisy = SVC(**params)

    elif m == "adaboost":
        params = {}
        model_noisy = AdaBoostClassifier(**params)

    elif m == "logitboost":
        # use GradientBoostingClassifier(loss='deviance') as LogitBoost surrogate
        params = {
            "loss": "log_loss",
        }
        model_noisy = GradientBoostingClassifier(**params)

    elif m == "bagging":
        params = {}
        model_noisy = BaggingClassifier(**params)

    elif m == "ffn":
        # small 2‑layer feedforward net
        input_dim = X_train.shape[1]
        num_classes = len(np.unique(y_train_noisy))
        hidden = int(config["hidden_size"])
        model_clean = nn.Sequential(
            nn.Linear(input_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
        model_noisy = nn.Sequential(
            nn.Linear(input_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
                
        def fit_torch(model, X, y, lr, epochs=15, val_frac=0.2, batch_size=64):
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model = model.to(device)
        
            # Tensorify all data
            X_t = torch.from_numpy(X).float()
            y_t = torch.from_numpy(y).long()
            print(X_t.shape)
            print(y_t.shape)
            dataset = TensorDataset(X_t, y_t)

            train_ds = dataset
        
            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            loss_fn   = nn.CrossEntropyLoss()
        
            best_val_loss = float("inf")
            best_val_acc  = 0.0
            best_state   = copy.deepcopy(model.state_dict())
        
            for epoch in range(1, epochs + 1):
                # --- Training pass ---
                model.train()
                for xb, yb in train_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    optimizer.zero_grad()
                    logits = model(xb)
                    loss_fn(logits, yb).backward()
                    optimizer.step()

            return model, best_val_acc

        model_noisy, _ = fit_torch(
            model_noisy, X_train, y_train_noisy,
            lr=config["lr"],
            epochs=15,
            val_frac=0.2
        )
        
    else:
        raise ValueError(f"Unknown model_type {m}")

    # train classical models
    if m != "ffn":
        model_noisy.fit(X_train, y_train_noisy)

    # evaluate
    if m == "ffn":
        # torch eval
        def eval_torch(model, X):
            model = model.eval()
            device = next(model.parameters()).device
            with torch.no_grad():
                logits = model(torch.from_numpy(X).float().to(device))
                preds = logits.argmax(dim=1).cpu().numpy()
                probs = torch.sigmoid(logits).cpu().numpy()
                print(probs.shape)
            return preds, probs[:,1]
        
        preds_noisy, probs_noisy = eval_torch(model_noisy, X_test)

    elif m == "svm":

        preds_noisy = model_noisy.predict(X_test)

    else:
        preds_noisy = model_noisy.predict(X_test)


    acc_noisy = accuracy_score(y_test, preds_noisy)
    f1_noisy = f1_score(y_test, preds_noisy)

    return {
             "noisy_accuracy": acc_noisy,
             "noisy_f1": f1_noisy,
             }

if __name__ == "__main__":
    search_spaces = {
        "ffn": {'model_type': 'ffn', 'hidden_size': 64, 'lr': 1e-3},

        "xgboost": {'model_type': 'xgboost'},

        "random_forest": {'model_type': 'random_forest'},

        "svm": {'model_type': 'svm'},

        "adaboost": {'model_type': 'adaboost'},

        "logitboost": {'model_type': 'logitboost'},

        "bagging": {'model_type': 'bagging'},

    }
    # load data 
    X_train, X_test, y_train_clean, y_test = load_adult(size_percent=size_percent)

    # add label noise to train only
    y_train_noisy, noisy_indices = add_label_noise(y_train_clean, noise_rate=eps)
    
    # prefilter with Confident Learning approach
    if args.annp != 100.0:
        X_train, y_train_noisy, info = prefilter_confident_learning(
            X_train, y_train_noisy, prune_fraction=(1.0 - annp / 100.0), cv=3
        )


    summary = []
    for name, space in search_spaces.items():
        res = train_evaluate(space)
        best_config = space
        print(res)
        summary.append({
            "model":        name,
            "noisy_acc":    res["noisy_accuracy"],
            "noisy_f1": res["noisy_f1"],
        })

    # build a DataFrame
    df = pd.DataFrame(summary).set_index("model")

    print(df[["noisy_acc", "noisy_f1"]].to_markdown())

    df.to_csv(filename, index=False)
    print("Results saved to ", filename)
    print("=" * 50)
