import random

from datasets import load_dataset
import numpy as np
import torch
from transformers import AutoTokenizer


def set_seed(seed: int) -> None:
    """
    Set seed for reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)


def get_data(
    dataset_name: str, 
    tokenizer: AutoTokenizer, 
    max_sequence_length: int,
    num_calibration_samples=None,
    seed: int = 42,
    eval_mode: bool = False
):
    if dataset_name == "fineweb-edu":
        assert eval_mode is False, "Fineweb-edu dataset is not available in eval mode."
        return get_fineweb_edu(tokenizer, max_sequence_length, num_calibration_samples, seed)
    if dataset_name == "c4":
        return get_c4(tokenizer, max_sequence_length, eval_mode, num_calibration_samples)
    if dataset_name == "wikitext2":
        return get_wikitext2(tokenizer, max_sequence_length, eval_mode=eval_mode, nsamples=num_calibration_samples)
    else:
        raise ValueError("Unknown dataset")


def get_fineweb_edu(
    tokenizer: AutoTokenizer, 
    max_sequence_length: int,
    num_calibration_samples=None,
    seed: int = 42
) -> torch.Tensor:
    """
    returns a tensor (N, SeqLen)
    """
    train_dataset_raw = load_dataset(
        "HuggingFaceFW/fineweb-edu", 
        "sample-10BT", 
        split="train", 
        streaming=True
    )
    train_dataset_raw = train_dataset_raw.shuffle(seed=seed, buffer_size=1_000)
    token_buffer = []
    encodings = []
    total_samples = 0
    for sample in train_dataset_raw:
        tokens = tokenizer(sample["text"], return_attention_mask=False, add_special_tokens=False)["input_ids"]
        token_buffer.extend(tokens)

        while len(token_buffer) >= max_sequence_length:
            chunk = token_buffer[:max_sequence_length]
            token_buffer = token_buffer[max_sequence_length:]
            encodings.append(torch.tensor(chunk))
            
            total_samples += 1
            if num_calibration_samples is not None and total_samples >= num_calibration_samples:
                return torch.stack(encodings)

    return torch.stack(encodings)


def get_wikitext2(tokenizer, seqlen, eval_mode=False, nsamples: int = 256):
    """
    returns a tensor (N, SeqLen)
    """
    if not eval_mode:
        traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
        trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
        trainloader = []
        for _ in range(nsamples):
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            trainloader.append(inp)
        return torch.vstack(trainloader)
    else:
        testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
        total_tokens = testenc.input_ids.shape[1]
        max_batches = total_tokens // seqlen
        testloader = []
        for i in range(min(nsamples, max_batches)):
            testloader.append(testenc.input_ids[:, i * seqlen : (i + 1) * seqlen])
        return torch.vstack(testloader)


def get_c4(tokenizer, seqlen, eval_mode=False, nsamples: int = 256):
    """
    returns a tensor (N, SeqLen)
    """
    if not eval_mode:
        traindata = load_dataset(
            "allenai/c4",
            "default",
            data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
            split="train",
            revision="607bd4c8450a42878aa9ddc051a65a055450ef87",
        )
        trainloader = []
        for _ in range(nsamples):
            while True:
                i = random.randint(0, len(traindata) - 1)
                trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
                if trainenc.input_ids.shape[1] >= seqlen:
                    break
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            trainloader.append(inp)
        return torch.vstack(trainloader)

    else:
        valdata = load_dataset(
            "allenai/c4",
            "default",
            data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
            split="validation",
            revision="607bd4c8450a42878aa9ddc051a65a055450ef87",
        )
        valenc = []
        for _ in range(nsamples):
            while True:
                i = random.randint(0, len(valdata) - 1)
                tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
                if tmp.input_ids.shape[1] >= seqlen:
                    break
            if tmp.input_ids.shape[1] == seqlen:
                # rare case, discovered with Yi tokenizer
                valenc.append(tmp.input_ids)
            else:
                i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
                j = i + seqlen
                valenc.append(tmp.input_ids[:, i:j])
        return torch.vstack(valenc)
