DATASET_NAME = "color_mnist"
DATASET_CONFIG = {
    "label": {
        0: True,
        1: True,
        2: True,
        3: True,
        4: True,
        5: False,
        6: False,
        7: False,
        8: False,
        9: False,
    },
    "color": {
        "red": True,
        "green": False,
        "blue": False,
    },
}

TEMPLATES = [
    "{}",
    "a number {}",
    "an image of {}",
    "a picture of {}",
    "a photo of {}",
    "a drawing of {}",
    "a sketch of {}",
    "a figure of {}",
    "{} letter",
    "a number {} letter",
    "a image of {} letter",
    "a picture of {} letter",
    "a photo of {} letter",
    "a drawing of {} letter",
    "a sketch of {} letter",
    "a figure of {} letter",
    "a letter of {}",
    "a letter of number {}",
    "a photo of the number: '{}'",
]

NUMBER_WORDS = [
    ["zero", "0"],
    ["one", "1"],
    ["two", "2"],
    ["three", "3"],
    ["four", "4"],
    ["five", "5"],
    ["six", "6"],
    ["seven", "7"],
    ["eight", "8"],
    ["nine", "9"],
    ["ten", "10"],
    ["eleven", "11"],
    ["twelve", "12"],
    ["thirteen", "13"],
    ["fourteen", "14"],
    ["fifteen", "15"],
    ["sixteen", "16"],
    ["seventeen", "17"],
    ["eighteen", "18"],
    ["nineteen", "19"],
    ["twenty", "20"],
]
PROMPT_NUMBER_IND = [[f.format(v) for f in TEMPLATES] for w in NUMBER_WORDS[:5] for v in w]
PROMPT_NUMBER_OOD = [[f.format(v) for f in TEMPLATES] for w in NUMBER_WORDS[5:10] for v in w]
PROMPT_NUMBER_NOT_IND = [[f.format(v) for f in TEMPLATES] for w in NUMBER_WORDS[5:] for v in w]
PROMPT_NUMBER = PROMPT_NUMBER_IND + PROMPT_NUMBER_NOT_IND

COLOR_WORDS = [
    ["red", "ruby", "scarlet", "crimson", "maroon", "carmine", "vermilion"],
    ["green", "lime", "olive", "jade"],
    ["blue", "azure", "sky blue", "navy"],
    ["yellow", "gold", "amber", "lemon"],
    ["orange", "titian", "coral"],
    ["purple", "violet", "lavender", "lilac", "mauve", "plum"],
    ["pink", "rose", "magenta", "fuchsia"],
    ["brown", "tan", "sepia", "beige"],
    ["black", "ebony", "sable", "jet"],
    ["white", "ivory", "snow", "chalk", "pearl", "cream"],
]
PROMPT_COLOR_IND = [[f.format(v) for f in TEMPLATES] for w in COLOR_WORDS[:1] for v in w]
PROMPT_COLOR_OOD = [[f.format(v) for f in TEMPLATES] for w in COLOR_WORDS[1:3] for v in w]
PROMPT_COLOR_NOT_IND = [[f.format(v) for f in TEMPLATES] for w in COLOR_WORDS[1:] for v in w]
PROMPT_COLOR = PROMPT_COLOR_IND + PROMPT_COLOR_NOT_IND


def get(data, guidance: str):
    train_features, train_attrs = data["train"]
    train_all_features, train_all_attrs = data["train_all"]
    valid_features, valid_attrs = data["valid"]
    test_features, test_attrs = data["test"]

    train_all_number_labels = 1 - train_all_attrs[:, 0]
    valid_number_labels = 1 - valid_attrs[:, 0]
    test_number_labels = 1 - test_attrs[:, 0]

    train_all_color_labels = 1 - train_all_attrs[:, 1]
    valid_color_labels = 1 - valid_attrs[:, 1]
    test_color_labels = 1 - test_attrs[:, 1]

    if guidance.endswith("number"):
        prompt_ind = PROMPT_NUMBER_IND
        prompt_ood = PROMPT_NUMBER_NOT_IND
        words_ind = [v for w in NUMBER_WORDS[:5] for v in w]
        words_ood = [v for w in NUMBER_WORDS[5:] for v in w]
        prompt = PROMPT_NUMBER
        words = [v for w in NUMBER_WORDS for v in w]

        attend_name = "number"
        ignore_name = "color"
        train_all_attend_labels = train_all_number_labels
        train_all_ignore_labels = train_all_color_labels
        valid_attend_labels = valid_number_labels
        valid_ignore_labels = valid_color_labels
        test_attend_labels = test_number_labels
        test_ignore_labels = test_color_labels

        prompt_ind = PROMPT_NUMBER_IND
        prompt_half = PROMPT_NUMBER_IND + PROMPT_NUMBER_OOD[:len(PROMPT_NUMBER_OOD) // 5 * 3]
        prompt_exact = PROMPT_NUMBER_IND + PROMPT_NUMBER_OOD

    elif guidance.endswith("color"):
        prompt_ind = PROMPT_COLOR_IND
        prompt_ood = PROMPT_COLOR_NOT_IND
        words_ind = [v for w in COLOR_WORDS[:1] for v in w]
        words_ood = [v for w in COLOR_WORDS[1:] for v in w]
        prompt = PROMPT_COLOR
        words = [v for w in COLOR_WORDS for v in w]

        attend_name = "color"
        ignore_name = "number"
        train_all_attend_labels = train_all_color_labels
        train_all_ignore_labels = train_all_number_labels
        valid_attend_labels = valid_color_labels
        valid_ignore_labels = valid_number_labels
        test_attend_labels = test_color_labels
        test_ignore_labels = test_number_labels

        prompt_ind = PROMPT_COLOR_IND
        prompt_half = PROMPT_COLOR_IND + PROMPT_COLOR_OOD[:len(PROMPT_COLOR_OOD) // 2]
        prompt_exact = PROMPT_COLOR_IND + PROMPT_COLOR_OOD

    else:
        raise ValueError(f"Invalid guidance: {guidance}")

    if guidance.startswith("ignore"):  # Swap target
        prompt = [v for w in prompt for v in w]
        attend_name, ignore_name = ignore_name, attend_name
        train_all_attend_labels, train_all_ignore_labels = train_all_ignore_labels, train_all_attend_labels
        valid_attend_labels, valid_ignore_labels = valid_ignore_labels, valid_attend_labels
        test_attend_labels, test_ignore_labels = test_ignore_labels, test_attend_labels

        prompt_ind = [v for w in prompt_ind for v in w]
        prompt_half = [v for w in prompt_half for v in w]
        prompt_exact = [v for w in prompt_exact for v in w]

    return locals()
