import torch
import torch.nn as nn
import fnmatch

# Import get_loaders function from data module within the same directory
from .data import get_loaders

# Function to evaluate perplexity (ppl) on a specified model and tokenizer
def eval_ppl(model, tokenizer, device=torch.device("cuda:0")):
    """
    Evaluate perplexity (ppl) on a specified model and tokenizer.

    Args:
        model (torch.nn.Module): The language model to be evaluated.
        tokenizer (Tokenizer): Tokenizer instance for encoding texts.
        device (torch.device): Device to move data onto (e.g., 'cuda:0' or 'cpu').

    Returns:
        float: The perplexity of the language model on the test dataset.
    """
    # Set dataset
    dataset = "wikitext2"   # Dataset consisting of extracted sentences from Wikipedia articles

    # Print status
    print(f"evaluating on {dataset}")

    # Get the test loader
    _, testloader = get_loaders(
        dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer 
    )

    # Evaluate perplexity in no grad context to avoid updating the model
    with torch.no_grad():
        # Perplexity measures how well the probability distribution predicted by the model aligns with the actual distribution of the words. Lower perplexity is better.
        ppl = eval_ppl_wikitext(model, testloader, 1, device)
    return ppl 

# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext(model, testenc, bs=1, device=torch.device("cuda:0"), truncated_nsamples=None):
    """
    Evaluate perplexity (ppl) specifically on the wikitext dataset.

    Args:
        model (torch.nn.Module): The language model to be evaluated.
        testenc (TokenizerWrapper): Encoded input IDs from test set.
        bs (int): Batch size for evaluation.
        device (torch.device): Device to move data onto (e.g., 'cuda:0' or 'cpu').

    Returns:
        float: The perplexity of the language model on the wikitext test dataset.
    """
    testenc = testenc.input_ids

    nsamples = testenc.numel() // model.seqlen

    if truncated_nsamples == None:
        truncated_nsamples = nsamples
    else:
        truncated_nsamples = truncated_nsamples

    nlls = []

    for i in range(0, nsamples, bs):
        if i >= truncated_nsamples:
            break

        if i % 50 == 0:
            print(f"sample {i}")

        j = min(i+bs, nsamples)

        inputs = testenc[:, (i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)
        
        lm_logits = model(inputs).logits    

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        nlls.append(neg_log_likelihood)

    ppl = torch.exp(torch.stack(nlls).sum() / (truncated_nsamples * model.seqlen)) 
    torch.cuda.empty_cache()

    return ppl.item()
