
import json
from typing import Dict, Union, List, Union
import torch
import numpy as np
from PIL import Image
from copy import deepcopy
from verifiers.DSG.dsg.vqa_utils import InstructBLIP
import typing_extensions as typing
from tqdm import tqdm
import requests
import re
import base64
from pydantic import BaseModel, conlist
import concurrent.futures
import os
from google import genai
from google.genai import types, errors
from utils import preload_ollama, base_post_args
import time


EVAL_PROMPT = (
    "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model.\n"
    "Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. \n"
    "The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. \n"
    "Below is a comprehensive guide to follow in your evaluation process: \n\n"
    "1. Key Evaluation Aspects and Scoring Criteria: \n"
    "For each aspect, provide a score from 0 to 100, where 0 represents poor performance and 100 represents excellent performance. \n"
    "For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. \n"
    "The aspects to evaluate are as follows: \n"
    "a) Accuracy to Prompt \n"
    "Assess how well the image matches the description given in the prompt.\n"
    "Consider whether all requested elements are present and if the scene, objects, and setting align accurately with the text. \n"
    "Score: 0 (no alignment) to 100 (perfect match to prompt).\n"
    "b) Creativity and Originality \n"
    "Evaluate the uniqueness and creativity of the generated image. \n"
    "Does the model present an imaginative or aesthetically engaging interpretation of the prompt?  \n"
    "Is there any evidence of creativity beyond a literal interpretation? \n"
    "Score: 0 (lacks creativity) to 100 (highly creative and original). \n"
    "c) Visual Quality and Realism \n"
    "Assess the overall visual quality, including resolution, detail, and realism.\n"
    "Look for coherence in lighting, shading, and perspective.\n"
    "Even if the image is stylized or abstract, judge whether the visual elements are well-rendered and visually appealing.\n"
    "Score: 0 (poor quality) to 100 (high-quality and realistic). \n"
    "d) Consistency and Cohesion\n"
    "Check for internal consistency within the image.\n"
    "Are all elements cohesive and aligned with the prompt? \n"
    "For instance, does the perspective make sense, and do objects fit naturally within the scene without visual anomalies?\n"
    "Score: 0 (inconsistent) to 100 (fully cohesive and consistent). \n"
    "e) Emotional or Thematic Resonance \n"
    "Evaluate how well the image evokes the intended emotional or thematic tone of the prompt. \n"
    "For example, if the prompt is meant to be serene, does the image convey calmness? \n"
    "If it’s adventurous, does it evoke excitement? \n"
    "Score: 0 (no resonance) to 100 (strong resonance with the prompt’s theme).\n\n"

    "2. Overall Score\n"
    "After scoring each aspect individually, provide an overall score, representing the model’s general performance on this image.\n"
    "This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects."

    "Now grade the image based on the above criteria. \n"
    "Below is the prompt: \n"
    "Prompt: {prompt} \n\n"
)


class Score(typing.TypedDict):
    explanation: str
    score: float


class Grading(BaseModel):
    accuracy_to_prompt: Score
    creativity_and_originality: Score
    visual_quality_and_realism: Score
    consistency_and_cohesion: Score
    emotional_or_thematic_resonance: Score
    overall_score: Score


class VLMVerifier:
    nickname = "vlmscore"
    SUPPORTED_METRIC_CHOICES = [
        "score",
    ]

    def __init__(self, verifier_args, device: str = None):

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.model_name = verifier_args['model_name']
        self.port = verifier_args.get('port', 11434)
        self.genai_client = None
        if verifier_args.get('use_genai_api', False):
            self.genai_client = genai.Client(
                api_key=os.environ['GENAI_API_KEY'],
            )
        else:
            preload_ollama(port=self.port, model_name=self.model_name)

    def prepare_inputs(
        self,
        images: Union[list[str], str],
        prompts: Union[list[str], str],
        **kwargs,
    ):

        images = images if isinstance(images, list) else [images]
        prompts = prompts if isinstance(prompts, list) else [prompts]
        assert len(images) == len(prompts), "images and prompts must have the same length"
        assert type(images[0]) == str, "images must be the path of the image"
        inputs = list(zip(prompts, images))

        return inputs

    @torch.inference_mode()
    def score(self, inputs: list[tuple[str, Union[str, Image.Image]]], prompt_idx, verbose=False, **kwargs):

        def score_single_image(prompt, image_data):
            success = False
            num_tries = 0
            while not success:
                try:
                    query = EVAL_PROMPT.format(prompt=prompt)

                    # import ipdb; ipdb.set_trace()
                    parsed_response = dict(self.query_llm(query, image_data))
                    # parsed_response = self.parse_llm_json(ret)
                    # if not self.check_parsed_results(parsed_response):
                    #     raise ValueError("Parsed results are not valid.")

                    result = {
                        'details': parsed_response,
                        'score': parsed_response['overall_score']['score']
                    }
                    result['reward'] = result['score']  # for compatibility with other verifiers

                    success = True

                except Exception as e:
                    if isinstance(e, errors.ClientError):
                        print(f"ClientError: {e.code} - {e.message}")
                        time.sleep(5)
                    num_tries += 1
                    if num_tries >= 10:
                        raise e
                    print(f"Error: {e}")
                    print("Retrying...")
                    continue
            return result

        # Use ThreadPoolExecutor to run score_single_image concurrently
        results = [None] * len(inputs)  # Create a list to store results in correct order
        with concurrent.futures.ThreadPoolExecutor() as executor:
            # Submit tasks and store the future with its index for ordering
            futures = {executor.submit(score_single_image, prompt, image_data): idx
                       for idx, (prompt, image_data) in enumerate(inputs)}

            # Collect results in the same order as inputs
            for future in concurrent.futures.as_completed(futures):
                idx = futures[future]  # Get the original index
                results[idx] = future.result()  # Place the result in the correct position

        # import ipdb; ipdb.set_trace()
        return results


    def query_llm(self, query, image_path):

        if self.genai_client is None:
            success = False

            with open(image_path, 'rb') as img_file:
                img_data = img_file.read()
                img_base64 = base64.b64encode(img_data).decode('utf-8')
            post_args = base_post_args(model_name=self.model_name, port=self.port)
            post_args['json']['prompt'] = query
            post_args['json']["format"] = Grading.model_json_schema()

            post_args['json']['images'] = [img_base64]

            raw_response = requests.post(**post_args)
            try:
                parsed_response = Grading.model_validate_json(raw_response.json()['response'])
            except Exception as e:
                print(f"Response: {raw_response.json()['response']}")
                raise e
            # raw_response = raw_response.json()['response']
            # parsed_response = raw_response.strip().split('\n')
            # ret =  [''.join(each.split('. ')[1:]) for each in parsed_response] # remove the index

            return parsed_response
        else:
            # use gemini api
            with open(image_path, 'rb') as img_file:
                img_data = img_file.read()
            contents = [
                types.Content(
                    role="user",
                    parts=[
                        types.Part.from_bytes(
                            mime_type="""image/png""",
                            data=img_data,
                        ),
                        types.Part.from_text(text=query),
                    ],
                ),
            ]
            generate_content_config = types.GenerateContentConfig(
                response_mime_type="application/json",
                response_schema=Grading,
            )
            raw_response = self.genai_client.models.generate_content(
                model=self.model_name,
                contents=contents,
                config=generate_content_config,
            )

            parsed_response = Grading.model_validate_json(raw_response.text)
            return parsed_response

    def aggregate_to_one(self, results: List[Dict], method='mean') -> Dict:
        '''
        Aggregate given results to one dict.
        results: List[Dict], each result is a dict with keys:
            - details: Dict, the details of the result (Grading as dict format)
            - score: float, the score of the result
        '''
        assert len(results) > 0, "results should not be empty"

        ret = {
            "reward": [],
            "score": [],
            "details": {
                "accuracy_to_prompt": [],
                "creativity_and_originality": [],
                "visual_quality_and_realism": [],
                "consistency_and_cohesion": [],
                "emotional_or_thematic_resonance": [],
                # "overall_score": [],
            }
        }
        # append
        for single_result in results:
            ret['reward'].append(single_result['score'])
            ret['score'].append(single_result['score'])
            for k, v in single_result['details'].items():
                if k == 'overall_score':
                    continue
                ret['details'][k].append(single_result['details'][k]['score'])

        # mean
        assert method == 'mean', "only mean is supported for now"
        ret['reward'] = sum(ret['reward']) / len(ret['reward'])
        ret['score'] = sum(ret['score']) / len(ret['score'])
        for k, v in ret['details'].items():
            ret['details'][k] = sum(ret['details'][k]) / len(ret['details'][k])

        return ret

    def parse_llm_json(self, resp: Union[str, List[str]]) -> Dict:
        """
        Parse JSON objects returned by an LLM, tolerant to Markdown code‑fences
        and line‑split inputs.

        Parameters
        ----------
        resp : str | list[str]
            Raw LLM reply (single string or list of lines).

        Returns
        -------
        dict
            Parsed JSON object.

        Raises
        ------
        ValueError
            If no JSON fragment is found or JSON is invalid.
        """

        text = "\n".join(resp) if isinstance(resp, list) else resp

        text = re.sub(r"```[\s\S]*?```",
                      lambda m: m.group(0).lstrip("`").rstrip("`"),
                      text, flags=re.M)

        start, end = text.find("{"), text.rfind("}")
        if start == -1 or end == -1 or end < start:
            raise ValueError("JSON object not found in response.")
        json_str = text[start:end + 1]

        json_str_clean = re.sub(r",\s*([}\]])", r"\1", json_str)
        return json.loads(json_str_clean)

    def check_parsed_results(self, result) -> bool:
        """
        Check if the parsed results are valid.
        """
        requirements = ["Accuracy to Prompt",
                        "Creativity and Originality",
                        "Visual Quality and Realism",
                        "Consistency and Cohesion",
                        "Emotional or Thematic Resonance",
                        "Overall Score"]

        if not isinstance(result, dict):
            return False

        if 'Overall Score' not in result:
            return False
        if not isinstance(result['Overall Score'], (int, float)):
            return False

        return True
