import os
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY", "")

if api_key != "":
    client = OpenAI(api_key=api_key)
else:
    print("Warning: OPENAI_API_KEY is not set")
import re
import json
import time
import pandas as pd
import sys
from tqdm import tqdm
from datetime import datetime
from collections import Counter
import argparse
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

output_dir = "YOUR_DIRECTORY/"
os.makedirs(output_dir, exist_ok=True)

system_prompt = "You are an evaluator of persuasiveness. Given an original opinion and a counterargument, rate how persuasive the counterargument would likely feel to the person who wrote the original opinion. Output a score from 1 to 10, where higher scores mean more persuasive."

title_only_scoring_prompt = '''
[INSTRUCTION]
You are given an original opinion and a counterargument.
Rate how persuasive the counterargument is on its own. Focus on how effectively it introduces a meaningful perspective.

Strong counterarguments:
- Present a clear and coherent point of view
- Offer insight from a distinctive identity or experience
- Reframe the issue in a thoughtful and authentic way
- Feel more compelling when written in a voice that adds depth and contrast

Output only a single integer between 1 and 10, based on the overall persuasiveness of the counterargument.

Score from 1–10:
1–2: Very weak, unlikely to affect the author’s view.
3–4: Limited impact, may raise points but not convincing.
5–6: Moderate, somewhat persuasive and could prompt reflection.
7–8: Strong, likely to make the author reconsider aspects of their view.
9–10: Very strong, realistically could change the author’s stance.

[Original Opinion]
{original_opinion}

[Counterargument]
{counterargument}

Persuasiveness Score (1–10):
'''

title_and_main_scoring_prompt = '''
[INSTRUCTION]
You are given an original opinion and a counterargument.
Rate how persuasive the counterargument would likely be to the person who wrote the original opinion — that is, how much it could realistically make them reflect, reconsider, or change their view.

A persuasive counterargument does more than simply oppose the claim. It introduces a way of seeing the issue that the author may not have fully considered, often by drawing from a distinct perspective, background, or line of reasoning. The strength of persuasiveness comes from how effectively the counterargument broadens the author’s frame of reference and shows that the issue can be understood differently.

Give a single score from 1 to 10, where higher numbers mean the counterargument is more likely to prompt genuine reflection or a shift in the author’s stance.

Score from 1–10:
1–2: Very weak, unlikely to affect the author’s view.
3–4: Limited impact, may raise points but not convincing.
5–6: Moderate, somewhat persuasive and could prompt reflection.
7–8: Strong, likely to make the author reconsider aspects of their view.
9–10: Very strong, realistically could change the author’s stance.

Output only a single integer between 1 and 10.

[Original Opinion]
{original_opinion}

[Counterargument]
{counterargument}

Persuasiveness Score (1–10):
'''

