from openai import OpenAI
import anthropic
import os
import time
import torch
import gc
from typing import Dict, List
import google.generativeai as genai
import urllib3
from copy import deepcopy
import openai
import requests

from config import LLAMA_API_LINK, VICUNA_API_LINK, MISTRAL_API_LINK
# OPENAI_API_KEY="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"

llm_count=0

class LanguageModel():
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generate(self, prompts_list: List, max_n_tokens: int, temperature: float):
        """
        Generates responses for a batch of prompts using a language model.
        """
        raise NotImplementedError
        
class HuggingFace(LanguageModel):
    def __init__(self,model_name, model, tokenizer):
        self.model_name = model_name
        self.model = model 
        self.tokenizer = tokenizer
        self.eos_token_ids = [self.tokenizer.eos_token_id]

    def batched_generate(self, 
                        full_prompts_list,
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        inputs = self.tokenizer(full_prompts_list, return_tensors='pt', padding=True)
        inputs = {k: v.to(self.model.device.index) for k, v in inputs.items()} 

        
        # Batch generation
        #For all targete models, the official default values are used for all parameters except [max_new_tokens]
        output_ids=self.model.generate(
                **inputs,
                max_new_tokens=500,
        )
        # If the model is not an encoder-decoder type, slice off the input tokens
        if not self.model.config.is_encoder_decoder:
            output_ids = output_ids[:, inputs["input_ids"].shape[1]:]

        # Batch decoding
        outputs_list = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        for key in inputs:
            inputs[key].to('cpu')
        output_ids.to('cpu')
        del inputs, output_ids
        gc.collect()
        torch.cuda.empty_cache()
        print("*********************")
        print(outputs_list)
        print("*********************")
        return outputs_list

    def extend_eos_tokens(self):        
        # Add closing braces for Vicuna/Llama eos when using attacker model
        self.eos_token_ids.extend([
            self.tokenizer.encode("}")[1],
            29913, 
            9092,
            16675])


class APIModel(LanguageModel): 

    #API_HOST_LINK = "ADD_LINK"
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 3
    API_MAX_RETRY = 20
    
    API_TIMEOUT = 100
    
    
    def choosen_model(self, model_name):
        if "llama-2" in model_name:
            model="meta-llama/Llama-2-7b-chat-hf"
        elif "llama-3" in model_name:
            model="meta-llama/Llama-3-8b-chat-hf"
        elif "vicuna" in model_name:
            model="lmsys/vicuna-7b-v1.5"
        elif "mistral" in model_name:
            model="mistralai/Mistral-7B-Instruct-v0.1"
        elif "qwen" in model_name:
            model="Qwen/Qwen1.5-7B-Chat"
        elif "gemma" in model_name:
            model="google/gemma-7b-it"
        elif "claude" in model_name:
            model="claude-3-sonnet-20240229"
        else:
            model=""
            print("ERROR")

        return model
        
    def generate(self, conv: List[Dict], 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, OpenAI API format
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        ''' 
        client = OpenAI(
        # This is the default and can be omitted
        api_key="",
        base_url=''
        )
        for _ in range(20):
            tag=False
            try:
                Model_Name=self.choosen_model(self.model_name)
                print("mistralai/Mistral-7B-Instruct-v0.1")
                if temperature>0: 
                    response=client.chat.completions.create(
                        messages=[{"role": "user", "content": conv}],
                        model="mistralai/Mistral-7B-Instruct-v0.1",
                        temperature = temperature,
                        max_tokens = max_n_tokens,
                    )
                else:
                    response=client.chat.completions.create(
                        messages= [{"role": "user", "content": conv}],
                        model="mistralai/Mistral-7B-Instruct-v0.1",
                        temperature = 1,
                        max_tokens = 1024,
                    )
                output=response.choices[0].message.content
                print("这是togetherai的回答")
                print(output)
                tag=True
                global llm_count
                llm_count+=1
                break
            except openai.APIConnectionError as e:
                print("The server could not be reached")
                print(e.__cause__)  # an underlying Exception, likely raised within httpx.
            except openai.RateLimitError as e:
                print("A 429 status code was received; we should back off a bit.")
            except openai.APIStatusError as e:
                print("Another non-200-range status code was received")
                print(e.status_code)
                print(e.response)
        
            time.sleep(3)
        return output if tag else "I can't assist your question"

    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class ClaudeModel(LanguageModel):
    def generate(self, conv: List[Dict], max_n_tokens: int, temperature: float,top_p: float):
        client=openai.OpenAI(
            api_key="",
            base_url=""
        )
        for _ in range(5):
            tag=False
            try:
                if temperature>0:
                    response=client.chat.completions.create(
                    model="claude-3-sonnet-20240229",
                    messages=[{"role": "user", "content": conv}],
                    temperature=temperature,
                    max_tokens=max_n_tokens,
                )
                else:
                    response=client.chat.completions.create(
                    model="claude-3-sonnet-20240229",
                    messages=[{"role": "user", "content": conv}],
                    temperature=1,
                    max_tokens=1024,
                )
                result=response.choices[0].message.content
                tag=True
                break
            except openai.APIConnectionError as e:
                print("The server could not be reached")
                print(e.__cause__)  # an underlying Exception, likely raised within httpx.
            except openai.RateLimitError as e:
                print("A 429 status code was received; we should back off a bit.")
            except openai.APIStatusError as e:
                print("Another non-200-range status code was received")
                print(e.status_code)
                print(e.response)
            time.sleep(2)
        return result if tag else "I can't assist your question"
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class APIModelLlama7B(APIModel): 
    API_HOST_LINK = LLAMA_API_LINK
    MODEL_API_KEY = os.getenv("LLAMA_API_KEY")

class APIModelVicuna13B(APIModel): 
    API_HOST_LINK = VICUNA_API_LINK 
    MODEL_API_KEY = os.getenv("VICUNA_API_KEY")

class APIModelMistral7BInstruct(APIModel): 
    API_HOST_LINK = MISTRAL_API_LINK 
    MODEL_API_KEY = os.getenv("MISTRAL_API_KEY")

class GPT(LanguageModel):
    API_RETRY_SLEEP = 3
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 2
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    
    # openai.api_key = os.getenv("OPENAI_API_KEY")
    # openai.api_key ="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"
    def generate(self, conv: List[Dict], 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, OpenAI API format
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        client = OpenAI(
        # This is the default and can be omitted
        api_key="",
        base_url=''
        )
        for _ in range(8):
            tag=False
            try: 
                response=client.chat.completions.create(
                    messages= conv,
                    model=self.model_name,
                    temperature = temperature,
                    max_tokens = max_n_tokens,
                )
                output=response.choices[0].message.content
                tag=True
                global llm_count
                llm_count+=1
                break
            except openai.APIConnectionError as e:
                print("The server could not be reached")
                print(e.__cause__)  # an underlying Exception, likely raised within httpx.
            except openai.RateLimitError as e:
                print("A 429 status code was received; we should back off a bit.")
            except openai.APIStatusError as e:
                print("Another non-200-range status code was received")
                print(e.status_code)
                print(e.response)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output if tag else "I can't assist your question"
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]
     
class PaLM():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."
    API_KEY = os.getenv("PALM_API_KEY")

    def __init__(self, model_name) -> None:
        self.model_name = model_name
        genai.configure(api_key=self.API_KEY) 

    def generate(self, conv: List, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):            
            try:
                completion = genai.chat(
                    messages=conv,
                    temperature=temperature,
                    top_p=top_p
                )
                output = completion.last
                
                if output is None:
                    # If PaLM refuses to output and returns None, we replace it with a default output
                    output = self.default_output
                else:
                    # Use this approximation since PaLM does not allow
                    # to specify max_tokens. Each token is approximately 4 characters.
                    output = output[:(max_n_tokens*4)] 
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]


class GeminiPro():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."
    API_KEY = os.getenv("PALM_API_KEY")

    def __init__(self, model_name) -> None:
        self.model_name = model_name
        genai.configure(api_key=self.API_KEY) 

    def generate(self, conv: List, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):            
            try:
                model = genai.GenerativeModel(self.model_name)
                output = model.generate_content(
                    contents = conv,
                    generation_config = genai.GenerationConfig(
                        candidate_count = 1,
                        temperature = temperature,
                        top_p = top_p,
                        max_output_tokens=max_n_tokens,
                    )
                )

                if output is None:
                    # If PaLM refuses to output and returns None, we replace it with a default output
                    output = self.default_output
                else:
                    # Use this approximation since PaLM does not allow
                    # to specify max_tokens. Each token is approximately 4 characters.
                    output = output.text
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]


class Claude():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."
    API_KEY = "sk-WEByQeUWNKIJsqFL4fE14d558aAd450aAcF3C21fC8Fa5b48"
    MODEL_NAME="claude-3-sonnet-20240229"

    def __init__(self, model_name) -> None:
        self.model_name = model_name
        genai.configure(api_key=self.API_KEY) 

    def generate(self, conv: List, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):
        '''
        Args:
            conv: List of dictionaries, 
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):            
            try:
                completion = genai.chat(
                    messages=conv,
                    temperature=temperature,
                    top_p=top_p
                )
                output = completion.last
                
                if output is None:
                    # If PaLM refuses to output and returns None, we replace it with a default output
                    output = self.default_output
                else:
                    # Use this approximation since PaLM does not allow
                    # to specify max_tokens. Each token is approximately 4 characters.
                    output = output[:(max_n_tokens*4)] 
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def batched_generate(self, 
                        convs_list: List[List[Dict]],
                        max_n_tokens: int, 
                        temperature: float,
                        top_p: float = 1.0,):
        return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]