import random
import sys
sys.path.append('..')
from pathlib import Path
from typing import Dict, List
from tqdm import tqdm
import json


def iter_jsonl(path):
    for line in open(path, 'r', encoding='utf8'):
        yield json.loads(line)


def load_paths(data_dir: Path) -> Dict[str, List[Path]]:
    '''
    Load all jsonl paths from `data_dir`, where the directory tree is:
    - test
        - chunk1
            - *.jsonl
        - chunk2
        - ...
    - train
        - chunk*
    - validation
        - chunk*

    Returns:
    {
        'test': list of paths,
        'train': list of paths,
        'validation': list of paths,
    }
    where each path is a JSONL file.
    '''
    splits = ['test', 'train', 'validation']
    paths = {}
    for split in splits:
        split_dir = Path(data_dir) / split
        chunk_dirs = sorted(split_dir.glob('chunk*'))
        paths[split] = []
        for chunk_dir in chunk_dirs:
            paths[split] += sorted(chunk_dir.glob('*.jsonl'))
    return paths


def main():
    dst_dir = Path('../../data/slim-pajama')
    src_dir = Path("/home/test/data/slimpajama")
    print(f"{src_dir = }")
    print(f"{dst_dir = }")
    paths = sum(load_paths(src_dir).values(), [])
    random.seed(0)
    random.shuffle(paths)

    # Concatenate
    dst_dir.mkdir(exist_ok=True, parents=True)
    data_path = dst_dir / 'data.jsonl'
    with open(data_path, 'w', encoding='utf8') as fout:
        for path in tqdm(paths):
            with open(path, 'r', encoding='utf8') as fin:
                content = fin.read()
                assert content[-1] == '\n'
                fout.write(content)


if __name__ == "__main__":
    main()
