import numpy as np
import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, average_precision_score
from sklearn.impute import SimpleImputer

import xgboost as xgb
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import copy
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
from tqdm import tqdm as tqdm
import argparse
import os
import sys

parser=argparse.ArgumentParser()
parser.add_argument("--eps", help="Contamination ratio",
                    type=float, default=0.0)
parser.add_argument("--run", help="Run (for averaging)",
                    type=int, default=0)
parser.add_argument("--comment", help="Comment",
                    type=str, default="")
parser.add_argument("--annp", help="ANN percentage",
                    type=float, default=100.0)
parser.add_argument("--force", help="Force overwriting old run",
                    type=bool, default=False)
parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
filename = f"adultsfixeddefaultv3_partial_{size_percent}_eps_{eps}_annp_{annp}_run_{run}{args.comment}.csv"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)

def add_label_noise(y, noise_rate=0.3):
    rng = np.random.RandomState(None)
    y_noisy = y.copy()
    n = len(y)
    n_noisy = int(noise_rate * n)
    idx = rng.choice(n, n_noisy, replace=False)
    labels = np.unique(y)
    for i in idx:
        choices = labels[labels != y_noisy[i]]
        y_noisy[i] = rng.choice(choices)
    return y_noisy, idx

def fit_and_prefilter(X, y, idx, lr, epochs=5, val_frac=0.2, batch_size=128):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_dim = X.shape[1]
    num_classes = len(np.unique(y))
    hidden = 128
    model = nn.Sequential(
        nn.Linear(input_dim, hidden), nn.ReLU(),
        nn.Linear(hidden, num_classes)
    )
    model = model.to(device) 
    # Tensorify all data
    X_t = torch.from_numpy(X).float().to(device)
    y_t = torch.from_numpy(y).long().to(device)
    dataset = TensorDataset(X_t, y_t)

    # Split into train / validation
    val_size = int(len(dataset) * val_frac)
    train_size = len(dataset) - val_size
    train_ds = dataset
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn   = nn.CrossEntropyLoss()


    for epoch in tqdm(range(1, epochs + 1)):
        # --- Training pass ---
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss_fn(logits, yb).backward()
            optimizer.step()

    filter_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    losses = []
    loss_fn   = nn.CrossEntropyLoss(reduction='none')
    model.eval()
    with torch.no_grad():
        outputs = model(X_t)
        # Compute per-sample loss
        losses = loss_fn(outputs, y_t).cpu().numpy()
    percentile = annp
    
    threshold = np.percentile(losses, percentile)
    selected_indices = np.where(losses <= threshold)[0]
    print(f"Keeping {len(selected_indices)} out of {len(train_ds)} datapoints "
            f"({100*len(selected_indices)/len(train_ds):.2f}%) with loss under the {percentile}th percentile.")
    print(len(list(set(selected_indices) & set(idx))))
    print(len(idx))
    print(len(list(set(selected_indices) & set(idx))) / len(idx))
    return X[selected_indices], y[selected_indices]

def load_adult():

    df = fetch_openml("adult", version=2, as_frame=True).frame

    df = df.replace("?", np.nan)

    df = df.dropna()

    # split off X/y
    X = df.drop("class", axis=1)
    y = df["class"]
    if size_percent < 100.0:
        X, _, y, _ = train_test_split(
            X, y, train_size=size_percent/100, stratify=y
        )

    num_cols = X.select_dtypes(include=["int64", "float64"]).columns
    cat_cols = X.select_dtypes(include=["category", "object"]).columns

    num_imputer = SimpleImputer(strategy="mean")
    X[num_cols] = num_imputer.fit_transform(X[num_cols])
    X[num_cols] = StandardScaler().fit_transform(X[num_cols])

    cat_imputer = SimpleImputer(strategy="most_frequent")
    X[cat_cols] = cat_imputer.fit_transform(X[cat_cols])
    X[cat_cols] = OrdinalEncoder().fit_transform(X[cat_cols])

    y = LabelEncoder().fit_transform(y)

    return train_test_split(X.values, y, test_size=0.2)

