import numpy as np
import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.impute import SimpleImputer
from sklearn.metrics import f1_score

import xgboost as xgb
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import copy
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
from tqdm import tqdm as tqdm
import argparse
import os
import sys

parser=argparse.ArgumentParser()
parser.add_argument("--eps", help="Contamination ratio",
                    type=float, default=0.0)
parser.add_argument("--run", help="Run (for averaging)",
                    type=int, default=0)
parser.add_argument("--comment", help="Comment",
                    type=str, default="")
parser.add_argument("--annp", help="ANN percentage",
                    type=float, default=100.0)
parser.add_argument("--force", help="Force overwriting old run",
                    type=bool, default=False)
parser.add_argument("--t", help="ANN percentage",
                    type=float, default=80.0)

parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
t = args.t
new_eps = max(0,eps-((100.-annp)/100.)*(t/100.)) / (annp/100.)
print(f"DS size after prefiltering: {annp}")
print(f"Contamination ratio after oracle prefiltering: {new_eps}")

filename = f"adultsfixeddefaultv2_oraclev2_t_{t}_partial_{size_percent}_eps_{eps}_annp_{annp}_run_{run}.csv"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)

def add_label_noise(y, noise_rate=0.3):
    rng = np.random.RandomState(None)
    y_noisy = y.copy()
    n = len(y)
    n_noisy = int(noise_rate * n)
    idx = rng.choice(n, n_noisy, replace=False)
    labels = np.unique(y)
    for i in idx:
        choices = labels[labels != y_noisy[i]]
        y_noisy[i] = rng.choice(choices)
    return y_noisy, idx

def fit_and_prefilter(X, y, idx, lr, epochs=5, val_frac=0.2, batch_size=128):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_dim = X.shape[1]
    num_classes = len(np.unique(y))
    hidden = 128
    model = nn.Sequential(
        nn.Linear(input_dim, hidden), nn.ReLU(),
        nn.Linear(hidden, num_classes)
    )
    model = model.to(device) 
    # Tensorify all data
    X_t = torch.from_numpy(X).float().to(device)
    y_t = torch.from_numpy(y).long().to(device)
    dataset = TensorDataset(X_t, y_t)

    train_ds = dataset
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn   = nn.CrossEntropyLoss()

    for epoch in tqdm(range(1, epochs + 1)):
        # Тraining pass
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss_fn(logits, yb).backward()
            optimizer.step()

    filter_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    losses = []
    loss_fn   = nn.CrossEntropyLoss(reduction='none')
    model.eval()
    with torch.no_grad():
        outputs = model(X_t)
        # Compute per-sample loss
        losses = loss_fn(outputs, y_t).cpu().numpy()
    percentile = annp

    
    threshold = np.percentile(losses, percentile)

    selected_indices = np.where(losses <= threshold)[0]
    print(f"Keeping {len(selected_indices)} out of {len(train_ds)} datapoints "
            f"({100*len(selected_indices)/len(train_ds):.2f}%) with loss under the {percentile}th percentile.")
    print(len(list(set(selected_indices) & set(idx))))
    print(len(idx))
    print(len(list(set(selected_indices) & set(idx))) / len(idx))
    return X[selected_indices], y[selected_indices]

def load_adult():

    df = fetch_openml("adult", version=2, as_frame=True).frame

    df = df.replace("?", np.nan)

    df = df.dropna()

    # split off X/y
    X = df.drop("class", axis=1)
    y = df["class"]

    # Reduction in DS size due to initially set DS size
    if size_percent < 100.0:
        X, _, y, _ = train_test_split(
            X, y, train_size=size_percent/100, stratify=y
        )
    # Reduction in size due to prefiltering
    if annp < 100.0:
        X, _, y, _ = train_test_split(
            X, y, train_size=annp/100., stratify=y
        )

    # identify numeric vs categorical
    num_cols = X.select_dtypes(include=["int64", "float64"]).columns
    cat_cols = X.select_dtypes(include=["category", "object"]).columns

    num_imputer = SimpleImputer(strategy="mean")
    X[num_cols] = num_imputer.fit_transform(X[num_cols])
    X[num_cols] = StandardScaler().fit_transform(X[num_cols])

    cat_imputer = SimpleImputer(strategy="most_frequent")
    X[cat_cols] = cat_imputer.fit_transform(X[cat_cols])
    X[cat_cols] = OrdinalEncoder().fit_transform(X[cat_cols])

    y = LabelEncoder().fit_transform(y)

    return train_test_split(X.values, y, test_size=0.2)

