import os
import random
import torch
import sys
from datasets import load_dataset
from torch.utils.data.dataset import Dataset

def get_calib_train_data(name, tokenizer, nsamples, seqlen=2048, seed=3, batch_size=1, dataset_cache_dir=None):
    """
    Build a small calibration dataset for post-training compression / distillation.

    Strategy:
      1) Load the raw text split for the given dataset (`c4`, `ptb`, `wikitext2`).
      2) Concatenate into one long string `tot_text`.
      3) Uniformly sample random windows from `tot_text`:
         - For each sample, take a long slice [i : i + seqlen*10] then tokenize,
           and keep the first `seqlen` tokens as one example.
      4) Pack examples into mini-batches of `batch_size`, attach `attention_mask=1`,
         and append dicts {"input_ids": BxT, "attention_mask": BxT} to a Python list.
      5) Cache the resulting Python list to disk for reuse.

    Notes:
      - The cache key depends on (name, nsamples, seqlen, seed, batch_size).
      - `nsamples` is internally incremented by 1 to simplify the final pack/flush.
      - Returned object is a Python list of dicts (NOT a torch Dataset).
    """
    import random
    random.seed(seed)
    cache_file = (
        f"cache/{name}_{nsamples}_{seqlen}_{seed}_{batch_size}.pt"
    )
    nsamples += 1   # one extra iteration to flush the final packed batch
    # Ensure cache directory exists
    if not os.path.exists("cache"):
        os.makedirs("cache")
    # Fast path: load from cache if available
    if os.path.exists(cache_file):
        traindataset = torch.load(cache_file)
        return traindataset
    # -----------------------
    # Load raw training split
    # -----------------------
    if name == "c4":
        # Expect a pre-filtered json stored under utils/
        traindata = load_dataset("json", data_files="utils/c4-train.json")['train']
        tot_text = "\n\n".join(traindata["text"])
    elif name == "ptb":
        traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir=dataset_cache_dir)
        tot_text = "\n\n".join(traindata["sentence"])
    elif name == "wikitext2":
        traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", cache_dir=dataset_cache_dir)
        tot_text = "\n\n".join(traindata["text"])
    else:
        raise NotImplementedError
    # -------------------------------------------
    # Random slicing + tokenization + batch pack
    # -------------------------------------------
    traindataset = []
    for s in range(nsamples):
        # Pick a random starting character index; keep a long window to ensure enough tokens
        i = random.randint(0, len(tot_text) - seqlen - 1) 
        j = i + seqlen * 10  # take a longer text span to survive tokenization shrinkage
        # Tokenize the window; we only keep tokens (no special truncation here)
        trainenc = tokenizer(tot_text[i:j], return_tensors="pt")
        # If not enough tokens, resample this iteration
        if trainenc.input_ids.shape[1] < seqlen:
            s = s - 1
            continue
        # Start a new packed batch every `batch_size` samples
        if s % batch_size == 0: 
            if s != 0: 
                # Flush previous batch
                attention_mask = torch.ones_like(inp)
                traindataset.append({"input_ids": inp, "attention_mask": attention_mask}) 
            # Start new batch: [1, T] -> will grow to [B, T] via cat
            inp = trainenc.input_ids[:, :seqlen]  
        else:
            # Append to current batch along batch dim
            inp = torch.cat((inp, trainenc.input_ids[:, :seqlen]), dim=0)
    # Persist the list of dicts for reuse
    torch.save(traindataset, cache_file)
    return traindataset

def get_test_data(name, tokenizer, seq_len=2048, batch_size = 4):
    """
    Construct a tokenized, fixed-length test set DataLoader for LM evaluation.

    Steps:
      1) Load the dataset split for evaluation.
      2) Concatenate all samples and tokenize once.
      3) Split the token stream into contiguous chunks of length `seq_len`.
      4) Wrap the tensor list in a light-weight Dataset and return a DataLoader.

    Returns:
      torch.utils.data.DataLoader yielding batches of shape [B, seq_len] (input_ids only).
    """

    # Minimal tensor dataset wrapper
    class IndexDataset(Dataset):
        def __init__(self, tensors):
            self.tensors = tensors
        def __getitem__(self, index):
            return self.tensors[index]
        def __len__(self):
            return len(self.tensors)
    # Convert raw texts into stacked [N, seq_len] token ids
    def process_data(samples, tokenizer, seq_len, field_name):
        # Tokenize the concatenated text once for speed
        test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
        test_ids_batch = []
        nsamples = test_ids.numel() // seq_len # drop remainder
        for i in range(nsamples):
            batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
            test_ids_batch.append(batch)
        test_ids_batch = torch.stack(test_ids_batch) # [N, seq_len]
        return IndexDataset(tensors=test_ids_batch)
    # -----------------------
    # Load evaluation split
    # -----------------------
    if 'wikitext2' in name:
        test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
        test_dataset = process_data(test_data, tokenizer, seq_len, 'text')
    if 'ptb' in name:
        test_data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
        test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence')
    elif 'c4' in name:
        # Use a small subset of validation for speed
        test_data = load_dataset("json", data_files="utils/c4-validation.json")['train']
        test_dataset = process_data(test_data[0:2000], tokenizer, seq_len, 'text')
    # No shuffle for deterministic evaluation
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader