import re
from enum import StrEnum
from io import BytesIO
from pathlib import Path

from PIL import Image

from llms.llm_utils import call_llm
from utils.osw_utils import annotate_action_on_image
from verifier.prompts import (SECOND_PASS_KNOWLEDGE_PROMPT,
                              SECOND_PASS_OBJECTIVE_PROMPT,
                              SECOND_PASS_REQUEST_PROMPT,
                              SECOND_PASS_SYSTEM_PROMPT)


class Evaluation(StrEnum):
    SUCCESS = 'SUCCESS'
    PARTIAL_SUCCESS = 'PARTIAL SUCCESS'
    FAILURE = 'FAILURE'


def get_actions(thoughts: list[str]) -> list[str]:
    actions = []
    for thought in thoughts:
        action_start = thought.find('Action:')
        action = thought[action_start:]
        if 'finished(' in action:
            action = 'Action: finished()'
        actions.append(action)
    return actions


def get_prompt_messages(objective: str, screenshots: list[bytes],
                        thoughts: list[str],
                        first_pass_knowledge: str) -> list:
    prompt_messages = [
        {
            'role': 'system',
            'content': SECOND_PASS_SYSTEM_PROMPT.strip(),
        },
        SECOND_PASS_OBJECTIVE_PROMPT.strip().format(objective=objective),
    ]
    actions = get_actions(thoughts)
    for i, (screenshot, action) in enumerate(zip(screenshots, actions)):
        pil_screenshot = Image.open(BytesIO(screenshot))
        annotated_screenshot = annotate_action_on_image(pil_screenshot, action)
        prompt_messages.append([
            f'## STATE t-{len(thoughts) - i} screenshot',
            annotated_screenshot,
        ])
        prompt_messages.append({
            'role': 'assistant',
            'content': action.strip(),
        })
    prompt_messages.extend([
        SECOND_PASS_KNOWLEDGE_PROMPT.strip().format(
            knowledge=first_pass_knowledge),
        SECOND_PASS_REQUEST_PROMPT.strip(),
    ])
    return prompt_messages


def parse_response_text(response_text: str) -> tuple[Evaluation, str]:
    pattern = re.compile(r'^([A-Z ]+):\s*(.*?)(?=^[A-Z ]+:|\Z)',
                         flags=re.DOTALL | re.MULTILINE)
    sections = {match[1]: match[2].strip() for match
                in pattern.finditer(response_text)}
    evaluation = Evaluation(sections['EVALUATION'])
    feedback = sections['FEEDBACK']
    return evaluation, feedback


def get_second_pass_evaluation(objective: str, screenshots: list[bytes],
                               thoughts: list[str], first_pass_knowledge: str,
                               run_results_path: Path, domain: str,
                               task_id: str) -> tuple[Evaluation, str, str]:
    messages = get_prompt_messages(objective, screenshots, thoughts,
                                   first_pass_knowledge)
    responses, texts = call_llm(
        gen_kwargs={
            'model': 'gemini-2.5-flash-preview-04-17',
            'thinking_budget': 0,
            'temperature': 0.5,
            'top_p': 0.01,
            'top_k': 40,
        },
        prompt=messages,
        conversation_dir=str(run_results_path / 'conversations'),
        usage_dir=str(run_results_path / 'usage'),
        call_id=f'{domain}_{task_id}_verifier_second_pass',
        dump_txt=False,
    )
    text = texts[0].text().strip()
    evaluation, feedback = parse_response_text(text)
    return evaluation, feedback, text
