
import torch
import numpy as np
from torchtext.datasets import WikiText103

from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import BertTokenizer

from torchtext.datasets import AG_NEWS

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import os

# tokenizer = get_tokenizer('basic_english')
# train_iter = AG_NEWS(split='train')
seed = 8888

# Set the random seed for Python's built-in random module
np.random.seed(seed)

# Set the random seed for PyTorch
torch.manual_seed(seed)

# Set the random seed for the CuDNN backend (if available)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

train_iter_wiki = WikiText103(split='train')

subword_bytes_file='./subword_byte_table.pt'

dataloader = DataLoader(train_iter_wiki, batch_size=16, shuffle=True)

# Load the tokenizer for the "bert-base-uncased" model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Get the vocabulary dictionary
vocab_dict = tokenizer.get_vocab()
print(len(vocab_dict))

num_embeddings = len(vocab_dict)

byte_dict_size = 256
max_byte_seq_len = 8
# build byte_tokenizer
subword2bytes = torch.randint(0, byte_dict_size, (num_embeddings, max_byte_seq_len), dtype=torch.long)
dirname = '/'.join(subword_bytes_file.split('/')[:-1])
if not os.path.isdir(dirname):
    os.makedirs(dirname)
torch.save(subword2bytes, subword_bytes_file)

dataloader_iterator = iter(dataloader)

inputs = next(dataloader_iterator)
print(len(inputs), inputs)
for i, input in enumerate(inputs):
    print(i, input)



inputs_slices = inputs[3]
print(inputs_slices)


tokenized_inputs = tokenizer(inputs_slices, padding=True, truncation=True, return_tensors='pt')
# for batch in dataloader:
#     # The batch is a tuple containing inputs and labels
#     inputs, labels = batch
#     tokenized_inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')

# Print the tokenized inputs
# print(tokenized_inputs['input_ids'])

words_set = set()
t_idx = tokenized_inputs['input_ids']

subwords = tokenizer.convert_ids_to_tokens(t_idx.squeeze())
subword_set = set(subwords)
# print("word indices", t_idx)


bytes_input = subword2bytes[t_idx]
print(bytes_input)

bytes_input = bytes_input.view(-1)

set_bytes = set(bytes_input.tolist())
print(len(set_bytes), set_bytes)

# bytes_input = subword2bytes[6769]
# print(bytes_input)

# Print the first 10 entries in the vocabulary dictionary
# for i, (token, index) in enumerate(vocab_dict.items()):
#     if i >= 1000 and i <= 1007:
#         print(f"{i}: {token}\t{index}")


# print('index of this:', vocab_dict['this'])