from datasets import load_dataset, load_from_disk
import os
from collections import Counter
from tqdm import tqdm
import numpy as np
import importlib.util
from utils import Vocabulary

# Import FileManager from the expt-core submodule
from FileManager import FileManager 

# Full wikipedia is 6407814 articles
# In 200k articles: mean article length is 678, max is 50k
#                   [0.5, 0.9, 0.95, 0.99] quantiles = [364, 1470, 2233, 5362.01]

vocab_sz = 20000    # must be less than 65535
min_length = 500
dirname = "min500"


data_dir = os.path.join("qwem") # TODO: set DATASETPATH in environment
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
data_fm = FileManager(data_dir)

try:
    cleaned_ds = load_from_disk(data_fm.get_filename("cleaned_wikipedia"))
    print('loaded cleaned dataset from disk.')
except:
    # Load only the train split directly to avoid unnecessary split processing
    train_ds = load_dataset("wikimedia/wikipedia", "20231101.en", split="train")

    print("Cleaning dataset... ", flush=True)
    def process_document(article):
        import re
        cleaned = re.findall(r"[a-z]", article["text"].lower())
        return {"cleaned": cleaned}

    cleaned_ds = train_ds.map(process_document, num_proc=8)
    cleaned_ds.save_to_disk(data_fm.get_filename("cleaned_wikipedia"))
    print("done.")

data_fm.set_filepath(dirname)
# Collect unigram statistics, construct vocabulary
counter = Counter()
article_idxs = []
arr_len_upperbound = 0
for i in tqdm(range(len(cleaned_ds))):
    article = cleaned_ds[i]["cleaned"]
    if len(article) >= min_length:
        article_idxs.append(i)
        arr_len_upperbound += 1 + len(article)
        counter.update(article)
counter = {word: c for word, c in counter.items()}

# Create list of required integers (0 to 2022)
required_integers = [str(i) for i in range(2023)]

# Separate integers and words from counter
integer_counts = []
word_counts = []

for word, count in counter.items():
    if word.isdigit() and len(word) <= 4:
        integer_counts.append((word, count))
    else:
        word_counts.append((word, count))

# Sort word counts by frequency
word_counts = sorted(word_counts, key=lambda x: (-x[1], x[0]))

# Ensure all required integers are included
integer_dict = dict(integer_counts)
final_vocab = []

# Add all required integers (0-2022), using count from data if available, else 0
for int_str in required_integers:
    count = integer_dict.get(int_str, 0)
    final_vocab.append((int_str, count))

# Calculate remaining vocabulary slots
remaining_slots = vocab_sz - len(required_integers)
if remaining_slots < 0:
    raise ValueError(f"Required integers ({len(required_integers)}) exceed vocab_sz ({vocab_sz})")

# # Add other integers that appeared in data but aren't in 0-2022 range
# other_integers = [(word, count) for word, count in integer_counts 
#                   if word not in required_integers]
# other_integers = sorted(other_integers, key=lambda x: (-x[1], x[0]))

# # Add other integers first (if any slots remain)
# for word, count in other_integers:
#     if remaining_slots > 0:
#         final_vocab.append((word, count))
#         remaining_slots -= 1

# Fill remaining slots with most common words
for word, count in word_counts:
    if remaining_slots > 0:
        final_vocab.append((word, count))
        remaining_slots -= 1
    else:
        break

word_counts = final_vocab
print(f"Vocabulary contains {len([w for w, c in word_counts if w.isdigit()])} integers and {len([w for w, c in word_counts if not w.isdigit()])} words")
data_fm.save(word_counts, "word_counts.pickle")
vocab = Vocabulary(word_counts)

# Construct bin file
print(f"Using {100*len(article_idxs)/6407814:0.2f}% of articles, min length {min_length}.", flush=True)
print(f"Creating bin file (vocab size = {vocab_sz})... ", flush=True)
filename = data_fm.get_filename("enwiki.bin")
arr = np.memmap(filename, dtype=np.uint16, mode='w+', shape=(arr_len_upperbound,))
EOF_token = vocab_sz
assert EOF_token >= len(vocab.words)
assert EOF_token < 2**16
idx = 0
article_arr_idxs = []
for i in tqdm(article_idxs):
    article = cleaned_ds[i]["cleaned"]
    assert len(article) >= min_length
    corpus = np.array([vocab.word2token[word] for word in article
                       if word in vocab.word2token] + [EOF_token],
                      dtype=np.uint16)
    assert idx + len(corpus) <= arr_len_upperbound
    arr[idx : idx + len(corpus)] = corpus
    article_arr_idxs.append(idx)
    idx += len(corpus)
article_arr_idxs.append(idx)
arr.flush()
with open(filename, 'r+b') as f:
    f.truncate(idx * np.dtype(np.uint16).itemsize)
data_fm.save(np.array(article_arr_idxs), "article_arr_idxs.npy")

print("done.")