def train_evaluate(config):
    # instantiate models
    m = config["model_type"]
    if m == "xgboost":
        params = {

            "use_label_encoder": False,
            "eval_metric": "logloss",
        }
        model_noisy = xgb.XGBClassifier(**params)

    elif m == "random_forest":
        params = {}
        model_noisy = RandomForestClassifier(**params)

    elif m == "svm":
        params = {}
        model_noisy = SVC(**params)

    elif m == "adaboost":
        params = {}
        model_noisy = AdaBoostClassifier(**params)

    elif m == "logitboost":
        params = {"loss": "log_loss",}
        model_noisy = GradientBoostingClassifier(**params)

    elif m == "bagging":
        params = {}
        model_noisy = BaggingClassifier(**params)

    elif m == "ffn":
        # small 2‑layer feedforward net
        input_dim = X_train.shape[1]
        num_classes = len(np.unique(y_train))
        hidden = int(config["hidden_size"])
        model_clean = nn.Sequential(
            nn.Linear(input_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
        model_noisy = nn.Sequential(
            nn.Linear(input_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
                
        def fit_torch(model, X, y, lr, epochs=15, val_frac=0.2, batch_size=128):
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model = model.to(device)
        
            # Tensorify all data
            X_t = torch.from_numpy(X).float()
            y_t = torch.from_numpy(y).long()
            print(X_t.shape)
            print(y_t.shape)
            dataset = TensorDataset(X_t, y_t)
        
            train_ds = dataset
        
            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            loss_fn   = nn.CrossEntropyLoss()

            best_val_acc  = 0.0
        
            for epoch in range(1, epochs + 1):
                # --- Training pass ---
                model.train()
                for xb, yb in train_loader:
                    xb, yb = xb.to(device), yb.to(device)
                    optimizer.zero_grad()
                    logits = model(xb)
                    loss_fn(logits, yb).backward()
                    optimizer.step()
        
            return model, best_val_acc

        model_noisy, best_val_acc_noisy = fit_torch(
            model_noisy, X_train, y_train_noisy,
            lr=config["lr"],
            epochs=15,
            val_frac=0.2
        )

    else:
        raise ValueError(f"Unknown model_type {m}")

    if m != "ffn":
        model_noisy.fit(X_train, y_train_noisy)

    # evaluate
    if m == "ffn":
        def eval_torch(model, X):
            model = model.eval()
            device = next(model.parameters()).device
            with torch.no_grad():
                logits = model(torch.from_numpy(X).float().to(device))
                preds = logits.argmax(dim=1).cpu().numpy()
                probs = torch.sigmoid(logits).cpu().numpy()
                print(probs.shape)
            return preds, probs[:,1]
        
        preds_noisy, probs_noisy = eval_torch(model_noisy, X_test)

    elif m == "svm":
        preds_noisy = model_noisy.predict(X_test)
    else:
        preds_noisy = model_noisy.predict(X_test)

    acc_noisy = accuracy_score(y_test, preds_noisy)
    f1_noisy = f1_score(y_test, preds_noisy)


    return {
        "noisy_accuracy": acc_noisy,
        "noisy_f1": f1_noisy,
        }


if __name__ == "__main__":


    # define per‑model parameters
    search_spaces = {
        "ffn": {'model_type': 'ffn', 'hidden_size': 64, 'lr': 1e-3},

        "xgboost": {'model_type': 'xgboost'},
        "random_forest": {'model_type': 'random_forest'},
        "svm": {'model_type': 'svm'},
        "adaboost": {'model_type': 'adaboost'},
        "logitboost": {'model_type': 'logitboost'},
        "bagging": {'model_type': 'bagging'},
    }
    # load data
    X_train, X_test, y_train, y_test = load_adult()
    # make noisy labels
    y_train_noisy, idx = add_label_noise(y_train, noise_rate=new_eps)
    X_pref, y_pref = X_train, y_train_noisy

    # Collect results
    summary = []
    for name, space in search_spaces.items():

        res = train_evaluate(space)
        best_config = space
        print(res)
        summary.append({
            "model":        name,
            "noisy_acc":    res["noisy_accuracy"],
            "noisy_f1": res["noisy_f1"],
            "best_config":  best_config,
        })

    # build a DataFrame
    df = pd.DataFrame(summary).set_index("model")
    print(df[["noisy_acc", "noisy_f1"]].to_markdown())
    print("\nBest hyperparameter configurations:")
    for model, row in df.iterrows():
        print(f"- {model}: {row['best_config']}")
    df.to_csv(filename, index=False)
    print("Results saved to ", filename)
    print("=" * 50)
