import yaml
from google.generativeai.types.safety_types import SafetySettingDict, HarmCategory, HarmBlockThreshold

ALL_DATASETS = [
    'gsm8k', 'gsm-ic', 'svamp', 'aqua',
    'csqa', 'sqa', 'socialiqa', 'date', 'ruin_names', 'causal_judgement',
    'reasoning_about_colored_objects', 'tracking_shuffled_objects', 'logical_deduction', 'coin_flips', 'last_letters',
]

PALM_SAFETY_SETTINGS = [
    SafetySettingDict({
        "category": HarmCategory.HARM_CATEGORY_UNSPECIFIED,
        "threshold": HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }),              
    SafetySettingDict({
        "category": HarmCategory.HARM_CATEGORY_DEROGATORY,
        "threshold": HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }),
    SafetySettingDict({
        "category": HarmCategory.HARM_CATEGORY_TOXICITY,
        "threshold": HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }),
    SafetySettingDict({
        "category": HarmCategory.HARM_CATEGORY_VIOLENCE,
        "threshold": HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }),
    SafetySettingDict({
        "category": HarmCategory.HARM_CATEGORY_SEXUAL,
        "threshold": HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }),
    SafetySettingDict({
        "category": HarmCategory.HARM_CATEGORY_MEDICAL,
        "threshold": HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }),
    SafetySettingDict({
        "category": HarmCategory.HARM_CATEGORY_DANGEROUS,
        "threshold": HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }),
]

DATASET_TYPE = {
    "gsm8k": "num",
    "gsm-ic": "num",
    "svamp": "num",
    "aqua": "mc",
    
    "csqa": "mc",
    "sqa": "yn",
    "socialiqa": "mc",
    "date": "mc",
    "causal_judgement": "yn",
    "ruin_names": "mc",

    "reasoning_about_colored_objects": "mc", 
    "logical_deduction": "mc",
    "coin_flips": "yn",
    "last_letters": "word",
    "tracking_shuffled_objects": "mc",
}