import numpy as np
import argparse
import sys
import os
import random
import pandas as pd
# import my package
from langchain.prompts import ChatPromptTemplate
# from dataset_llm_templates import langchain_templates
from template_oracle_features import langchain_templates_oracle_feature, parser_answer
import utils
from data import get_dataset_few_shot, label2num, XY2df
from utils_LLM import query_gpt_txt


def get_args(command=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str)
    parser.add_argument('--query_times', type=int, default=30)
    if command is not None:
        args = parser.parse_args(command.split())
    else:
        args = parser.parse_args()
    return args
def generate_data_gpt(df_support,info,dataset,cat_condidates,feature_name=None):

    num_cnt = 5
    df_empty = pd.DataFrame({col: pd.Series(dtype=df_support[col].dtype) for col in df_support.columns})
    prompt, task, feature_desc, format_instructions = langchain_templates_oracle_feature(df_empty,info, dataset, cat_condidates,feature_name, num_cnt)
    messages = prompt.format_messages(task=task, feature_desc=feature_desc, format_instructions=format_instructions)  

    system_prompt = messages[0].content.split('\n')[0]
    user_prompt = "\n".join(messages[0].content.split('\n')[1:])
    answer = query_gpt_txt(text_list=[user_prompt],system_txt=system_prompt)[0]
    answer_dict = parser_answer(answer,info, cat_categories=cat_condidates,feature_name=feature_name, num_cnt=num_cnt)
    return answer_dict, answer
def get_file_prefix(args):
    return f'{args.data}'
    
def get_feature_proto(info,cat_condidates,dataname, cols, gpt_shot,example='False',description='True',num_cnt=5):
    assert example == 'False', 'example should be \'False\''
    assert description == 'True', 'description should be \'True\''
    x_gpt = {}
    for feature_name in cols:
        x_gpt[feature_name] = {}
        totoal_shot = len(os.listdir(f'./oracle_features/{dataname}/{feature_name}'))
        gpt_shot_list = list(range(totoal_shot))
        random.shuffle(gpt_shot_list)
        gpt_shot_list = gpt_shot_list[:gpt_shot]
        for shot in gpt_shot_list:
            assert os.path.exists(f'./oracle_features/{dataname}/{feature_name}/{dataname}-{example}-{description}_{num_cnt}_{shot}.txt'), f'./oracle_features/{dataname}/{feature_name}/{dataname}-{example}-{description}_{num_cnt}_{shot}.txt does not exist'
            with open(f'./oracle_features/{dataname}/{feature_name}/{dataname}-{example}-{description}_{num_cnt}_{shot}.txt','r') as f:
                answer = f.read()
            answer_dict = parser_answer(answer, info, cat_categories=cat_condidates,feature_name=feature_name, num_cnt=num_cnt)
            for target, feature_value_list in answer_dict.items():
                if x_gpt[feature_name].get(target) is None:
                    x_gpt[feature_name][target] = []
                x_gpt[feature_name][target].append(feature_value_list)
                # x_gpt[feature_name][target] += answer_dict['left']
    return x_gpt

if __name__ == '__main__':
    args = get_args()
    _SHOT = 1
    _SEED = 0
    _DATA = args.data
    query_times = args.query_times
    df, info, X, y = get_dataset_few_shot(_DATA, _SHOT, _SEED)
    X,y = XY2df(X,y,info)
    N_cols = info['N_cols']
    C_cols = info['C_cols']

    # remove the c_cols that only have one unique value
    C_cols = [col for col in C_cols if len(df[col].unique()) > 1]
    info['C_cols'] = C_cols



    target = info['target']
    X_n_support = X['support_x_n'] if _SHOT >= 1 else None
    X_c_support = X['support_x_c'] if _SHOT >= 1 else None
    y_support = y['support_y'] if _SHOT >= 1 else None



    cat_condidates = {}
    for col in C_cols:
        cat_condidates[col] = list(df[col].unique())
    cat_condidates[target] = list(df[target].unique())
    df_support = pd.concat([X_n_support,X_c_support,y_support],axis=1)

    for feature_name in N_cols+C_cols:
        if not os.path.exists(f'oracle_features/{_DATA}/{feature_name}'):
            os.makedirs(f'oracle_features/{_DATA}/{feature_name}')
        prefix = get_file_prefix(args)
        for proto_cnt in range(query_times):
            
            if os.path.exists(f'oracle_features/{_DATA}/{feature_name}/{prefix}_{proto_cnt}.txt'):
                print(f'oracle_features/{_DATA}/{feature_name}/{prefix}_{proto_cnt}.txt exists')
                continue
            for try_cnt in range(10):
                print(f'{_DATA} {feature_name} {proto_cnt} {try_cnt}')
                try:
                    feature_dict,feature_answer = generate_data_gpt(df_support,info,_DATA,cat_condidates,feature_name=feature_name)
                    print(feature_dict)
                    with open(f'oracle_features/{_DATA}/{feature_name}/{prefix}_{proto_cnt}.txt','w') as f:
                        f.write(feature_answer)
                    break
                except Exception as e:
                    print(e)
                    pass
  