import json
import logging
import pandas as pd
import requests
import time

from enum import StrEnum
from pathlib import Path
from typing import List
from xml.etree import ElementTree

from bencher.matcher.matcher import Matcher as BaseMatcher
from bencher.matcher.matcher import Match
from bencher.matcher.llm.matcher import StructuredResponse
from perceptor.models.detector.base import Detection


logger = logging.getLogger(__name__)


class ModelKind(StrEnum):
    """Represents the GPT models kinds.
    """

    Llama31405BInstruct = "meta-llama/Meta-Llama-3.1-405B-Instruct"
    Llama3370B = "meta-llama/Llama-3.3-70B-Instruct"
 

class Matcher(BaseMatcher):
    """A matcher based on the GPT tool.
    """

    def __init__(self, query: str, frames: List[List[Detection]], key: str, attempts: int = 1, kind: ModelKind = ModelKind.Llama3370B, temperature: float = 0.2, ) -> None:
        """Initialize the Generative Pre-Trained Transformer tool.
        """

        self.query = query
        logger.debug(f"query=\"{self.query}\"; frames={len(frames)}")

        # Format the data.
        #
        # The data input to the model is sent as a CSV-styled format to save
        # token space as well as simplify the formatting.
        self.data = []
        for index, frame in enumerate(frames):
            for detection in frame:
                self.data.append([
                    index,
                    detection.label,
                    detection.identifier,
                    detection.score,
                    detection.bbox.xmin,
                    detection.bbox.ymin,
                    detection.bbox.xmax,
                    detection.bbox.ymax
                ])

        self.data = pd.DataFrame(
            self.data,
            columns=["frame", "label", "identifier", "score", "xmin", "ymin", "xmax", "ymax"]
        )

        path = (
            Path(__file__).parent.parent
            .joinpath("data")
            .joinpath("prompts")
            .joinpath("roleA.v5.txt")
        )

        with open(path, "r") as infile:
            self.role = infile.read()

        self.key = key
        self.attempts = attempts
        self.kind = kind
        self.temperature = temperature

    def run(self) -> List[Match]:
        """Find the matches.
        """

        attempts = 0
        while attempts < self.attempts:
            try:
                start = time.perf_counter_ns()
                response = requests.post(
                    r"https://api.hyperbolic.xyz/v1/chat/completions",
                    headers={
                        "Content-Type": "application/json",
                        "Authorization": f"Bearer {self.key}"
                    },
                    json={
                        "messages": [
                            {
                                "role": "system",
                                "content": [
                                    { "type": "text", "text": self.role }
                                ]
                            },
                            {
                                "role": "user",
                                "content": [
                                    { "type": "text", "text": f"<query>{self.query}</query>" },
                                    { "type": "text", "text": f"<data>{self.data.to_csv(index=False).replace('[', '').replace(']', '')}</data>" },
                                ]
                            }
                        ],
                        "model": self.kind,
                        "temperature": self.temperature,
                        "max_tokens": 2048,
                    }
                )

                elapsed = float((time.perf_counter_ns() - start) * 1e-9)

                response = response.json()

                logger.debug(f"tokens(input)={response['usage']['prompt_tokens']}; tokens(output)={response['usage']['completion_tokens']}; tokens(total)={response['usage']['total_tokens']}; response=\"{response['choices'][0]['message']['content'].replace('\n', '')}\"")

                matches = []
                if response["choices"][0]["message"]["content"] != "":
                    for m in ElementTree.fromstring(response["choices"][0]["message"]["content"]):
                        matches.append(Match(
                            start=int(m.find("start").text),
                            end=int(m.find("end").text)
                        ))

                logger.debug(f"time={elapsed}s; matches={len(matches)}")

                return elapsed, matches

            except Exception as e:
                logger.error(str(e))
                attempts += 1

        logger.warning(f"failed search with {attempts} attempts")
        return None, None
