import torch
from peft import PeftModel
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForPreTraining

from models import VariationalIRT, IRTGenerator


def load_gpt_model(peft_config, ckpt_path, use_fp16=True, model_name='gpt2-large', num_param_tokens=5, beta=0.1, device=0):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.add_tokens(['[UNK]'])

    model = AutoModelForPreTraining.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model = IRTGenerator(model, num_param_tokens, beta)
    if use_fp16:
        model.half()
    model = PeftModel(model, peft_config)

    state_dict = {}
    with safe_open(ckpt_path, framework="pt", device=device) as f:
        for k in f.keys():
            if 'wte' in k or 'wpe' in k or 'lm_head' in k:
                mod_k = k[:-7] + '.modules_to_save.default.weight'
            elif 'param_embed' in k or 'transform.' in k:
                segs = k.split('.')
                mod_k = '.'.join(segs[:-2]) + '.modules_to_save.default.' + '.'.join(segs[-2:])
            else:
                mod_k = k[:-7] + '.default.weight'
            state_dict[mod_k] = f.get_tensor(k)
    model.load_state_dict(state_dict, strict=False)
    model.eval().cuda()
    return model, tokenizer


def load_llama_model(peft_config, ckpt_path, use_fp16=True, model_name='llama_path', num_param_tokens=5, beta=0.1, device=0):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.add_tokens(['[UNK]'])

    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model = IRTGenerator(model, num_param_tokens, beta)
    if use_fp16:
        model.half()
    model = PeftModel(model, peft_config)

    state_dict = {}
    with safe_open(ckpt_path, framework="pt", device=device) as f:
        for k in f.keys():
            if 'embed_tokens' in k or 'lm_head' in k:
                mod_k = k[:-7] + '.modules_to_save.default.weight'
            elif 'param_embed' in k or 'transform.' in k:
                segs = k.split('.')
                mod_k = '.'.join(segs[:-2]) + '.modules_to_save.default.' + '.'.join(segs[-2:])
            else:
                mod_k = k[:-7] + '.default.weight'
            state_dict[mod_k] = f.get_tensor(k)
    model.load_state_dict(state_dict, strict=False)
    model.eval().cuda()
    return model, tokenizer


def load_virt_model(ckpt_path, device=0):
    model = VariationalIRT().to(device=device)
    state_dict = torch.load(ckpt_path)
    model.load_state_dict(state_dict)
    return model
