"""
Download, preprocess and serve the TinyStories dataset as a DataLoader.
"""

import argparse
import glob
import json
import os
import random
from typing import List
from concurrent.futures import ProcessPoolExecutor
from functools import partial

import numpy as np
import requests
import sentencepiece as spm
import torch
import torch.distributed as dist
from tqdm import tqdm

from tokenizer import Tokenizer
from random import shuffle
from collections import Counter

import matplotlib.pylab as plt


DATA_CACHE_DIR = "./TinyStories"
DATA_PROCESS_DIR = "./TinyStories/TinyStories_processing_file"

def download_file(url: str, fname: str, chunk_size=1024):
    """Helper function to download a file from a given url"""
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get("content-length", 0))
    with open(fname, "wb") as file, tqdm(
        desc=fname,
        total=total,
        unit="iB",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)


def download():
    """Downloads the TinyStories dataset to DATA_CACHE_DIR"""
    os.makedirs(DATA_CACHE_DIR, exist_ok=True)

    # download the TinyStories dataset, unless it's already downloaded
    data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
    data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
    if not os.path.exists(data_filename):
        print(f"Downloading {data_url} to {data_filename}...")
        download_file(data_url, data_filename)
    else:
        print(f"{data_filename} already exists, skipping download...")

    # unpack the tar.gz file into all the data shards (json files)
    data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
    if not os.path.exists(data_dir):
        os.makedirs(data_dir, exist_ok=True)
        print(f"Unpacking {data_filename}...")
        os.system(f"tar -xzf {data_filename} -C {data_dir}")
    else:
        print(f"{data_dir} already exists, skipping unpacking...")

    # print a single example just for debugging and such
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
    with open(shard_filenames[0], "r") as f:
        data = json.load(f)
    print("Download done.")
    print(f"Number of shards: {len(shard_filenames)}")
    print(f"Example story:\n{data[0]}")


def train_vocab(vocab_size, target_tokens, num_stories = 2000):
    """
    Trains a custom sentencepiece tokenizer on the TinyStories dataset.
    The custom tokenizer files will be saved in DATA_CACHE_DIR/tok{N} directories,
    where N is the vocab size. This is also where the pretok .bin files will go.
    """
    assert vocab_size > 0, "Vocab size must be positive"

    # output file prefix path for sentencepiece
    prefix = os.path.join(DATA_PROCESS_DIR, f"tok{vocab_size}")

    # how many shards we'll use for vocab training, kept low for efficiency
    num_shards = 1

    # 1) export a large chunk of text as a single text file tiny.txt
    tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
    data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))


    print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
    with open(tiny_file, "w", encoding="utf-8") as of:
        for shard in tqdm(shard_filenames[:num_shards]):
            with open(shard, "r") as f:
                data = json.load(f)
            
            for example in tqdm(data[0:num_stories]):
                
                text = example["story"]
                text = text.strip()
                of.write(text + "\n")

    print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")

    # 2) train the sentencepiece model
    print("Will now train the vocab...")

    spm.SentencePieceTrainer.train(input=tiny_file,
                                   model_prefix=prefix,
                                   model_type="bpe",
                                   vocab_size=vocab_size, 
                                   user_defined_symbols=target_tokens)

    print(f"Trained tokenizer is in {prefix}.model")
    print("Done.")


def tokenize_one(vocab_size, json_file, num_stories):

    tokenizer_model = get_tokenizer_model_path(vocab_size)
    enc = Tokenizer(tokenizer_model)

    with open(json_file, "r") as f:
        data = json.load(f)
    all_tokens = []

    for example in tqdm(data[0:num_stories]):
        text = example["story"]
        text = text.strip()  # get rid of leading/trailing whitespace
        tokens = enc.encode(text, bos=True, eos=False)  # encode the text, use BOS
        all_tokens.extend(tokens)
        all_tokens.append(999)

    # convert to uint16 nparray
    all_tokens = np.array(all_tokens, dtype=np.uint16)
    # calculate the output filename
    if vocab_size == 0:
        # if we're using Llama 2, just save the tokenized file in the same dir
        tokenized_filename = json_file.replace(".json", ".bin")
    else:
        # save .bin files into a new tok{N} directory
        # bin_dir = os.path.join(DATA_PROCESS_DIR, f"tok{vocab_size}")
        shard_basename = os.path.basename(json_file)
        bin_basename = shard_basename.replace(".json", ".bin")
        tokenized_filename = os.path.join(DATA_PROCESS_DIR, bin_basename)
    # write the bytes
    with open(tokenized_filename, "wb") as f:
        f.write(all_tokens.tobytes())
    # calculate the average sequence length (they are separated by BOS=1)
    avg_seq_len = all_tokens.size / ((all_tokens == 1).sum())
    print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}")


