import torch

from ppl_eval import ppl_metric
from transformers.models.mistral import modeling_mistral
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def get_model(model_dir):
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip

    config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        config=config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        local_files_only=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_dir,
        trust_remote_code=True,
        local_files_only=True,
        use_fast=False,   
    )
    if tokenizer.bos_token is None:
        tokenizer.bos_token = "<s>"
    if tokenizer.eos_token is None:
        tokenizer.eos_token = "</s>"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.unk_token is None:
        tokenizer.unk_token = tokenizer.eos_token
    return model, tokenizer


if __name__ == "__main__":
    print('load model...')
    model, tokenizer = get_model("../mistral-7b-local")
    model.seqlen = 128
    model.eval()

    print('start evaluate...')
    ppl_metric(model.cuda().half(), tokenizer, ['wikitext2'], 128, 8)
    print(f"[LOG] layer_num: {modeling_mistral.layer_num}")
    print(f"[LOG] Skipped layers: {modeling_mistral.skipped_layer_count}")
    if modeling_mistral.layer_num > 0:
        print(f"[LOG] skipped ratio: {modeling_mistral.skipped_layer_count/modeling_mistral.layer_num}")