import os
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY", "")
if api_key != "":
    client = OpenAI(api_key=api_key)
else:
    print("Warning: OPENAI_API_KEY is not set")

import re
import json
import pandas as pd
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
import sys
import time
from tqdm import tqdm
from datetime import datetime
from collections import Counter

output_dir = "YOUR_DIRECTORY/"
os.makedirs(output_dir, exist_ok=True)

criteria_list = [
    ('Grammaticality', 'Evaluate whether the text adheres to standard grammar conventions, including punctuation, sentence structure, and syntax.'),
    ('Appropriateness', 'Evaluate whether the language and tone are suitable for the context and proportional to the significance of the issue.'),
    ('Relevance', 'Evaluate how directly the counterargument engages with the original opinion and addresses its key points.'),
    ('Clarity', 'Evaluate whether the writing is clear, well-organized, and free from ambiguity or unnecessary complexity.')
]

system_prompt = "You are an objective evaluator of counterargument writing quality."

scoring_prompt = '''
[INSTRUCTION]
Below is an original opinion and a counterargument written in response.
Evaluate the counterargument based on the criterion below.
Be as objective as possible.
For each aspect, provide only score from 1 (worst) to 5 (best).

[Original Opinion]
{original_opinion}

[Counterargument]
{counterargument}

[Evaluation Criteria]
{criteria}: {criteria_desc}

Evaluation Form (scores ONLY):

- {criteria}:
'''

def majority_vote_with_max_tiebreak(scores):
    counter = Counter(scores)
    max_freq = max(counter.values())
    candidates = [score for score, freq in counter.items() if freq == max_freq]
    return max(candidates)

import multiprocessing as mp

def score_criterion(proc_id, q, original_opinion, counterargument, criteria, criteria_desc):
    str_temp, int_temp = [], []
    max_retries = 10

    for _ in range(5):
        retry_count = 0
        while retry_count < max_retries:
            try:
                response = client.chat.completions.create(model="gpt-4o-mini",
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": scoring_prompt.format(
                            original_opinion=original_opinion,
                            counterargument=counterargument,
                            criteria=criteria,
                            criteria_desc=criteria_desc
                        )}
                    ],
                    temperature=0.2,
                    max_tokens=5,
                    seed=SEED
                )
                result = response.choices[0].message.content
                int_result = int(result) if result.isdigit() else None

                str_temp.append(result)
                if int_result is not None:
                    int_temp.append(int_result)
                break
            except Exception as e:
                retry_count += 1
                print(f"[Process-{proc_id}] Error: {e}")
                time.sleep(2)
                if retry_count == max_retries:
                    input(f"Check: [Process-{proc_id}] Failed after {max_retries} retries.")

    result = {
        'criteria': criteria,
        'responses': str_temp,
        'scores': int_temp,
        'majority': majority_vote_with_max_tiebreak(int_temp) if len(int_temp) == 5 else None
    }
    q.put(result)

if __name__ == "__main__":
    sub_dict = {item['post_id']: ' '.join(item['sub_title']) + ' ' + ' '.join(item['sub_text']) for item in json.load(open('../../data/for_evaluate.json', 'r', encoding='utf-8')) if len(item['delta_coms']) == 3}

    df_dict = {}

    directory_path_list = []
    directory_path_list.append("../../results/")

    for directory_path in directory_path_list:
        file_list = [f for f in os.listdir(directory_path) if f.endswith(".csv") and os.path.isfile(os.path.join(directory_path, f))]
        full_paths = [os.path.join(directory_path, f) for f in file_list]
        for path in full_paths:
            df = pd.read_csv(path)
            model_name = path.split('/')[-1][:-4]
            df_dict[model_name] = df

    for key in df_dict:
        if df_dict[key] is None:
            continue
        df = df_dict[key].dropna(subset=['post_id', 'gen_counter'])
        post_counts = df['post_id'].value_counts()
        valid_posts = post_counts[post_counts >= 3].index
        df_dict[key] = df[df['post_id'].isin(valid_posts)]
    
    print(df_dict.keys())
    print(len(df_dict.keys()))
    input('check!!')

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    for i, (model_name, df) in tqdm(enumerate(df_dict.items())):
        if df is None:
            continue

        print("Model:", model_name)
        output_file = os.path.join(output_dir, f"gpt_quality_evaluation_{model_name}_{timestamp}.jsonl")

        df_grouped = df.groupby('post_id')

        for name, group in tqdm(df_grouped, desc=f"Processing groups for {model_name}"):
            pid = name
            original_opinion = sub_dict[pid]

            idx = 0
            for _, row in group.iterrows():
                counterargument = row['gen_counter']

                q = mp.Queue()
                processes = []

                for i, (criteria, criteria_desc) in enumerate(criteria_list):
                    p = mp.Process(target=score_criterion, args=(i, q, original_opinion, counterargument, criteria, criteria_desc))
                    p.start()
                    processes.append(p)

                for p in processes:
                    p.join()

                result_dict = {}
                while not q.empty():
                    res = q.get()
                    result_dict[res['criteria']] = {
                        'responses': res['responses'],
                        'scores': res['scores'],
                        'majority': res['majority']
                    }

                row = {
                    'post_id': pid,
                    'idx': idx,
                    'gen_counter': counterargument,
                    **result_dict
                }

                with open(output_file, 'a', encoding='utf-8') as f:
                    f.write(json.dumps(row, ensure_ascii=False))
                    f.write('\n')

                idx += 1