import time

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

class IndexDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(self.tensors)

def process_data(samples, tokenizer, seq_len, field_name):
    test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
    test_ids_batch = []
    stride = seq_len // 2
    nsamples = test_ids.numel() // stride

    for i in range(nsamples):
        begin = i * stride
        end = begin + seq_len

        if end > test_ids.numel():
            break

        batch = test_ids[begin:end]
        test_ids_batch.append(batch)
    test_ids_batch = torch.stack(test_ids_batch)
    return IndexDataset(tensors=test_ids_batch)

def get_test_loaders(name, tokenizer, seq_len=2048, batch_size=8):
    if 'wikitext2' in name:
        test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
        test_dataset = process_data(test_data, tokenizer, seq_len, 'text')
    if 'ptb' in name:
        test_data = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
        test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence')

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

@torch.no_grad()
def llama_eval(model, dataset, test_lodaer, device, **kwargs):
    nlls = []

    bar_format = "Calculating PPL for " + dataset + ":" + "{l_bar}{bar}{r_bar}"
    for batch in tqdm(test_lodaer, bar_format=bar_format, ncols=None):
        batch = batch.to(device)
        output = model(batch, **kwargs)
        lm_logits = output.logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        nlls.append(loss)

    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
    return ppl.item()

def PPLMetric(model, tokenizer, datasets, seq_len=128, batch_size=4, device="cuda", **kwargs):
    metric = {}
    for dataset in datasets:
        test_loader = get_test_loaders(dataset, tokenizer, seq_len=seq_len, batch_size=batch_size)
        ppl = llama_eval(model, dataset, test_loader, device, **kwargs)
        metric[dataset] = ppl
    # print(f"Perplexity on the validation set: {metric}")
    return metric
