import json, os, re, requests, time, random, csv, sys, argparse
from openai import OpenAI
#from bert_score import score as bert_score
from sklearn.metrics import precision_score, recall_score, f1_score
from concurrent.futures import ThreadPoolExecutor


PROMPT_TEMPLATE = """
The following paragraph contains the reasoning process and answer(s) to a specific question. Please read and understand the paragraph, and extract the answer(s) from it.

### Instructions:
1.Extract the answer(s) only from the paragraph. Do not fabricate or infer any information that is not explicitly stated.
2.Output the extracted answer(s) in a list format, using English quotation marks and commas.
Your output format must be: ["<Answer1>", "<Answer2>", "<Answer3>"]. If the paragraph does not mention any answer, please output an empty list: "[]". 

### Paragraph:
{model_res}

### Attention:
# ** You only need to output the answer list !!! Do not output anything else.**
"""

def query_llm(prompt):
    print("Query......")
    openai_api_key = "EMPTY"
    openai_api_base = "http://localhost:9104/v1"
    model_name = "Qwen-2.5-72B-Instruct"

    client = OpenAI(    
        api_key=openai_api_key,   
        base_url=openai_api_base,   
    ) 
    while(True):
        try:
            response = client.completions.create(   
                model=model_name,     
                prompt=prompt,    
                stream=False,
                max_tokens=1024
            )
            # model_response = response['choices'][0]['message']['content']
            model_response = response.choices[0].text
            print(model_response)
            #print(response)
            break  # 成功获取数据，跳出循环
        except Exception as e:
            print(f"{e},请求失败，进行重试...")
            time.sleep(random.uniform(1, 3))  
        
    return model_response

def query_4omini(prompt):
    print("Query......")
    openai_api_key = "your openai api key"
    openai_api_base = "your openai base url"
    model_name = "gpt-4o-mini"

    client = OpenAI(    
        api_key=openai_api_key,   
        base_url=openai_api_base,   
    ) 
    while(True):
        try:
            response = client.completions.create(   
                model=model_name,     
                prompt=prompt,    
                stream=False,
                max_tokens=1024
            )
            # model_response = response['choices'][0]['message']['content']
            model_response = response.choices[0].text
            print(model_response)
            #print(response)
            break  # 成功获取数据，跳出循环
        except Exception as e:
            print(f"{e},请求失败，进行重试...")
            time.sleep(random.uniform(1, 3))  
        
    return model_response
    
def query_gpt4o(prompt):
    api_options = [
        {
            "url": "",
            "apitoken": "your token"
        },
        {
            "url": "",
            "apitoken": "your token"
        }
    ]

    choice = random.choice(api_options)
    url = choice["url"]
    apitoken = choice["apitoken"]
    data = {
        "messages": [{"role": "user", "content": prompt}]
    }
    headers = {
        #'Authorization': 'Bearer ' + apitoken,  
        'api-key': apitoken,
        'Content-Type': 'application/json'
    }
    retry = 0
    model_response = None
    
    while retry < 5:
        try:
            response = requests.post(url, json=data, headers=headers, timeout=20)  # 设置 10s 超时
            response.close()
            if response.status_code == 200:
                response_data = response.json()
                model_response = response_data['choices'][0]['message']['content']
                break  # 成功获取数据，跳出循环
            elif response.status_code == 429:  # 处理 API 速率限制
                print("请求频率超限, 进行重试...")
                time.sleep(random.uniform(5, 10))
            else:
                print(f"请求失败，状态码: {response.status_code}, 进行重试...")
        except requests.exceptions.Timeout:
            print("请求超时，进行重试...")
        except requests.exceptions.ConnectionError:
            print("网络连接错误，进行重试...")
        except Exception as e:
            print(f"发生未知错误: {e}，进行重试...")
        retry += 1
        time.sleep(random.uniform(0.5 + 1 * retry, 1.5 + 1 * retry))
    return model_response
    
    
