import os
import time
from typing import Dict, List

import anthropic
import google.generativeai as genai
import openai
import ray
import torch
from accelerate.utils import find_executable_batch_size
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from huggingface_hub import login as hf_login
from vllm import LLM, SamplingParams

from ..model_utils import load_model_and_tokenizer


class LanguageModel():
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generate(self, prompts_list: List, max_n_tokens: int, temperature: float):
        """
        Generates responses for a batch of prompts using a language model.
        """
        raise NotImplementedError

    def is_initialized(self):
        print(f"==> Initialized {self.model_name}")

    def get_attribute(self, attr_name):
        return getattr(self, attr_name, None)
    
class HuggingFace(LanguageModel):
    def __init__(self, model_name_or_path, **model_kwargs):        
        model, tokenizer = load_model_and_tokenizer(model_name_or_path, **model_kwargs)
        self.model_name = model_name_or_path
        self.model = model 
        self.tokenizer = tokenizer
        self.eos_token_ids = [self.tokenizer.eos_token_id]
        self.generation_batch_size = 32

        self.n_input_tokens = 0
        self.n_output_tokens = 0
        self.n_fwd_passes = 0


        self.eos_token_ids.extend([self.tokenizer.encode("}", add_special_tokens=False)[0]])

    def batch_generate_bs(self, batch_size, inputs, **generation_kwargs):
        if batch_size != self.generation_batch_size:
            self.generation_batch_size = batch_size
            print(
                f"WARNING: Found OOM, reducing generation batch_size to: {self.generation_batch_size}"
            )

        gen_outputs = []
        for i in range(0, len(inputs), batch_size):
            inputs_b = inputs[i : i + batch_size]
            encoded_b = self.tokenizer(inputs_b, return_tensors="pt", padding="longest")
            self.n_input_tokens += encoded_b["input_ids"].shape[1] * encoded_b["input_ids"].shape[0]
            self.n_fwd_passes += encoded_b["input_ids"].shape[0]

            with torch.no_grad():
                output_ids = self.model.generate(
                    **encoded_b.to(self.model.device), **generation_kwargs
                ).cpu()

            if not self.model.config.is_encoder_decoder:
                output_ids = output_ids[:, encoded_b["input_ids"].shape[1] :]
            self.n_output_tokens += output_ids.shape[1] * output_ids.shape[0]
            decoded_outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            gen_outputs.extend(decoded_outputs)
        return gen_outputs
    
    def batched_generate(self, 
                        full_prompts_list,
                        max_n_tokens: int, 
                        temperature: float,
                        stop_tokens: List[str]=[]
                        ):
        if temperature > 0:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=True,
                    temperature=temperature,
                    eos_token_id=self.eos_token_ids)
        else:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=False,
                    eos_token_id=self.eos_token_ids,
            )
        generation_function = find_executable_batch_size(self.batch_generate_bs, self.generation_batch_size)
        outputs = generation_function(full_prompts_list, **generation_kwargs)
        return outputs
    
    def return_logged_variables(self):
        logged_variables = {
            "n_input_tokens": self.n_input_tokens,
            "n_output_tokens": self.n_output_tokens,
            "n_fwd_passes": self.n_fwd_passes
            }
        return logged_variables

    def reset_logged_variables(self):
        self.n_input_tokens = 0
        self.n_output_tokens = 0
        self.n_fwd_passes = 0


