# -*- coding: utf-8 -*-
"""
Evaluates text generated by a language model using GPT-4 for scoring.

This script reads a JSONL file containing generated text, sends each entry to the
OpenAI API for evaluation based on a predefined prompt, and saves the results
(including the scores) to an output JSONL file.

It uses multiprocessing to handle API requests in parallel for faster processing
and includes signal handling for graceful shutdown.

Key Features:
- Reads input from a JSONL file.
- Evaluates text using a specified GPT model (e.g., gpt-4-turbo).
- Supports different evaluation criteria via selectable prompts ('formality' or 'knowledge').
- Uses multiprocessing to speed up evaluation.
- Resumes progress by skipping already evaluated items.
- Securely loads API credentials from environment variables.

Setup:
1.  Install required libraries: pip install openai
2.  Set environment variables:
    export OPENAI_API_KEY="your_api_key_here"
    export OPENAI_BASE_URL="your_base_url_here" (optional, for custom endpoints)

Usage:
    python evaluate_with_gpt.py --input_file ./merged_output.jsonl --output_file ./evaluation_results.jsonl --prompt_type formality
"""

import json
import os
import argparse
import signal
import sys
from openai import OpenAI
from multiprocessing import Process

# --- Argument Parsing ---
parser = argparse.ArgumentParser(description="Use GPT-4 to evaluate generated text from a JSONL file.")
parser.add_argument('--input_file', type=str, default='./merged_output.jsonl', help='Path to the input JSONL file (generated text).')
parser.add_argument('--output_file', type=str, default='./evaluation_results.jsonl', help='Path to the output JSONL file (with scores).')
parser.add_argument('--prompt_type', type=str, default='formality', choices=['formality', 'knowledge'], help='The type of evaluation prompt to use.')
parser.add_argument('--model', type=str, default='gpt-4.1', help='The GPT model to use for evaluation.')
parser.add_argument('--num_processes', type=int, default=10, help='Number of parallel processes to run.')
args = parser.parse_args()
# --------------------

# --- Globals ---
processes = []
client = None
EVALUATION_PROMPT = ""
# ---------------

# --- Evaluation Prompts ---
KNOWLEDGE_DIFFICULTY_PROMPT ='''
You are an expert evaluator for natural language generation quality. 
Evaluate the following text based on four dimensions. Focus especially on knowledge difficulty.

1. Relevance: Does the response stay on topic and directly address the question or topic without irrelevant content?
2. Fluency: Is the language clear, coherent, and easy to read, with proper grammar and logical flow?
3. Scientific Accuracy: Are the knowledge points and explanations factually correct and reliable based on authoritative sources?
4. Knowledge Difficulty: Assess whether the knowledge level matches the expected difficulty. 
   - Higher scores: Highly technical or research-level content suitable for graduate-level understanding.
   - Lower scores: Simple, introductory, or popular-science level explanations.
Give each dimension a score from 1 (poor) to 10 (excellent).
Output your score in the following JSON format, without additional explanation:
···json
{{
  "relevance": [score],
  "fluency": [score],
  "scientific_accuracy": [score],
  "knowledge_difficulty": [score]
}}
···
Evaluate the following text:
{text}
Output:

'''

FORMALITY_PROMPT ='''
You are an expert evaluator for natural language generation quality.
Evaluate the following text based on four dimensions. Focus especially on formality.

1. Relevance: Does the response stay on topic and directly address the question or topic without irrelevant content?
2. Fluency: Is the language clear, coherent, and easy to read, with proper grammar and logical flow?
3. Tone Appropriateness: Is the tone of the response (e.g., professional, casual, empathetic, neutral) suitable for the prompt's context and implied audience?
4. Formality: Assess the text's level of formality based on word choice, sentence structure, and adherence to stylistic conventions.(Assign a score of 5 for standard model output.)
  - Higher scores: Highly formal language, sophisticated vocabulary, complex sentence structures, and an objective tone. Avoids slang, contractions, and colloquialisms.
  - Lower scores: Highly informal or conversational language. May use slang, contractions, simple sentence structures, and a personal or subjective tone.
Give each dimension a score from 1 (poor) to 10 (excellent).

Output your score in the following JSON format, without additional explanation:
···json
{{
  "relevance": [score],
  "fluency": [score],
  "tone_appropriateness": [score],
  "formality": [score]
}}
···
Evaluate the following text:
{text}
Output:

'''
# --------------------------

