import os
import abc
import time
from dataclasses import dataclass

from openai import OpenAI
import anthropic


@dataclass(frozen=True)
class LLMResponse:
    prompt_text: str
    response_text: str
    prompt_info: dict
    logprobs: list
    
    
class LanguageModels(abc.ABC):
    """ A pretrained Large language model"""

    @abc.abstractmethod
    def completion(self, prompt: str) -> LLMResponse:
        raise NotImplementedError("Override me!")
    

class GPTChatModel(LanguageModels):

    def __init__(self, args):
        self.model_name = args.model_name
        self.temperature = args.temperature
        self.max_new_tokens = args.max_new_tokens
        
        self.gpt = OpenAI()
        self.err_cnt = 0

    def completion(self, prompt: list) -> LLMResponse:
        while True:
            api_flag = False
            try:
                completion = self.gpt.chat.completions.create(
                    model=self.model_name,  # gpt-4-0125-preview
                    messages=prompt,
                    temperature=self.temperature,
                    max_tokens=self.max_new_tokens
                )
                result = completion.choices[0].message.content
                api_flag = True
            except: 
                self.err_cnt += 1
                if api_flag == False:
                    print(f"API error occured, wait for 10 seconds. Error count: {self.err_cnt}")
                time.sleep(10)

            if api_flag:
                break
        return self._raw_to_llm_response(model_response=result, prompt_text=prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature)

    @staticmethod
    def _raw_to_llm_response(model_response, prompt_text: str, max_new_tokens: int, temperature: float) -> LLMResponse:
        prompt_info = {
            'max_new_tokens': max_new_tokens,
            'temperature': temperature
        }

        return LLMResponse(prompt_text=prompt_text, response_text=model_response, prompt_info=prompt_info, logprobs=[])


class ClaudeModel(LanguageModels):
    
    def __init__(self, args) -> None:
        self.model_name = args.model_name
        self.temperature = args.temperature
        self.max_new_tokens = args.max_new_tokens
        
        self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
        self.err_cnt = 0
        
    def completion(self, prompt: list) -> LLMResponse:
        system_prompt = prompt.pop(0)
        time.sleep(1)
        while True:
            api_flag = False
            try:
                message = self.client.messages.create(
                    model=self.model_name,
                    max_tokens=self.max_new_tokens,
                    messages=prompt,
                    system=system_prompt['content']
                )
                result = message.content[0].text
                api_flag = True
            except:
                self.err_cnt += 1
                if api_flag == False:
                    print(f"API error occured, wait for 3 seconds. Error count: {self.err_cnt}")
                time.sleep(3)

            if api_flag:
                break
        
        return self._raw_to_llm_response(model_response=result, prompt_text=prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature)
    
    @staticmethod
    def _raw_to_llm_response(model_response,
                             prompt_text: str,
                             max_new_tokens: int,
                             temperature: float) -> LLMResponse:
        prompt_info = {
            'max_new_tokens': max_new_tokens,
            'temperature': temperature
        }

        return LLMResponse(prompt_text=prompt_text, response_text=model_response, prompt_info=prompt_info, logprobs=[])

class LocalModel(LanguageModels):
    
    def __init__(self, args) -> None:
        self.model_name = args.model_name
        self.temperature = args.temperature
        self.max_new_tokens = args.max_new_tokens
        
        self.model = OpenAI(
            api_key="EMPTY",
            base_url=f"http://{args.base_url}/v1"
        )
    
    def completion(self, prompt: list) -> LLMResponse:
        completion = self.model.chat.completions.create(
            model=self.model_name,
            messages=prompt,
            temperature=self.temperature,
            max_tokens=self.max_new_tokens
        )
        result = completion.choices[0].message.content

        return self._raw_to_llm_response(model_response=result, prompt_text=prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature)
    
    @staticmethod
    def _raw_to_llm_response(model_response, prompt_text: str, max_new_tokens: int, temperature: float) -> LLMResponse:
        prompt_info = {
            'max_new_tokens': max_new_tokens,
            'temperature': temperature
        }
        
        return LLMResponse(prompt_text=prompt_text, response_text=model_response, prompt_info=prompt_info, logprobs=[])
        

class LanguageModelInterface:
    model_family = {
        'gpt_chat_models': ['gpt-4-turbo-2024-04-09', 'gpt-4o-2024-08-06'],
        'local_models': ["meta-llama-3.1-70b-instruct", 'meta-llama-3.1-8b-instruct', '846357c7ee5e3f50575fd4294edb3d898c8ea100', 'provergen-augment-10000', 'provergen-augment-20000', "provergen-final-5000", "checkpoint-38", "checkpoint-77", "checkpoint-114", "checkpoint-155", "checkpoint-231", "checkpoint-116", "checkpoint-233", "checkpoint-348", "e0bc86c23ce5aae1db576c8cca6f06f1f73af2db", "41bd4c9e7e4fb318ca40e721131d4933966c2cc1", "provergen-final-5000-no_rc", "provergen-5000-final", "checkpoint-1056", "checkpoint-35", "provergen-final-5000-no_rc"],
        'claude_models': ['claude-3-opus-20240229', 'claude-3-5-sonnet-20240620']
    }

    model_mapping = {
        'gpt_chat_models': GPTChatModel,
        'local_models': LocalModel,
        'claude_models': ClaudeModel,
    }

    def __init__(self, args) -> None:
        self.model_name = args.model_name
        name = self.model_name.split('/')[-1].lower()

        self.model_type = ''
        for key in self.model_family.keys():
            if name in self.model_family[key]:
                self.model_type = key
                break
        if self.model_type == '':
            raise ValueError(f"The interface for {self.model_name} is not implemented.")
        else:
            self.model = self.model_mapping[self.model_type](args)

    def completion(self, prompt: list) -> LLMResponse:
        return self.model.completion(prompt)