def count_metrics(answer_list, ground_truth_list):
    """
    Count:
    - Accuracy: 1 if all ground truth terms are matched, else 0
    - Precision: intersection(answer, ground_truth) / len(answer)
    - Recall: intersection(answer, ground_truth) / len(ground_truth)
    - F1: harmonic mean of precision and recall
    - BERTScore: semantic similarity between joined strings
    """
    answer_set = set([a.lower() for a in answer_list])
    ground_truth_set = set([g.lower() for g in ground_truth_list])
    
    answer_list_single = list(answer_set)
    ground_truth_list_single = list(ground_truth_set)
    intersection = answer_set.intersection(ground_truth_set)
    accuracy = 1.0 if ground_truth_set.issubset(answer_set) and len(answer_set) == len(ground_truth_set) else 0.0
    precision = len(intersection) / len(answer_set) if answer_set else 0.0
    recall = len(intersection) / len(ground_truth_set) if ground_truth_set else 0.0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    # BERTScore: 以句子形式连接后再比较
    # answer_text = ','.join(answer_list_single)
    # ground_truth_text = ','.join(ground_truth_list_single)
    # P, R, F1 = bert_score([answer_text], [ground_truth_text], lang="en", verbose=False)
    # bert_f1 = F1[0].item()
    
    return accuracy, precision, recall, f1

def check_answer_list_format(res_answer_list):
    if not isinstance(res_answer_list, (str, bytes, bytearray)):
        return False
    try:
        res_list = json.loads(res_answer_list)
        return isinstance(res_list, list) and all(isinstance(item, str) for item in res_list)
    except json.JSONDecodeError:
        return False
    
def extract_answer_list(model_res):
    res_list = []
    for i in range(5):
        prompt = PROMPT_TEMPLATE.format(model_res=model_res)
        model_res_list = query_4omini(prompt)
        if check_answer_list_format(model_res_list):
            res_list = json.loads(model_res_list)
            break
        else:
            print(f"Retry {i+1}: Invalid format, trying again...")
    return res_list
    
