from absl import flags, app
import mlxu
from typing import List, Any
import random
import json
import logging
import sys
import functools
from .parquet_sample_and_compose import BuffferedFile
import io
if __name__ == "__main__":
    flags.DEFINE_multi_string("dataset_paths", None, "")
    flags.DEFINE_multi_string("source_fields", None, "")
    flags.DEFINE_multi_string("fields_to_retain", [], "")
    flags.DEFINE_string("destination_field", "text", "")
    flags.DEFINE_string("origin_field_name", "src", "")
    flags.DEFINE_multi_string("origin_to_add", None, "")
    flags.DEFINE_multi_float("char_proportions", None, "")
    flags.DEFINE_integer("max_chars", None, "")
    flags.DEFINE_boolean("random_shuffle", True, "")
    flags.DEFINE_boolean("preshuffle_files", False, "")
    flags.DEFINE_string("destination_path", None, "")
    FLAGS = flags.FLAGS
File = Any


LOGGER = logging.Logger("Dataset merge", level=logging.INFO)
LOGGER_HANDLER = logging.StreamHandler(sys.stderr)
LOGGER_HANDLER.setFormatter(logging.Formatter("[%(asctime)s] DSM [%(levelname)s] : %(message)s"))
LOGGER.addHandler(LOGGER_HANDLER)


class LineFile:
    def __init__(self, line_list: List[str]):
        self.line_list = line_list
        self.pointer = 0

    def readline(self):
        if self.pointer == len(self.line_list):
            return None
        else:
            result = self.line_list[self.pointer]
            self.pointer += 1
            return result

    def close(self):
        pass


def preshuffle_file(file):
    raw_content = file.read()
    with io.StringIO(raw_content) as io_file:
        content = io_file.readlines()
    random.shuffle(content)
    return LineFile(content)


def preshuffle_files(files: List[File]):
    result = []
    for fid, f in enumerate(files):
        LOGGER.info(f"Processing file {fid}")
        result.append(preshuffle_file(f))

    return result


def load_proportionally(
    files: List[File],
    source_fields: List[str],
    destination_field: str,
    char_proportions: List[float],
    max_chars: int,
    origin_field_name: str,
    origin_to_add: List[str],
    fields_to_retain: List[str] = [],
):
    assert len(files) == len(source_fields)
    assert char_proportions is None or len(files) == len(char_proportions)
    assert len(files) == len(origin_to_add)

    if char_proportions is None:
        char_proportions = [1.0] * len(files)
        full_merge_mode = True
    else:
        full_merge_mode = False

    LOGGER.info(
        f"Starting proportional loading with source_fields: {source_fields}, destination_field: {destination_field}, char_proportions {char_proportions}, max_chars {max_chars}, origin_field_name {origin_field_name}, origin_to_add {origin_to_add}"
    )
    chars_per_file = [0 for _ in char_proportions]
    finished_file = [False for _ in char_proportions]
    total_loaded_chars = 0

    def get_underrepresented_file_id():
        nonlocal chars_per_file
        nonlocal total_loaded_chars
        nonlocal finished_file
        smallest_diff = 1.0
        prop_sum = 0.0
        smallest_id = None
        for id, (cpf, req_props, fin_file) in enumerate(zip(chars_per_file, char_proportions, finished_file)):
            curr_prop = cpf / max(total_loaded_chars, 1)
            prop_sum += curr_prop
            diff = curr_prop - req_props
            if diff < smallest_diff and not fin_file:
                smallest_id = id
                smallest_diff = diff

        assert smallest_diff <= 1e-4 or prop_sum == 0.0
        assert abs(prop_sum - 1.0) <= 1e-4 or prop_sum == 0.0
        assert smallest_id is not None
        return smallest_id

    record_list = []
    while True:
        fid = get_underrepresented_file_id()

        assert not finished_file[fid]

        raw_record = files[fid].readline()
        if (raw_record is None or len(raw_record) == 0):
            if not full_merge_mode:
                LOGGER.warning(f"Finished prematurely due to end of source {origin_to_add[fid]}")
                break
            else:
                finished_file[fid] = True
                acc = True
                for ff in finished_file:
                    acc = acc and ff
                if acc:
                    LOGGER.info(f"All files finished")
                    break

        else:
            dict_raw_record = json.loads(raw_record)
            record = {}
            record[destination_field] = dict_raw_record[source_fields[fid]]
            record[origin_field_name] = origin_to_add[fid]
            for ftr in fields_to_retain:
                record[ftr] = dict_raw_record[ftr]

            num_chars = len(record[destination_field])

            record_list.append(record)
            chars_per_file[fid] += num_chars
            total_loaded_chars += num_chars
            if total_loaded_chars >= max_chars:
                LOGGER.info(f"Finished due to char limit {max_chars} vs {total_loaded_chars}")
                break

    final_props = [cpf / total_loaded_chars for cpf in chars_per_file]
    LOGGER.info(f"Requested proportions {char_proportions} vs actual {final_props}")
    LOGGER.info(f"Loaded {len(record_list)} records")

    return record_list

from datetime import datetime
def main(_):
    random.seed(42)
    assert FLAGS.destination_path is not None
    assert FLAGS.char_proportions is None or abs(sum(FLAGS.char_proportions) - 1.0) <= 1e-4
    LOGGER.info(f"Opening files {FLAGS.dataset_paths}")
    files = [mlxu.open_file(p, "r") for p in FLAGS.dataset_paths]

    if FLAGS.preshuffle_files:
        LOGGER.info(f"Loading to memory and preshuffling")
        files = preshuffle_files(files)
    else:
        LOGGER.warning(f"NO PRESHUFFLING APPLIED, make sure that the source datasets are shuffled")

    source_fields = ["text"] * len(files) if FLAGS.source_fields is None else FLAGS.source_fields
    origin_to_add = FLAGS.dataset_paths if FLAGS.origin_to_add is None else FLAGS.origin_to_add

    record_list = load_proportionally(
        files=files,
        source_fields=source_fields,
        destination_field=FLAGS.destination_field,
        char_proportions=FLAGS.char_proportions,
        max_chars=FLAGS.max_chars,
        origin_field_name=FLAGS.origin_field_name,
        origin_to_add=origin_to_add,
        fields_to_retain=FLAGS.fields_to_retain
    )

    LOGGER.info(f"Shuffling")
    if FLAGS.random_shuffle:
        random.shuffle(record_list)

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S%f")

    destination_path = FLAGS.destination_path + f"{timestamp}.jsonl"

    LOGGER.info(f"Saving to {destination_path}")
    print(f"Saving to {destination_path}")

    

    dest_f = BuffferedFile(destination_path)

    for r in record_list:
        dest_f.write(json.dumps(r) + "\n")

    dest_f.close()

    for f in files:
        f.close()

    LOGGER.info(f"Done")