def process_shard(args, vocab_size):
    shard_id, shard = args
    tokenizer_model = get_tokenizer_model_path(vocab_size)
    enc = Tokenizer(tokenizer_model)
    with open(shard, "r") as f:
        data = json.load(f)
    all_tokens = []
    for example in tqdm(data, position=shard_id):
        text = example["story"]
        text = text.strip()  # get rid of leading/trailing whitespace
        tokens = enc.encode(text, bos=True, eos=False)  # encode the text, use BOS
        all_tokens.extend(tokens)
    # convert to uint16 nparray
    all_tokens = np.array(all_tokens, dtype=np.uint16)
    # calculate the output filename
    if vocab_size == 0:
        # if we're using Llama 2, just save the tokenized file in the same dir
        tokenized_filename = shard.replace(".json", ".bin")
    else:
        # save .bin files into a new tok{N} directory
        bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
        shard_basename = os.path.basename(shard)
        bin_basename = shard_basename.replace(".json", ".bin")
        tokenized_filename = os.path.join(bin_dir, bin_basename)
    # write the bytes
    with open(tokenized_filename, "wb") as f:
        f.write(all_tokens.tobytes())
    # calculate the average sequence length (they are separated by BOS=1)
    avg_seq_len = all_tokens.size / ((all_tokens == 1).sum())
    print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}")


def pretokenize(vocab_size):
    # how many shards we'll use to find target contexts
    num_shards = 1

    # iterate the shards and tokenize all of them one by one
    data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
    # if vocab_size > 0:
    #     # .bin files will be saved into tok{N} directory, create it once here
    #     bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
    #     os.makedirs(bin_dir, exist_ok=True)
    #
    # # process all the shards in a process pool
    # fun = partial(process_shard, vocab_size=vocab_size)
    # with ProcessPoolExecutor() as executor:
    #     executor.map(fun, enumerate(shard_filenames))

    for shard in range(0,num_shards):
        tokenize_one(vocab_size, shard_filenames[shard])
    print("Done.")


def pretokenize_target_stories(vocab_size, num_stories):
    # iterate the shards and tokenize all of them one by one
    data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
    # if vocab_size > 0:
    #     # .bin files will be saved into tok{N} directory, create it once here
    #     bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
    #     os.makedirs(bin_dir, exist_ok=True)
    #
    # # process all the shards in a process pool
    # fun = partial(process_shard, vocab_size=vocab_size)
    # with ProcessPoolExecutor() as executor:
    #     executor.map(fun, enumerate(shard_filenames))

    tokenize_one(vocab_size, DATA_CACHE_DIR + "/TinyStories_all_data/data00.json", num_stories)
    print("Done.")

class PretokDataset(torch.utils.data.IterableDataset):
    """Loads pretokenized examples from disk and yields them as PyTorch tensors."""

    def __init__(self, split, max_seq_len, vocab_size, vocab_source):
        super().__init__()
        self.split = split
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        self.vocab_source = vocab_source
        self.iteration = 0
        self.num_batches = 0

        # get worker info within a DataLoader
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else 0
        # get DDP rank info
        rank = dist.get_rank() if dist.is_initialized() else 0
        # combine the worker_id and worker_rank to create a unique seed for rng
        seed = 42 + worker_id + 1337 * rank
        self.rng = random.Random(seed)
        print(f"Created a PretokDataset with rng seed {seed}")
        if self.vocab_source == "llama2":
            # the .bin files are right along the .json files
            bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
            shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
        elif self.vocab_source == "custom":
            # the .bin files are in tok{N} directory
            # bin_dir = os.path.join(DATA_PROCESS_DIR, f"tok{self.vocab_size}")
            bin_dir = DATA_PROCESS_DIR
            shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
        # train/test split. let's use only shard 0 for test split, rest train
        # shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
        assert len(shard_filenames)>0, f"No bin files found in {bin_dir}"
        
        self.rng.shuffle(shard_filenames)
        for shard in shard_filenames:
            # open the dataset for reading but keep it on disk with memmap
            self.m = np.memmap(shard, dtype=np.uint16, mode="r")
            # indices = np.where(m)[999]
            indices = [i for i, x in enumerate(self.m == 999) if x]
            indices.insert(0, True)
            num_examples = sum(indices[i] - indices[i-1] - self.max_seq_len for i in range(1,len(indices)))

            self.indices_to_sample_from = []
            for i in range(1, len(indices)):
                self.indices_to_sample_from.extend(range(indices[i-1] + 1, indices[i] - self.max_seq_len))
       
            self.num_batches = len(self.indices_to_sample_from) // self.max_seq_len
            # self.num_batches = 100
            # num_batches -= 1  # drop the last partial batch
            assert self.num_batches > 0, "this shard is way too small? investigate."
        
        # rng.shuffle(self.indices_to_sample_from)
            

    def __iter__(self):
        print("Total number of training samples: " + str(len(self.indices_to_sample_from)))

        self.rng.shuffle(self.indices_to_sample_from)
        for ix in self.indices_to_sample_from:
            start = ix
            end = start + self.max_seq_len + 1
            # calling .astype will copy the data into a new numpy array, now in RAM
            chunk = torch.from_numpy((self.m[start:end]).astype(np.int64))

            # if chunk[-1] == 999:
            #     start = end + 1
            #     end = start + self.max_seq_len + 1
            #     chunk = torch.from_numpy((m[start:end]).astype(np.int64))

            x = chunk[:-1]
            y = chunk[-1]
            yield x, y

    def __next__(self):
        if self.iteration < self.num_batches:
            sample = self.__iter__()
            self.iteration += 1
            return sample
        else:
            # Raise StopIteration when the iteration is complete
            self.rng.shuffle(self.indices_to_sample_from)
            self.iteration = 0
            raise StopIteration

                

