import openai
import anthropic
import os
import time
import torch
from typing import Dict, List
import google.generativeai as genai
from accelerate.utils import find_executable_batch_size
from vllm import SamplingParams
from ..model_utils import load_model_and_tokenizer, load_vllm_model, get_template
import ray


def load_indiv_model(model_name_or_path, use_vllm=False, token=None, **model_kwargs):
    if 'gpt-' in model_name_or_path:
        lm = GPT(model_name_or_path, token=token)
    elif 'claude-' in model_name_or_path:
        lm = Claude(model_name_or_path, token=token)
    elif 'gemini-' in model_name_or_path:
        lm = Gemini(model_name_or_path, token=token)
    else:
        # We use ray to manage opensource models here
        num_gpus = model_kwargs.get("num_gpus", 1)
        if use_vllm:
            # need to init as ray actors to load multiple vllm models in one environment (Related issue: https://github.com/vllm-project/vllm/issues/1676)
            # https://github.com/vllm-project/vllm/blob/main/vllm/engine/ray_utils.py#L115
            # will allocate num_gpus for this vllm ray actors so we don't pass in num_gpus
            options = {"num_gpus": 1} if num_gpus == 1 else {}
            lm = ray.remote(VLLM).options(**options).remote(model_name_or_path, token=token, **model_kwargs)
        else:
            lm = ray.remote(HuggingFace).options(num_gpus=num_gpus).remote(model_name_or_path, token=token, **model_kwargs)
        ray.get(lm.is_initialized.remote()) # init ray model
    return lm

class LanguageModel():
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generate(self, prompts_list: List, max_new_tokens: int, temperature: float):
        """
        Generates responses for a batch of prompts using a language model.
        """
        raise NotImplementedError

    def is_initialized(self):
        print(f"==> Initialized {self.model_name}")

    def get_attribute(self, attr_name):
        return getattr(self, attr_name, None)

class VLLM(LanguageModel):
    def __init__(self, model_name_or_path, **model_kwargs):
        self.model_name = model_name_or_path
        self.model = load_vllm_model(model_name_or_path, **model_kwargs)
        self.template = get_template(model_name_or_path, **model_kwargs)['prompt']

    def batched_generate(self, 
                    inputs: List[str],
                    max_new_tokens: int, 
                    temperature: float,
                    stop_tokens: List[str]=[],
                    use_tqdm: bool=False,
                    init_msg: str=None,
                    **gen_kwargs):
        
        init_msg = init_msg if init_msg else ''
        inputs = [self.template.format(instruction=_input) + init_msg for _input in inputs]
        sampling_params = SamplingParams(temperature=temperature, 
                                    max_tokens=max_new_tokens, 
                                    stop=stop_tokens,
                                    **gen_kwargs)

        outputs = self.model.generate(inputs, sampling_params, use_tqdm=use_tqdm)
        outputs = [o.outputs[0].text for o in outputs]
        return outputs
    
