import requests
import json

import ast
# from load_dataset import load_doi
import os
from datetime import datetime
from datasets import load_dataset
import pandas as pd


def load_travel_dataset():
    ds = load_dataset("osunlp/TravelPlanner", "validation") 
    validation_ds = ds['validation']           # 获取 validation split
    travel_dataset = []
    for i in range(len(validation_ds)):
        item = validation_ds[i]
        date_list = ast.literal_eval(item['date'])
        travel_dataset.append({
            'departure_city': item['org'],
            'destination_city': item['dest'],
            'departure_date': date_list[0],
            'days': item['days'],
            'budget': '$' + str(item['budget']),
            'traveler_count': item['people_number'],
            'destination_country': 'America',
            'return_date': date_list[-1],
            'query': item['query']
        })
    return travel_dataset


# 你的 API Key（从 Coze 平台获取）
API_KEY = "pat_AiScEX8faBuuwaR1mHSuNmA7UnszdqfR7gbHhy6tC3W1iNLA6x74kkNgNGx5mZha"

# 你的 Workflow ID
WORKFLOW_ID = "7532689664160923711"  # 示例，替换为你的 workflow_id

# API 地址
URL = "https://api.coze.cn/v1/workflows/chat"

# 请求头
headers = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

def main():
    travel_dataset = load_travel_dataset()
    for i in range(156, len(travel_dataset)):
        item = travel_dataset[i]
        departure_city = item['departure_city']
        destination_city = item['destination_city']
        depature_date = item['departure_date']
        days = item['days']
        budget = item['budget']
        traveler_count = item['traveler_count']
        destination_country = item['destination_country']
        return_date = item['return_date']
        query = item['query']
        payload = {
            "workflow_id": WORKFLOW_ID,
            "parameters": {
                "destination": destination_city,
                "departure": departure_city,
                "days_num": days,
                "people_num": traveler_count,
                "start_date": depature_date,  # 可选
                "travel_theme": "travel"    # 可选
            },
            "additional_messages": [
                {
                    "content_type": "text",
                    "role": "user",
                    "type": "question",
                    "content": query
                }
            ]
        }
        # 获取 SSE 响应流
        response = requests.post(URL, headers=headers, json=payload, stream=True)
        print(response.text)
        # 设置响应编码为 UTF-8
        response.encoding = 'utf-8'

        result = []
        for line in response.iter_lines():
            if line:
                # 手动解码为UTF-8
                line_str = line.decode('utf-8')
                if line_str.startswith("data: "):
                    try:
                        data = json.loads(line_str[6:])
                        if "content" in data:
                            result.append(data["content"])
                    except json.JSONDecodeError:
                        continue

        final_text = "".join(result)
        print("运行成功，原始返回内容如下：")
        print(final_text)
        final_result = {
            'id': i,
            'departure_city': departure_city,
            'destination_city': destination_city,
            'departure_date': depature_date,
            'days': days,
            'budget': budget,
            'traveler_count': traveler_count,
            'answer': final_text
        }

        # 确保输出目录存在
        output_dir = 'output/travel'
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        # 生成输出文件名
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(output_dir, f'travel_log_{i}.json')
        # 将结果写入json文件
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(final_result, f, ensure_ascii=False, indent=2)
        print("运行成功，结果已写入JSON文件：", output_file)





if __name__ == "__main__":
    main()