import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse
from src.dataset_utils.dataset import get_dataset

from tqdm import tqdm

from src.llm.utils.gpt_utils import OpenAI_API, API_Manager, OpenAI_config


def main(args):
    dataset = get_dataset(args.dataset, args.data_type)

    dataset.show_statistic_information()
    api: OpenAI_API = API_Manager('src/llm/resources/ampi.json').load().get_api_by_id(args.api_name)
    dr_samples = dataset.get_samples_for_default_reasoning(slice(None))

    print('开始运行...')
    prediction_path = args.prediction_path
    if not os.path.exists(prediction_path):
        if not os.path.exists(os.path.dirname(prediction_path)):
            os.makedirs(os.path.dirname(prediction_path))
        prediction = []
    else:
        prediction = dataset.read_prediction(prediction_path)
    prediction_IO = open(prediction_path, 'a', encoding='utf-8')

    flag_start = len(prediction)
    num_this_run = args.max_num_this_run
    if num_this_run < 0:
        flag_end = len(dr_samples)
    else:
        flag_end = flag_start + num_this_run

    print('本次运行范围:', flag_start, '~', flag_end-1)
    print('本次运行个数为:', flag_end - flag_start, '\n')

    for data in tqdm(dr_samples[flag_start:flag_end]):
        llm_result = None
        handled_result = dataset.get_result_by_LLM(data=data, api=api, args=args)
        if handled_result is not None:
            if type(handled_result) is list:
                for result in handled_result:
                    p = dataset.save_one_prediction(current_data=data, IO=prediction_IO, prediction=result,
                                                LLM_response=llm_result)
                    prediction.append(p)
            else:
                dataset.save_one_prediction(current_data=data, IO=prediction_IO, prediction=handled_result, LLM_response=llm_result)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="llm predict labels")
    parser.add_argument('--dataset', type=str, help='The name of the dataset.', default='LabelClassification')
    parser.add_argument('--data_dir_path', type=str, help='The path of dataset')
    parser.add_argument('--data_type', type=str, help='The type of dataset', default='related_word_symbolic')
    parser.add_argument("--prediction_path", type=str, help="存放预测结果的路径", required=True)
    parser.add_argument("--prompt_path", type=str, help="存放提示的文件路径，提示是json文件", required=True)
    parser.add_argument("--prompt_id", type=str, help="提示的id", required=True)
    parser.add_argument("--api_name", type=str, help="调用api的名字", required=True)
    parser.add_argument("--max_num_this_run", type=int, default=-1, help="本次运行跑几个样例，-1为所有")
    parser.add_argument("--error_extraction_count", type=int, default=3, help="错误抽取尝试")
    parser.add_argument("--temperature", default=1e-10, type=float, help="llm temperature")

    args = parser.parse_args()
    OpenAI_config['temperature'] = args.temperature
    print(args)
    main(args)