import transformers
import torch
from openai import OpenAI


def load_LLM(args):
    # load LLM
    if "Llama" in args.llm_model:
        print("You are using Llama model.")
        model_id = f"meta-llama/Meta-{args.llm_model}-Instruct"
        print(f"Using LLM model: {model_id}")
        print(f"Max new tokens: {args.max_new_tokens}")
        print(f"Temperature: {args.temperature}")
        llm_handler = LlamaHandler(args.max_new_tokens, args.temperature, model_id)
    elif "gpt" in args.llm_model:
        print("You are using GPT model.")
        print(f"Using GPT model: {args.llm_model}")
        llm_handler = GPTHandler(args.openai_api_key, args.llm_model)
    return llm_handler

class LLMHandlerBase():
    def __init__(self, model_id):
        self.LLM_type = None
        if "Llama" in model_id:
            self.LLM_type = "Llama"
        elif "gpt" in model_id:
            self.LLM_type = "GPT"

class LlamaHandler(LLMHandlerBase):
    def __init__(self, max_new_tokens, temperature, model_id="meta-llama/Meta-Llama-3.1-8B-Instruct"):
        super().__init__(model_id)
        self.pipeline = transformers.pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature

    def gen_text(self, llm_prompt):
        # messages = [
        #     {
        #         "role":"system",
        #         "content":"Please rephrase the following text, and please directly tell me the answer."
        #     },
        #     {
        #         "role":"user",
        #         "content":"{}".format(text)
        #     }
        # ]
        messages = [
            {
                "role":"system",
                "content":"You are helpful AI"
            },
            {
                "role":"user",
                "content":llm_prompt
            }
        ]
        prompt = self.pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        terminators = [
            self.pipeline.tokenizer.eos_token_id,
            self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        outputs = self.pipeline(
            prompt,
            max_new_tokens = self.max_new_tokens,
            eos_token_id = terminators,
            do_sample = True,
            temperature = self.temperature,
            top_p = 0.9,
        )
        return outputs[0]["generated_text"][len(prompt):]

    # def get_input_template(self, text):
    #     return [
    #         {
    #             "role":"system",
    #             "content":"You are helpful AI"
    #         },
    #         {
    #             "role":"user",
    #             "content":text
    #         }
    #     ]

class GPTHandler(LLMHandlerBase):
    def __init__(self, api_key, model_id="gpt-4o"):
        super().__init__(model_id)
        self.client = OpenAI(api_key=api_key)
        self.model_id = model_id

    def gen_text(self, llm_prompt):
        completion = self.client.chat.completions.create(
            model=self.model_id,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."}, # <-- This is the system message that provides context to the model
                {"role": "user", "content": llm_prompt}  # <-- This is the user message for which the model will generate a response
            ]
        )
        return completion.choices[0].message.content
