# Model is a simple class that handles the API requests to different end services.
import json

# import Model class
from openai import OpenAI
from src.entity.models.Model import Model
from credentials import OPENAI_CREDENTIAL
from pydantic import BaseModel

class Llama3(Model):
    api_key: str
    def __init__(self, ctx_len: int=4096):
        # load from environment variable.
        self.api_key = 'EMPTY'
        self.ctx_len = ctx_len

    def interact(self, messages: str, **kwargs):
        # require a json_format, temperature, e.t.c

        json_format = kwargs.get("json_format")
        temperature = kwargs.get("temperature")

        messages = [
            {"role": "user", "content": messages}
        ] if isinstance(messages, str) else messages

        if json_format is None:
            raise ValueError("json_format is required")
        if temperature is None:
            raise ValueError("temperature is required")

        # call the API and return the response.
        client = OpenAI(api_key=self.api_key, base_url=f"http://0.0.0.0:8080/v1",)


        if issubclass(json_format, BaseModel):
            json_format = {
                "type": "json_object",
                "schema": json_format.model_json_schema()
            }

        prompt_input = {
            "model": "llama3.1-8b",
            "messages": messages,
            "temperature": 0.6,
            "n": 1,
            'max_tokens': self.ctx_len,
            "response_format": json_format
        }
        
        response = client.beta.chat.completions.parse(
            **prompt_input
        )
        # return a json of the response.
        response = json.loads(response.choices[0].message.content.replace('\n', ' '))
        return response