from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from vllm import LLM, SamplingParams

import deepspeed
import torch
import yaml
import os
import mii

def load_embedding_model(path_to_yml):
    with open(path_to_yml, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    model = AutoModel.from_pretrained('models/jina-embeddings-v2-base-zh', trust_remote_code=True)
    device = torch.device(f"cuda:{config['hybrid_searcher']['device']}")
    model.to(device)

    print("Embedding model loaded successfully")
    return model

# please use vllm 0.5.0

class Model:
    def __init__(self, model_name, path_to_yml):
        self.model_name = model_name
        self.path_to_yml = path_to_yml 

    def load_tokenizer(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        
        tokenizer = AutoTokenizer.from_pretrained(config[self.model_name]['base_model'], padding_side = "left")
        return tokenizer

    def load_model(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        model = AutoModelForCausalLM.from_pretrained(
            config[self.model_name]['base_model'],
            torch_dtype=torch.bfloat16,
            device_map=config[self.model_name]['device'],
        )
        return model
    
    def load_config(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        params = {"max_new_tokens":config[self.model_name]["params"]["max_new_tokens"],
                  "do_sample": config[self.model_name]["params"]["do_sample"],
                  "temperature": config[self.model_name]["params"]["temperature"],
                  "top_p": config[self.model_name]["params"]["top_p"],
                  "system_prompt": config[self.model_name]["params"]["system_prompt"],
                 }
        
        return params
    
class DS_Model:
    def __init__(self, model_name, path_to_yml):
        self.model_name = model_name
        self.path_to_yml = path_to_yml

    def load_tokenizer(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        
        tokenizer = AutoTokenizer.from_pretrained(config[self.model_name]['base_model'], padding_side = "left")
        return tokenizer
    
    def load_model(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        model = AutoModelForCausalLM.from_pretrained(
            config[self.model_name]['base_model'],
        )
        infer_config = dict(
                    tensor_parallel = {'tp_size': config[self.model_name]["world_size"]},
                    dtype = torch.bfloat16,
                    replace_method = "auto",
                    replace_with_kernel_inject=config[self.model_name]["replace_with_kernel_inject"]
                    )
        ds_model = deepspeed.init_inference(model, config=infer_config)
        model = ds_model.module
        return model
        

    def load_config(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        params = {"max_new_tokens":config[self.model_name]["params"]["max_new_tokens"],
                  "do_sample": config[self.model_name]["params"]["do_sample"],
                  "temperature": config[self.model_name]["params"]["temperature"],
                  "top_p": config[self.model_name]["params"]["top_p"],
                  "system_prompt": config[self.model_name]["params"]["system_prompt"],
                 }
        
        return params

class lora_Model:
    def __init__(self, model_name, path_to_yml):
        self.model_name = model_name
        self.path_to_yml = path_to_yml
    
    def load_tokenizer(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        tokenizer = AutoTokenizer.from_pretrained(config[self.model_name]['base_model'], padding_side = "left")
        return tokenizer
    def load_model(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        model = AutoModelForCausalLM.from_pretrained(
            config[self.model_name]['base_model'],
            torch_dtype=torch.bfloat16,
            device_map=config[self.model_name]['device'],
        )
        model.load_adapter(config[self.model_name]['lora_path'],adapter_name = "adapter_1")
        model.set_adapter("adapter_1")
        return model

    def load_config(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        params = {"max_new_tokens":config[self.model_name]["params"]["max_new_tokens"],
                  "do_sample": config[self.model_name]["params"]["do_sample"],
                  "temperature": config[self.model_name]["params"]["temperature"],
                  "top_p": config[self.model_name]["params"]["top_p"],
                  "system_prompt": config[self.model_name]["params"]["system_prompt"],
                 }
        return params

class DS2_Model:
    def __init__(self, model_name, path_to_yml):
        self.model_name = model_name
        self.path_to_yml = path_to_yml

    def load_tokenizer(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        tokenizer = AutoTokenizer.from_pretrained(config[self.model_name]['base_model'])
        return tokenizer

    def load_model(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        model = mii.pipeline(config[self.model_name]['base_model'])

        return model

    def load_config(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        params = {"max_new_tokens":config[self.model_name]["params"]["max_new_tokens"],
                  "do_sample": config[self.model_name]["params"]["do_sample"],
                  "temperature": config[self.model_name]["params"]["temperature"],
                  "top_p": config[self.model_name]["params"]["top_p"],
                  "system_prompt": config[self.model_name]["params"]["system_prompt"],
                 }

        return params

class VLLM_Model:
    def __init__(self, model_name, path_to_yml):
        self.model_name = model_name
        self.path_to_yml = path_to_yml

    def load_tokenizer(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        tokenizer = AutoTokenizer.from_pretrained(config[self.model_name]['base_model'])
        return tokenizer

    def load_model(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        model = LLM(model = config[self.model_name]['base_model'],
            dtype= "bfloat16",
            tensor_parallel_size = config[self.model_name]['world_size'],
            gpu_memory_utilization = 0.9,
            enable_lora = config[self.model_name]['enable_lora'],
            max_lora_rank=64,
            tokenizer = config[self.model_name]['base_model'],
        )
        return model

    def load_config(self):
        with open(self.path_to_yml, "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        params = {"max_new_tokens":config[self.model_name]["params"]["max_new_tokens"],
                  "do_sample": config[self.model_name]["params"]["do_sample"],
                  "temperature": config[self.model_name]["params"]["temperature"],
                  "top_p": config[self.model_name]["params"]["top_p"],
                  "system_prompt": config[self.model_name]["params"]["system_prompt"],
                  "lora_path": config[self.model_name]["lora_path"],
                  "enable_lora": config[self.model_name]["enable_lora"],
                  "repetition_penalty": config[self.model_name]["params"]["repetition_penalty"],
                 }

        return params
