import openai
from openai import OpenAI
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from prompts import identity
import time


def get_model(args):
    model_name, temperature = args.model, args.temperature
    if 'gpt' in model_name:
        model = GPT(args.api_key, model_name, temperature)
        return model
    elif 'llama' in model_name.lower():
        return LLaMA(model_name, temperature)
    elif 'mistral' in model_name.lower():
        return Mistral(model_name, temperature)
    elif 'mixtral' in model_name.lower():
        return Mixtral(model_name, temperature)


class Model(object):
    def __init__(self):
        self.post_process_fn = identity
    
    def set_post_process_fn(self, post_process_fn):
        self.post_process_fn = post_process_fn


class GPT(Model):
    def __init__(self, api_key, model_name, temperature):
        super().__init__()
        self.model_name = model_name
        self.temperature = temperature
        self.client = OpenAI(api_key=api_key)

    def get_response(self, **kwargs):
        try:
            res = self.client.chat.completions.create(**kwargs)
            return res
        except openai.APIConnectionError as e:
            print('APIConnectionError')
            time.sleep(30)
            return self.get_response(**kwargs)
        except openai.RateLimitError as e:
            print('RateLimitError')
            time.sleep(10)
            return self.get_response(**kwargs)
        except openai.APITimeoutError as e:
            print('APITimeoutError')
            time.sleep(30)
            return self.get_response(**kwargs)
        except openai.BadRequestError as e:
            print('BadRequestError')
            kwargs['model'] = 'gpt-3.5-turbo-16k'
            return self.get_response(**kwargs)

    def forward(self, head, prompts, max_new_tokens=20, single_turn=True,):
        del single_turn, max_new_tokens
        messages = [
            {"role": "system", "content": head}
        ]
        info = {}
        for i, prompt in enumerate(prompts):
            messages.append(
                {"role": "user", "content": prompt}
            )
            response = self.get_response(
                model=self.model_name,
                messages=messages,
                temperature=self.temperature,
            )
            messages.append(
                {"role": "assistant", "content": response.choices[0].message.content}
            )
            info = dict(response.usage)  # completion_tokens, prompt_tokens, total_tokens
            info['response'] = messages[-1]["content"]
            info['message'] = messages
        return self.post_process_fn(info['response']), info


class LLaMA(Model):
    def __init__(self, model_name, temperature):
        super().__init__()
        self.model_name = model_name
        self.temperature = temperature

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = "[PAD]"
        tokenizer.padding_side = "left"
        self.tokenizer = tokenizer
        self.pipeline = pipeline(
            "text-generation",
            model=model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            tokenizer=tokenizer,
            temperature=temperature,
        )

    def forward(self, head, prompts, max_new_tokens=20):
        del head
        prompt = prompts[0]
        sequences = self.pipeline(
            prompt,
            do_sample=False,
            top_k=1,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
        )
        response = sequences[0]['generated_text']
        info = {
            'message': prompt,
            'response': response
        }
        return self.post_process_fn(info['response']), info
    

class Mistral(Model):
    def __init__(self, model_name, temperature):
        super().__init__()
        self.model_name = model_name
        self.temperature = temperature

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        self.tokenizer = tokenizer
        self.pipeline = pipeline(
            "text-generation",
            model=model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            tokenizer=tokenizer,
            temperature=temperature,
        )

    def forward(self, head, prompts, max_new_tokens):
        del head
        prompt = prompts[0]
        sequences = self.pipeline(
            prompt,
            do_sample=False,
            top_k=1,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
        )
        response = sequences[0]['generated_text']
        info = {
            'message': prompt,
            'response': response
        }
        return self.post_process_fn(info['response']), info


class Mixtral(Model):
    def __init__(self, model_name, temperature):
        super().__init__()
        self.model_name = model_name
        self.temperature = temperature

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        self.tokenizer = tokenizer
        quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
        )
        model = AutoModelForCausalLM.from_pretrained(model_name,
                                                    quantization_config=quantization_config,
                                                    device_map="auto",)
        self.pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            temperature=temperature,
        )

    def forward(self, head, prompts, max_new_tokens, single_turn=True):
        del head, single_turn
        prompt = prompts[0]
        sequences = self.pipeline(
            prompt,
            do_sample=False,
            top_k=1,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
        )
        response = sequences[0]['generated_text']
        info = {
            'message': prompt,
            'response': response
        }
        return self.post_process_fn(info['response']), info
