import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
import os
from loguru import logger
import json
from peft import get_peft_model, get_peft_config, PeftModel
from trl import AutoModelForCausalLMWithValueHead
from auto_gptq import AutoGPTQForCausalLM
from transformers import LlamaTokenizer

LLAMA_PATH = "/home/ubuntu/llama/weights/converted"
VICUNA_PATH = "/home/ubuntu/"

def load_tokenizer(dir_or_model):
    logger.debug(f"Loading tokenizer for {dir_or_model}")

    is_lora_dir = os.path.isfile(os.path.join(dir_or_model, "adapter_config.json"))

    if is_lora_dir:
        loaded_json = json.load(open(os.path.join(dir_or_model, "adapter_config.json"), "r"))
        model_name = loaded_json["base_model_name_or_path"]
    else:
        model_name = dir_or_model

    if model_name.startswith("llama"):
        model_name = os.path.join(LLAMA_PATH, model_name)
    elif model_name.startswith("vicuna"):
        model_name = os.path.join(VICUNA_PATH, model_name)

    if model_name.startswith("decapoda-research"):
        tokenizer = LlamaTokenizer.from_pretrained(model_name)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    if model_name.startswith("llama") or model_name.startswith(os.path.join(LLAMA_PATH, "llama")) or model_name.startswith("decapoda-research"):
        tokenizer.eos_token = "</>"  #not sure why this doesn't get loaded in 
        tokenizer.eos_token_id = 2

    if tokenizer.pad_token is None:
        logger.debug("Setting pad token to eos token")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    return tokenizer

def load_model(dir_or_model, classification=False, token_classification=False, return_tokenizer=False, dtype=torch.bfloat16, load_dtype=True, 
                rl=False, peft_config=None):
    logger.debug(f"Loading model for {dir_or_model} with {classification}, {dtype}, {load_dtype}")
    is_lora_dir = os.path.isfile(os.path.join(dir_or_model, "adapter_config.json"))

    if not load_dtype:
        dtype = torch.float32

    if is_lora_dir:
        loaded_json = json.load(open(os.path.join(dir_or_model, "adapter_config.json"), "r"))
        model_name = loaded_json["base_model_name_or_path"]
    else:
        model_name = dir_or_model

    original_model_name = model_name

    if model_name.startswith("llama"):
        model_name = os.path.join(LLAMA_PATH, model_name)
    elif model_name.startswith("vicuna"):
        model_name = os.path.join(VICUNA_PATH, model_name)

    if classification:
        model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, torch_dtype=dtype, use_auth_token=True, device_map="auto")  # to investigate: calling torch_dtype here fails.
    elif token_classification:
        model = AutoModelForTokenClassification.from_pretrained(model_name, trust_remote_code=True, torch_dtype=dtype, use_auth_token=True, device_map="auto")
    elif rl:
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name, trust_remote_code=True, torch_dtype=dtype, use_auth_token=True, 
                                                                  peft_config=peft_config, device_map="auto")
    else:
        if model_name.endswith("GPTQ"):
            model = AutoGPTQForCausalLM.from_quantized(model_name,
                                                        use_safetensors=True,
                                                        trust_remote_code=True,
                                                        use_triton=True,
                                                        quantize_config=None, device_map="auto")
        else:
            model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=dtype, use_auth_token=True, device_map="auto")
    
    if is_lora_dir:
        model = PeftModel.from_pretrained(model, dir_or_model)
        
    try:
        tokenizer = load_tokenizer(original_model_name)
        model.config.pad_token_id = tokenizer.pad_token_id
    except Exception:
        pass
    if return_tokenizer:
        return model, load_tokenizer(original_model_name)
    return model
