from transformers.models.fuyu.processing_fuyu import original_to_transformed_w_coords
from openai import OpenAI
from openai import LengthFinishReasonError
import base64
from typing import List, Dict, Any, Optional, Union, Tuple
from utils.api_keys import get_openai_key
from utils.prompts import *
from pydantic import BaseModel

class ExtractCommandOutput(BaseModel):
    current_task: str
    remaining_tasks: str
    task_type: str

class RewardModelOutput(BaseModel):
    score: float
    feedback: str

class EvalT2IOutput(BaseModel):
    alignment: float
    technical: float
    aesthetic: float
    explanation: str

class EvalI2IOutput(BaseModel):
    alignment: float
    preservation: float
    aesthetic: float
    explanation: str

class GPTService:
    def __init__(self):
        """Initialize the GPT service with an OpenAI API key and experts registry.

        Args:
            api_key: OpenAI API key
            experts: Dictionary of expert instances
        """
        self.client = OpenAI(api_key=get_openai_key())


    def get_embedding(self, text: str, model: str = "text-embedding-3-small", dim: int = 1536) -> List[float]:
        """Get text embeddings from OpenAI's embedding model.
        
        Args:
            text: The input text to embed
            model: The embedding model to use
            dim: The dimension of the embedding vector (default is 1536)
        Returns:
            List of embedding values
        """
        response = self.client.embeddings.create(
            input=text,
            model=model,
            encoding_format="float",
            dimensions=dim
        )
        
        return response.data[0].embedding
    


    def extract_command(self, original_prompt: str, remaining_tasks: str) -> Dict[str, str]:
        """
        Given the full user prompt, split it into current_task and remaining_tasks by considering all experts.
        Returns a dict with keys: 'current_task', 'remaining_tasks', 'task_type'.
        """
        if remaining_tasks == "":
            return "", "", ""

        user_message = EXTRACT_COMMAND_USER_MESSAGE.format(original_prompt=original_prompt, remaining_tasks=remaining_tasks)

        messages = [
            {"role": "system", "content": EXTRACT_COMMAND_SYSTEM_PROMPT.strip()},
            {"role": "user", "content": user_message.strip()}
        ]

        try:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=ExtractCommandOutput,  # Structured JSON output
                # temperature=0.3,
                # max_tokens=300,
                max_completion_tokens=2048
            ).choices[0].message.parsed
        except LengthFinishReasonError:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=ExtractCommandOutput,  # Structured JSON output
                # temperature=0.3,
                # max_tokens=300,
                max_completion_tokens=4096
            ).choices[0].message.parsed

        results_dict = response.dict()
        return results_dict.get("current_task", ""), results_dict.get("remaining_tasks", ""), results_dict.get("task_type", "")

    def score_image_similarity(
            self,
            original_prompt: str,
            current_task: str,
            previous_feedback: str,
            generated_image_path: str,
            original_image_path: str = None,
        ) -> Dict[str, Any]:
        user_message = SCORING_USER_MESSAGE.format(original_prompt=original_prompt, current_task=current_task, previous_feedback=previous_feedback)

        messages = [
            {"role": "system", "content": SCORING_SYSTEM_PROMPT.strip()},
            {"role": "user", "content": user_message.strip()}
        ]

        # Add provided image if it exists
        if original_image_path:
            try:
                with open(original_image_path, "rb") as image_file:
                    original_image_encoded = base64.b64encode(image_file.read()).decode('utf-8')
                    messages.append({
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": "Here is the original/previous image:"
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{original_image_encoded}"
                                }
                            }
                        ]
                    })
            except Exception as e:
                print(f"Warning: Failed to read original image {original_image_path}: {e}")

        # Add generated image
        try:
            with open(generated_image_path, "rb") as image_file:
                generated_image_encoded = base64.b64encode(image_file.read()).decode('utf-8')
                messages.append({
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": "Here is the generated image:"
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{generated_image_encoded}"
                            }
                        }
                    ]
                })
        except Exception as e:
            print(f"Error: Failed to read generated image {generated_image_path}: {e}")
            return {"score": 0.0, "explanation": f"Failed to read generated image: {e}"}


        try:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=RewardModelOutput,  # Structured JSON output
                # temperature=0.3,
                # max_tokens=300,
                max_completion_tokens=2048
            ).choices[0].message.parsed
        except LengthFinishReasonError:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=RewardModelOutput,  # Structured JSON output
                # temperature=0.3,
                # max_tokens=300,
                max_completion_tokens=4096
            ).choices[0].message.parsed

        results_dict = response.dict()
        score = results_dict.get("score", "")
        feedback = results_dict.get("feedback", "")

        return score, feedback


    def eval_I2I(self, instruction: str, input_image_path: str, generated_image_path: str) -> Dict[str, Any]:
        user_message = EVAL_I2I_USER_MESSAGE.format(instruction=instruction)
        messages = [
            {"role": "system", "content": EVAL_T2I_SYSTEM_PROMPT},
            {"role": "user", "content": user_message}
        ]

        with open(input_image_path, "rb") as image_file:
            input_image_encoded = base64.b64encode(image_file.read()).decode('utf-8')
            messages.append({
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Here is the input_image:"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{input_image_encoded}"
                        }
                    }
                ]
            })

        with open(generated_image_path, "rb") as image_file:
            generated_image_encoded = base64.b64encode(image_file.read()).decode('utf-8')
            messages.append({
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Here is the output_image:"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{generated_image_encoded}"
                        }
                    }
                ]
            })

        try:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=EvalI2IOutput,  # Structured JSON output
                # temperature=0.2,
                # max_tokens=300,
                max_completion_tokens=2048
            ).choices[0].message.parsed
        except LengthFinishReasonError:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=EvalI2IOutput,  # Structured JSON output
                # temperature=0.2,
                # max_tokens=300,
                max_completion_tokens=4096
            ).choices[0].message.parsed

        results_dict = response.dict()
        alignment = results_dict.get("alignment", "")
        preservation = results_dict.get("preservation", "")
        aesthetic = results_dict.get("aesthetic", "")
        explanation = results_dict.get("explanation", "")
        return {"alignment": alignment, "preservation": preservation, "aesthetic": aesthetic, "explanation": explanation}


    def eval_T2I(self, prompt: str, generated_image_path: str) -> Dict[str, Any]:
        user_message = EVAL_T2I_USER_MESSAGE.format(prompt=prompt)
        messages = [
            {"role": "system", "content": EVAL_T2I_SYSTEM_PROMPT.strip()},
            {"role": "user", "content": user_message.strip()}
        ]
        with open(generated_image_path, "rb") as image_file:
            generated_image_encoded = base64.b64encode(image_file.read()).decode('utf-8')
            messages.append({
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Here is the generated image:"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{generated_image_encoded}"
                        }
                    }
                ]
            })

        try:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=EvalT2IOutput,  # Structured JSON output
                # temperature=0.25,
                # max_tokens=300,
                max_completion_tokens=2048
            ).choices[0].message.parsed
        except LengthFinishReasonError:
            response = self.client.beta.chat.completions.parse(
                model="o3",
                messages=messages,
                response_format=EvalT2IOutput,  # Structured JSON output
                # temperature=0.25,
                # max_tokens=300,
                max_completion_tokens=4096
            ).choices[0].message.parsed

        results_dict = response.dict()
        alignment = results_dict.get("alignment", "")
        technical = results_dict.get("technical", "")
        aesthetic = results_dict.get("aesthetic", "")
        explanation = results_dict.get("explanation", "")
        return {"alignment": alignment, "technical": technical, "aesthetic": aesthetic, "explanation": explanation}


# if __name__ == "__main__":
#     # example usage
#     from experts import create_experts
    
#     key = get_openai_key()
    # experts = create_experts()
    # gpt_service = GPTService()
    # print(gpt_service.get_embedding("Hello, world!"))
    # print(gpt_service.expert_selector("Generate an image of a man singing using the given face as a reference. Also add a black background and remove the curtains."))
    # print(gpt_service.score_image_similarity("An astronaut riding a horse in an painting style", "./images/output.jpg"))
    # print(gpt_service.extract_structured_command("model_3", "Generate an image of a man singing using the given face as a reference. Also add a black background and remove the curtains."))
