import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch.nn as nn

from llm_logger import main_logger
import pickle


def prepare_model(model_name):
    try: 
        model = AutoPeftModelForCausalLM.from_pretrained(model_name).to("cuda")
        model.print_trainable_parameters()
    except:
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
   
    main_logger.debug(f"Model {model_name} loaded")
    return model, tokenizer


def batch_prediction(model, tokens, batch_size, disable_tqdm=False, **kwargs):
    sz = len(tokens.input_ids)
    for i in tqdm(range(0, sz, batch_size), disable=disable_tqdm):
        batch_inputs = tokens.input_ids[i:i + batch_size, :-1].to("cuda")
        batch_attention = tokens.attention_mask[i:i + batch_size, :-1].to("cuda")
        yield model(input_ids=batch_inputs, attention_mask=batch_attention, **kwargs), tokens.input_ids[i:i + batch_size, 1:].to("cuda")
        del batch_inputs
        del batch_attention


@torch.no_grad()
def get_losses(model, tokens, batch_size, disable_tqdm=False):
    loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=0)
    token_loss = []
    for output, labels in batch_prediction(model, tokens, batch_size, disable_tqdm=disable_tqdm, use_cache=False):
        logits = output.logits
        token_loss.append(loss_fct(logits.view(-1, logits.size(-1)), labels.reshape(-1)).reshape(*labels.shape).cpu())
    token_loss = torch.cat(token_loss)
    return token_loss
