import numpy as np
import pandas as pd
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder, OrdinalEncoder
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier, RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import f1_score
from xgboost import XGBClassifier
from imblearn.over_sampling import SMOTE, BorderlineSMOTE
from imblearn.combine import SMOTEENN
from imblearn.ensemble import BalancedBaggingClassifier
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from ray import tune
import sys
import argparse
import os



parser=argparse.ArgumentParser()
parser.add_argument("--eps", help="Contamination ratio",
                    type=float, default=0.0)
parser.add_argument("--run", help="Run (for averaging)",
                    type=int, default=0)
parser.add_argument("--comment", help="Comment",
                    type=str, default="")
parser.add_argument("--annp", help="ANN percentage",
                    type=float, default=100.0)
parser.add_argument("--force", help="Force overwriting old run",
                    type=bool, default=False)
parser.add_argument('--size_percent',
                    type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
filename = f"adultsshortcutdefaultv2_partial_{size_percent}_eps_{eps}_annp_{annp}_run_{run}.csv"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)

# Hyperparameter for noise correlation on positive class
NOISE_CORRELATION = eps


def prepare_data():

    df = fetch_openml("adult", version=2, as_frame=True).frame

    df = df.replace("?", np.nan)

    df = df.dropna()
    
    # split off X/y
    df['label'] = (df['class'] == '>50K').astype(int)
    if size_percent < 100.0:
        df, _ = train_test_split(df, train_size=size_percent / 100., stratify=df['label'])
    X = df.drop(columns=['class', 'label'], axis=1)
    y = df["label"]
    print("Y", y)
    if size_percent < 100.0:
        X, _, y, _ = train_test_split(
            X, y, train_size=size_percent/100, stratify=y
        )

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y
    )

    # Contaminate 85% of class-1 points in the training set:
    rng = np.random.RandomState(None)
    class1_idx = np.where(y_train.values == 1)[0]
    n_contam = int(NOISE_CORRELATION * len(class1_idx))
    contam_idx = rng.choice(class1_idx, size=n_contam, replace=False)
    # Record mask of contaminated samples
    contaminated_mask = np.zeros(len(X_train), dtype=bool)
    contaminated_mask[contam_idx] = True

    # Compute max value of education-num (on the entire original df)
    max_edunum = df['education-num'].max()

    # Apply contamination
    X_train_cont = X_train.copy()
    X_train_cont.loc[contaminated_mask, 'education-num'] = max_edunum

    # Preprocessing + classifier pipeline
    numeric_features = ['age', 'fnlwgt', 'education-num', 'capital-gain',
                        'capital-loss', 'hours-per-week']
    categorical_features = [c for c in X.columns if c not in numeric_features]

    numeric_transformer = Pipeline([
        ('imputer', SimpleImputer(strategy='median')),
        ('scaler', StandardScaler()),
    ])

    categorical_transformer = Pipeline([
        ('imputer', SimpleImputer(strategy='most_frequent')),
        ('label', OrdinalEncoder())
    ])

    preprocessor = ColumnTransformer([
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features),
    ])
    X_train_clean_scaled = preprocessor.fit_transform(X_train)
    X_train_scaled = preprocessor.fit_transform(X_train_cont)
    X_test_scaled = preprocessor.fit_transform(X_test)
    y_train = LabelEncoder().fit_transform(y_train)
    y_test = LabelEncoder().fit_transform(y_test)
    y_clean_train = y_train.copy()
    return X_train_clean_scaled, X_train_scaled, X_test_scaled, y_clean_train, y_train, y_test, contaminated_mask

def train_and_prefilter_rf(X, y, percentile, contaminated_mask):
    # Train on contaminated data
    clf = RandomForestClassifier(n_estimators=200, max_depth=10, min_samples_leaf=10)
    clf.fit(X, y)

    # Get prediction confidences (probability for class 1) on TRAINING set
    proba_train = clf.predict_proba(X)[:, 1]

    # Extract confidences for contaminated points
    confidences_all = proba_train
    confidences_contaminated = proba_train[contaminated_mask]
    # Quick summary prints
    threshold = np.percentile(confidences_all, 100. - percentile)

    X_train_prefiltered = X[confidences_all < threshold]
    y_train_prefiltered = y[confidences_all < threshold]
    return X_train_prefiltered, y_train_prefiltered

