import json
import datasets
from tqdm import tqdm
from src.prompts import build_prompt
from src.graph_utils import summary

import os
from src.models import OpenModel, Zhipu, SiliconFlow
os.environ["OPENAI_API_KEY"] = "sk-tFQRvB7QXNxqNNOd0e57C28e5fD44c6aA6Bf85698dBc8d01"
os.environ["OPENAI_API_BASE"] = "https://xiaoai.plus/v1"
os.environ["SILICONFLOW_API_KEY"] = "sk-aeoampticyvnpeladzgsygfkvdmozbpdsofstmoqnvsmbvql"
os.environ["ZHIPUAI_API_KEY"] = "5f41e6796e99ff55bbf5eaa4c71c7bac.UWklq6ZGg7BSw5Lb"


dataset = datasets.load_dataset("data/RoG-webqsp", split="test")

# model = OpenModel(model_name="gpt-4o")
# model = Zhipu(model_name="glm-4-flash")
model = SiliconFlow()

with open("outputs/predictions.jsonl", "w") as f:
    for ins in tqdm(dataset):
        ins_summary = summary(ins)

        prompt = build_prompt(type="saq", question=ins["question"])
        response = model.predict(prompt)

        prediction = ins_summary.copy()
        prediction["prompt"] = prompt
        prediction["ori_pred"] = response.content
        prediction.pop("truth_paths")
        f.write(json.dumps(prediction, ensure_ascii=False) + "\n")

        # print("\nUSER:", prompt, "\nAI:", response.content, "\nTruth:", ins["answer"])
        # import time
        # time.sleep(10)

from src.eval import evaluate

with open("outputs/predictions.jsonl", "r") as f:
    results = [json.loads(line) for line in f]
    for result in results:
        _pred = result["ori_pred"].replace("-", ",").split(",")
        result["pred"] =  [p.strip() for p in _pred if p.strip()]

score = evaluate(results)
print(score)