import json
import os
from datetime import datetime
from tqdm import tqdm
from vllm import LLM, SamplingParams
from prompts import icl_sys_prompt_score, icl_user_prompt_score, icl_ass_prompt, last_prompt

os.environ["HIP_VISIBLE_DEVICES"] = "0,1,2,3"

def get_messages_for_sample(sample_info):
    scored_triples = sample_info.get('scored_triples', [])
    triples = [  
        (h, r, t, conf) for h, r, t, conf in scored_triples
    ] if scored_triples else []
    triples_str = "\n".join([  
        f"({h}, {r}, {t})[Confidence: {conf}]"
        for h, r, t, conf in triples
    ])

    messages = [
        {"role": "system", "content": icl_sys_prompt_score},
        {"role": "user", "content": icl_user_prompt_score},
        {"role": "assistant", "content": icl_ass_prompt},
        {"role": "user", "content": f"Input Triplets:\n{triples_str}\nQuestion: {sample_info['question']}\n{last_prompt}"}
    ]
    return messages

if __name__ == "__main__":
    ts = datetime.now().strftime('%b%d-%H:%M:%S:%f')[:-3]
    model_path = "reason/llama/Meta-Llama-3.1-8B-Instruct"
    file_path = 'retrieveresult/cwq/retrieval_result_100.json'
    with open(file_path, 'r', encoding='utf-8') as f:
        retrieval_data = json.load(f)

    llm = LLM(
        model=model_path,
        tensor_parallel_size=4,
        max_seq_len_to_capture=8192,
    )

    temperatures = [0]
    top_p_values = [1.0]
    frequency_penalties = [0]

    for temp in temperatures:
        for top_p in top_p_values:
            for f_p in frequency_penalties:
                output_file_path = f'outputs/cwq/xxx_temp{temp}_top{top_p}_f{f_p}_{ts}_scored.json'

                sampling_params = SamplingParams(
                    temperature=temp,
                    max_tokens=4000,
                    frequency_penalty=f_p,
                    top_p=top_p,
                )

                processed_data = {}

                for sample_id, sample_info in tqdm(retrieval_data.items(), desc=f"Generating responses (temp={temp}, top_p={top_p}, f_p={f_p})"):
                    messages = get_messages_for_sample(sample_info)
                    output = llm.chat(messages=messages, sampling_params=sampling_params)
                    text = output[0].outputs[0].text.strip()

                    answers = [line.strip() for line in text.split('\n') if line.strip().lower().startswith('ans:')]
                    processed_data[sample_id] = {
                        'question': sample_info.get('question'),
                        'a_entity': sample_info.get('a_entity'),
                        'answers': answers,
                    }

                output_dir = os.path.dirname(output_file_path)
                os.makedirs(output_dir, exist_ok=True)

                try:
                    with open(output_file_path, 'w', encoding='utf-8') as f:
                        json.dump(processed_data, f, ensure_ascii=False, indent=4)
                    print(f"Parameter combination temp={temp}, top_p={top_p}, f_p={f_p} The results have been saved to {output_file_path}")
                except Exception as e:
                    print(f"Save parameter combination temp={temp}, top_p={top_p}, f_p={f_p} An error occurred while accessing the data:{e}")
