import logging
import pandas as pd
import time

from anthropic import Anthropic
from enum import StrEnum
from pathlib import Path
from typing import List
from xml.etree import ElementTree

from bencher.matcher.matcher import Matcher as BaseMatcher
from bencher.matcher.matcher import Match
from perceptor.models.detector.base import Detection


logger = logging.getLogger(__name__)

class ModelKind(StrEnum):
    """Represents the Claude model kinds.
    """
    
    ClaudeSonnet = "claude-3-7-sonnet-20250219"
    ClaudeHaiku = "claude-3-5-haiku-20241022"
    ClaudeOpus = "claude-3-opus-20240229"


class Matcher(BaseMatcher):
    """A matcher based on the Claude tool.
    """

    def __init__(self, query: str, frames: List[List[Detection]], key: str, 
                 attempts: int = 1, kind: ModelKind = ModelKind.ClaudeSonnet, temperature: float = 0.2) -> None:
        """Initialize the Claude model tool.
        """

        self.query = query
        logger.debug(f"query=\"{self.query}\"; frames={len(frames)}")

        # Format the data.
        #
        # The data input to the model is sent as a CSV-styled format to save
        # token space as well as simplify the formatting.
        self.data = []
        for index, frame in enumerate(frames):
            for detection in frame:
                self.data.append([
                    index,
                    detection.label,
                    detection.identifier,
                    detection.score,
                    detection.bbox.xmin,
                    detection.bbox.ymin,
                    detection.bbox.xmax,
                    detection.bbox.ymax
                ])

        self.data = pd.DataFrame(
            self.data,
            columns=["frame", "label", "identifier", "score", "xmin", "ymin", "xmax", "ymax"]
        )

        path = (
            Path(__file__).parent.parent
            .joinpath("data")
            .joinpath("prompts")
            .joinpath("roleA.v5.txt")
        )

        with open(path, "r") as infile:
            self.role = infile.read()

        self.attempts = attempts
        self.kind = kind
        self.temperature = temperature

        # Load the model.
        #
        # This establishes a connection with the Anthropic API to prepare for
        # sending/receiving queries.
        self.model = Anthropic(api_key=key)

    def run(self) -> List[Match]:
        """Find the matches.
        """

        attempts = 0
        while attempts < self.attempts:
            try:
                start = time.perf_counter_ns()
                response = self.model.messages.create(
                    model=self.kind,
                    max_tokens=8192,
                    temperature=self.temperature,
                    system=self.role,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": f"<query>{self.query}</query>"},
                                {"type": "text", "text": f"<data>{self.data.to_csv(index=False).replace('[', '').replace(']', '')}</data>"}
                            ]
                        }
                    ],
                )

                elapsed = float((time.perf_counter_ns() - start) * 1e-9)

                logger.debug(f"tokens(input)={response.usage.input_tokens}; tokens(output)={response.usage.output_tokens}; tokens(total)={response.usage.input_tokens + response.usage.output_tokens}; response=\"{response.content[0].text.replace('\n', '')}\"")

                matches = []
                if response.content[0].text != "":
                    for m in ElementTree.fromstring(response.content[0].text):
                        matches.append(Match(
                            start=int(m.find("start").text),
                            end=int(m.find("end").text)
                        ))

                logger.debug(f"time={elapsed}s; matches={len(matches)}")

                return elapsed, matches

            except Exception as e:
                logger.error(str(e))
                attempts += 1

        logger.warning(f"failed search with {attempts} attempts")
        return None, None
