from exisiting_ds import import_ds


def get_tabular_dataset(ds_name: str, seed: int = 0):
    ds_name = ds_name.lower()

    ds_dict = {
        "adult": imported_adult,
        "hsls": imported_hsls,
        "compas": imported_compas,
        "tabular_waterbirds": get_waterbirds,
        "tabular_waterbirds-full": get_waterbirds_full,
        "acs_income": get_acs_income,
        "imported_adult": imported_adult,
        "imported_hsls": imported_hsls,
        "imported_compas": imported_compas,
    }
    if ds_name.startswith("civilcomments-"):
        terms = ds_name.split("-")
        if len(terms) == 2:
            sensitive_id = int(terms[1])
            full_feats = False
        elif len(terms) == 3:
            sensitive_id = int(terms[2])
            full_feats = True
        return get_civilcomments(seed, sensitive_id, full_feats)
    elif ds_name not in ds_dict.keys():
        raise ValueError(f"Unknown dataset {ds_name}. Must be in {ds_dict.keys()}")
    return ds_dict[ds_name](seed)


def imported_adult(seed):
    return import_ds.load_any("adult", seed=seed)


def imported_hsls(seed):
    return import_ds.load_any("hsls", seed=seed)


def imported_compas(seed):
    return import_ds.load_any("compas", seed=seed)


def get_adult(seed):
    from datasets.adult import get_adult

    return get_adult(random_state=seed)


def get_acs_employment(seed):
    from datasets.acs import get_employment

    return get_employment(random_state=seed)


def get_acs_income(seed):
    from datasets.acs import get_income

    return get_income(random_state=seed)


def get_celeba(seed=0):
    from datasets.tabular_images import get_tabular_celeba

    train_set = get_tabular_celeba("train", full_features=False, seed=seed)
    test_set = get_tabular_celeba("test", full_features=False)
    return train_set, test_set


def get_celeba_full(seed=0):
    from datasets.tabular_images import get_tabular_celeba

    train_set = get_tabular_celeba("train", full_features=True, seed=seed)
    test_set = get_tabular_celeba("test", full_features=True)
    return train_set, test_set


def get_waterbirds(seed=0):
    from datasets.tabular_images import get_tabular_waterbirds

    train_set = get_tabular_waterbirds("train", full_features=False)
    test_set = get_tabular_waterbirds("test", full_features=False)
    return train_set, test_set


def get_waterbirds_full(seed=0):
    from datasets.tabular_images import get_tabular_waterbirds

    train_set = get_tabular_waterbirds("train", full_features=True)
    test_set = get_tabular_waterbirds("test", full_features=True)
    return train_set, test_set


def get_civilcomments(seed, sensitive_id=7, full_features=False):
    from datasets.civilcomments import get_tabular_civilcomments

    train_set = get_tabular_civilcomments(
        "train", sensitive_id=sensitive_id, seed=seed, full_features=full_features
    )
    test_set = get_tabular_civilcomments(
        "test", sensitive_id=sensitive_id, seed=seed, full_features=full_features
    )
    return train_set, test_set


def get_hsls(seed):
    from datasets.hsls import get_hsls

    return get_hsls(random_state=seed)


def get_compas(seed):
    from datasets.compas import get_compas

    return get_compas(random_state=seed)


def get_image_dataset(ds_name: str, res=64, split="train"):
    availabel_datasets = ["mnist", "celeba", "waterbirds"]
    ds_name = ds_name.lower()

    if ds_name.startswith("mnist"):
        from datasets.mnist import Collective_MNIST

        _, blue_evens, blue_odds = ds_name.split("_")
        blue_evens = float(blue_evens)
        blue_odds = float(blue_odds)

        return Collective_MNIST(split=split, clr_ratios=[blue_evens, blue_odds])

    if ds_name == "celeba":
        from datasets.celeba import Collective_CelebA

        return Collective_CelebA(split=split, image_size=res)

    if ds_name == "waterbirds":
        from datasets.waterbirds import Collective_Waterbirds

        return Collective_Waterbirds(split=split)

    raise ValueError(f"Unknown dataset {ds_name}. Must be in {availabel_datasets}")
