import os
import time
from typing import Dict, List
import torch
import ray
from accelerate.utils import find_executable_batch_size
from huggingface_hub import login as hf_login
from ..model_utils import load_model_and_tokenizer
from vllm import LLM, SamplingParams
import openai
import anthropic
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold

from together import Together

class LanguageModel():
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generate(self, prompts_list: List, max_n_tokens: int, temperature: float):
        """
        Generates responses for a batch of prompts using a language model.
        """
        raise NotImplementedError

    def is_initialized(self):
        print(f"==> Initialized {self.model_name}")

    def get_attribute(self, attr_name):
        return getattr(self, attr_name, None)
    
class HuggingFace(LanguageModel):
    def __init__(self, model_name_or_path, **model_kwargs):        
        model, tokenizer = load_model_and_tokenizer(model_name_or_path, **model_kwargs)
        self.model_name = model_name_or_path
        self.model = model 
        self.tokenizer = tokenizer
        self.eos_token_ids = [self.tokenizer.eos_token_id]
        self.generation_batch_size = 32

        self.eos_token_ids.extend([self.tokenizer.encode("}", add_special_tokens=False)[0]])

    def batch_generate_bs(self, batch_size, inputs, **generation_kwargs):
        if batch_size != self.generation_batch_size:
            self.generation_batch_size = batch_size
            print(f"WARNING: Found OOM, reducing generation batch_size to: {self.generation_batch_size}")

        gen_outputs = []
        for i in range(0, len(inputs), batch_size):
            inputs_b = inputs[i:i+batch_size]
            encoded_b = self.tokenizer(inputs_b, return_tensors='pt', padding='longest')
            with torch.no_grad():
                output_ids = self.model.generate(
                        **encoded_b.to(self.model.device),
                        **generation_kwargs
                    ).cpu()

            if not self.model.config.is_encoder_decoder:
                output_ids = output_ids[:, encoded_b["input_ids"].shape[1]:]
            decoded_outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            gen_outputs.extend(decoded_outputs)
        return gen_outputs
    
    def batched_generate(self, 
                        full_prompts_list,
                        max_n_tokens: int, 
                        temperature: float,
                        stop_tokens: List[str]=[]
                        ):
        if temperature > 0:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=True,
                    temperature=temperature,
                    eos_token_id=self.eos_token_ids)
        else:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=False,
                    eos_token_id=self.eos_token_ids,
            )
        generation_function = find_executable_batch_size(self.batch_generate_bs, self.generation_batch_size)
        outputs = generation_function(full_prompts_list, **generation_kwargs)
        return outputs

class GPT(LanguageModel):
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 20

    def __init__(self, model_name, token):
        self.model_name = model_name
        openai.api_key = token
        self.api_key = token 


    def generate(self, conv: List[Dict], 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, OpenAI API format
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                #OpenAI migration
                client = openai.OpenAI(api_key=self.api_key)

                response = client.chat.completions.create(
                        model = self.model_name,
                        messages = conv,
                        max_tokens = max_n_tokens,
                        temperature = temperature,
                        top_p = top_p,
                        )
                output = response.choices[0].message.content
                break
            except openai.APIConnectionError as e:
                print("The server could not be reached")
                print(e.__cause__)  # an underlying Exception, likely raised within httpx.
                time.sleep(self.API_RETRY_SLEEP)
            except openai.RateLimitError as e:
                print("A 429 status code was received; we should back off a bit.")
                time.sleep(self.API_RETRY_SLEEP)
            except openai.APIStatusError as e:
                print("Another non-200-range status code was received")
                print(e.status_code)
                print(e.response)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class TogetherAPI(LanguageModel):
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 20

    def __init__(self, model_name, api_key):
        #super().__init__(model_name=model_name)
        self.model_name=model_name
        self.client = Together(api_key=api_key)

    def generate(self, input, max_new_tokens: int, temperature: float, top_p: float):
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[{"role": "user", "content": input}],
                    max_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=50,
                    repetition_penalty=1.0,
                    stop=["[/INST]", "</s>"],
                    stream=False,  # Adjust this based on your requirements
                )
                output = response.choices[0].message.content
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(self.API_QUERY_SLEEP)
        return output

    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(_input, max_n_tokens, temperature, top_p) for _input in convs_list]


