import openai # For GPT-3 API ...
import os,requests
import time
import datetime
import torch
from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM,AutoTokenizer, AutoConfig
from .model_name import llms_path 
from .statlib.stat_mask2 import STAT
from .statlib.utils.read_head_config import get_head


class decoder_for_openai():
    def __init__(self):
        openai.api_key = os.getenv("OPENAI_API_KEY")
        self.engine = llms_path[args.model]
    
    def decode(self,args,input):
        
        time.sleep(args.api_time_interval)            
        response = openai.Completion.create(
        engine=self.engine,
        prompt=input,
        max_tokens=args.max_length,
        temperature=args.temperature,
        stop=None
        )
    
        return response["choices"][0]["text"]

class decoder_for_huggingface():
    def __init__(self, args):
        self.args = args
        model_name = llms_path[self.args.model]
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype="auto",
            attn_implementation="eager",
            return_dict_in_generate=True,
            output_attentions=True).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        self.head_config = get_head(args.stat_head_config)

        self.stat = STAT(
        model=self.model,
        tokenizer=self.tokenizer,
        stat_mode = self.args.stat_mode,
        head_config = self.head_config,
        scale_static = self.args.scale_static,
        scale_dynamic=self.args.scale_dynamic
        )

    def make_query(self, user_input, system_prompt, examples):
        query_type = self.args.query_type

        if query_type == 'concat':
            query = system_prompt
            for example in examples:
                query += f"{example['user']} {example['assistant']}\n"
            query += user_input  
            
        elif query_type == 'chat':
            def format(role, message):
                return f"<|im_start|>{role}\n{message}<|im_end|>\n"

            query = format('system', system_prompt)
            for example in examples:
                query += format('user', example['user'])
                query += format('assistant', example['assistant'])
            query += format('user', user_input)

        elif query_type == 'chat-add':
            def format(role, message):
                return f"<|im_start|>{role}\n{message}<|im_end|>\n"

            query = format('system', system_prompt)
            for example in examples:
                query += format('user', example['user'])
                query += format('assistant', example['assistant'])
            query += format('user', user_input)
            
            query += "<|im_start|>assistant\n"

        elif query_type in ['lamma2-chat', 'codellama-instruct']:
            # ref: llama2-chat: https://gpus.llm-utils.org/llama-2-prompt-template/
            # ref: codellama-instruct: https://github.com/facebookresearch/codellama/blob/main/llama/generation.py#L319-L361

            query = ""
            query += f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" # tokenizer add the first <s>
            for example in examples:
                query += f"{example['user']} [/INST] {example['assistant']}</s><s>[INST] "
            query += f"{user_input} [/INST]"

        elif query_type in ['lamma3-instruct']:
            # ref: https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202
            # ref: https://huggingface.co/blog/llama3#how-to-prompt-llama-3    
        
            query = ""
            query += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" # will tokenizer add the first <|begin_of_text|> ??
            for example in examples:
                query += f"{example['user']} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n {example['assistant']} <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
            query += f"{user_input} <|eot_id|>"

            query += "<|start_header_id|>assistant<|end_header_id|>\n\n"

        elif query_type in ['mistral-instruct']:
            # ref: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
            query = ""
            query += f"[INST] {system_prompt} "
            for example in examples:
                query += f"{example['user']} [/INST] {example['assistant']}</s> [INST] "
            query += f"{user_input} [/INST]"

        else:
            raise NotImplementedError
    
        return query
    
    def decode(self, input): 
        try:
            query = self.make_query(input, "", "")
            input_id = self.tokenizer(query, return_tensors="pt").to(self.model.device)
            start_time = time.time()
            outputs = self.model.generate(input_id.input_ids, 
                do_sample=True, 
                use_cache=True,
                attention_mask=input_id.attention_mask,
                eos_token_id=self.terminators,
                temperature=0.1,
                max_new_tokens=self.args.max_tokens)
            end_time = time.time()
            print("time: {}".format(end_time-start_time))
            response = self.tokenizer.decode(outputs[0,input_id.input_ids.shape[-1]:], skip_special_tokens=True)

            # print(response)
            return response
        except:
            return ""
    
    def decode_with_stat(self, input, parse_groups):
        try:
            query = self.make_query(input, "", "")
            inputs, offset_mapping = self.stat.inputs_from_batch(query)
            with self.stat.apply_steering(
                model=self.model, 
                strings=query, 
                substring_groups=parse_groups, 
                model_input=inputs, 
                offsets_mapping=offset_mapping
            ) as steered_model: 
                start_time = time.time()
                response = steered_model.generate(inputs.input_ids, 
                attention_mask=inputs.attention_mask,
                max_new_tokens=self.args.max_tokens, 
                do_sample=True, 
                use_cache=True,
                temperature=0.01,
                # repetition_penalty = 1.0,
                eos_token_id=self.terminators,
                )
                end_time = time.time()
                print("time: {}".format(end_time-start_time))
            response = self.tokenizer.decode(response[0,inputs.input_ids.shape[-1]:], skip_special_tokens=True)    
            del inputs
            torch.cuda.empty_cache()

            return response 
        except:
            return ""
    

