import torch
import openai
from tqdm import tqdm
import os

from utils import API_MODELS


class Model:
    """
    _create_model: create model and tokenizer
    _predict_by_local_model: predict by local model
    _pred_by_api: predict by api
    """
    
    def __init__(self, model_name, model_path, token_ratio) -> None:
        self.model_name = model_name
        self.model_path = model_path
        self.token_ratio = token_ratio
        self.tokenizer, self.model = self._create_model()

    def predict(self, data):
        if self.model_name not in API_MODELS:
            return self._predict_by_local_model(data)
        else:
            return self._pred_by_api(data)
        
    def _predict_by_local_model(self, data):
        model_name = self.model_name
        preds = []

        for inputs in tqdm(data):
            outputs = "Error! The model did not predict anything."
            input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to("cuda")
            max_new_tokens = int(len(input_ids[0])*self.token_ratio)
            # max_new_tokens = 1000

            if "t5" in model_name or "ul2" in model_name:
                outputs = self.model.generate(input_ids, 
                                            max_length=max_new_tokens, 
                                            early_stopping=True)
            
                outputs = self.tokenizer.decode(outputs[0])
            
            elif "wizard" in model_name.lower():
                outputs = self.model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=0)
                outputs = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                outputs = outputs[len(inputs):]

            elif "xwin" in model_name.lower():

                outputs = self.model.generate(input_ids, max_new_tokens=max_new_tokens, temperature=0.000001)
                outputs = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                outputs = outputs[len(inputs):]

            elif model_name == "microsoft/phi-1_5":
                outputs = self.model.generate(input_ids,
                                              temperature=0,
                                              max_length=max_new_tokens+len(input_ids))
                
                outputs = self.tokenizer.decode(outputs[0])
                outputs = outputs[len(inputs):]

            elif model_name == "EleutherAI/gpt-neox-20b":
                outputs = self.model.generate(input_ids,
                                            temperature=0.000001, 
                                            max_new_tokens=max_new_tokens,
                                            early_stopping=True,
                                            pad_token_id=self.tokenizer.eos_token_id)
                
                outputs = self.tokenizer.decode(outputs[0])

    
            elif model_name in ["llama-13b", "llama2-70b", "llama2-70b-chat", "llama2-7b", "llama2-7b-chat", "llama2-13b", "llama2-13b-ft", "llama2-13b-chat", "vicuna-13b", "vicuna-13b-v1.3"]:
                outputs = self.model.generate(input_ids, 
                                            temperature=0,
                                            max_new_tokens=max_new_tokens+len(input_ids), 
                                            early_stopping=True)
                
                outputs = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                # remove the input sentence
                outputs = outputs[len(inputs):]
        
            elif model_name in ["databricks/dolly-v1-6b", "cerebras/Cerebras-GPT-13B"]:
                outputs = self.model.generate(input_ids, 
                                            temperature=0,
                                            max_new_tokens=max_new_tokens,
                                            pad_token_id=self.tokenizer.eos_token_id,
                                            early_stopping=True)
                
                outputs = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            elif model_name == "tiiuae/falcon-40b-instruct":
                outputs = self.model.generate(input_ids, 
                                            temperature=0,
                                            max_new_tokens=max_new_tokens,
                                            pad_token_id=self.tokenizer.eos_token_id,
                                            early_stopping=True)
                
                outputs = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            preds.append(outputs)
        
        return preds

    def _pred_by_api(self, data):
        model_name = self.model_name
        if model_name in ["chatgpt", "chatgpt-16k", "gpt4", "gpt4-32k"]:
            return self.__pred_by_openai_api(data)
        else:
            raise NotImplementedError
    
    def __pred_by_openai_api(self, data):
        # max_new_tokens = self.max_new_tokens

        if self.model_name == "chatgpt":
            model_name = "gpt-3.5-turbo"
        elif self.model_name == "chatgpt-16k":
            model_name = "gpt-3.5-turbo-16k"
        elif self.model_name == "gpt4":
            model_name = "gpt-4"
        elif self.model_name == "gpt4-32k":
            model_name = "gpt-4-32k"
        
        """
        implement api_infer for OpenAI models here
        """
        def api_infer(input_texts):
            pass
        
        data_len = len(data)
        input_texts = []
        preds = []
        
        for idx in tqdm(range(data_len)):
            pass

        return preds

        
    def _create_model(self):
        tokenizer, model = None, None
        model_name = self.model_name

        if model_name not in API_MODELS:
            """
            Here you can add you own model.
            """

            if model_name == "google/flan-t5-large":
                from transformers import T5Tokenizer, T5ForConditionalGeneration

                tokenizer = T5Tokenizer.from_pretrained(model_name, device_map="auto")
                model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")

            elif model_name == "WizardLM/WizardMath-13B-V1.0":
                from transformers import AutoTokenizer, AutoModelForCausalLM
                model = AutoModelForCausalLM.from_pretrained("WizardLM/WizardMath-13B-V1.0", device_map="auto")
                tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-13B-V1.0", device_map="auto")
            
            elif model_name == "WizardLM/WizardMath-70B-V1.0":
                from transformers import AutoTokenizer, AutoModelForCausalLM
                model = AutoModelForCausalLM.from_pretrained("WizardLM/WizardMath-70B-V1.0", device_map="auto")
                tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-70B-V1.0", device_map="auto")

            elif model_name == "Xwin-LM/Xwin-LM-13B-V0.1":
                from transformers import AutoTokenizer, AutoModelForCausalLM
                model = AutoModelForCausalLM.from_pretrained("Xwin-LM/Xwin-LM-13B-V0.1", device_map="auto")
                tokenizer = AutoTokenizer.from_pretrained("Xwin-LM/Xwin-LM-13B-V0.1", device_map="auto")                
            
            elif model_name == "Xwin-LM/Xwin-LM-70B-V0.1":
                from transformers import AutoTokenizer, AutoModelForCausalLM
                model = AutoModelForCausalLM.from_pretrained("Xwin-LM/Xwin-LM-70B-V0.1", device_map="auto")
                tokenizer = AutoTokenizer.from_pretrained("Xwin-LM/Xwin-LM-70B-V0.1", device_map="auto")

            elif model_name == "microsoft/phi-1_5":
                from transformers import AutoModelForCausalLM, AutoTokenizer
                model_path = os.path.join(self.model_path, model_name)
                model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto", device_map="auto")
                tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto", device_map="auto")

            elif model_name.lower() in ["llama-13b", "llama2-70b", "llama2-70b-chat", "llama2-7b", "llama2-7b-chat", "llama2-13b", "llama2-13b-ft", "llama2-13b-chat"]:
                
                from transformers import LlamaForCausalLM, LlamaTokenizer
                model_path = os.path.join(self.model_path, model_name)

                tokenizer = LlamaTokenizer.from_pretrained(model_path, device_map="auto", use_default_system_prompt=False)
                if '70b' in model_name:
                    model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16)
                else:
                    model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16)
                
            elif model_name.lower() in ["vicuna-13b", "vicuna-13b-v1.3"]:

                from transformers import AutoModelForCausalLM, AutoTokenizer
                model_path = os.path.join(self.model_path, model_name)

                tokenizer = AutoTokenizer.from_pretrained(model_path, device_map="auto", use_fast=False)
                model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16)

            elif model_name == "google/flan-ul2":

                from transformers import T5ForConditionalGeneration, AutoTokenizer
                
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)

            elif model_name == "tiiuae/falcon-40b-instruct":                                                         
                from transformers import AutoTokenizer, AutoModelForCausalLM

                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
                
            else:
                raise NotImplementedError("The model is not implemented!")
        
        return tokenizer, model