class GPT(LanguageModel):
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 20

    def __init__(self, model_name, token):
        self.model_name = model_name
        openai.api_key = token


    def generate(self, conv: List[Dict], 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, OpenAI API format
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                response = openai.ChatCompletion.create(
                            model = self.model_name,
                            messages = conv,
                            max_tokens = max_n_tokens,
                            temperature = temperature,
                            top_p = top_p,
                            request_timeout = self.API_TIMEOUT,
                            )
                output = response["choices"][0]["message"]["content"]
                break
            except openai.error.OpenAIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class Claude():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

   
    def __init__(self, model_name, token) -> None:
        self.model_name = model_name
        self.API_KEY = token

        self.model= anthropic.Anthropic(
            api_key=self.API_KEY,
            )

    def generate(self, conv: List, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of conversations 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = self.model.completions.create(
                    model=self.model_name,
                    max_tokens_to_sample=max_n_tokens,
                    prompt=conv,
                    temperature=temperature,
                    top_p=top_p
                )
                output = completion.completion
                break
            except Exception as e:
                # as of Jan 2023, this show the output has been blocked
                if "Output blocked by content filtering policy" in str(e):
                    output = self.default_output
                    break
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,  **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]
        
class Gemini():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

    def __init__(self, model_name, token) -> None:
        self.model_name = model_name
        genai.configure(api_key=token)

        self.safety_settings={
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
        }

        self.model = genai.GenerativeModel(model_name)

    def generate(self, conv: List, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        generation_config=genai.types.GenerationConfig(
                max_output_tokens=max_n_tokens,
                temperature=temperature,
                top_p=top_p)
        chat = self.model.start_chat(history=[])
        safety_settings = self.safety_settings
        
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = chat.send_message(conv, generation_config=generation_config, safety_settings=safety_settings)
                output = completion.text
                break
            except (genai.types.BlockedPromptException, genai.types.StopCandidateException, ValueError):
                # Prompt was blocked for safety reasons
                output = self.default_output
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(1)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0, **kwargs):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

@ray.remote
class VLLM:
    def __init__(self, model_name_or_path, num_gpus=1, **model_kwargs):
        self.model_name = model_name_or_path

        # In https://github.com/vllm-project/vllm/blob/main/vllm/engine/ray_utils.py
        # VLLM will manually allocate gpu placement_groups for ray actors if using tensor_parallel (num_gpus) > 1
        # So we will set CUDA_VISIBLE_DEVICES to all and let vllm.engine.ray_utils handle this
        if num_gpus > 1:
            resources = ray.cluster_resources()
            available_devices = ",".join([str(i) for i in range(int(resources.get("GPU", 0)))])
            os.environ["CUDA_VISIBLE_DEVICES"] = available_devices

        self.model = self.load_vllm_model(model_name_or_path, num_gpus=num_gpus, **model_kwargs)

        self.n_input_tokens = 0
        self.n_output_tokens = 0
        self.n_fwd_passes = 0

    def batched_generate(
        self, full_prompts_list, max_n_tokens: int, temperature: float, stop_tokens: List[str] = []
    ):

        sampling_params = SamplingParams(
            temperature=temperature, max_tokens=max_n_tokens, stop=stop_tokens
        )
        outputs = self.model.generate(full_prompts_list, sampling_params, use_tqdm=False)
        outputs = [o.outputs[0].text for o in outputs]

        tokenizer = self.model.llm_engine.tokenizer.tokenizer
        self.n_input_tokens += sum([len(tokenizer.encode(prompt)) for prompt in full_prompts_list])
        self.n_fwd_passes += len(full_prompts_list)
        self.n_output_tokens += sum([len(tokenizer.encode(output)) for output in outputs])

        return outputs

    def load_vllm_model(
        self,
        model_name_or_path,
        dtype="auto",
        trust_remote_code=False,
        download_dir=None,
        revision=None,
        token=None,
        quantization=None,
        num_gpus=1,
        ## tokenizer_args
        use_fast_tokenizer=True,
        pad_token=None,
        eos_token=None,
        **kwargs,
    ):
        if token:
            hf_login(token=token)

        model = LLM(
            model=model_name_or_path,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
            download_dir=download_dir,
            revision=revision,
            quantization=quantization,
            tokenizer_mode="auto" if use_fast_tokenizer else "slow",
            tensor_parallel_size=num_gpus,
        )

        if pad_token:
            model.llm_engine.tokenizer.tokenizer.pad_token = pad_token
        if eos_token:
            model.llm_engine.tokenizer.tokenizer.eos_token = eos_token

        return model

    def is_initialized(self):
        print(f"==> Initialized {self.model_name}")

    def return_logged_variables(self):
        logged_variables = {
            "n_input_tokens": self.n_input_tokens,
            "n_output_tokens": self.n_output_tokens,
            "n_fwd_passes": self.n_fwd_passes,
        }
        return logged_variables

    def reset_logged_variables(self):
        self.n_input_tokens = 0
        self.n_fwd_passes = 0
        self.n_output_tokens = 0