class Claude():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

   
    def __init__(self, model_name, token) -> None:
        self.model_name = model_name
        self.API_KEY = token

        self.model= anthropic.Anthropic(
            api_key=self.API_KEY,
            )

    def generate(self, conv: List, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of conversations 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = self.model.completions.create(
                    model=self.model_name,
                    max_tokens_to_sample=max_n_tokens,
                    prompt=conv,
                    temperature=temperature,
                    top_p=top_p
                )
                output = completion.completion
                break
            except Exception as e:
                # as of Jan 2023, this show the output has been blocked
                if "Output blocked by content filtering policy" in str(e):
                    output = self.default_output
                    break
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,  **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]
        
class Gemini():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

    def __init__(self, model_name, token) -> None:
        self.model_name = model_name
        genai.configure(api_key=token)

        self.safety_settings={
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        }

        self.model = genai.GenerativeModel(model_name)

    def generate(self, conv: List, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        generation_config=genai.types.GenerationConfig(
                max_output_tokens=max_n_tokens,
                temperature=temperature,
                top_p=top_p)
        chat = self.model.start_chat(history=[])
        safety_settings = self.safety_settings
        
        for _ in range(self.API_MAX_RETRY):
            try:

                completion = chat.send_message(conv, generation_config=generation_config, safety_settings=safety_settings)
                output = completion.text
                break
            except (genai.types.BlockedPromptException, genai.types.StopCandidateException, ValueError):
                # Prompt was blocked for safety reasons
                output = self.default_output
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(1)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

@ray.remote
class VLLM:
    def __init__(self, model_name_or_path, num_gpus=1, **model_kwargs):
        self.model_name = model_name_or_path

        # In https://github.com/vllm-project/vllm/blob/main/vllm/engine/ray_utils.py
        # VLLM will manually allocate placement_groups for ray actors if using tensor_parallel (num_gpus) > 1
        # So we will set CUDA_VISIBLE_DEVICES to all and let vllm.ray_utils handle this
        if num_gpus > 1:
            resources = ray.cluster_resources()
            available_devices = ",".join([str(i) for i in range(int(resources.get("GPU", 0)))])
            os.environ['CUDA_VISIBLE_DEVICES'] = available_devices
            
        self.model = self.load_vllm_model(model_name_or_path, num_gpus=num_gpus, **model_kwargs)
    
    def batched_generate(self, 
                    full_prompts_list,
                    max_n_tokens: int, 
                    temperature: float,
                    stop_tokens: List[str]=[]):
        
        sampling_params = SamplingParams(temperature=temperature, 
                                    max_tokens=max_n_tokens, 
                                    stop=stop_tokens)
        outputs = self.model.generate(full_prompts_list, sampling_params, use_tqdm=False)
        outputs = [o.outputs[0].text for o in outputs]
        return outputs
    
    def load_vllm_model(self,
        model_name_or_path,
        dtype='auto',
        trust_remote_code=False,
        download_dir=None,
        revision=None,
        token=None,
        quantization=None,
        num_gpus=1,
        ## tokenizer_args
        use_fast_tokenizer=True,
        pad_token=None,
        eos_token=None,
        **kwargs
    ):
        if token:
            hf_login(token=token)
        
        model = LLM(model=model_name_or_path, 
                    dtype=dtype,
                    trust_remote_code=trust_remote_code,
                    download_dir=download_dir,
                    revision=revision,
                    quantization=quantization,
                    tokenizer_mode="auto" if use_fast_tokenizer else "slow",
                    tensor_parallel_size=num_gpus)
        
        if pad_token:
            model.llm_engine.tokenizer.tokenizer.pad_token = pad_token
        if eos_token:
            model.llm_engine.tokenizer.tokenizer.eos_token = eos_token

        return model 

    def is_initialized(self):
        print(f"==> Initialized {self.model_name}")
