import argparse
import json
import os
import re
import time
from typing import List, Dict, Optional, Union
import openai
from openai import OpenAI

os.environ["OPENAI_API_KEY"] = ""

class FluencyJudge:
    
    def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4o-mini"):
        self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
        self.model = model
        
        self.system_prompt = """Please act as an impartial judge and evaluate the fluency of the sentence fragment provided below. Focus solely on fluency,
disregarding its completeness, relevance, coherence with any broader context, or informativeness.
Begin your evaluation by briefly describing the fluency of the sentence, noting any unnatural phrasing, awkward transitions, grammatical errors, or repetitive structures that may hinder readability. After providing your explanation, rate the sentence's
fluency on a scale from 0 to 2, where 0 indicates the sentence is not fluent and highly unnatural (e.g., incomprehensible or repetitive),
1 indicates it is somewhat fluent but contains noticeable errors or awkward phrasing, and 2 indicates the sentence is fluent and
almost perfect. Provide your rating using this exact format: "Rating: [[score]]"."""

    def evaluate_fluency(self, text: str, max_retries: int = 3) -> Dict[str, Union[str, int, float]]:
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": text}
                    ],
                    temperature=0,  # For consistent evaluation
                    max_tokens=500
                )
                
                evaluation = response.choices[0].message.content.strip()
                rating = self._extract_rating(evaluation)
                
                return {
                    "text": text,
                    "evaluation": evaluation,
                    "rating": rating,
                    "model": self.model,
                    "success": True
                }
                
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)  # Exponential backoff
                else:
                    return {
                        "text": text,
                        "evaluation": "",
                        "rating": -1,
                        "model": self.model,
                        "success": False,
                        "error": str(e)
                    }
    
    def _extract_rating(self, evaluation: str) -> int:
        # Look for the pattern "Rating: [[score]]"
        pattern = r"Rating:\s*\[\[(\d+)\]\]"
        match = re.search(pattern, evaluation)
        
        if match:
            rating = int(match.group(1))
            if 0 <= rating <= 2:
                return rating
        
        # Fallback: look for "Rating: digit"
        pattern = r"Rating:\s*(\d+)"
        match = re.search(pattern, evaluation)
        if match:
            rating = int(match.group(1))
            if 0 <= rating <= 2:
                return rating
                
        print(f"Warning: Could not extract rating from: {evaluation}")
        return -1
    
    def evaluate_batch(self, texts: List[str], delay: float = 0.5) -> List[Dict]:
        results = []
        
        for i, text in enumerate(texts):
            print(f"Evaluating {i + 1}/{len(texts)}: {text[:50]}...")
            result = self.evaluate_fluency(text)
            results.append(result)
            
            if i < len(texts) - 1:  # Don't wait after the last item
                time.sleep(delay)
                
        return results

def load_data(input_file: str) -> List[str]:
    texts = []
    
    if input_file.endswith('.json'):
        with open(input_file, 'r') as f:
            data = json.load(f)

        for item in data:
            texts.append(item['context'] + item['generation'])

    else:
        # Plain text file, one text per line
        with open(input_file, 'r') as f:
            texts = [line.strip() for line in f if line.strip()]
    
    return texts

def save_results(results: List[Dict], output_file: str):
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

def main():
    parser = argparse.ArgumentParser(description='Evaluate text fluency using GPT-4o-mini')
    parser.add_argument('--input', type=str, required=True, help='Input file with texts to evaluate')
    parser.add_argument('--output', type=str, default='fluency_results.json', help='Output file for results')
    parser.add_argument('--model', type=str, default='gpt-4o-mini', help='OpenAI model to use')
    parser.add_argument('--delay', type=float, default=0.5, help='Delay between API calls (seconds)')
    parser.add_argument('--text', type=str, help='Single text to evaluate (instead of file)')
    
    args = parser.parse_args()
    
    # Initialize the judge
    judge = FluencyJudge(model=args.model)
    
    if args.text:
        # Single text evaluation
        result = judge.evaluate_fluency(args.text)
        print(f"Text: {result['text']}")
        print(f"Rating: {result['rating']}")
        print(f"Evaluation: {result['evaluation']}")
        
        # Save single result
        save_results([result], args.output)
        
    else:
        # Batch evaluation from file
        texts = load_data(args.input)
        print(f"Loaded {len(texts)} texts to evaluate")
        
        results = judge.evaluate_batch(texts, delay=args.delay)
        
        # Print summary
        successful = sum(1 for r in results if r['success'])
        ratings = [r['rating'] for r in results if r['success'] and r['rating'] >= 0]
        
        print(f"\nEvaluation complete!")
        print(f"Successful evaluations: {successful}/{len(results)}")
        
        if ratings:
            avg_rating = sum(ratings) / len(ratings)
            print(f"Average rating: {avg_rating:.2f}")
            print(f"Rating distribution: 0={ratings.count(0)}, 1={ratings.count(1)}, 2={ratings.count(2)}")
        
        save_results(results, args.output)
        print(f"Results saved to {args.output}")

if __name__ == "__main__":
    main()
