import os
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY", "")

if api_key != "":
    client = OpenAI(api_key=api_key)
else:
    print("Warning: OPENAI_API_KEY is not set")
import re
import json
import time
import pandas as pd
import sys
from tqdm import tqdm
from datetime import datetime
from collections import Counter

import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

output_dir = "YOUR_DIRECTORY/"
os.makedirs(output_dir, exist_ok=True)

system_prompt = "You are an objective evaluator of stance opposition between an original opinion and a counterargument. Your task is to assess how strongly the counterargument challenges or contradicts the original opinion. Give a score between 0 and 100, where higher scores indicate greater opposition in stance."

scoring_prompt = '''
[INSTRUCTION]
Below is the Original Opinion and Counterargument.

Please score the stance relationship between their statements on a continuous scale from 0 to 100:
- A score of 0 means "Counterargument totally supports Original Opinion"
- A score of 100 means "Counterargument completely opposes or contradicts Original Opinion"

Be as objective as possible. Do not explain your reasoning—just output the score.

[Original Opinion]
{original_opinion}

[Counterargument]
{counterargument}

Score (0–100):
'''

def majority_vote_with_max_tiebreak(scores):
    counter = Counter(scores)
    max_freq = max(counter.values())
    candidates = [score for score, freq in counter.items() if freq == max_freq]
    return max(candidates)

if __name__ == "__main__":
    sub_dict = {item['post_id']: ' '.join(item['sub_title']) + ' ' + ' '.join(item['sub_text']) for item in json.load(open('../../data/for_evaluate.json', 'r', encoding='utf-8')) if len(item['delta_coms']) == 3}

    df_dict = {}

    directory_path_list = []
    directory_path_list.append("../../results/")

    for directory_path in directory_path_list:
        file_list = [f for f in os.listdir(directory_path) if f.endswith(".csv") and os.path.isfile(os.path.join(directory_path, f))]
        full_paths = [os.path.join(directory_path, f) for f in file_list]
        for path in full_paths:
            df = pd.read_csv(path)
            model_name = path.split('/')[-1][:-4]
            df_dict[model_name] = df

    gpt_model_name = "gpt-4o-mini"

    for key in df_dict:
        if df_dict[key] is None:
            continue
        df = df_dict[key].dropna(subset=['post_id', 'gen_counter'])
        post_counts = df['post_id'].value_counts()
        if 'cmv' in key:
            valid_posts = post_counts[post_counts >= 2].index
        else:
            valid_posts = post_counts[post_counts >= 3].index
        df_dict[key] = df[df['post_id'].isin(valid_posts)]

    print(df_dict.keys())
    print(len(df_dict.keys()))
    input('check!!')

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    for i, (model_name, df) in tqdm(enumerate(df_dict.items())):
        if df is None:
            continue
                
        print("Model:", model_name)
        output_file = os.path.join(output_dir, f"gpt_op_stance_evaluation_{model_name}_{timestamp}.jsonl")

        df_grouped = df.groupby('post_id')

        for name, group in tqdm(df_grouped, desc=f"Processing groups for {model_name}"):
            pid = name
            if 'cmv' in model_name:
                original_opinion = group['opinion'].iloc[0]
            else:
                original_opinion = sub_dict[pid]
            counterarguments = group['gen_counter'].tolist()

            rows = []
            max_retries = 100
            for counter in counterarguments:
                str_temp, int_temp = [], []
                for _ in range(5):
                    retry_count = 0
                    while retry_count < max_retries:
                        try:
                            response = client.chat.completions.create(model=gpt_model_name,
                                messages=[
                                    {"role": "system", "content": system_prompt},
                                    {"role": "user", "content": scoring_prompt.format(
                                        original_opinion=original_opinion,
                                        counterargument=counter
                                    )}
                                ],
                                temperature=0.2,
                                max_tokens=5,
                                seed=SEED
                            )
                            result = response.choices[0].message.content
                            int_result = int(result) if result.isdigit() else None

                            str_temp.append(result)
                            if int_result is not None:
                                int_temp.append(int_result)
                            break
                        except Exception as e:
                            retry_count += 1
                            print(f"[Process-{pid}] Error: {e}")
                            time.sleep(1)
                            if retry_count == max_retries:
                                input(f"Check: [Process-{pid}] Failed after {max_retries} retries.")

                result = {
                    'responses': str_temp,
                    'scores': int_temp,
                    'majority': majority_vote_with_max_tiebreak(int_temp) if len(int_temp) == 5 else None
                }

                row = {
                    'post_id': pid,
                    'gen_counter': counter,
                    'op_stance_score': result
                }
                rows.append(row)

            with open(output_file, 'a', encoding='utf-8') as f:
                for row in rows:
                    f.write(json.dumps(row, ensure_ascii=False))
                    f.write('\n')
            
        time.sleep(5)