import json
import asyncio
from openai import AsyncOpenAI
from typing import List, Dict, Any, Optional
import subprocess
import sys
import tempfile
from tqdm import tqdm

class EncycloBenchEvaluator:
    def __init__(self, base_url: str, api_key: str, prompt_file: str = "prompts.txt"):
        self.client = AsyncOpenAI(base_url=base_url, api_key=api_key)
        self.prompt_file = prompt_file
        self.system_prompt, self.user_prompt = self.load_prompts()

    def load_prompts(self):

        try:
            with open(self.prompt_file, "r", encoding="utf-8") as f:
                lines = f.readlines()
                system_prompt = lines[0].split(":", 1)[1].strip()
                user_prompt = lines[1].split(":", 1)[1].strip()
                return system_prompt, user_prompt
        except Exception as e:
            print(f"An error occur when loading prompt: {e}")
            return (
                "You are a helpful assistant that determines whether the inferred answer matches the correct answer. If the error range of the number is less than 1, we can consider it correct. Return 'True' if they match, and 'False' if they do not.",
                "Inferred Answer: {infer_answer}\nCorrect Answer: {question_answer}\nDo these answers match? Return only 'True' or 'False'.",
            )
            
    async def api_predict(
        self,
        conversation: List[Dict[str, str]],
        model: str,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
    ) -> str:

        response = await self.client.chat.completions.create(
            model=model,
            messages=conversation,
            max_tokens=max_tokens,
            temperature=temperature,
        )
        return response.choices[0].message.content
    
    async def LLM_judge(
        self,
        infer_answer: str,
        question_answer: str,
        model: str,
        max_tokens: int = 1024,
        temperature: Optional[float] = None,
    ) -> bool:

        prompt = [
            {"role": "system", "content": self.system_prompt},
            {
                "role": "user",
                "content": self.user_prompt.format(
                    infer_answer=infer_answer, question_answer=question_answer
                ),
            },
        ]
        response = await self.client.chat.completions.create(
            model=model,
            messages=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
        )
        result = response.choices[0].message.content.strip().lower()
        return result == "true"
    
    def run_generated_code(self, result_python_code: str, initialized_variables: Dict[str, Any]):

        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as temp_file:
            temp_file.write(result_python_code)
            temp_file_path = temp_file.name

        try:
            result = subprocess.run(
                [sys.executable, temp_file_path],
                input=str(initialized_variables),
                text=True,
                capture_output=True
            )
            output = result.stdout
            return output
        finally:
            import os
            os.remove(temp_file_path)

    def remove_code_block_markers(self, text: str):
        lines = text.split("\n")
        return "\n".join(line for line in lines if not line.strip().startswith("```"))
    
    async def evaluate(
        self,
        infer_result_path: str,
        model: str,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
    ):
        
        infer_results = []

        with open(infer_result_path, "r") as f:
            for line in f:
                infer_results.append(json.loads(line))
            # infer_results = json.load(f)

        stats = {
            'total_tests': 0,
            'correct_count': 0,
            'question_logic_wrong': 0,
            'difficulty_stats': {
                'easy': {'total': 0, 'correct': 0},
                'medium': {'total': 0, 'correct': 0},
                'hard': {'total': 0, 'correct': 0},
            }
        }

        tasks = [
            self.process_infer_result(
                infer_result,
                model=model,
                max_tokens=max_tokens,
                temperature=temperature,
            )
            for infer_result in infer_results
        ]

        for future in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"Evaluating {model}"):
            try:
                is_correct, difficulty = await future
            except Exception as e:
                # print(f"Error processing result: {e}")
                stats['question_logic_wrong'] += 1
            else:
                stats['total_tests'] += 1
                if is_correct:
                    # print("correct+1")
                    stats['correct_count'] += 1
                    stats['difficulty_stats'][difficulty]['correct'] += 1
                stats['difficulty_stats'][difficulty]['total'] += 1
        print("\nFinal Statistics:")
        print(f"Total Tests: {stats['total_tests']}")
        print(f"Correct Answers: {stats['correct_count']}")
        print(f"Question logic wrong: {stats['question_logic_wrong']}")
        
        if stats['total_tests'] > 0:
            accuracy = stats['correct_count'] / stats['total_tests'] * 100
            print(f"Accuracy: {accuracy:.2f}%")
            
            record_difficulty = {}
            for difficulty, counts in stats['difficulty_stats'].items():
                if counts['total'] > 0:
                    difficulty_accuracy = counts['correct'] / counts['total'] * 100
                    print(f"{difficulty.capitalize()} Accuracy: {difficulty_accuracy:.2f}%")
                    record_difficulty[difficulty] = difficulty_accuracy
            return accuracy, record_difficulty
        else:
            print("No tests were run.")
            
            return None, None

    async def process_infer_result(
        self,
        infer_result: Dict[str, Any],
        model: str,
        max_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
    ) -> bool:
        try:
            ans_py_code = self.remove_code_block_markers(infer_result["python_code"])
            infer_ans = infer_result["response"]
            ans_variables = infer_result["init_variables"]
            difficulty = infer_result["difficulty"]

            question_answer = self.run_generated_code(ans_py_code, ans_variables)

            if isinstance(question_answer, list):
                question_answer = ",".join(str(item) for item in question_answer)
            else:
                question_answer = str(question_answer)
                
            is_correct = await self.LLM_judge(
                infer_ans,
                question_answer,
                model=model,
                max_tokens=max_tokens,
                temperature=temperature,
            )
            # print(f"Test: Correct = {is_correct}")

            return is_correct, difficulty
        except Exception as e:
            # print(f"Test: Missing variables - {e}")
            raise e


async def main():
    base_url = "YOUR BASE URL"
    api_key = "YOUR API KEY"
    prompt_file = "./prompts/eval/eval_bool.txt"
    evaluator = EncycloBenchEvaluator(base_url, api_key, prompt_file)

    infer_result_path = "./results"

    await evaluator.evaluate(
        infer_result_path,
        model="Grok-3", 
        max_tokens=1024,  
        temperature=0.7,  
    )


if __name__ == "__main__":
    asyncio.run(main())