import numpy as np
import argparse
import sys
import os
import pandas as pd
import json
# import my package

# from dataset_llm_templates import langchain_templates
from template_weights import langchain_templates_weights, parser_weights_answer
import utils
from data import get_dataset_few_shot
from utils_LLM import query_gpt_txt
def get_args(command=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str)
    parser.add_argument('--query_times', type=int, default=10)
    if command is not None:
        args = parser.parse_args(command.split())
    else:
        args = parser.parse_args()
    return args

def get_weights_gpt(info,dataset,cat_condidates):
    prompt, format_instructions, task_desc, feature_desc = langchain_templates_weights(info, dataset, cat_condidates)
    messages = prompt.format_messages(format_instructions=format_instructions, task_desc=task_desc, feature_desc=feature_desc)
    
    system_prompt = messages[0].content.split('\n')[0]
    user_prompt = "\n".join(messages[0].content.split('\n')[1:])
    answer = query_gpt_txt(text_list=[user_prompt],system_txt=system_prompt)[0]
    weights = parser_weights_answer(answer,info, cat_categories=cat_condidates)
    return weights
def get_file_prefix(args):
    return f'{args.data}'
if __name__ == '__main__':
    args = get_args()
    _SHOT = 0
    _SEED = 0
    _DATA = args.data

    df, info, X, y = get_dataset_few_shot(_DATA, _SHOT, _SEED)
    N_cols = info['N_cols']
    C_cols = info['C_cols']
    target = info['target']


    cat_condidates = {}
    for col in C_cols:
        cat_condidates[col] = list(df[col].unique())
    cat_condidates[target] = list(df[target].unique())
    for time in range(args.query_times):
        if os.path.exists(f'./gpt_weights/{_DATA}/{get_file_prefix(args)}_{time}.json'):
            continue
        for try_cnt in range(5):
            try:
                weights = get_weights_gpt(info,_DATA,cat_condidates)
                if os.path.exists(f'./gpt_weights/{_DATA}') == False:
                    os.makedirs(f'./gpt_weights/{_DATA}')
                json.dump(weights, open(f'./gpt_weights/{_DATA}/{get_file_prefix(args)}_{time}.json','w'), indent=4)
                break
            except:
                print('error')
                pass
    