class decoder_for_vllm():
    def __init__(self, args):
        self.args = args
        model_path = llms_path[self.args.model]
        self.sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_length)
        self.llm = LLM(model=model_path)
    
    def make_query(self, user_input, system_prompt, examples):
        query_type = self.args.query_type

        if query_type == 'concat':
            query = system_prompt
            for example in examples:
                query += f"{example['user']} {example['assistant']}\n"
            query += user_input  
            
        elif query_type == 'chat':
            def format(role, message):
                return f"<|im_start|>{role}\n{message}<|im_end|>\n"

            query = format('system', system_prompt)
            for example in examples:
                query += format('user', example['user'])
                query += format('assistant', example['assistant'])
            query += format('user', user_input)

        elif query_type == 'chat-add':
            def format(role, message):
                return f"<|im_start|>{role}\n{message}<|im_end|>\n"

            query = format('system', system_prompt)
            for example in examples:
                query += format('user', example['user'])
                query += format('assistant', example['assistant'])
            query += format('user', user_input)
            
            query += "<|im_start|>assistant\n"

        elif query_type in ['lamma2-chat', 'codellama-instruct']:
            # ref: llama2-chat: https://gpus.llm-utils.org/llama-2-prompt-template/
            # ref: codellama-instruct: https://github.com/facebookresearch/codellama/blob/main/llama/generation.py#L319-L361

            query = ""
            query += f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" # tokenizer add the first <s>
            for example in examples:
                query += f"{example['user']} [/INST] {example['assistant']}</s><s>[INST] "
            query += f"{user_input} [/INST]"

        elif query_type in ['lamma3-instruct']:
            # ref: https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202
            # ref: https://huggingface.co/blog/llama3#how-to-prompt-llama-3    
        
            query = ""
            query += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" # will tokenizer add the first <|begin_of_text|> ??
            for example in examples:
                query += f"{example['user']} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n {example['assistant']} <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
            query += f"{user_input} <|eot_id|>"

            query += "<|start_header_id|>assistant<|end_header_id|>\n\n"

        elif query_type in ['mistral-instruct']:
            # ref: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
            query = ""
            query += f"[INST] {system_prompt} "
            for example in examples:
                query += f"{example['user']} [/INST] {example['assistant']}</s> [INST] "
            query += f"{user_input} [/INST]"

        else:
            raise NotImplementedError
    
        return query

    def decode(self, input):
        query = self.make_query(input, "", "")
        output = self.llm.generate(query, self.sampling_params)
        # prompt = output[0].prompt
        response = output[0].outputs[0].text

        return response


def create_generate_decoder(args):
    if args.call_mode == "vllm":
        return decoder_for_vllm(args)
    elif args.call_mode == "gpt":
        return decoder_for_openai(args)
    else:
        return decoder_for_huggingface(args)