class HuggingFace(LanguageModel):
    def __init__(self, model_name_or_path, **model_kwargs):        
        model, tokenizer = load_model_and_tokenizer(model_name_or_path, **model_kwargs)
        self.model_name = model_name_or_path
        self.model = model 
        self.tokenizer = tokenizer
        self.template = get_template(model_name_or_path, **model_kwargs)['prompt']

        self.eos_token_ids = [self.tokenizer.eos_token_id]
        self.generation_batch_size = 32


    def batch_generate_bs(self, batch_size, inputs, **generation_kwargs):
        if batch_size != self.generation_batch_size:
            self.generation_batch_size = batch_size
            print(f"WARNING: Found OOM, reducing generation batch_size to: {self.generation_batch_size}")

        gen_outputs = []
        for i in range(0, len(inputs), batch_size):
            inputs_b = inputs[i:i+batch_size]
            encoded_b = self.tokenizer(inputs_b, return_tensors='pt', padding='longest')
            with torch.no_grad():
                output_ids = self.model.generate(
                        **encoded_b.to(self.model.device),
                        **generation_kwargs
                    ).cpu()

            if not self.model.config.is_encoder_decoder:
                output_ids = output_ids[:, encoded_b["input_ids"].shape[1]:]
            decoded_outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            gen_outputs.extend(decoded_outputs)
        return gen_outputs
    
    def batched_generate(self, 
                        inputs,
                        max_new_tokens: int, 
                        temperature: float,
                        stop_tokens: List[str]=[],
                        init_msg: str=None,
                        **gen_kwargs,
                        ):
        eos_token_id = [self.tokenizer.eos_token_id] + [self.tokenizer.encode(t, add_special_tokens=False) for t in stop_tokens]
        if temperature > 0:
            generation_kwargs = dict(max_new_tokens=max_new_tokens, 
                    do_sample=True,
                    temperature=temperature,
                    eos_token_id=self.eos_token_ids, 
                    **gen_kwargs)
        else:
            generation_kwargs = dict(max_new_tokens=max_new_tokens, 
                    do_sample=False,
                    eos_token_id=self.eos_token_ids,
                    **gen_kwargs
            )
        
        init_msg = init_msg if init_msg else ''
        inputs = [self.template.format(instruction=_input) + init_msg for _input in inputs]

        generation_function = find_executable_batch_size(self.batch_generate_bs, self.generation_batch_size)
        outputs = generation_function(inputs, **generation_kwargs)
        return outputs

class GPT(LanguageModel):
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 20

    def __init__(self, model_name, token):
        self.model_name = model_name
        openai.api_key = token


    def generate(self, 
                input, 
                max_new_tokens: int, 
                temperature: float,
                top_p: float):

        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                response = openai.ChatCompletion.create(
                            model = self.model_name,
                            messages = [{"role": "user", "content": input}],
                            max_tokens = max_new_tokens,
                            temperature = temperature,
                            top_p = top_p,
                            request_timeout = self.API_TIMEOUT,
                            )
                output = response["choices"][0]["message"]["content"]
                break
            except openai.error.OpenAIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def batched_generate(self, 
                        inputs,
                        max_new_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(_input, max_new_tokens, temperature, top_p) for _input in inputs]

class Claude():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
   
    def __init__(self, model_name, token) -> None:
        self.model_name = model_name
        self.model= anthropic.Anthropic(
            api_key=token,
            )
        self.template = get_template(fschat_template="claude")['prompt']

    def generate(self, input, 
                max_new_tokens: int, 
                temperature: float,
                top_p: float):

        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = self.model.completions.create(
                    model=self.model_name,
                    max_tokens_to_sample=max_new_tokens,
                    prompt=self.template.format(instruction=input),
                    temperature=temperature,
                    top_p=top_p
                )
                output = completion.completion
                break
            except anthropic.APIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def batched_generate(self, 
                        inputs: List[List[Dict]],
                        max_new_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,  **kwargs):
        return [self.generate(_input, max_new_tokens, temperature, top_p) for _input in inputs]
        
class Gemini():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

    def __init__(self, model_name, token) -> None:
        self.model_name = model_name
        genai.configure(api_key=token)

        self.model = genai.GenerativeModel(model_name)

    def generate(self, input, 
                max_new_tokens: int, 
                temperature: float,
                top_p: float):

        output = self.API_ERROR_OUTPUT
        generation_config=genai.types.GenerationConfig(
                max_output_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p)
        chat = self.model.start_chat(history=[])
        
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = chat.send_message(input, generation_config=generation_config)
                output = completion.text
                break
            except (genai.types.BlockedPromptException, genai.types.StopCandidateException):
                # Prompt was blocked for safety reasons
                output = self.default_output
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(1)
        return output
    
    def batched_generate(self, 
                        inputs,
                        max_new_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(_input, max_new_tokens, temperature, top_p) for _input in inputs]
