from src.searchlight.utils import AbstractLogged
from abc import ABC, abstractmethod
import openai
import tiktoken
from typing import Optional
import os


class BaseMessage(ABC):
    '''
    Abstract class for different types of messages (System, Human, AI)
    '''
    def __init__(self, content: str):
        self.content = content

    @abstractmethod
    def get_role(self) -> str:
        pass


class SystemMessage(BaseMessage):
    def get_role(self) -> str:
        return "system"


class HumanMessage(BaseMessage):
    def get_role(self) -> str:
        return "user"


class AIMessage(BaseMessage):
    def get_role(self) -> str:
        return "assistant"


class LLMModel(AbstractLogged):
    '''
    Abstract class for LLM models
    '''
    def __init__(self) -> None:
        self.num_calls = 0
        super().__init__()

    def generate(self, messages: list[BaseMessage], temperature: float = 0.7) -> str:
        '''
        Generate a response to a list of input messages

        Args:
            messages: input messages
            temperature: temperature to use for generation

        Returns:
            generated response as a string
        '''
        self.num_calls += 1
        self.logger.debug(f"Generating response to input messages: \n {messages}")
        output = self._generate(messages, temperature)
        self.logger.debug(f"Generated response: \n {output}")
        return output

    @abstractmethod
    def get_num_total_tokens(self) -> int:
        pass
    
    @abstractmethod
    def get_num_output_tokens(self) -> int:
        pass

    def get_num_calls(self) -> int:
        return self.num_calls

    @abstractmethod
    def _generate(self, messages: list[BaseMessage], temperature: float = 0.7) -> str:
        pass

class OpenAIModel(LLMModel):
    def __init__(self, model_name: str, api_key: Optional[str] = None) -> None:
        super().__init__()
        self.model_name = model_name
        if api_key is None:
            api_key = os.environ.get("OPENAI_API_KEY")
        self.api_key = api_key
        self.client = openai.OpenAI(api_key=self.api_key)
        try:
            self.encoding = tiktoken.encoding_for_model(model_name)
        except:
            self.encoding = tiktoken.get_encoding("o200k_base")
        self.total_tokens = 0
        self.output_tokens = 0
    
    def _generate(self, messages: list[BaseMessage], temperature: float = 0.7) -> str:
        # Format the messages for the chat completion API
        formatted_messages = [
            {"role": message.get_role(), "content": message.content} 
            for message in messages
        ]

        # Call the OpenAI API using the client instance
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=formatted_messages,
            temperature=temperature,
            # max_tokens=150
        )

        # Extract the assistant's message properly
        assistant_message = response.choices[0].message.content

        # Update token counts
        input_tokens = self._count_tokens(messages)
        output_tokens = self._count_tokens([HumanMessage(content=assistant_message)])

        self.total_tokens = input_tokens + output_tokens
        self.output_tokens = output_tokens

        return assistant_message
    
    def _count_tokens(self, messages: list[BaseMessage]) -> int:
        # Counts tokens for a list of messages
        total = 0
        for message in messages:
            total += len(self.encoding.encode(message.content))
        return total

    def get_num_total_tokens(self) -> int:
        return self.total_tokens

    def get_num_output_tokens(self) -> int:
        return self.output_tokens