import os
import argparse
import json
import time
import re
import random
import base64
import requests
import anthropic
from openai import OpenAI
from zai import ZhipuAiClient
from typing import Dict, List, Tuple, Optional
import ast
from PIL import Image
from io import BytesIO
PROMPT_DIR = "/Users/xx/Documents/codes/cuarewardbench/config/prompt"


class LLMEvaluator:
    def __init__(
            self,
            api_type: str = "openai",
            model: str = "gpt-4o-2024-11-20",
            base_url: str = None,
            api_key: Optional[str] = None,
            prompt_file: str = None,
            prompt_dir: str = None,
            temperature: float = 0.0,
            voting_type: Optional[str] = None,
            voting_num: int = 1,
    ):
        """Initialize LLM evaluator with specified API
        
        Args:
            api_type: Type of LLM API to use ('openai', 'claude', 'qwen', 'local')
            api_key: API key if required
        """
        self.api_type = api_type
        self.model = model
        self.base_url = base_url
        self.api_key = api_key
        self.temperature = temperature
        self.voting_type = voting_type
        self.voting_num = voting_num
        print(f"Initializing LLMEvaluator with {api_type}")
        self._setup_client()

        # load prompt file
        prompt_dir = prompt_dir if prompt_dir is not None else PROMPT_DIR
        if prompt_file is not None:
            prompt_path = os.path.join(prompt_dir, prompt_file)
            with open(prompt_path, 'r') as f:
                self.prompt = json.load(f)

    def _setup_client(self):
        """Setup appropriate client based on API type"""
        try:
            if self.api_type == "openai":
                self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
                # print("OpenAI client initialized")
            elif self.api_type == "zhipu":
                self.client = ZhipuAiClient(api_key=self.api_key)
            elif self.api_type == "claude":
                self.client = anthropic.Anthropic(api_key=self.api_key)
                # print("Claude client initialized")
            elif self.api_type == "3rd_openai":
                self.client = None
                # print("3rd party OpenAI client initialized")
            elif "qwen" in self.api_type.lower():
                self.client = None
                # print("Qwen client initialized")
            elif self.api_type == "3rd_openai_glm":
                self.client = OpenAI(
                    api_key=self.api_key,
                    base_url=self.base_url
                )
            else:
                print(f"Unsupported LLM type: {self.api_type}")
                raise NotImplementedError(f"Unsupported LLM type: {self.api_type}")
        except Exception as e:
            print(f"Failed to initialize client: {str(e)}")
            raise

    def _encode_image(self, image_bytes: bytes=None, image_path: str=None, img_scale=1.0) -> str:
        """Convert image to base64 string"""
        if image_bytes is not None:
            return base64.b64encode(image_bytes).decode("utf-8")
        elif image_path is not None:
            if img_scale == 1.0:
                with open(image_path, "rb") as image_file:
                    return base64.b64encode(image_file.read()).decode('utf-8')
            else:
                with open(image_path, "rb") as image_file:
                    image = Image.open(image_file)

                    new_size = (int(image.width * img_scale), int(image.height * img_scale))
                    image = image.resize(new_size)

                    buffer = BytesIO()
                    image.save(buffer, format="PNG")
                    encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
                return encoded_image
        else:
            raise NotImplementedError

    def call_llm(self, prompt: str, n_samples: int = 1, max_tokens: int = 4096, temperature: float = 0.0, 
                 retry_times: int = 1, retry_delay: int = 1, timeout: int = 360, return_list: bool = False):
        """Call LLM API with prompt and optional images
        
        Args:
            prompt: Text prompt for the LLM
            retry_times: Number of retry attempts if API call fails
            retry_delay: Delay in seconds between retries
            
        Returns:
            str: LLM response text
            
        Raises:
            Exception: If API call fails after all retries or unsupported LLM type
        """
        # print(f"Calling {self.api_type} API")
        for attempt in range(retry_times):
            try:
                print(f"Calling {self.api_type} API (attempt {attempt+1}/{retry_times})")
                if self.api_type in ["openai"]:
                    response = self.client.chat.completions.create(
                        model=self.model,
                        messages=prompt,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        n=n_samples,
                        extra_body={
                            # "ignore_eos": True,
                            # "skip_special_tokens": False,
                            "chat_template_kwargs": {"enable_thinking": False}
                        } if self.model == "MiniCPM-V-4_5" else {}
                    )
                    if return_list:
                        return [choice.message.content for choice in response.choices]
                    else:
                        assert n_samples == 1
                        return response.choices[0].message.content

                elif self.api_type in ["3rd_openai_glm"]:
                    response = self.client.chat.completions.create(
                        model=self.model,
                        messages=prompt,
                        max_tokens=max_tokens,
                        # temperature=temperature,
                        n=n_samples,
                    )
                    if return_list:
                        return [choice.message.content for choice in response.choices]
                    else:
                        assert n_samples == 1
                        return response.choices[0].message.content
                elif self.api_type in ["zhipu"]:
                    response = self.client.chat.completions.create(
                        model=self.model,
                        messages=prompt,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        thinking={
                            "type": "enabled"
                        }
                    )
                    if return_list:
                        return [choice.message.content for choice in response.choices]
                    else:
                        assert n_samples == 1
                        output = f"reasoning_content:\n{response.choices[0].message.reasoning_content}\n\n\content:\n{response.choices[0].message.content}"
                        return output
                elif self.api_type == "claude":
                    response = self.client.messages.create(
                        model=self.model,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        messages=prompt
                    )
                    return response.content
                    
                elif self.api_type == "3rd_openai" or "qwen" in self.api_type.lower():
                    headers = {
                        "Content-Type": "application/json",
                        "Authorization": f"Bearer {self.api_key}"
                    }
                    response = requests.post(
                        os.path.join(self.base_url, "chat/completions"),
                        headers=headers,
                        json={
                            "model": self.model,
                            "messages": prompt,
                            "max_tokens": max_tokens,
                            "temperature": temperature,
                            "n": n_samples,
                        },
                        timeout=timeout,
                    )
                    if response.status_code != 200:
                        print(f"API request failed: {response.text}")
                        if attempt < retry_times - 1:  # Don't sleep on the last attempt
                            print(f"Retrying in {retry_delay} seconds...")
                            time.sleep(retry_delay)
                            continue
                        return [] if return_list else ""

                    if return_list:
                        return [choice['message']['content'] for choice in response.json()['choices']]
                    else:
                        assert n_samples == 1
                        return response.json()['choices'][0]['message']['content']
                else:
                    print(f"Unsupported LLM type: {self.api_type}")
                    raise ValueError(f"Unsupported LLM type: {self.api_type}")
                    
            except Exception as e:
                print(f"Error calling {self.api_type} API (attempt {attempt+1}/{retry_times}): {str(e)}")
                if attempt < retry_times - 1:  # Don't sleep on the last attempt
                    print(f"Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                else:
                    print(f"Failed after {retry_times} attempts")
                    return "failed in response"

        return [] if return_list else ""

    def parse_from_response(self, response: str) -> float:
        """Parse completion status and confidence from LLM response
        
        Args:
            response: Raw response string from LLM
            
        Returns:
            Tuple of (success, confidence) where:
                success: True if task completed successfully, False otherwise
                confidence: Float between 0.0 and 1.0, or None if not found
        """
        try:
            # Find all completion patterns in the response
            response = response.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "")
            completion_pattern = r'(?:SCORE|[Ss]core)(?:\*\*)?[:]?\s*(?:\*\*)?(?:\[)?([0-9]+)(?:\])?'
            completion_matches = list(re.finditer(completion_pattern, response, re.IGNORECASE))
            
            if completion_matches:
                # Get the last occurrence
                last_match = completion_matches[-1]
                success_text = float(last_match.group(1))
                success = (success_text == 1.0)
                reward = 1.0 if success else 0.0
            else:
                raise ValueError("Completion value not found in analysis")
        except (ValueError, AttributeError) as e:
            print(f"Error extracting completion: {e}")
            # import pdb; pdb.set_trace()
            reward = -2.0

        return reward

    def parse_from_response_sewsm(self, response: str) -> Tuple[float, dict]:
        """Parse completion status from LLM response for sewsm mode
        
        Args:
            response: Raw response string from LLM
            
        Returns:
            Tuple of (reward, res_dict) where:
                reward: 1.0 if Correctness is True, 0.0 if False, -2.0 if error
                res_dict: The extracted dictionary from the response
        """
        try:
            # 尝试从响应中提取字典
            if '<res_dict>' in response:
                res_dict_str = response.split('<res_dict>')[1].split('</res_dict>')[0].strip()
                res_dict_str = res_dict_str.replace('<|end_of_box|>', '')
                # 处理缺少花括号的情况
                if not res_dict_str.startswith('{'):
                    res_dict_str = '{' + res_dict_str
                if not res_dict_str.endswith('}'):
                    res_dict_str = res_dict_str + '}'
                res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
                res_dict_str = res_dict_str.replace("null","None")
                res_dict_str = res_dict_str.replace("\n}  \n}", "}")
                res_dict = ast.literal_eval(res_dict_str)
            elif '```json' in response and 'Correctness' in response.split('```json')[-1].split('```')[0].strip():
                res_dict_str = response.split('```json')[-1].split('```')[0].strip()
                res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
                res_dict_str = res_dict_str.replace("null","None")
                res_dict = ast.literal_eval(res_dict_str)
            elif '```python' in response:
                response = response.replace("res_dict =","")
                res_dict_str = response.split('```python')[-1].split('```')[0].strip()
                res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
                res_dict_str = res_dict_str.replace("null","None")
                res_dict = ast.literal_eval(res_dict_str)
            elif '"Correctness":' in response:
                res_dict_str = response.split('"Correctness":')[1].split('"Correct_Action":')[0].strip()
                # 处理缺少花括号的情况
                res_dict_str = '{"Correctness":' + res_dict_str
                if not res_dict_str.endswith('}'):
                    res_dict_str = res_dict_str + '}'
                res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
                res_dict_str = res_dict_str.replace("null","None")
                res_dict = ast.literal_eval(res_dict_str)
            else:
                raise ValueError("No valid marker found in response")
            
            # 检查字典中是否有'Correctness'键
            if 'Correctness' in res_dict:
                reward = 1.0 if res_dict['Correctness'] else 0.0
                return reward, res_dict
            else:
                raise KeyError("Key 'Correctness' not found in res_dict")
        except Exception as e:
            print(f"Error extracting res_dict: {e}")
            # import pdb; pdb.set_trace()
            # 解析失败时只提取Correctness字段
            res_dict = {}
            try:
                # 尝试从响应中直接提取Correctness字段
                if '"Correctness": True' in response or "'Correctness': True" in response:
                    res_dict['Correctness'] = True
                    reward = 1.0
                elif '"Correctness": False' in response or "'Correctness': False" in response:
                    res_dict['Correctness'] = False
                    reward = 0.0
                else:
                    # 如果连Correctness字段都没找到，返回-1
                    return -2.0, {}
                return reward, res_dict
            except Exception as fallback_e:
                print(f"Fallback extraction failed: {fallback_e}")
                return -2.0, {}

    def parse_from_response_opencua(self, response: str) -> Tuple[float, dict]:
        """Parse completion status from LLM response for sewsm mode
        
        Args:
            response: Raw response string from LLM
            
        Returns:
            Tuple of (reward, res_dict) where:
                reward: 1.0 if Correctness is True, 0.0 if False, -2.0 if error
                res_dict: The extracted dictionary from the response
        """
        response = response.replace("last_step_correct","Target_Step_Correct").replace("last_step_redundant","Target_Step_Redundant")
        try:
            # 尝试从响应中提取字典
            if '<res_dict>' in response:
                res_dict_str = response.split('<res_dict>')[1].split('</res_dict>')[0].strip()
                res_dict_str = res_dict_str.replace('<|end_of_box|>', '')
                # 处理缺少花括号的情况
                if not res_dict_str.startswith('{'):
                    res_dict_str = '{' + res_dict_str
                if not res_dict_str.endswith('}'):
                    res_dict_str = res_dict_str + '}'
                res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
                res_dict_str = res_dict_str.replace("null","None")
                res_dict_str = res_dict_str.replace("\n}  \n}", "}")
                res_dict = ast.literal_eval(res_dict_str)
            elif '```json' in response:
                res_dict_str = response.split('```json')[-1].split('```')[0].strip()
                res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
                res_dict_str = res_dict_str.replace("null","None")
                res_dict = ast.literal_eval(res_dict_str)
            elif '```python' in response:
                response = response.replace("res_dict =","")
                res_dict_str = response.split('```python')[-1].split('```')[0].strip()
                res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
                res_dict_str = res_dict_str.replace("null","None")
                res_dict = ast.literal_eval(res_dict_str)
            else:
                raise ValueError("No valid marker found in response")
            return res_dict
        
        except Exception as e:
            print(f"Error extracting res_dict: {e}")
            # import pdb; pdb.set_trace()
            # 解析失败时只提取Correctness字段
            res_dict = {}
            response = response.replace("true", "True").replace("false", "False")
            try:
                # 尝试从响应中直接提取Target_Step_Correct和Target_Step_Redundant字段
                if '"Target_Step_Correct": True' in response or "'Target_Step_Correct': True" in response:
                    res_dict['Target_Step_Correct'] = True
                elif '"Target_Step_Correct": False' in response or "'Target_Step_Correct': False" in response:
                    res_dict['Target_Step_Correct'] = False
                else:
                    return {}
                
                if '"Target_Step_Redundant": True' in response or "'Target_Step_Redundant': True" in response:
                    res_dict['Target_Step_Redundant'] = True
                elif '"Target_Step_Redundant": False' in response or "'Target_Step_Redundant': False" in response:
                    res_dict['Target_Step_Redundant'] = False
                
                return res_dict
            except Exception as fallback_e:
                print(f"Fallback extraction failed: {fallback_e}")
                return -2.0, {}

    def evaluate_task(self, task_config: dict, trajectory: dict, action=None, eval_mode="zerogui", img_scale=1.0):
        """Evaluate task completion using screenshots and instruction
        
        Args:
            task_config: osworld task config
            trajectory: screenshots and other info
        
        Returns:
            Tuple of (reward, confidence) where:
                reward: 1.0 for success, 0.0 for failure
                confidence: Float between 0.0 and 1.0
        """
        try:
            # print(f"Evaluating task with {len(screenshots)} screenshots")
            # print(f"Instruction: {instruction}")
            
            # construct prompt
            instruction = task_config["instruction"]
            system_prompt = self.prompt["system_prompt"]
            user_prompt = self.prompt["user_prompt"]
            actions = trajectory["actions"]
            if action:
                action =  int(action)
                action_code = actions[action-1]
            format_args = {'instruction': instruction}
            if "{step_index}" in user_prompt:
                format_args['step_index'] = action
            if "{action_code}" in user_prompt:
                format_args['action_code'] = action_code
            user_prompt = user_prompt.format(**format_args)
            screenshots = trajectory["screenshots"]
            screenshots_marker = trajectory["screenshots_marker"]
            actions = trajectory["actions"]
            assert len(screenshots) == len(actions)

            # prepare messages based on LLM type
            if self.api_type in ["openai", "3rd_openai", "3rd_openai_glm", "zhipu"] or "qwen" in self.api_type.lower():
                messages = [{"role": "system", "content": system_prompt}]
                user_content = [{"type": "text", "text": user_prompt}]
                
                # Add screenshots
                for i, image_path in enumerate(screenshots):
                    # print(f"Adding action: {actions[i]}")
                    # print(f"Adding screenshot: {image_path}")
                    if "w_action" in  eval_mode:
                        user_content.append({"type": "text", "text": actions[i]})
                        image_path = screenshots_marker[i]

                    user_content.append({
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{self._encode_image(image_path=image_path, img_scale=img_scale)}",
                            "detail": "high",
                        },
                    })
                messages.append({"role": "user", "content": user_content})
                formatted_prompt = messages   
            elif self.api_type == "claude":
                content = [{"type": "text", "text": user_prompt}]
                for i, image_path in enumerate(screenshots):
                    if "w_action" in  eval_mode:
                        user_content.append({"type": "text", "text": actions[i]})
                        image_path = screenshots_marker[i]

                    content.append({
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": self._encode_image(image_path=image_path, img_scale=img_scale)
                        }
                    })
                formatted_prompt = [{"role": "user", "content": content}]
            else:
                print(f"Unsupported LLM type: {self.api_type}")
                raise NotImplementedError(f"Unsupported LLM type: {self.api_type}")

            # Call LLM with formatted prompt
            analysis = self.call_llm(formatted_prompt, temperature=self.temperature)
            eval_outputs = {
                "llm_output": analysis,
            }
            return eval_outputs
            
        except Exception as e:
            print(f"Error in evaluate_task: {str(e)}")
            raise

    def evaluate_task_step(self, task_config: dict, trajectory: dict, action=None, eval_mode="zerogui", img_scale=1.0):
        """Evaluate task completion using screenshots and instruction
        
        Args:
            task_config: osworld task config
            trajectory: screenshots and other info
        
        Returns:
            Tuple of (reward, confidence) where:
                reward: 1.0 for success, 0.0 for failure
                confidence: Float between 0.0 and 1.0
        """
        try:
            # print(f"Evaluating task with {len(screenshots)} screenshots")
            # print(f"Instruction: {instruction}")
            
            # construct prompt
            instruction = task_config["instruction"]
            actions = trajectory["actions"]
            action =  int(action)
            action_code = actions[action-1]
            system_prompt = self.prompt["system_prompt"]
            user_prompt = self.prompt["user_prompt"]
            user_prompt1 = user_prompt[0].format(instruction=instruction)
            user_prompt2 = user_prompt[1]
            user_prompt3 = user_prompt[2].format(step_index=action, action_code=action_code)
            screenshots = trajectory["screenshots"]
            screenshots_marker = trajectory["screenshots_marker"]
            assert len(screenshots) == len(actions)

            # prepare messages based on LLM type
            if self.api_type in ["openai", "3rd_openai", "3rd_openai_glm", "zhipu"] or "qwen" in self.api_type.lower():
                messages = [{"role": "system", "content": system_prompt}]
                user_content = [{"type": "text", "text": user_prompt1}]
                
                # Add screenshots
                
                if "fulltraj" in eval_mode:
                    len_history = len(screenshots)
                else:
                    len_history = action-1
                for i in range(len_history):
                    image_path = screenshots[i]
                    if "w_action" in  eval_mode:
                        user_content.append({"type": "text", "text": actions[i]})
                        print(user_content[-1])
                        image_path = screenshots_marker[i]

                    user_content.append({
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{self._encode_image(image_path=image_path, img_scale=img_scale)}",
                            "detail": "high",
                        },
                    })
                    print(image_path)
                
                user_content.append({"type": "text", "text": user_prompt2})
                print(user_content[-1])
                image_path = screenshots_marker[action-2]
                user_content.append({
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{self._encode_image(image_path=image_path, img_scale=img_scale)}",
                        "detail": "high",
                    },
                })
                print(image_path)
                image_path = screenshots[action-1]
                user_content.append({
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{self._encode_image(image_path=image_path, img_scale=img_scale)}",
                        "detail": "high",
                    },
                })
                print(image_path)
                user_content.append({"type": "text", "text": user_prompt3})
                print(user_content[-1])
                messages.append({"role": "user", "content": user_content})
                formatted_prompt = messages   
            else:
                print(f"Unsupported LLM type: {self.api_type}")
                raise NotImplementedError(f"Unsupported LLM type: {self.api_type}")


            analysis = self.call_llm(formatted_prompt, temperature=self.temperature)
            eval_outputs = {
                "llm_output": analysis,
            }
            return eval_outputs
            
        except Exception as e:
            print(f"Error in evaluate_task: {str(e)}")
            raise