def main(log_file):
    save_dir = "log/qa_format"
    qa_log_dir = "log/qa_test_log_new"
    writer_path = os.path.join(save_dir, log_file)
    qa_log_path = os.path.join(qa_log_dir, log_file)
    
    """ 已经处理过的 """
    processed_ids = set()
    if os.path.exists(writer_path):
        with open(writer_path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        id_value = data.get("id")
                        if id_value is not None:
                            processed_ids.add(id_value)
                    except json.JSONDecodeError:
                        continue
                    
    with open(qa_log_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                line_data = json.loads(line)
                id = line_data["id"]
                if id in processed_ids:
                    continue  # 跳过已处理的 ID

                ques = line_data["question"]
                true_answer_list = line_data["true_answer"]
                model_res = line_data["model_response"]
            except (KeyError, json.JSONDecodeError):
                continue

            res_list = extract_answer_list(model_res)
            accuracy, precision, recall, f1  = count_metrics(res_list, true_answer_list)
            writer = {
                "id": id,
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "question": ques,
                "true_answer": true_answer_list,
                "extracted_answer": res_list,  
                "model_response": model_res,
            }
            with open(writer_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(writer, ensure_ascii=False) + "\n")
                
    """ 统计结果 """
    total_accuracy = total_precision = total_recall = total_f1 = total_bert_f1 = 0.0
    count = 0
    with open(writer_path, "r", encoding="utf-8") as f: 
        for line in f:
            if not line.strip():
                continue
            try:
                data = json.loads(line)
                total_accuracy += data["accuracy"]
                total_precision += data["precision"]
                total_recall += data["recall"]
                total_f1 += data["f1"]
                #total_bert_f1 += data["bert_f1"]
                count += 1
            except json.JSONDecodeError:
                continue
    with open("results.txt", "a", encoding="utf-8") as f:
        f.write(f"******************** File: {log_file} ********************\n")
        f.write(f"Total Count: {count}\n")
        f.write(f"Average Accuracy: {total_accuracy / count:.4f}\n")
        f.write(f"Average Precision: {total_precision / count:.4f}\n")
        f.write(f"Average Recall: {total_recall / count:.4f}\n")
        f.write(f"Average F1: {total_f1 / count:.4f}\n")
        f.write(f"Average BERT F1: {total_bert_f1 / count:.4f}\n")
        f.write("\n")
      
def minitest(dataset):
    step_log_dir = f"log/test_log/qwen3/ours/{dataset}"   
    save_dir = "log/qa_format"
    log_file_list = os.listdir(step_log_dir)
    total_accuracy = total_precision = total_recall = total_f1 = total_bert_f1 = 0.0
    count = 0
    for log_file in log_file_list:
        with open(os.path.join(step_log_dir, log_file), "r", encoding="utf-8") as f:
            step_data = json.load(f)
        last_step = step_data[-1] if step_data else None
        if not last_step:
            print(f"No valid steps found in {log_file}")
            continue
        answer_list = []
        ground_truth_list = last_step["true_answer"]
        action = last_step["extract_res"]["Action"]
        if action == "Finish":
            answer_list = last_step["extract_res"]["Objects"]
        acc, prec, rec, f1, bert = count_metrics(answer_list, ground_truth_list)
        total_accuracy += acc
        total_precision += prec
        total_recall += rec
        total_f1 += f1
        total_bert_f1 += bert
        count += 1
        with open(os.path.join(save_dir, f"ours_step_{dataset}.jsonl"), "a", encoding="utf-8") as f:
            writer = {
                "id": log_file.split('.')[0],
                "accuracy": acc,
                "precision": prec,
                "recall": rec,
                "f1": f1,
                "bert_f1": bert,
                "question": last_step["question"],
                "true_answer": ground_truth_list,
                "extracted_answer": answer_list,  
                "model_response": last_step["model_response"],
            }
            f.write(json.dumps(writer, ensure_ascii=False) + "\n")
            
    with open("results.txt", "a", encoding="utf-8") as f:
        f.write(f"******************** File: {step_log_dir} ********************\n")
        f.write(f"Total Count: {count}\n")
        f.write(f"Average Accuracy: {total_accuracy / count:.4f}\n")
        f.write(f"Average Precision: {total_precision / count:.4f}\n")
        f.write(f"Average Recall: {total_recall / count:.4f}\n")
        f.write(f"Average F1: {total_f1 / count:.4f}\n")
        f.write(f"Average BERT F1: {total_bert_f1 / count:.4f}\n")
        f.write("\n")

def test_baseline_lightrag(model, dataset):
    save_dir = "log/qa_format"
    log_dir = f"log/lightrag_baseline_test_log/{model}/{dataset}"
    log_file_list = os.listdir(log_dir)

    writer_path = os.path.join(save_dir, f"{model}_lightrag_{dataset}.jsonl")
    for idx, log_file in enumerate(log_file_list):
        print(f"Processing {idx}/{len(log_file_list)}")
        try:
            with open(os.path.join(log_dir, log_file), "r", encoding="utf-8") as f:
                test_data = json.load(f)
        except json.JSONDecodeError:
            print(f"Error decoding JSON in file: {log_file}")
            continue
        ques = test_data["question"]
        true_answer_list = test_data["answer"]
        try:
            model_res = test_data["global"]
        except KeyError:
            model_res = " "
        res_list = extract_answer_list(model_res)
        accuracy, precision, recall, f1  = count_metrics(res_list, true_answer_list)
        writer = {
            "id": log_file.split('.')[0],
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "question": ques,
            "true_answer": true_answer_list,
            "extracted_answer": res_list,  
            "model_response": model_res,
        }
        try:
            with open(writer_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(writer, ensure_ascii=False) + "\n")
        except IOError as e:
            writer = {
                "id": log_file.split('.')[0],
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "question": ques,
                "true_answer": true_answer_list,
                "extracted_answer": res_list,  
            }
            with open(writer_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(writer, ensure_ascii=False) + "\n")


    """ 统计结果 """
    total_accuracy = total_precision = total_recall = total_f1 = 0.0
    count = 0
    with open(writer_path, "r", encoding="utf-8") as f: 
        for line in f:
            if not line.strip():
                continue
            try:
                data = json.loads(line)
                total_accuracy += data["accuracy"]
                total_precision += data["precision"]
                total_recall += data["recall"]
                total_f1 += data["f1"]
                #total_bert_f1 += data["bert_f1"]
                count += 1
            except json.JSONDecodeError:
                continue
    with open("results.txt", "a", encoding="utf-8") as f:
        f.write(f"******************** File: {model}_lightrag_{dataset}.jsonl ********************\n")
        f.write(f"Total Count: {count}\n")
        f.write(f"Average Accuracy: {total_accuracy / count:.4f}\n")
        f.write(f"Average Precision: {total_precision / count:.4f}\n")
        f.write(f"Average Recall: {total_recall / count:.4f}\n")
        f.write(f"Average F1: {total_f1 / count:.4f}\n")
        f.write("\n")

def test_metrics(file_path):
    total_accuracy = total_precision = total_recall = total_f1 = 0.0
    count = 0
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                data = json.loads(line)
                total_accuracy += data["accuracy"]
                total_precision += data["precision"]
                total_recall += data["recall"]
                total_f1 += data["f1"]
                count += 1
            except json.JSONDecodeError:
                continue
    print(f"Total Count: {count}")
    print(f"Average Accuracy: {total_accuracy / count:.4f}")
    print(f"Average Precision: {total_precision / count:.4f}")
    print(f"Average Recall: {total_recall / count:.4f}")
    print(f"Average F1: {total_f1 / count:.4f}")

def ours_test_metrics(log_dir):
    log_file_list = os.listdir(log_dir)
    total_accuracy = total_precision = total_recall = total_f1 = 0.0
    count = 0
    for log_file in log_file_list:
        log_file_path = os.path.join(log_dir, log_file)
        with open(log_file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        try:
            last_step = data[-1] 
        except IndexError:
            print(f"No valid steps found in {log_file}")
            continue
        true_answer = last_step["true_answer"]
        action = last_step["extract_res"]["Action"]
        model_answer = last_step["extract_res"]["Objects"] if action == "Finish" else []
        accuracy, precision, recall, f1 = count_metrics(model_answer, true_answer)
        total_accuracy += accuracy
        total_precision += precision
        total_recall += recall
        total_f1 += f1
        count += 1
    print(f"Total Count: {count}")
    print(f"Accuracy: {total_accuracy / count:.4f}")
    print(f"F1: {total_f1 / count:.4f}")

def find_case_study():

    ours_file = "log/qa_format/qwen_ours_metaqa_3hop.jsonl"
    file_lists = []
    ours_id_list = []
    with open(ours_file, 'r', encoding='utf-8') as f:
        for line in f:
            if not line.strip():
                continue
            data = json.loads(line.strip())
            if data["accuracy"] == 1.0:
                ours_id_list.append(data["id"])
    
    true_set_list = []
    

    for idx, file in enumerate(file_lists):
        true_set = set()
        with open(file, 'r', encoding='utf-8') as f:
            data = json.loads(line.strip())
            if data["accuracy"] == 1.0:
                true_set.add(data["id"])
        true_set_list.append(true_set)

    for idd in ours_id_list:
        flag = True
        for sset in true_set_list:
            if idd in sset:
                flag = False
                break
        if flag:
            print(idd)

        




    




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="QA experiment")
    parser.add_argument('--file', type=str, required=True, help='path to the file', default='tuned_ours_webqsp.jsonl')
    args = parser.parse_args()
    log_file = args.file
    main(log_file) 

    
                    
                
            