import click
from datasets import DownloadConfig, load_dataset
from transformers import AutoTokenizer

from ddlm.data.c4 import Mc4
from ddlm.data.pile import ThePile


@click.command()
@click.option("--num-shards", default=1)
@click.option("--dataset-name", default="pile")
@click.option("--start-idx", default=0)
@click.option("--load-num-proc", default=4)
@click.option("--num-proc", default=4)
@click.option("--max-length", default=128)
@click.option("--tokenizer-name", default="mc4-tokenizer")
def main(num_shards, dataset_name, load_num_proc, num_proc, start_idx, max_length, tokenizer_name):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    def batched_split(batch, drop_last=True):
        examples = []
        last_c = max_length if drop_last else 0
        for sentence in batch["input_ids"]:
            sentence = [tokenizer.bos_token_id] + sentence + [tokenizer.eos_token_id]
            for start_idx in range(0, len(sentence)-last_c, max_length):
                examples.append(sentence[start_idx:start_idx + max_length])
        return {"input_ids": examples}

    if dataset_name == "mc4":
        builder = Mc4(languages=["en"], num_shards=num_shards, start_idx=start_idx)
    elif dataset_name == "c4":
        dataset = load_dataset(
            "allenai/c4",
            data_files=[f"en/c4-train.{i:05}-of-01024.json.gz" for i in range(start_idx, start_idx+num_shards)]
        )
    else:
        builder = ThePile(subsets=["all"], num_shards=num_shards, start_idx=start_idx)

    if dataset_name != "c4":
        builder.download_and_prepare(download_config=DownloadConfig(num_proc=load_num_proc, max_retries=3))
        dataset = builder.as_dataset()
    if dataset_name == "c4":
        dataset = dataset.remove_columns(["url", "timestamp"])
    else:
        dataset = dataset.remove_columns(["meta"])
    dataset = dataset.map(lambda x: tokenizer(x["text"], truncation=False), batched=True, num_proc=num_proc,
                              remove_columns="text")
    dataset = dataset.map(batched_split, batched=True, remove_columns=["attention_mask"],
                              num_proc=num_proc)["train"]
    dataset.save_to_disk("preprocessed_dataset")

if __name__ == "__main__":
    main()