# -----------------------------------------------------------------------------
# public interface functions

def get_tokenizer_model_path(vocab_size):
    """
    Returns path to the sentencepiece tokenizer model for a given vocab size
    vocab_size = 0 designates the default Llama 2 tokenizer, in that case
    None is returned.
    """
    if vocab_size == 0:
        return None
    else:
        return os.path.join(DATA_PROCESS_DIR, f"tok{vocab_size}.model")


def extract_contexted_stories(vocab_size, target_contexts):

    # how many shards we'll use to find target contexts
    num_shards = 1

    # 1) export a large chunk of text as a single text file tiny.txt
    data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
    shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))

    # retrieve the tokenizer with this list of target tokens
    tokenizer_model = get_tokenizer_model_path(vocab_size)
    enc = Tokenizer(tokenizer_model)
    target_stories  = []
    
    for shard in tqdm(shard_filenames[:num_shards]):
        with open(shard, "r") as f:
            data_list = json.load(f)

        for data in tqdm(data_list, desc='Processing'):
            string_original = data["story"]
            text = string_original.strip()  # get rid of leading/trailing whitespace
            # token_indices = enc.encode(text, bos=True, eos=False)  # encode the text, use BOS
            token_strings_list = enc.sp_model.encode_as_pieces(text)

            token_strings_list_string  = " ".join(token_strings_list)
            for context in target_contexts.keys():
                if context in token_strings_list_string:
                    target_stories.append(data)


    print("_" * 100)
    print("Extracted " + str(len(target_contexts)) + " with relevant contexts.")
    print("_" * 100)
    with open(DATA_PROCESS_DIR + "/target_stories.json", 'w') as fout:
      json.dump(target_stories, fout)

    return

            
# -----------------------------------------------------------------------------
# CLI for constructing the dataset

if __name__ == "__main__":
    """
    These stages are designed to be run in order.

    To tokenize data with the Llama 2 tokenizer:
    python tinystories.py download
    python tinystories.py pretokenize

    To tokenize data with a custom tokenizer we train ourselves with sentencepiece, e.g.:
    python tinystories.py download
    python tinystories.py train_vocab --vocab_size=2048
    python tinystories.py pretokenize --vocab_size=2048
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("stage", type=str, choices=["download", "train_vocab"])
    parser.add_argument("--vocab_size", type=int, default=64, help="pretokenization vocab size. 0 = use Llama 2 tokenizer.")
    parser.add_argument("--num_stories", type=int, default=100, help="Number of stories to use")
    parser.add_argument("--context_length", type=int, default=6, help="Number of stories to use")
    args = parser.parse_args()

    # depending on the stage call the appropriate function
    if args.stage == "download":
        download()
    elif args.stage == "train_vocab":
        train_vocab(vocab_size=args.vocab_size, target_tokens = [], num_stories = args.num_stories)
        pretokenize_target_stories(vocab_size=args.vocab_size, num_stories = args.num_stories)
    else:
        raise ValueError(f"Unknown stage {args.stage}")