def signal_handler(signum, frame):
    """Signal handler to gracefully terminate all child processes."""
    print(f"\nReceived signal {signum}, terminating all processes...")
    for p in processes:
        if p.is_alive():
            p.terminate()
            p.join() # Wait for termination to complete
    sys.exit(0)

def setup_api_client():
    """Initializes the OpenAI client from environment variables."""
    global client
    api_key = os.getenv("OPENAI_API_KEY")
    base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")

    if not api_key:
        raise ValueError("Error: Environment variable OPENAI_API_KEY is not set.")
    
    client = OpenAI(api_key=api_key, base_url=base_url)

def load_data(input_file):
    """Loads data from a JSONL file."""
    data = []
    with open(input_file, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    print(f"Loaded {len(data)} items from '{input_file}'.")
    return data

def load_completed_indices(output_file):
    """Loads indices of already completed items to allow for resuming."""
    completed_indices = set()
    if os.path.exists(output_file):
        with open(output_file, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    try:
                        result = json.loads(line)
                        completed_indices.add(str(result["index"]))
                    except (json.JSONDecodeError, KeyError):
                        print(f"Warning: Could not parse line in output file: {line.strip()}")
    return completed_indices

def ask_gpt(text_data):
    """Sends a request to the GPT model and returns the parsed JSON response."""
    messages = [
        {"role": "user", "content": EVALUATION_PROMPT.format(text=text_data["generated_text"])}
    ]
    try:
        completion = client.chat.completions.create(
            model=args.model,
            messages=messages,
            temperature=0.2, # Lower temperature for more consistent evaluation
        )
        answer = completion.choices[0].message.content.strip()
        try:
            # Find and parse the JSON block within the response
            json_match = answer[answer.find('{'):answer.rfind('}')+1]
            parsed_answer = json.loads(json_match)
            return parsed_answer
        except json.JSONDecodeError:
            print(f"JSON parsing error for response: {answer}")
            return None
    except Exception as e:
        print(f"API request failed: {e}")
        return None

def save_result(output_file, result):
    """Appends a single result to the output JSONL file."""
    with open(output_file, "a", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False)
        f.write("\n")

def process_item(item, output_file_path):
    """Processes a single item: gets evaluation from GPT and saves it."""
    scores = ask_gpt(item)
    if scores is not None:
        # Create a new result object to control the order of keys
        result = {
            "index": item.get("index"),
            "prompt": item.get("prompt"),
            "generated_text": item.get("generated_text"),
            "scores": scores,
        }
        # Add any other fields from the original item that aren't already in the result
        for key, value in item.items():
            if key not in result:
                result[key] = value
        save_result(output_file_path, result)

def worker_process(tasks, output_file_path, process_id):
    """The main function for each worker process."""
    print(f"Process {process_id + 1} started, handling {len(tasks)} tasks.")
    for item in tasks:
        process_item(item, output_file_path)
    print(f"Process {process_id + 1} finished all tasks.")

if __name__ == "__main__":
    # --- Setup ---
    setup_api_client()
    
    if args.prompt_type == 'formality':
        EVALUATION_PROMPT = FORMALITY_PROMPT
    else:
        EVALUATION_PROMPT = KNOWLEDGE_DIFFICULTY_PROMPT

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
    # -------------
    
    data = load_data(args.input_file)
    completed_indices = load_completed_indices(args.output_file)

    tasks_to_run = [item for item in data if str(item.get("index")) not in completed_indices]

    if not tasks_to_run:
        print("🎉 All tasks have already been completed.")
        sys.exit(0)
    
    print(f"Found {len(tasks_to_run)} new tasks to process.")
    
    num_processes = min(args.num_processes, len(tasks_to_run))
    chunk_size = (len(tasks_to_run) + num_processes - 1) // num_processes # Ensure all tasks are assigned

    print(f"Starting {num_processes} worker processes...")
    try:
        for i in range(num_processes):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, len(tasks_to_run))
            process_tasks = tasks_to_run[start_idx:end_idx]
            
            if not process_tasks:
                continue

            p = Process(target=worker_process, args=(process_tasks, args.output_file, i))
            processes.append(p)
            p.start()
        
        print("All processes started. Waiting for completion...")
        for p in processes:
            p.join()
    
    except (KeyboardInterrupt, SystemExit):
        print("\nMain process caught interrupt, cleaning up...")
    
    except Exception as e:
        print(f"An unexpected error occurred in the main process: {e}")
        # Terminate any remaining processes
        signal_handler(signal.SIGTERM, None)

    print("\n🎉 All tasks complete.")