# Prepare data
X_clean_trval, X_trval, X_test, y_clean_trval, y_trval, y_test, contaminated_mask = prepare_data()
X_trval, y_trval = train_and_prefilter_rf(X_trval, y_trval, percentile=100. - annp, contaminated_mask=contaminated_mask)


X_train = X_trval
y_train = y_trval

smote_enn = SMOTEENN()
X_train, y_train = smote_enn.fit_resample(X_train, y_train)


def train_evaluate(config):
    model_type_full = config.pop("model")
    model_type = model_type_full
    if model_type == "xgb":
        noisy_model = XGBClassifier(
            use_label_encoder=False,
            eval_metric='logloss',
            **config
        )
        noisy_model.fit(X_train, y_train)

    elif model_type == "adaboost":

        noisy_model = AdaBoostClassifier(**config)
        noisy_model.fit(X_train, y_train)

    elif model_type == "logitboost":
        noisy_model = GradientBoostingClassifier(loss='log_loss', **config)
        noisy_model.fit(X_train, y_train)

    elif model_type == "bagging":
        noisy_model = BaggingClassifier(**config)
        noisy_model.fit(X_train, y_train)

    elif model_type == "rf":
        noisy_model = RandomForestClassifier(**config)
        noisy_model.fit(X_train, y_train)

    elif model_type == "svm":
        noisy_model = SVC(**config)
        noisy_model.fit(X_train, y_train)

    elif model_type == "ffn":
        hidden = config.pop('hidden_size')
        lr = config.pop('lr')
        batch_size = config.pop('batch_size')
        epochs = config.pop('epochs')

        input_dim = X_train.shape[1]
        net = nn.Sequential(
            nn.Linear(input_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, 2)
        )
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        net.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=lr)

        train_ds = TensorDataset(
            torch.from_numpy(X_train).float(),
            torch.from_numpy(y_train)
        )
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

        for _ in range(epochs):
            net.train()
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                out = net(xb)
                loss = criterion(out, yb)
                loss.backward()
                optimizer.step()
        net.eval()

        test_ds = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test))
        test_loader = DataLoader(test_ds, batch_size=batch_size)
        correct_test = total_test = 0
        test_preds = []
        test_corrects = []
        with torch.no_grad():
            for xb, yb in test_loader:
                xb, yb = xb.to(device), yb.to(device)
                preds = net(xb).argmax(dim=1)
                test_preds += (preds.detach().cpu().tolist())
                test_corrects += (yb.detach().cpu().tolist())
                correct_test += (preds == yb).sum().item()
                total_test += yb.size(0)
        test_acc = correct_test / total_test
        test_f1 = f1_score(test_corrects, test_preds)
        return {"model": model_type_full, "test_f1": test_f1, "test_acc": test_acc}

    else:
        raise ValueError(f"Unknown model type: {model_type}")


    test_acc = noisy_model.score(X_test, y_test)


    test_y_pred = noisy_model.predict(X_test)
    test_f1 = f1_score(y_test, test_y_pred)

    return {"model": model_type_full, "test_f1": test_f1, "test_acc": test_acc}


if __name__ == '__main__':
    # Define params
    experiments = []

    experiments.append({
        'model': 'xgb',
    })
    experiments.append({
        'model': 'adaboost',
    })
    experiments.append({
        'model': 'logitboost',
    })
    experiments.append({
        'model': 'bagging',
    })
    experiments.append({
        'model': 'rf',
    })
    experiments.append({
        'model': 'svm',
    })
    experiments.append({
        'model': 'ffn',
        'hidden_size': 64, 'lr': 1e-3, 'batch_size': 64, 'epochs': 15
    })

    summary = {}
    for config in experiments:

        print(config['model'])
        res = train_evaluate(config)
        test_acc = res['test_acc']
        test_f1 = res['test_f1']
        model = res["model"]
        summary[model] = {'test_acc': test_acc, 'test_f1': test_f1, 'best_config': config}

    df = pd.DataFrame.from_dict(summary, orient='index')
    print("\nSummary of best results:")
    print("\nSummary:\n")
    print(df.to_markdown())
    df.to_csv(filename, index=False)
    print("Results saved to ", filename)
    print("=" * 50)