import requests
import json

import ast
# from load_dataset import load_doi
import os
from datetime import datetime
# from datasets import load_dataset
import pandas as pd
import json

def load_industry_dataset():
    # 读取dataset/market/test_inputs.json文件
    dataset_path = "dataset/industry/input.json"
    industry_data = []
    with open(dataset_path, "r", encoding="utf-8") as f:
        data = json.load(f)
        for item in data:
            industry_data.append(str(item))
    return industry_data


# 你的 API Key（从 Coze 平台获取）
API_KEY = "pat_AiScEX8faBuuwaR1mHSuNmA7UnszdqfR7gbHhy6tC3W1iNLA6x74kkNgNGx5mZha"

# 你的 Workflow ID
WORKFLOW_ID = "7533190968968921122"  # 示例，替换为你的 workflow_id

# API 地址
URL = "https://api.coze.cn/v1/workflow/stream_run"

# 请求头
headers = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}


def run_industry_dataset():
    industry_data = load_industry_dataset()
    
    for i in range(len(industry_data)):
        payload = {
            "workflow_id": WORKFLOW_ID,
            "parameters": {
                "input": industry_data[i]
            },
            "additional_messages": [
                {
                    "content_type": "text",
                    "role": "user",
                    "type": "question",
                    "content": industry_data[i]
                }
            ]
        }
        # 获取 SSE 响应流
        response = requests.post(URL, headers=headers, json=payload, stream=True)
        print(response.text)
        # 设置响应编码为 UTF-8
        response.encoding = 'utf-8'

        result = []
        for line in response.iter_lines():
            if line:
                # 手动解码为UTF-8
                line_str = line.decode('utf-8')
                if line_str.startswith("data: "):
                    try:
                        data = json.loads(line_str[6:])
                        if "content" in data:
                            result.append(data["content"])
                    except json.JSONDecodeError:
                        continue

        final_text = "".join(result)
        print("运行成功")
        #print(final_text)
        final_result = {
            'id': i,
            'question': industry_data[i],
            'answer': final_text
        }

        # 确保输出目录存在
        output_dir = 'output/industry'
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        # 生成输出文件名
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(output_dir, f'industry_log_{i}.json')
        # 将结果写入json文件
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(final_result, f, ensure_ascii=False, indent=2)
        print("运行成功，结果已写入JSON文件：", output_file)

run_industry_dataset()