import os
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY", "")

if api_key != "":
    client = OpenAI(api_key=api_key)
else:
    print("Warning: OPENAI_API_KEY is not set")
import re
import json
import time
import pandas as pd
import sys
from tqdm import tqdm
from datetime import datetime
from collections import Counter
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

output_dir = "YOUR_DIRECTORY/"
os.makedirs(output_dir, exist_ok=True)

perspective_diversity_system_prompt = "You are an objective evaluator assessing the diversity of multiple counterarguments. Your task is to judge how distinctly each counterargument approaches the original opinion, focusing on diffences in perspectives. Evaluate whether the responses reflect opinions that could come from individuals with different values, belief systems, or life experiences. Your assessment should consider whether the counterarguments present genuinely varied worldviews, not just superficial differences in wording or logic."

perspective_diversity_scoring_prompt = '''
[INSTRUCTION]
Below is an original opinion and three counterarguments written in response.
Evaluate the overall diversity among the counterarguments using the criterion below.
Be as objective and concise as possible.
Provide only a score from 1 (very low diversity) to 5 (very high diversity).

[Original Opinion]
{original_opinion}

[Counterargument 1]
{counterargument_1}

[Counterargument 2]
{counterargument_2}

[Counterargument 3]
{counterargument_3}

[Evaluation Criterion]
Diversity: Assess whether the three counterarguments approach the original opinion from clearly different perspectives, drawing on distinct social identities, belief systems, or lived experiences. High scores should be given when each response plausibly reflects the worldview of a different kind of individual. Low scores indicate surface-level variation or repetition of the same underlying reasoning.

Evaluation Form (scores ONLY):

- Diversity:
'''

def majority_vote_with_max_tiebreak(scores):
    counter = Counter(scores)
    max_freq = max(counter.values())
    candidates = [score for score, freq in counter.items() if freq == max_freq]
    return max(candidates)

if __name__ == "__main__":
    sub_dict = {item['post_id']: ' '.join(item['sub_title']) + ' ' + ' '.join(item['sub_text']) for item in json.load(open('../../data/for_evaluate.json', 'r', encoding='utf-8')) if len(item['delta_coms']) == 3}

    df_dict = {}

    directory_path_list = []
    directory_path_list.append("../../results/")

    for directory_path in directory_path_list:
        file_list = [f for f in os.listdir(directory_path) if f.endswith(".csv") and os.path.isfile(os.path.join(directory_path, f))]

        full_paths = [os.path.join(directory_path, f) for f in file_list]
        for path in full_paths:
            df = pd.read_csv(path)
            model_name = path.split('/')[-1][:-4]
            df_dict[model_name] = df

    print(df_dict.keys())
    print(len(df_dict.keys()))
    input('check!!')

    for key in df_dict:
        if df_dict[key] is None:
            continue
        df = df_dict[key].dropna(subset=['post_id', 'gen_counter'])
        post_counts = df['post_id'].value_counts()
        valid_posts = post_counts[post_counts >= 3].index
        df_dict[key] = df[df['post_id'].isin(valid_posts)]

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    for i, (model_name, df) in tqdm(enumerate(df_dict.items())):
        if df is None:
            continue

        print("Model:", model_name)
        output_file = os.path.join(output_dir, f"gpt_perspective_diversity_evaluation_{model_name}_{timestamp}.jsonl")

        df_grouped = df.groupby('post_id')

        for name, group in tqdm(df_grouped, desc=f"Processing groups for {model_name}"):
            pid = name
            original_opinion = sub_dict[pid]

            counterarguments = group['gen_counter'].tolist()

            str_temp, int_temp = [], []
            max_retries = 5

            for _ in range(5):
                retry_count = 0
                while retry_count < max_retries:
                    try:
                        response = client.chat.completions.create(model="gpt-4o-mini",
                            messages=[
                                {"role": "system", "content": perspective_diversity_system_prompt},
                                {"role": "user", "content": perspective_diversity_scoring_prompt.format(
                                    original_opinion=original_opinion,
                                    counterargument_1=counterarguments[0],
                                    counterargument_2=counterarguments[1],
                                    counterargument_3=counterarguments[2]
                                )}
                            ],
                            temperature=0.2,
                            max_tokens=5,
                            seed=SEED
                        )
                        result = response.choices[0].message.content
                        int_result = int(result) if result.isdigit() else None

                        str_temp.append(result)
                        if int_result is not None:
                            int_temp.append(int_result)
                        break
                    except Exception as e:
                        retry_count += 1
                        print(f"[Process-{pid}] Error: {e}")
                        time.sleep(2)
                        if retry_count == max_retries:
                            input(f"Check: [Process-{pid}] Failed after {max_retries} retries.")

            result = {
                'responses': str_temp,
                'scores': int_temp,
                'majority': majority_vote_with_max_tiebreak(int_temp) if len(int_temp) == 5 else None
            }

            row = {
                'post_id': pid,
                'gen_counter': counterarguments,
                'Diversity': result
            }

            with open(output_file, 'a', encoding='utf-8') as f:
                f.write(json.dumps(row, ensure_ascii=False))
                f.write('\n')