import openai

from azure.identity import AzureCliCredential

from data_utils import pre_process_dataset
from label_cost_data import label_cost_data_main


NUM_PERTURBS = 999   # number of datapoints to use
TEMP = 1.0           # randomness of LLM
CONNECTION_MUL = 10  # number of extra connections per node
SEED = 0
DATA_SIZE = 1000


def main():
    
    for PROMPT_TYPE in ['normal', 'custom']:
        for dataset in ['heloc', 'adult', 'german_credit']:

            X, y, means, std, encoder = pre_process_dataset(SEED, dataset, DATA_SIZE)
            print('Starting caching...')
            label_cost_data_main(SEED, dataset, X, means, std, NUM_PERTURBS, TEMP, CONNECTION_MUL, PROMPT_TYPE)
            
            
if __name__ == '__main__':
    credential = AzureCliCredential()
    openai_token = credential.get_token("???")
    openai.api_key = openai_token.token
    openai.api_base = "???"    # required
    openai.api_type = "???"    # required
    openai.api_version = "???" # change as needed
    main()