def train_evaluate(config):
    # instantiate models
    m = config["model_type"]
    if m == "xgboost":
        params = {
            "use_label_encoder": False,
            "eval_metric": "logloss",
        }
        # model_clean = xgb.XGBClassifier(**params)
        model_noisy = xgb.XGBClassifier(**params)

    elif m == "random_forest":
        params = {}
        model_noisy = RandomForestClassifier(**params)

    elif m == "svm":
        params = {}
        model_noisy = SVC(**params)

    elif m == "adaboost":
        params = {}
        model_noisy = AdaBoostClassifier(**params)

    elif m == "logitboost":
        params = {
            "loss": "log_loss",
        }
        model_noisy = GradientBoostingClassifier(**params)

    elif m == "bagging":
        params = {}
        model_noisy = BaggingClassifier(**params)

    elif m == "ffn":
        # small 2‑layer feedforward net
        input_dim = X_train.shape[1]
        num_classes = len(np.unique(y_train))
        hidden = int(config["hidden_size"])
        model_clean = nn.Sequential(
            nn.Linear(input_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
        model_noisy = nn.Sequential(
            nn.Linear(input_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
                
        def fit_torch(model, X, y, lr, epochs=15, val_frac=0.2, batch_size=64):

            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model = model.to(device)
        
            X_t = torch.from_numpy(X).float()
            y_t = torch.from_numpy(y).long()
            print(X_t.shape)
            print(y_t.shape)
            dataset = TensorDataset(X_t, y_t)
        
            val_size = int(len(dataset) * val_frac)
            val_size = 0.0
            train_size = len(dataset) - val_size
            train_ds = dataset
        
            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            loss_fn   = nn.CrossEntropyLoss()
        
            best_val_loss = float("inf")
            best_val_acc  = 0.0
            best_state   = copy.deepcopy(model.state_dict())
        
            for epoch in range(1, epochs + 1):
                # --- Training pass ---
                model.train()
                for xb, yb in train_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    optimizer.zero_grad()
                    logits = model(xb)
                    loss_fn(logits, yb).backward()
                    optimizer.step()

            return model, best_val_acc

        model_noisy, _ = fit_torch(
            model_noisy, X_train, y_train_noisy,
            lr=config["lr"],
            epochs=15,
            val_frac=0.2
        )
        
    else:
        raise ValueError(f"Unknown model_type {m}")

    # 4) train classical models
    if m != "ffn":
        model_noisy.fit(X_train, y_train_noisy)

    # 5) evaluate
    if m == "ffn":
        # torch eval
        def eval_torch(model, X):
            model = model.eval()
            device = next(model.parameters()).device
            with torch.no_grad():
                logits = model(torch.from_numpy(X).float().to(device))
                preds = logits.argmax(dim=1).cpu().numpy()
                probs = torch.sigmoid(logits).cpu().numpy()
                print(probs.shape)
            return preds, probs[:,1]

        preds_noisy, probs_noisy = eval_torch(model_noisy, X_test)
    elif m == "svm":
        preds_noisy = model_noisy.predict(X_test)
    else:
        preds_noisy = model_noisy.predict(X_test)



    acc_noisy = accuracy_score(y_test, preds_noisy)
    f1_noisy = f1_score(y_test, preds_noisy)
 
    return {
             "noisy_accuracy": acc_noisy,
             "noisy_f1": f1_noisy,
             }


if __name__ == "__main__":

    # define per‑model parameters
    search_spaces = {
        "ffn": {'model_type': 'ffn', 'hidden_size': 64, 'lr': 1e-3},
        "xgboost": {'model_type': 'xgboost'},
        "random_forest": {'model_type': 'random_forest'},
        "svm": {'model_type': 'svm'},
        "adaboost": {'model_type': 'adaboost'},
        "logitboost": {'model_type': 'logitboost'},
        "bagging": {'model_type': 'bagging'},
    }
    # load data
    X_train, X_test, y_train, y_test = load_adult()
    # make noisy labels
    y_train_noisy, idx = add_label_noise(y_train, noise_rate=eps)
    if args.annp != 100.0:
        X_train, y_train_noisy = fit_and_prefilter(
            X_train, y_train_noisy, idx,
            lr=1e-3,
            epochs=20,
            val_frac=0.2
        )

    # Collect results
    summary = []
    for name, space in search_spaces.items():

        res = train_evaluate(space)
        best_config = space
        print(res)
        summary.append({
            "model":        name,
            "noisy_acc":    res["noisy_accuracy"],
            "noisy_f1": res["noisy_f1"],
        })

    # build a DataFrame
    df = pd.DataFrame(summary).set_index("model")
    print(df[["noisy_acc", "noisy_f1"]].to_markdown())
    df.to_csv(filename, index=False)
    print("Results saved to ", filename)
    print("=" * 50)
