import torch
import numpy as np
from tqdm import tqdm

from Pruner.datasets.ppl_dataset import get_loaders

def PPLMetric(model, tokenizer, datasets, mask_module=None, seq_len=2048, batch_size=4, device="cuda", sparse=False, lm_head=None):
    metric = {}
    for dataset in datasets:
        _, test_loader = get_loaders(True if lm_head is not None else False, dataset, tokenizer, seq_len=seq_len, batch_size = batch_size)
        if sparse:
            ppl = llama_eval_sparse(model, test_loader, device, lm_head)
        else:
            ppl = llama_eval_instantation(model, mask_module, test_loader, device, lm_head)
        metric[dataset] = ppl
        print(metric)
    return metric


@torch.no_grad()
def llama_eval_instantation(model, mask_module, test_lodaer, device, lm_head=None):
    nlls = []
    for batch in tqdm(test_lodaer):
        batch = batch.to(device)

        if lm_head is None:
            output = model.forward(tokens=batch, start_pos=0, mask_module=mask_module, ppl_during_train=True)
        else:
            output = model.forward(input_ids=batch, mask_module=mask_module, ppl_during_train=True)
        lm_logits = output['logits']
        if lm_head is not None:
            lm_logits = lm_head(lm_logits).contiguous()

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()
        
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        nlls.append(loss)
    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
    return ppl.item()

@torch.no_grad()
def llama_eval_sparse(model, test_lodaer, device, lm_head=None):
    nlls = []
    for batch in tqdm(test_lodaer):
        batch = batch.to(device)
        if lm_head is None:
            output = model.model.instantation_forward(batch, 0, None, None)
        else:
            output = model.model.forward(input_ids=batch)
        
        lm_logits = output['logits']
        if lm_head is not None:
            lm_logits = lm_head(lm_logits).contiguous()
    
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()
        
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        nlls.append(loss)

    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
    return ppl.item()