import requests
import json

import ast
# from load_dataset import load_doi
import os
from datetime import datetime
# from datasets import load_dataset
import pandas as pd



def mutate_question(question):
    """
    使用大模型对问题进行改写但保持语义不变，产生三个改写结果
    """
    from openai import OpenAI
    
    # 初始化OpenAI客户端
    client = OpenAI(api_key = "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")   # ← 替换为你的 key
    
    prompt = f"""请对以下问题进行改写，保持语义不变，但用不同的表达方式。请生成3个改写版本。

原问题：{question}

要求：
1. 保持原问题的核心语义不变
2. 使用不同的词汇和句式结构
3. 确保改写后的问题仍然自然流畅
4. 返回格式为JSON，包含3个不同的改写结果

示例格式：
{{
    "rewrites": [
        "改写版本1",
        "改写版本2", 
        "改写版本3"
    ]
}}
"""

    try:
        response = client.chat.completions.create(
            model="gpt-4o",  # 或使用其他模型
            messages=[
                {"role": "user", "content": prompt}
            ],
            temperature=0.1
        )
        
        result = response.choices[0].message.content 
        # 解析JSON结果，处理markdown格式
        import json
        import re
        
        # 提取JSON部分（去除markdown格式）
        json_match = re.search(r'```json\s*(.*?)\s*```', result, re.DOTALL)
        if json_match:
            json_str = json_match.group(1)
        else:
            # 如果没有markdown格式，直接使用原文本
            json_str = result
            
        rewrites_data = json.loads(json_str)
        return rewrites_data["rewrites"]
        
    except Exception as e:
        print(result)
        print(f"改写问题时出错: {e}")
        #return [question, question, question]  # 如果出错，返回原问题

def generate_more_QA():
    # 读取QA数据集
    qa_dataset = load_QA_dataset()
    for i in range(0, len(qa_dataset)):
        item = qa_dataset[i]
        question = item['question']
        ground_truth = item['answer']

        mutated_questions = mutate_question(question)
        # 为每个改写的问题创建新的QA对
        for rewritten_question in mutated_questions:
            # 创建新的CSV行数据
            new_row = {
                'question': rewritten_question,
                'answer': ground_truth
            }
            
            df_new = pd.DataFrame([new_row])
            df_new.to_csv("dataset/QA/coze_faq_dataset_rewritten.csv", 
                         mode='a', 
                         header=False, 
                         index=False, 
                         encoding='utf-8')
        


def load_QA_dataset():
    # 读取CSV文件
    csv_file_path = "dataset/QA/coze_faq_dataset_rewritten.csv"
    df = pd.read_csv(csv_file_path)
    
    # 将DataFrame转换为字典列表
    qa_dataset = []
    for index, row in df.iterrows():
        qa_dataset.append({
            'question': row.iloc[0],  # 第一列是问题
            'answer': row.iloc[1]     # 第二列是答案
        })
    return qa_dataset


# 你的 API Key（从 Coze 平台获取）
API_KEY = "pat_AiScEX8faBuuwaR1mHSuNmA7UnszdqfR7gbHhy6tC3W1iNLA6x74kkNgNGx5mZha"

# 你的 Workflow ID
WORKFLOW_ID = "7533871945704800292"  # 示例，替换为你的 workflow_id

# API 地址
URL = "https://api.coze.cn/v1/workflows/chat"

# 请求头
headers = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

def main():
    qa_dataset = load_QA_dataset()
    for i in range(0, len(qa_dataset)):
        item = qa_dataset[i]
        question = item['question']
        ground_truth = item['answer']
        payload = {
            "workflow_id": WORKFLOW_ID,
            "parameters": {
                # "CONVERSATION_NAME": 'name',
                # "BOT_USER_INPUT": question
            },
            "additional_messages": [
                {
                "content_type": "text",
                "role": "user",
                "type": "question",
                "content": question
                }
            ]     
        }
        # 获取 SSE 响应流
        response = requests.post(URL, headers=headers, json=payload, stream=True)
        print(response.text)
        # 设置响应编码为 UTF-8
        response.encoding = 'utf-8'

        result = []
        for line in response.iter_lines():
            if line:
                # 手动解码为UTF-8
                line_str = line.decode('utf-8')
                if line_str.startswith("data: "):
                    try:
                        data = json.loads(line_str[6:])
                        if "content" in data:
                            result.append(data["content"])
                    except json.JSONDecodeError:
                        continue

        final_text = "".join(result)
        print("运行成功")
        #print(final_text)
        final_result = {
            'id': i,
            'question': question,
            'answer': final_text,
            'ground_truth': ground_truth
        }

        # 确保输出目录存在
        output_dir = 'output/QA-rewritten'
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        # 生成输出文件名
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(output_dir, f'QA_log_{i}.json')
        # 将结果写入json文件
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(final_result, f, ensure_ascii=False, indent=2)
        print("运行成功，结果已写入JSON文件：", output_file)





if __name__ == "__main__":
    main()
    #generate_more_QA()