from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from model_impl.llama import LlamaForCausalLM_SALS
from model_impl.mistral import MistralForCausalLM_SALS
from model_impl.llama_kivi import LlamaForCausalLM_KIVI
from model_impl.mistral_kivi import MistralForCausalLM_KIVI
import os
import torch

def get_model_path(model_name_or_path, config):
    model_name = model_name_or_path.split("/")[-1]
    withsoftmax = getattr(config, 'withsoftmax', False)
    fintuned = getattr(config, 'finetuned', False)
    withsoftmax = withsoftmax and fintuned
    model_name_or_path = f"./compressed_model/{model_name}_{config.k_high_rank}_{config.sparse_rank}_{config.nsamples}_{config.method}{'_withsoftmax' if withsoftmax else''}{'_finetuned' if fintuned else ''}"
    # model_name_or_path = f"./compressed_model/{model_name}_{config.k_high_rank}_{128}_{config.nsamples}"

    return model_name_or_path

def load_model_and_tokenizer(model_name_or_path, config, compressed=False, preprocessing=False):
    # import pdb;pdb.set_trace()
    if "sals" in config.method.lower() :
        model_name_or_path = get_model_path(model_name_or_path, config) if (not preprocessing and compressed) else model_name_or_path
    elif "origin" in config.method:
        model_name_or_path = model_name_or_path
    elif "kivi" in config.method:
        model_name_or_path = model_name_or_path
    print("loading model of path:", model_name_or_path)
    if "mistral" in model_name_or_path.lower():
        if config.method == "kivi":
            config.use_flash = True
            model = MistralForCausalLM_KIVI.from_pretrained(
                model_name_or_path, config=config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
        elif preprocessing or compressed or config.method == "mean_sparse":
            model = MistralForCausalLM_SALS.from_pretrained(
                model_name_or_path, config=config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
        else:
            print("loading using origin modeling")
            model = AutoModelForCausalLM.from_pretrained(
                model_name_or_path, config=config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", attn_implementation="flash_attention_2")
        tokenizer = LlamaTokenizer.from_pretrained(
        model_name_or_path,
        trust_remote_code=True,
    )
    else:
        if config.method == "origin":
            model = AutoModelForCausalLM.from_pretrained(
                model_name_or_path, config=config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", attn_implementation="flash_attention_2")
        elif config.method == "kivi":
            config.use_flash = True
            model = LlamaForCausalLM_KIVI.from_pretrained(
                model_name_or_path, config=config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto"
            )
        else:
            model = LlamaForCausalLM_SALS.from_pretrained(
                model_name_or_path, config=config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        trust_remote_code=True,
    )
    model.eval()
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token
    # tokenizer.generation_config.do_sample = False
    return model, tokenizer

def save_model_and_tokenizer(model, tokenizer, model_name_or_path, config):
    os.makedirs("./compressed_model", exist_ok=True)
    model_name_or_path = get_model_path(model_name_or_path, config)
    print("saving model of path:", model_name_or_path)
    model.save_pretrained(model_name_or_path)
    tokenizer.save_pretrained(model_name_or_path)