import openai
import anthropic
import os
import time
import torch
from typing import Dict, List
from accelerate.utils import find_executable_batch_size
from vllm import SamplingParams
from ..model_utils import load_model_and_tokenizer, load_vllm_model
import ray
import vertexai  # pip3 install google-cloud-aiplatform
from vertexai.preview.language_models import ChatModel
from vertexai.preview.generative_models import GenerativeModel


class LanguageModel():
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generate(self, prompts_list: List, max_n_tokens: int, temperature: float):
        """
        Generates responses for a batch of prompts using a language model.
        """
        raise NotImplementedError
    
    def is_initialized(self):
        print(f"==> Initialized {self.model_name}")

    def get_attribute(self, attr_name):
        return getattr(self, attr_name, None)

class VLLM(LanguageModel):
    def __init__(self, model_name_or_path, **model_kwargs):
        self.model_name = model_name_or_path
        self.model = load_vllm_model(model_name_or_path, **model_kwargs)
    
    def batched_generate(self, 
                    full_prompts_list,
                    max_n_tokens: int, 
                    temperature: float,
                    stop_tokens: List[str]=[]):
        
        sampling_params = SamplingParams(temperature=temperature, 
                                    max_tokens=max_n_tokens, 
                                    stop=stop_tokens)

        outputs = self.model.generate(full_prompts_list, sampling_params, use_tqdm=False)
        outputs = [o.outputs[0].text for o in outputs]
        return outputs
    
class HuggingFace(LanguageModel):
    def __init__(self, model_name_or_path, **model_kwargs):        
        model, tokenizer = load_model_and_tokenizer(model_name_or_path, **model_kwargs)
        self.model_name = model_name_or_path
        self.model = model 
        self.tokenizer = tokenizer
        self.eos_token_ids = [self.tokenizer.eos_token_id]
        self.generation_batch_size = 32

        self.eos_token_ids.extend([self.tokenizer.encode("}", add_special_tokens=False)[0]])
    
    def batch_generate_bs(self, batch_size, inputs, **generation_kwargs):
        if batch_size != self.generation_batch_size:
            self.generation_batch_size = batch_size
            print(f"WARNING: Found OOM, reducing generation batch_size to: {self.generation_batch_size}")

        gen_outputs = []
        for i in range(0, len(inputs), batch_size):
            inputs_b = inputs[i:i+batch_size]
            encoded_b = self.tokenizer(inputs_b, return_tensors='pt', padding='longest')
            with torch.no_grad():
                output_ids = self.model.generate(
                        **encoded_b.to(self.model.device),
                        **generation_kwargs
                    ).cpu()

            if not self.model.config.is_encoder_decoder:
                output_ids = output_ids[:, encoded_b["input_ids"].shape[1]:]
            decoded_outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            gen_outputs.extend(decoded_outputs)
        return gen_outputs
    
    def batched_generate(self, 
                        full_prompts_list,
                        max_n_tokens: int, 
                        temperature: float,
                        stop_tokens: List[str]=[]
                        ):
        eos_token_id = [self.tokenizer.eos_token_id] + [self.tokenizer.encode(t, add_special_tokens=False) for t in stop_tokens]
        if temperature > 0:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=True,
                    temperature=temperature,
                    eos_token_id=self.eos_token_ids)
        else:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=False,
                    eos_token_id=self.eos_token_ids,
            )
        generation_function = find_executable_batch_size(self.batch_generate_bs, self.generation_batch_size)
        outputs = generation_function(full_prompts_list, **generation_kwargs)
        return outputs

class GPT(LanguageModel):
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 45

    def __init__(self, model_name, token, **kwargs):
        self.model_name = model_name
        openai.api_key = token

    def generate(self, conv: List[Dict], 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, OpenAI API format
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                response = openai.ChatCompletion.create(
                            model = self.model_name,
                            messages = conv,
                            max_tokens = max_n_tokens,
                            temperature = temperature,
                            top_p = top_p,
                            request_timeout = self.API_TIMEOUT,
                            )
                output = response["choices"][0]["message"]["content"]
                break
            except openai.error.OpenAIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def batched_generate(self, 
                        convs_list,
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class Claude():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 30
    default_output = "I'm sorry, but I cannot assist with that request."
   
    def __init__(self, model_name, token, **kwargs) -> None:
        self.model_name = model_name
        self.model= anthropic.Anthropic(api_key=token)

    def generate(self, conv, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of conversations 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = self.model.completions.create(
                    model=self.model_name,
                    max_tokens_to_sample=max_n_tokens,
                    prompt=conv,
                    temperature=temperature,
                    top_p=top_p
                )
                output = completion.completion
                break
            except Exception as e:
                print(str(e))
                # as of Jan 2023, this show the output has been blocked
                if "Output blocked by content filtering policy" in str(e):
                    output = self.default_output
                    break
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def batched_generate(self, 
                        convs_list,
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,  **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class Gemini():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

    def __init__(self, model_name, gcp_credentials, project_id, location) -> None:
        self.model_name = model_name
        os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = gcp_credentials
        vertexai.init(project=project_id, location=location)
        self.model = GenerativeModel(model_name)

    def generate(self, conv, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        generation_config={"temperature": temperature, "max_output_tokens": max_n_tokens}
        chat = self.model.start_chat()
        
        for _ in range(self.API_MAX_RETRY):
            try:
                # print("input", conv)
                completion = chat.send_message(conv, generation_config=generation_config)
                output = completion.text
                break
            except (vertexai.generative_models._generative_models.ResponseBlockedError):
                # Prompt was blocked for safety reasons
                output = self.default_output
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(1)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]
