import random
import numpy as np
import torch
import pickle

from datasets import load_dataset
from torch.utils.data import DataLoader

def get_c4(tokenizer, n_samples, seq_len, batch_size, name):
    traindata = load_dataset(
        'allenai/c4', data_files={'train': 'c4-train.00000-of-01024.json.gz'}, split='train'
    )
    
    tokenized_samples, history = [], []
    for _ in range(n_samples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            tokenized_sample = tokenizer.encode(traindata[i]['text'], bos=False, eos=False)
    
            if len(tokenized_sample) >= seq_len and i not in history:
                history.append(i)
                break
        i = random.randint(0, len(tokenized_sample) - seq_len )
        tokenized_sample = torch.tensor(tokenized_sample).view(1, -1)
        tokenized_samples.append(tokenized_sample[:, i:i+seq_len])
    train_data = torch.cat(tokenized_samples, dim=0)

    dataloader = DataLoader(dataset=train_data,
                            batch_size=batch_size,
                            drop_last=True,
                            num_workers=4)

    print('================== DATA PREPARE DONE ==================')
    return dataloader


def get_wanda_c4(tokenizer, n_samples=128, seq_len=128, model_name='llama2_7b'): # TODO fix opt encoder
    traindata = load_dataset(
        'allenai/c4', data_files={'train': 'c4-train.00000-of-01024.json.gz'}, split='train'
    )
    
    tokenized_samples, history = [], []
    for _ in range(n_samples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            if 'llama' in model_name:
                tokenized_sample = tokenizer.encode(traindata[i]['text'], bos=False, eos=False)
            elif 'opt' in model_name:
                tokenized_sample = tokenizer(traindata[i]['text'])['input_ids']
    
            if len(tokenized_sample) >= seq_len and i not in history:
                history.append(i)
                break
        i = random.randint(0, len(tokenized_sample) - seq_len )
        tokenized_sample = torch.tensor(tokenized_sample).view(1, -1)
        tokenized_samples.append(tokenized_sample[:, i:i+seq_len])

    return tokenized_samples