
from __future__ import annotations

import argparse
import json
import os
import time
from typing import Any

import openai
from tqdm import tqdm

import time
import random
from functools import wraps
from openai import AuthenticationError, RateLimitError
import pandas as pd

# from safe_rlhf.utils import PromptManager

import re

from jinja2 import Environment, FileSystemLoader

class PromptManager:
    def __init__(self, prompts_dir: str = "prompts/"):
        """Initializes the prompt manager with the path to the prompts directory."""
        if not os.path.isdir(prompts_dir):
            raise FileNotFoundError(f"Prompts directory not found: {prompts_dir}")
        self.env = Environment(loader=FileSystemLoader(prompts_dir))

    def get_prompt(self, template_name: str, **kwargs) -> tuple[str, str]:
        """
        Loads and renders a prompt template, splitting it into system and user parts.

        Args:
            template_name: The path to the template file relative to the prompts dir.
                           e.g., "evaluation/judge_v1_score_and_reason.j2"
            **kwargs: The variables to fill into the template.

        Returns:
            A tuple containing (system_prompt, user_prompt).
        """
        try:
            template = self.env.get_template(template_name)
            rendered_prompt = template.render(**kwargs)
            
            # Split the prompt into system and user parts based on our convention
            if "---SYSTEM---" in rendered_prompt and "---USER---" in rendered_prompt:
                parts = rendered_prompt.split("---USER---", 1)
                system_part = parts[0].replace("---SYSTEM---", "").strip()
                user_part = parts[1].strip()
                return system_part, user_part
            else:
                # Default behavior if no separator is found (treat whole file as user prompt)
                return "", rendered_prompt.strip()

        except Exception as e:
            print(f"Error loading or rendering prompt '{template_name}': {e}")
            raise

def parse_scores_from_judgment(content: str) -> tuple[float, float]:
    """
    Parses the LLM's judgment to find a score pattern like [[score1, score2]].

    This function uses regex to robustly find the scores, ignoring any
    surrounding text. It is flexible with whitespace and can handle both
    integer and floating-point numbers.

    Args:
        content: The full text response from the LLM.

    Returns:
        A tuple containing (score1, score2).
        Returns (0.0, 0.0) if the pattern is not found.
    """
    # Regex pattern to find [[number, number]]
    # \s*     - matches any whitespace (spaces, tabs, newlines) zero or more times
    # (\d+\.?\d*) - captures a group of digits (int or float)
    # \[\[     - matches the literal opening brackets [[
    pattern = r"\[\[\s*(\d+\.?\d*)\s*,\s*(\d+\.?\d*)\s*\]\]"

    match = re.search(pattern, content)

    if match:
        try:
            # The captured groups are group(1) and group(2)
            score1 = float(match.group(1))
            score2 = float(match.group(2))
            return score1, score2
        except (ValueError, IndexError):
            # This is a fallback in case of a very unusual regex match
            return 0.0, 0.0
    else:
        # If no match is found anywhere in the content, return default scores
        return 0.0, 0.0

def parse_arguments() -> argparse.Namespace:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(
        description='Evaluate model responses with GPT-4 as a judge.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        '--prompt',
        type=str,
        help='Path to the JSONL file containing prompts.',
        choices=['scores', 'helpfulness', 'harmlessness'],
        required=True,
    )

    # Input Files
    parser.add_argument(
        '--model_1_response_file',
        type=str,
        help='Path to the JSONL file containing responses from the first model.',
        required=True,
    )
    parser.add_argument(
        '--model_2_response_file',
        type=str,
        help='Path to the JSONL file containing responses from the second model.',
        required=True,
    )
    
    # API and Logging
    parser.add_argument(
        '--openai_api_key',
        type=str,
        default=None,
        help='Your OpenAI API key. If not provided, it will be read from the OPENAI_API_KEY environment variable.',
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        required=True,
        help='Where to store the evaluation output.',
    )
    parser.add_argument(
        '--dry_run',
        action='store_true',
        help='Prepare prompts and print them without calling the OpenAI API.',
    )
    return parser.parse_args()


def load_and_validate_data(file_path_1: str, file_path_2: str) -> list[dict[str, Any]]:
    """Loads data from two JSONL files and validates that their prompts match."""
    with open(file_path_1, 'r', encoding='utf-8') as f:
        data1 = [json.loads(line) for line in f]
    with open(file_path_2, 'r', encoding='utf-8') as f:
        data2 = [json.loads(line) for line in f]

    if len(data1) != len(data2):
        raise ValueError("Mismatched number of responses between the two files.")
    
    validated_data = []
    for i, (item1, item2) in enumerate(zip(data1, data2)):
        if item1['prompt'] != item2['prompt']:
            raise ValueError(f"Mismatched prompts at line {i+1}. Ensure both files are generated from the same prompts in the same order.")
        
        # Extract the generated part of the response
        _, _, answer1 = item1['response_text'].partition('ASSISTANT:')
        _, _, answer2 = item2['response_text'].partition('ASSISTANT:')

        validated_data.append({
            'prompt': item1['prompt'],
            'answer1': answer1.strip(),
            'answer2': answer2.strip(),
        })
    
    return validated_data

def retry_with_exponential_backoff(
    max_retries=3,
    initial_delay=1,
    backoff_factor=2,
):
    """
    A decorator to retry a function call with exponential backoff if it fails
    due to an OpenAI RateLimitError.
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            delay = initial_delay
            for i in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except RateLimitError:
                    print(f"Rate limit hit. Retrying in {delay:.2f} seconds... (Attempt {i+1}/{max_retries})")
                    time.sleep(delay + random.uniform(0, 1)) # Add jitter
                    delay *= backoff_factor
            
            # If all retries fail, raise the last exception
            print("All retries failed. Could not recover from rate limit.")
            raise RateLimitError(f"Exceeded max retries ({max_retries}) for rate limiting.")
        return wrapper
    return decorator

@retry_with_exponential_backoff()
def gpt4_eval(sys_prompt: str, user_prompt: str, client: openai.OpenAI) -> str:
    """Sends a request to the OpenAI API and handles retries."""
    try:
        chat_completion = client.chat.completions.create(
            model='gpt-4',
            messages=[
                {'role': 'system', 'content': sys_prompt},
                {'role': 'user', 'content': user_prompt},
            ],
            temperature=0.7,
            max_tokens=2048,
        )
        return chat_completion.choices[0].message.content
    except AuthenticationError:
        # This is a different, fatal error. The key is bad.
        print("FATAL: OpenAI API Authentication Error. Your key is likely invalid.")
        # We don't retry here, we stop. Returning None is one option.
        return None
    except Exception as e:
        # Catch any other unexpected errors
        print(f"An unexpected API error occurred: {e}")
        return None

def main() -> None:
    """The main function."""
    args = parse_arguments()

    # Load and prepare data from files
    try:
        paired_data = load_and_validate_data(args.model_1_response_file, args.model_2_response_file)
        model_1_name = os.path.basename(os.path.dirname(args.model_1_response_file))
        model_2_name = os.path.basename(os.path.dirname(args.model_2_response_file))
    except ValueError as e:
        print(f"Error loading data: {e}")
        return

    print(f'Evaluating {len(paired_data)} pairs of responses with GPT-4...')
    client = openai.OpenAI(api_key=args.openai_api_key) # Uses env var if key is None
    
    prompts_dir='safe_rlhf/prompts/gpt4'
    prompt_manager = PromptManager(prompts_dir=prompts_dir)
    prompt_file = f"{args.prompt}.j2"

    os.makedirs(args.output_dir, exist_ok=True)
    output_file_path = os.path.join(args.output_dir, 'gpt4_evaluation_results.json')
    with open(output_file_path, mode='w', encoding='utf-8') as f:
        for item in tqdm(paired_data):
            prompt_data = {
                'question': item['prompt'],
                'answer1': item['answer1'],
                'answer2': item['answer2'],
            }
            system_prompt, user_prompt = prompt_manager.get_prompt(
                prompt_file,
                **prompt_data
            )

            if args.dry_run:
                content = "Dry run - no API call made."
                score1, score2 = None, None

                single_result = {
                    'prompt': item['prompt'],
                    f'{model_1_name}_answer': item['answer1'],
                    f'{model_2_name}_answer': item['answer2'],
                    f'{model_1_name}_score': score1,
                    f'{model_2_name}_score': score2,
                    'gpt4_judgment': content
                }
                break                
            
            content = gpt4_eval(sys_prompt=system_prompt, user_prompt=user_prompt, client=client)
            score1, score2 = parse_scores_from_judgment(content)
            # # Parse scores safely
            # try:
            #     score_line = content.split('\n')[0]
            #     score1, score2 = map(float, score_line.split())
            # except (ValueError, IndexError):
            #     score1, score2 = 0.0, 0.0

            single_result = {
                'prompt': item['prompt'],
                f'{model_1_name}_answer': item['answer1'],
                f'{model_2_name}_answer': item['answer2'],
                f'{model_1_name}_score': score1,
                f'{model_2_name}_score': score2,
                'gpt4_judgment': content,
            }
            
            f.write(json.dumps(single_result, ensure_ascii=False) + '\n')
            f.flush()  # Ensure each result is written immediately
                
    print(f"Evaluation complete. Results saved to {output_file_path}")


if __name__ == '__main__':
    main()