import json
import os
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from openai import OpenAI
import requests
import tqdm
from global_vars import *

# prompt_without_comment = "Your task is to expand the opinion given topics. It doesn't need to be too long, just 3 to 4 sentences expansion. The original opinion needs to be rephrased. The topic is: {topic}. The opinion is: {opinion}. Please directly output the expanded opinion, don't include any other text."

def expand_opinion_gpt(prompt):
    
    client = OpenAI(api_key="")
    
    try:
        completion = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are a helpful assistant for expanding the opinion."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.5  # 使用较低的temperature以获得更确定的答案
        )
        
        response = completion.choices[0].message.content
        return response
        
    except Exception as e:
        print(f"Error occurred: {e}")
        return None


def aggreement_validation(posts_or_comments, opinion):
    prompt = "Your task is to determine if the expanded opinions are still supporting the original opinion. The expanded opinion is: {posts_or_comments}. The original opinion is: {opinion}. Please just answer with 'yes' or 'no'."
    client = OpenAI(api_key="")
    try:
        completion = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are a helpful assistant for determind if the expanded opinion is similar to the original opinion."},
                {"role": "user", "content": prompt.format(posts_or_comments=posts_or_comments, opinion=opinion)}
            ],
            temperature=0.1  
        )
        
        response = completion.choices[0].message.content
        return response
        
    except Exception as e:
        print(f"Error occurred: {e}")
        return None

def add_model_validation(data_path):
    data = json.load(open(data_path))
    for item in tqdm.tqdm(data):
        for pro in item["pros"]:
            pro["expanded_validation"] = aggreement_validation(pro["expanded"], pro["opinion"])
        for con in item["cons"]:
            con["expanded_validation"] = aggreement_validation(con["expanded"], con["opinion"])
    with open(data_path, "w") as f:
        json.dump(data, f)

def count_validation(data_path):
    data = json.load(open(data_path))
    total = 0
    yes = 0
    for item in data:
        for pro in item["pros"]:
            total += 1
            if "yes" in pro["expanded_validation"].lower():
                yes += 1

        for con in item["cons"]:
            total += 1
            if "yes" in con["expanded_validation"].lower():
                yes += 1
    print(yes / total)


def main(save_loc):
    data_with_best_match = json.load(open(KIALO_DATA_PATH.best_match))
    for item in tqdm.tqdm(data_with_best_match):
        topic = item["topic"]
        for pro in item["pros"]:
            opinion = pro["opinion"]
            best_matches = pro["top5_matches"]
            expanded_opinions = []
            for best_match in best_matches:
                expanded_opinions.append(expand_opinion_gpt(PROMPT.opinion_expansion_with_best_match.format(topic=topic, opinion=opinion, post_or_comment=best_match)))
            pro["expanded"] = expanded_opinions
        for con in item["cons"]:
            opinion = con["opinion"]
            best_matches = con["top5_matches"]
            expanded_opinions = []
            for best_match in best_matches:
                expanded_opinions.append(expand_opinion_gpt(PROMPT.opinion_expansion_with_best_match.format(topic=topic, opinion=opinion, post_or_comment=best_match)))
            con["expanded"] = expanded_opinions
    with open(save_loc, "w") as f:
        json.dump(data_with_best_match, f)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_loc", type=str, required=False, help="Path to save the expanded data")
    args = parser.parse_args()
    
    main(args.save_loc)