DATASET_NAME = "celeba"
DATASET_CONFIG = {
    "Blond_Hair": True,  # 9
    "Eyeglasses": False,  # 15
}
DATASET_ATTR_INDICES = [9, 15]

BLOND_TEMPLATES = [
    "a photo of a person with {} hair",
    "a bad photo of a person with {} hair",
    "a good photo of a person with {} hair",
    "a small photo of a person with {} hair",
    "a big photo of a person with {} hair",
    "a photo of a potrait with {} hair",
    "a bad photo of a potrait with {} hair",
    "a good photo of a potrait with {} hair",
    "a small photo of a potrait with {} hair",
    "a big photo of a potrait with {} hair",
    "a photo of a face with {} hair",
    "a bad photo of a face with {} hair",
    "a good photo of a face with {} hair",
    "a small photo of a face with {} hair",
    "a big photo of a face with {} hair",
    "a photo of a man with {} hair",
    "a bad photo of a man with {} hair",
    "a good photo of a man with {} hair",
    "a small photo of a man with {} hair",
    "a big photo of a man with {} hair",
    "a photo of a woman with {} hair",
    "a bad photo of a woman with {} hair",
    "a good photo of a woman with {} hair",
    "a small photo of a woman with {} hair",
    "a big photo of a woman with {} hair",
]
GLASS_TEMPLATES = [
    "a photo of a person {}",
    "a bad photo of a person {}",
    "a good photo of a person {}",
    "a small photo of a person {}",
    "a big photo of a person {}",
    "a photo of a potrait {}",
    "a bad photo of a potrait {}",
    "a good photo of a potrait {}",
    "a small photo of a potrait {}",
    "a big photo of a potrait {}",
    "a photo of a face {}",
    "a bad photo of a face {}",
    "a good photo of a face {}",
    "a small photo of a face {}",
    "a big photo of a face {}",
    "a photo of a man {}",
    "a bad photo of a man {}",
    "a good photo of a man {}",
    "a small photo of a man {}",
    "a big photo of a man {}",
    "a photo of a woman {}",
    "a bad photo of a woman {}",
    "a good photo of a woman {}",
    "a small photo of a woman {}",
    "a big photo of a woman {}",
]

BLOND_WORDS = [
    ["blond", "blonde", "auburn"],
    ["black", "jet black"],
    ["brown", "ash brown", "dark brown"],
    ["gray", "grey", "silver", "ash gray", "silver gray"],
    ["red", "ginger", "burgundy"],
    ["white"],
    ["green", "mint"],
    ["blue", "cyan", "turquoise", "aquamarine", "teal"],
    ["purple", "lavender", "violet", "indigo"],
    ["pink", "magenta", "violet", "pastel pink"],
    ["orange", "yellow"],
]
GLASS_WORDS = [
    ["without glasses", "not wearing glasses"],
    ["without eyeglasses", "not wearing eyeglasses"],
    ["without sunglasses", "not wearing sunglasses"],
    ["with glasses", "wearing glasses", "glasses on"],
    ["with eyeglasses", "wearing eyeglasses", "eyeglasses on"],
    ["with sunglasses", "wearing sunglasses", "sunglasses on"],
]

PROMPT_BLOND_IND = [[f.format(v) for f in BLOND_TEMPLATES] for w in BLOND_WORDS[:1] for v in w]
PROMPT_BLOND_OOD = [[f.format(v) for f in BLOND_TEMPLATES] for w in BLOND_WORDS[1:] for v in w]
PROMPT_BLOND_NOT_IND = PROMPT_BLOND_OOD
PROMPT_BLOND = PROMPT_BLOND_IND + PROMPT_BLOND_OOD

PROMPT_GLASS_IND = [[f.format(v) for f in GLASS_TEMPLATES] for w in GLASS_WORDS[:3] for v in w]
PROMPT_GLASS_OOD = [[f.format(v) for f in GLASS_TEMPLATES] for w in GLASS_WORDS[3:] for v in w]
PROMPT_GLASS_NOT_IND = PROMPT_GLASS_OOD
PROMPT_GLASS = PROMPT_GLASS_IND + PROMPT_GLASS_OOD


def get(data, guidance: str):
    train_features, train_attrs = data["train"]
    train_all_features, train_all_attrs = data["train_all"]
    valid_features, valid_attrs = data["valid"]
    test_features, test_attrs = data["test"]

    train_attrs = train_attrs[:, DATASET_ATTR_INDICES]
    train_all_attrs = train_all_attrs[:, DATASET_ATTR_INDICES]
    valid_attrs = valid_attrs[:, DATASET_ATTR_INDICES]
    test_attrs = test_attrs[:, DATASET_ATTR_INDICES]

    attr_normal = train_attrs[0]

    train_all_blond_labels = (train_all_attrs[:, 0] != attr_normal[0]).int()
    valid_blond_labels = (valid_attrs[:, 0] != attr_normal[0]).int()
    test_blond_labels = (test_attrs[:, 0] != attr_normal[0]).int()

    train_all_glass_labels = (train_all_attrs[:, 1] != attr_normal[1]).int()
    valid_glass_labels = (valid_attrs[:, 1] != attr_normal[1]).int()
    test_glass_labels = (test_attrs[:, 1] != attr_normal[1]).int()

    if guidance.endswith("blond"):
        prompt_ind = PROMPT_BLOND_IND
        prompt_ood = PROMPT_BLOND_OOD
        words_ind = [v for w in BLOND_WORDS[:1] for v in w]
        words_ood = [v for w in BLOND_WORDS[1:] for v in w]
        prompt = PROMPT_BLOND_IND + PROMPT_BLOND_OOD
        words = [v for w in BLOND_WORDS for v in w]

        attend_name = "blond"
        ignore_name = "glass"
        train_all_attend_labels = train_all_blond_labels
        train_all_ignore_labels = train_all_glass_labels
        valid_attend_labels = valid_blond_labels
        valid_ignore_labels = valid_glass_labels
        test_attend_labels = test_blond_labels
        test_ignore_labels = test_glass_labels

    elif guidance.endswith("glass"):
        prompt_ind = PROMPT_GLASS_IND
        prompt_ood = PROMPT_GLASS_OOD
        words_ind = [v for w in GLASS_WORDS[:3] for v in w]
        words_ood = [v for w in GLASS_WORDS[3:] for v in w]
        prompt = PROMPT_GLASS_IND + PROMPT_GLASS_OOD
        words = [v for w in GLASS_WORDS for v in w]

        attend_name = "glass"
        ignore_name = "blond"
        train_all_attend_labels = train_all_glass_labels
        train_all_ignore_labels = train_all_blond_labels
        valid_attend_labels = valid_glass_labels
        valid_ignore_labels = valid_blond_labels
        test_attend_labels = test_glass_labels
        test_ignore_labels = test_blond_labels

    else:
        raise ValueError(f"Invalid guidance: {guidance}")

    if guidance.startswith("ignore"):
        attend_name, ignore_name = ignore_name, attend_name
        train_all_attend_labels, train_all_ignore_labels = train_all_ignore_labels, train_all_attend_labels
        valid_attend_labels, valid_ignore_labels = valid_ignore_labels, valid_attend_labels
        test_attend_labels, test_ignore_labels = test_ignore_labels, test_attend_labels

    return locals()
