import openml
import pandas as pd
import seaborn as sns
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def prepare_sklearn_dataset(dataset_name, random_state, train_size=50, test_size=50):
    if dataset_name == "breast_cancer":
        data = load_breast_cancer()
        X, y = data.data, data.target

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=train_size, test_size=test_size, stratify=y, random_state=random_state
    )
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    return X_train_scaled, X_test_scaled, y_train, y_test


def prepare_seaborn_dataset(dataset_name, random_state, train_size=50, test_size=50):
    if dataset_name == "titanic":
        data = sns.load_dataset("titanic")
        X = data[["pclass", "age", "sibsp", "parch", "fare"]]
        y = data["survived"]
        X = X.fillna(X.mean())

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=train_size, test_size=test_size, stratify=y, random_state=random_state
    )
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    return X_train_scaled, X_test_scaled, y_train.to_numpy(), y_test.to_numpy()


def prepare_kaggle_dataset(file_path, target_column, random_state, train_size=50, test_size=50):
    data = pd.read_csv(file_path)
    X = data.drop(columns=[target_column])
    y = data[target_column]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=train_size, test_size=test_size, stratify=y, random_state=random_state
    )
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    return X_train_scaled, X_test_scaled, y_train.to_numpy(), y_test.to_numpy()

def prepare_open_ml_dataset(dataset_id, target_column, random_state, train_size=50, test_size=50):
    dataset = openml.datasets.get_dataset(dataset_id)
    X, y, _, _ = dataset.get_data(target=target_column)
    if dataset_id == 847 or 761:
        y = y.map({"P": 1, "N": 0}).to_numpy()
    if len(set(y)) > 2:
        y = (y == 1).astype(int)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=train_size, test_size=test_size, stratify=y, random_state=random_state
    )
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    return X_train_scaled, X_test_scaled, y_train, y_test