def majority_vote_with_max_tiebreak(scores):
    counter = Counter(scores)
    max_freq = max(counter.values())
    candidates = [score for score, freq in counter.items() if freq == max_freq]
    return max(candidates)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--title_mode', choices=['title_only', 'title_and_main'], default='title_and_main',
                    help="Choose whether to use only titles or titles with main content")
    parser.add_argument('--setting', choices=['maximum', 'all', 'combined'], default='all')

    args = parser.parse_args()
    
    sub_whole_dict = {item['post_id']: ' '.join(item['sub_title']) + ' ' + ' '.join(item['sub_text']) for item in json.load(open('../../data/for_evaluate.json', 'r', encoding='utf-8')) if len(item['delta_coms']) == 3}
    sub_title_dict = {item['post_id']: ' '.join(item['sub_title']) for item in json.load(open('../../data/for_evaluate.json', 'r', encoding='utf-8')) if len(item['delta_coms']) == 3}

    sub_dict = sub_title_dict if args.title_mode == 'title_only' else sub_whole_dict
    
    scoring_prompt = title_only_scoring_prompt if args.title_mode == 'title_only' else title_and_main_scoring_prompt
    print(f"title_mode: {args.title_mode}")

    df_dict = {}

    directory_path_list = []
    directory_path_list.append("../../results/")
    
    gpt_model_name = "gpt-4o-mini"

    for directory_path in directory_path_list:
        file_list = sorted([f for f in os.listdir(directory_path) if f.endswith(".csv") and os.path.isfile(os.path.join(directory_path, f))])

        full_paths = [
            os.path.join(directory_path, f) for f in file_list
        ]

        for path in full_paths:
            df = pd.read_csv(path)
            model_name = path.split('/')[-1][:-4]
            df_dict[model_name] = df

    print(df_dict.keys())
    print(len(df_dict.keys()))
    input('check!!')

    for key in df_dict:
        if df_dict[key] is None:
            continue
        df = df_dict[key].dropna(subset=['post_id', 'gen_counter'])
        post_counts = df['post_id'].value_counts()
        valid_posts = post_counts[post_counts == 3].index
        df_dict[key] = df[df['post_id'].isin(valid_posts)]

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    for i, (model_name, df) in tqdm(enumerate(df_dict.items())):
        if df is None:
            continue
                
        print("Model:", model_name)
        output_file = os.path.join(output_dir, f"gpt_persuasiveness_evaluation_{args.title_mode}_{model_name}_{timestamp}.jsonl")

        df_grouped = df.groupby('post_id')

        for name, group in tqdm(df_grouped, desc=f"Processing groups for {model_name}"):
            pid = name
            original_opinion = sub_dict[pid]
            if args.setting == 'combined':
                counterarguments = group['combined_counter'].tolist()
            else:
                counterarguments = group['gen_counter'].tolist()

            rows = []
            max_retries = 100
            if args.setting == 'maximum' or args.setting == 'combined':
                for counter in counterarguments:
                    str_temp, int_temp = [], []
                    for _ in range(5):
                        retry_count = 0
                        while retry_count < max_retries:
                            try:
                                response = client.chat.completions.create(model=gpt_model_name,
                                    messages=[
                                        {"role": "system", "content": system_prompt},
                                        {"role": "user", "content": scoring_prompt.format(
                                            original_opinion=original_opinion,
                                            counterargument=counter
                                        )}
                                    ],
                                    temperature=0.2,
                                    max_tokens=5,
                                    seed=SEED
                                )
                                result = response.choices[0].message.content
                                int_result = int(result) if result.isdigit() else None

                                str_temp.append(result)
                                if int_result is not None:
                                    int_temp.append(int_result)
                                break
                            except Exception as e:
                                retry_count += 1
                                print(f"[Process-{pid}] Error: {e}")
                                time.sleep(1)
                                if retry_count == max_retries:
                                    input(f"Check: [Process-{pid}] Failed after {max_retries} retries.")

                    result = {
                        'responses': str_temp,
                        'scores': int_temp,
                        'majority': majority_vote_with_max_tiebreak(int_temp) if len(int_temp) == 5 else None
                    }

                    row = {
                        'post_id': pid,
                        'gen_counter': counter,
                        'persuasiveness_score': result
                    }
                    rows.append(row)
            elif args.setting == "all":
                str_temp, int_temp = [], []
                counter = '\n'.join(counterarguments)

                for _ in range(5):
                    retry_count = 0
                    while retry_count < max_retries:
                        try:
                            response = client.chat.completions.create(model=gpt_model_name,
                                messages=[
                                    {"role": "system", "content": system_prompt},
                                    {"role": "user", "content": scoring_prompt.format(
                                        original_opinion=original_opinion,
                                        counterargument=counter
                                    )}
                                ],
                                temperature=0.2,
                                max_tokens=5,
                                seed=SEED
                            )
                            result = response.choices[0].message.content
                            int_result = int(result) if result.isdigit() else None

                            str_temp.append(result)
                            if int_result is not None:
                                int_temp.append(int_result)
                            break
                        except Exception as e:
                            retry_count += 1
                            print(f"[Process-{pid}] Error: {e}")
                            time.sleep(1)
                            if retry_count == max_retries:
                                input(f"Check: [Process-{pid}] Failed after {max_retries} retries.")

                result = {
                    'responses': str_temp,
                    'scores': int_temp,
                    'majority': majority_vote_with_max_tiebreak(int_temp) if len(int_temp) == 5 else None
                }

                row = {
                    'post_id': pid,
                    'gen_counter': counter,
                    'persuasiveness_score': result
                }
                rows.append(row)

            with open(output_file, 'a', encoding='utf-8') as f:
                for row in rows:
                    f.write(json.dumps(row, ensure_ascii=False))
                    f.write('\n')
            
        time.sleep(5)