from typing import Any
from openai import OpenAI
import openai
from tenacity import retry, wait_chain, wait_fixed
import google.generativeai as genai
import boto3
import json
import time


class GPT:
    def __init__(self, model_name, temperature=0, seed=0, api_key=None):
        self.model_name = model_name
        self.client = OpenAI(api_key=api_key)
        self.T = temperature
        self.seed=seed
        

    def __call__(self, prompt,  n:int=1, debug=False, **kwargs: Any) -> Any:
        prompt = [{'role':'user', 'content':prompt}]
        if debug:
            return self.client.chat.completions.create(messages=prompt, n=n, model=self.model_name, temperature=self.T, seed=self.seed, **kwargs)
            
        else:
            return self.call_wrapper(messages=prompt, n=n, model=self.model_name,temperature=self.T, seed=self.seed,  **kwargs)
    

    @retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +
                [wait_fixed(5) for i in range(2)] +
                [wait_fixed(10)])) # not use for debug
    def call_wrapper(self, **kwargs):
        return self.client.chat.completions.create(**kwargs)
    
    def resp_parse(self, response)->list:
        n = len(response.choices)
        return [response.choices[i].message.content for i in range(n)]
    
    def eval_call_wrapper(self, prompt, **kwargs: Any) -> Any:
        retry_count = 0
        stop = False
        while retry_count < 5 or stop:
            try:
                return self.__call__(prompt, debug=True, **kwargs)
            
            except openai.BadRequestError as e:
                err = e.body['message']
                if 'repetitive pattern' in err:
                    print(f"error: {err}, use default fail eval")
                    stop = True
                    return None
            except Exception as e:
                print(f"error exception: {e}")
                print(kwargs)
                time.sleep(2**retry_count+1)
                retry_count += 1
        
    # only for handle eval error on repetitive pattern
    def eval_call(self, prompt, n:int=1, debug=False, **kwargs: Any) -> Any:
        if debug:
            return self.__call__(prompt, debug=True,  n=n,  **kwargs)
            
        else:
            return self.eval_call_wrapper(prompt, n=n,   **kwargs)
class Gemini:
    FAILED_OUTPUT = "Sorry, I can't assist with that."
    def __init__(self, temperature=0, api_key=None) -> None:
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel('gemini-pro')
        self.T = temperature

    def __call__(self, prompt, n=1, debug=False, **kwargs: Any) -> Any:
        if n != 1:
            raise ValueError(f"n should be 1, but got {n}")
        
        if debug:
            response = self.model.generate_content(prompt,
            safety_settings={'HARASSMENT':'block_none',"HATE_SPEECH":'block_none', "DANGEROUS":'block_none',"SEXUAL":'block_none' },
            generation_config=genai.types.GenerationConfig(
                # Only one candidate for now.
                candidate_count=1,
                temperature=self.T),
            **kwargs
            )
        else:
            response = self.call_wrapper(prompt,
            safety_settings={'HARASSMENT':'block_none',"HATE_SPEECH":'block_none', "DANGEROUS":'block_none',"SEXUAL":'block_none' },
            generation_config=genai.types.GenerationConfig(
                # Only one candidate for now.
                candidate_count=1,
                temperature=self.T),
            **kwargs
            )

        return response

    @retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +
                [wait_fixed(5) for i in range(2)] +
                [wait_fixed(10)])) # not use for debug
    def call_wrapper(self, prompt, **kwargs):
        return self.model.generate_content(prompt,**kwargs)
    
    def resp_parse(self, response)->list:
        '''
        response: google.generativeai.types.generation_types.GenerateContentResponse
        '''
        try:
            res = [response.candidates[0].content.parts[0].text]
        except Exception as e:
            res = [self.FAILED_OUTPUT]
            print(f"error: {e}")
    
        return res


class Claude:
    def __init__(self, model_name, temperature=0) -> None:
        self.T = temperature
        self.model = boto3.client(service_name='bedrock-runtime')
        if model_name == 'claude-v2':
            self.model_name = 'anthropic.claude-v2'
        elif model_name == 'claude-v2.1':
            self.model_name = 'anthropic.claude-v2:1'
        elif model_name == 'claude-v1':
            raise NotImplementedError
        else:
            raise ValueError(f"model_name should be claude-v2 or claude-v2.1, but got {model_name}")
        

    def __call__(self, prompt, n=1, debug=False, **kwargs: Any) -> Any:
        prompt = f"""Human:{prompt}\n\nAssistant:"""
        
        body = json.dumps({
            "prompt": prompt,
            "max_tokens_to_sample": 2048,
            "temperature": self.T,
        })

        if debug:
            response = self.model.invoke_model(
                body=body,
                modelId=self.model_name,
                accept='application/json',
                contentType='application/json',
                **kwargs
            )
        else:
            response = self.call_wrapper(
                body=body,
                modelId=self.model_name,
                accept='application/json',
                contentType='application/json',
                **kwargs
            )

        return json.loads(response.get('body').read())
    
    @retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +
                [wait_fixed(5) for i in range(2)] +
                [wait_fixed(10)])) # not use for debug
    def call_wrapper(self, **kwargs):
        return self.model.invoke_model(**kwargs)
    
    def resp_parse(self, response)->list:
        return [response.get('completion')]


class Llama_vllm:
    def __init__(self, model_name, temperature=0) -> None:
        raise NotImplementedError
   
class Llama_api:
    def __init__(self, model_name, temperature=0.6) -> None:
        assert 'llama-api' in model_name
        self.T = temperature
        self.client = OpenAI(
            api_key="your-api-keys",
            base_url="https://api.deepinfra.com/v1/openai")
        
        if '7b' in model_name:
            self.model_DI = "meta-llama/Llama-2-7b-chat-hf"
        elif '13b' in model_name:
            self.model_DI = "meta-llama/Llama-2-13b-chat-hf"
        elif '70b' in model_name:
            self.model_DI = "meta-llama/Llama-2-70b-chat-hf"
        else:
            raise ValueError(f"model_name should be llama-api-7b, llama-api-13b or llama-api-70b, but got {model_name}")
        
    def __call__(self, prompt, n=1, debug=False, top_p=0.9, **kwargs: Any) -> Any:
        prompt = [{'role':'user', 'content':prompt}]
        if debug:
            response = self.client.chat.completions.create(messages=prompt, n=n, model=self.model_DI, temperature=self.T, top_p=top_p, **kwargs)
        else:
            response = self.call_wrapper(messages=prompt, n=n, model=self.model_DI, temperature=self.T, top_p=top_p, **kwargs)
        return response
    
    @retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +
                [wait_fixed(5) for i in range(2)] +
                [wait_fixed(10)])) # not use for debug
    def call_wrapper(self, **kwargs):
        return self.client.chat.completions.create(**kwargs)

    def resp_parse(self, response)->list:
        n = len(response.choices)
        return [response.choices[i].message.content for i in range(n)]

def load_model(model_name, api_key, **kwargs):
    if "gpt" in model_name and "gpt2" not in model_name:
        return GPT(model_name, api_key=api_key, **kwargs)
    
    elif "gpt2" in model_name:
        return GPT2(model_name, api_key=api_key, **kwargs)

    elif "gemini" in model_name:
        return Gemini(api_key=api_key, **kwargs)
    
    elif "claude" in model_name:
        return Claude(model_name, **kwargs)
    elif "llama" in model_name:
        if "llama-vllm" in model_name:
            return Llama_vllm(model_name, **kwargs)
        return Llama_api(model_name, **kwargs)
    else:
        raise ValueError(f"model_name invalid")


