import torch
from datasets import load_dataset, load_from_disk
from tqdm import tqdm

def random_subset(x, num_samples, seed=None):
    # Set the random seed if provided
    if seed is not None:
        torch.manual_seed(seed)
    
    # Get the total number of elements in the input tensor
    total_elements = x.size(0)
    
    # Ensure num_samples is not greater than the total number of elements
    num_samples = min(num_samples, total_elements)
    
    # Generate random indices
    indices = torch.randperm(total_elements)[:num_samples]
    
    # Return the subset of the input tensor
    return x[indices]

@torch.no_grad()
def evaluate_model_openwebtext(
    model, tokenizer, device: str = "cuda:0", max_num_samples: int = 1000, progress = True, shuffle = False, shuffle_seed = 0
):
    # Load the pre-trained model and tokenizer
    model.eval()

    dataset = load_dataset("Skylion007/openwebtext") ## 55gigs total

    # Compute perplexity over the entire dataset
    total_loss = 0.0
    num_samples = 0
    num_batches = 0
    
    dataset_indices = torch.tensor(range(max_num_samples))
    if shuffle is True:
        print(f"Shuffling with seed: {shuffle_seed}")
        dataset_indices = random_subset(x=dataset_indices, seed=shuffle_seed, num_samples=max_num_samples)

    dataset_indices = dataset_indices.tolist()
    
    for i in tqdm(dataset_indices, desc="Running Eval", disable=not(progress)):
        input_string = dataset["train"][i]["text"]
        input_ids = tokenizer(input_string, return_tensors="pt")["input_ids"].to(device)
    
        # Assert the shape is (1, *)
        assert input_ids.shape[0] == 1 and input_ids.shape[1] > 0, f"Unexpected shape: {input_ids.shape}"

        if input_ids.shape[1] < 2:
            # Skip sequences with a length of 1
            continue

        if input_ids.shape[1] > 1024:
            input_ids = input_ids[:, :1024]
        
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits

        # Shift logits and labels for perplexity calculation
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()

        # Flatten the logits and labels
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        shift_labels = shift_labels.view(-1)

        # Calculate loss
        loss_fct = torch.nn.CrossEntropyLoss(reduction="mean")
        loss = loss_fct(shift_logits, shift_labels)

        total_loss += loss
        num_samples += logits.shape[0]
        num_batches += 1
        if max_num_samples is not None:
            if num_samples == max_num_samples:
                break

    # Compute average perplexity over all batches
    avg_loss = total_loss / num_batches

    return avg_loss.item()
