from absl import flags, app
import mlxu
from typing import List, Any
import random
import json
import logging
import sys
import functools
from .parquet_sample_and_compose import BuffferedFile
from .merge_datasets import LineFile
from datetime import datetime
import io
import tqdm
import math
import os
if __name__ == "__main__":
    flags.DEFINE_string("dataset_path", None, "")
    flags.DEFINE_string("content_field", "text", "")
    flags.DEFINE_integer("split_every_n_chars", None, "")
    flags.DEFINE_float("split_proportions", None, "")
    flags.DEFINE_boolean("preshuffle", False, "")
    flags.DEFINE_string("output_dir", None, "")
    FLAGS = flags.FLAGS


def split_dataset_to_parts(ds_file, ds_path: str, output_dir, split_every_n_chars: int, content_field: str):

    print(f"Splitting {ds_path} to {output_dir} every {split_every_n_chars} using field {content_field}")

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%f")
    source_name = ds_path.split("/")[-1]
    def get_output_file(file_id):
        return os.path.join(output_dir, f"splitted/{file_id}.{source_name}.{timestamp}.jsonl")

    cur_id = 0
    cur_chars = 0
    
    cur_file = get_output_file(cur_id)
    print(f"writting from {ds_path} to {cur_file}")
    cur_progress_bar = tqdm.tqdm(total=split_every_n_chars)
    file = BuffferedFile(path=cur_file)
    while True:
        line = ds_file.readline()
        if line is None or len(line) == 0:
            break
        
        record = json.loads(line)
        cur_chars += len(record[content_field])
        cur_progress_bar.update(len(record[content_field]))
        file.write(json.dumps(record) + "\n")
        if cur_chars >= split_every_n_chars:
            file.flush()
            file.close()
            print(f"Generated file {cur_file} with {cur_chars} chars")
            cur_chars = 0
            cur_id += 1
            cur_file = get_output_file(cur_id)
            cur_progress_bar = tqdm.tqdm(total=split_every_n_chars)
            file = BuffferedFile(path=cur_file)

    file.flush()
    file.close()
    print("Finished")


def get_file_lines(file_path: str):
    with mlxu.open_file(FLAGS.dataset_path, "r") as f:
        lines = f.readlines()

    for l in lines:
        assert len(l) > 0

    print(f"Loaded {len(lines)} lines")

    return lines
    


def main(_):
    random.seed(42)
    assert FLAGS.split_every_n_chars is None or FLAGS.split_proportions is None
    if FLAGS.preshuffle:
        lines = get_file_lines(FLAGS.dataset_path)

        print(f"Shuffling")
        random.shuffle(lines)
        ds_file = LineFile(line_list=lines)
    else:
        ds_file = mlxu.open_file(FLAGS.dataset_path, "r")


    if FLAGS.split_proportions is not None:
        total_chars = 0
        with mlxu.open_file(FLAGS.dataset_path, "r") as f:
            while True:
                line = f.readline()
                if line is None or len(line) == 0:
                    break 
                else:
                    total_chars += len(json.loads(line)[FLAGS.content_field])

        split_every_n_chars = int(math.ceil(total_chars*FLAGS.split_proportions) + 1)
    else:
        split_every_n_chars = FLAGS.split_every_n_chars
    
    split_dataset_to_parts(ds_file=ds_file, ds_path=FLAGS.dataset_path, content_field=FLAGS.content_field, split_every_n_chars=split_every_n_chars, output_dir=FLAGS.output_dir)
    ds_file.close()


