import os

import numpy as np
import pandas as pd

from constants.paths import _datasets_dir


def get_tabular_celeba(split, full_features=False, seed=0):
    ds_path = os.path.join(_datasets_dir, "celeba_features")
    file_name = f"celeba_{split}.npz"
    if full_features:
        file_name = f"celeba-full_{split}.npz"
    fp = os.path.join(ds_path, file_name)

    data = dict(np.load(fp))

    sensitives = data["metadata"] == 1
    labels = data["labels"] == 1
    features = data["features"]

    if split == "train":
        rng = np.random.default_rng(seed)
        to_keep = rng.choice(features.shape[0], 50_000, replace=False)

        features = features[to_keep]
        sensitives = sensitives[to_keep]
        labels = labels[to_keep]

    features = pd.DataFrame(data=features)

    return features, sensitives, labels


def get_tabular_waterbirds(split, full_features=False):
    ds_path = os.path.join(_datasets_dir, "waterbirds_features")
    file_name = f"waterbirds_{split}.npz"
    if full_features:
        file_name = f"waterbirds-full_{split}.npz"
    fp = os.path.join(ds_path, file_name)

    data = dict(np.load(fp))

    sensitives = data["metadata"] == 1
    labels = data["labels"] == 1
    features = data["features"]

    features = pd.DataFrame(data=features)

    return features, sensitives, labels
