import concurrent
import json
import os.path
from openai import OpenAI
import httpx
from tqdm import tqdm


def prepare_input_data(data, include_reasoning=False):
    facts = "Facts:\n"
    for i, fact in enumerate(data.get('facts-tuned-nl', [])):
        facts += f"{i + 1}. {fact}\n"
    rules = "Rules:\n"
    for i, rule in enumerate(data.get('rules-tuned-nl', [])):
        rules += f"{i + 1}. {rule}\n"
    query_entity, query_attribute = data.get('query', (None, None))
    query = f"Query:\nWhat is the value of {query_entity}'s {query_attribute}?\n"
    if not include_reasoning:
        return facts + rules + query
    intermediate_results = "After a detailed explanation, you would conclude as follows.\n"
    reasoning_process = "Reasoning:\n"
    reasoning_process += data.get("reasoning_process_nl", "") + "\n"
    answer = data.get("values", {}).get(
        query_entity, {}).get(query_attribute, "")
    answer = f"Answer: \\boxed{{{answer}}}\n"
    return facts + rules + query + intermediate_results + reasoning_process + answer


def load_data(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        return [json.loads(line.strip()) for line in f]


def load_prompt(data_name, example_num):
    with open("../prompt/instruction.txt", 'r', encoding='utf-8') as f:
        instruction = f.read()
    if example_num > 0:
        few_shot_data = load_data(f"../prompt/{data_name}.jsonl")
        assert len(
            few_shot_data) >= example_num, f"Not enough examples for {data_name}"
        few_shot_data = few_shot_data[:example_num]
        example_str = "Here are some examples:\n"
        example_str += "\n".join(list(map(lambda x: prepare_input_data(x,
                                 include_reasoning=True), few_shot_data)))
        example_str += "\nPlease follow the same format to conclude the answer at last:\n"
        instruction += example_str
    return instruction


def load_datasets(data_name):
    data_path = f"../data/{data_name}.jsonl"
    return load_data(data_path)


def call_api(model, client, data):
    id, msg = data
    response = client.chat.completions.create(
        model=model,
        messages=msg,
        stream=False,
        timeout=1200
    )
    return (id, response)


def process_questions_multithreaded(model, client, concurrence, datas):
    answers = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=concurrence) as executor:
        futures = [executor.submit(call_api, model, client, data)
                   for data in datas]

        for future in tqdm(concurrent.futures.as_completed(futures), total=len(datas)):
            try:
                answer = future.result()
                answers.append(answer)
            except Exception as e:
                print(f"An error occurred: {e}")

    return answers


def pack_message(instr, question):
    return [
        {"role": "system", "content": instr},
        {"role": "user", "content": question},
    ]


def normal_api_request(key, url, model, concurrence, shot_num=0, MAX_PROCESS_NUM=1, all_dataset_name=['el-en', 'el-hn', 'hl-en', 'hl-hn'], results_dir='results'):
    output_file = f"{results_dir}/{model}-shot{shot_num}.raw.json"
    print(key, url, model, output_file)
    # use proxy
    client = OpenAI(
        api_key=key,
        base_url=url,
        timeout=1200
    )

    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            results = json.load(f)
    else:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        results = {}

    for data_name in tqdm(all_dataset_name, desc="Processing datasets"):
        data = load_datasets(data_name)
        if MAX_PROCESS_NUM > 0:
            data = data[:MAX_PROCESS_NUM]
        prompt = load_prompt(data_name, shot_num)

        if data_name not in results:
            results[data_name] = {}

        processed_ids = set(results[data_name].keys()
                            if data_name in results else set())
        data = [item for item in data if str(item['id']) not in processed_ids]
        print(
            f"For dataset {data_name}: Processed {len(processed_ids)} items, remaining {len(data)} items to process.")

        if not data:
            print(f"Dataset {data_name} already fully processed, skipping...")
            continue

        datas_to_process = []
        for item in data:
            user_prompt = prepare_input_data(item, include_reasoning=False)
            datas_to_process.append(
                (item['id'], pack_message(prompt, user_prompt)))

        raw_results = process_questions_multithreaded(
            model, client, concurrence, datas_to_process)

        for dataid, result_response in raw_results:
            results[data_name][str(dataid)] = result_response.dict() if hasattr(
                result_response, 'dict') else result_response

        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)

        print(f"Completed processing dataset {data_name}")

    print("All datasets processed successfully!")
    return output_file


def raw2text(raw_file, dataset_names):
    if not os.path.exists(raw_file):
        raise FileNotFoundError(f'{raw_file} not found')
    success = True
    output_file = raw_file.replace('.raw.json', '.json')
    raw_results = json.load(open(raw_file, 'r', encoding='utf-8'))
    results = {}
    for data_name in dataset_names:
        data = load_datasets(data_name)
        results[data_name] = []
        for item in data:
            id_str = str(item['id'])
            if data_name in raw_results and id_str in raw_results[data_name]:
                response = raw_results[data_name][id_str]
                if isinstance(response, dict) and 'choices' in response:
                    llm_output = response['choices'][0]['message']['content']
                else:
                    llm_output = str(response)
                item['llm_output'] = llm_output
                results[data_name].append(item)
            else:
                success = False
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=4)

    return success


if __name__ == '__main__':
    raw = normal_api_request("api-xxx", "url-xxx", "model-xxx",
                             concurrence=8, shot_num=0, MAX_PROCESS_NUM=-1)
    success = raw2text(raw, ["el-en", "el-hn", "hl-en", "hl-hn"])

    raw = normal_api_request("api-xxx", "url-xxx", "model-xxx",
                             concurrence=8, shot_num=3, MAX_PROCESS_NUM=-1)
    success = raw2text(raw, ["el-en", "el-hn", "hl-en", "hl-hn"])
