from .thestack_preprocess import get_records_from_dataset, save_jsonl_efficient
from .merge_datasets import preshuffle_files, load_proportionally, BuffferedFile
from absl import flags, app
import functools
from typing import Union, List, Tuple, Optional
from hashlib import sha1
from uuid import uuid4
import logging
import sys
import mlxu
import json

import random

from datetime import datetime
import os
from .utils import MultiLogger


LOGGER = logging.Logger("SSTP & Mixture", level=logging.INFO)
LOGGER_HANDLER = logging.StreamHandler(sys.stderr)
LOGGER_HANDLER.setFormatter(logging.Formatter("[%(asctime)s] SSTP&M [%(levelname)s] : %(message)s"))
LOGGER.addHandler(LOGGER_HANDLER)

LOGGER = MultiLogger(basic_loggers=[print], advanced_loggers=[LOGGER])


flags.DEFINE_multi_string("sstp_sources", [], "")
flags.DEFINE_multi_string("sstp_source_fields", None, "")
flags.DEFINE_multi_integer("sstp_topks", None, "")
flags.DEFINE_multi_integer("sstp_max_doc_chars", None, "")
flags.DEFINE_multi_integer("sstp_prefilter_max_doc_chars", None, "")
flags.DEFINE_multi_integer("sstp_max_total_chars", None, "")
flags.DEFINE_multi_integer("sstp_precompute", None, "")
flags.DEFINE_multi_integer("sstp_random_shuffle_tree", None, "")
flags.DEFINE_multi_integer("sstp_topk_sample", None, "")
flags.DEFINE_multi_integer("sstp_reverse_tree", 0, "")
flags.DEFINE_multi_float("sstp_stop_fraction", None, "")
flags.DEFINE_multi_string("sstp_modes", None, "")
flags.DEFINE_multi_string("other_sources", [], "")
flags.DEFINE_multi_string("other_source_fields", [], "")
flags.DEFINE_string("output_origin_field", "src", "")
flags.DEFINE_multi_float("all_data_proportions", None, "")
flags.DEFINE_multi_string("all_data_origins", None, "")
flags.DEFINE_integer("min_len_in_chars", 128, "")
flags.DEFINE_integer("mixture_max_chars", None, "")
flags.DEFINE_multi_integer("sstp_take_char_prefix", None, "")
flags.DEFINE_multi_integer("sstp_tokenizer_path", None, "")


flags.DEFINE_string("result_dir", None, "")

FLAGS = flags.FLAGS




def get_sstp_data_paths(
    sstp_sources: List[Union[str, Tuple[str, str, str]]],
    sstp_source_fields: List[str],
    result_dir: str,
    topk: List[int],
    max_docs_chars: List[int],
    aggregate_modes: List[str],
    min_len_in_chars: int,
    max_total_chars: List[int],
    precompute: List[bool],
    stop_fraction: List[int],
    prefilter_max_docs_chars: List[int],
    sstp_random_shuffle_tree: List[bool],
    sstp_topk_sample: List[int],
    sstp_reverse_tree: List[bool],
    sstp_take_char_prefix: List[int],
    sstp_tokenizer_path: List[Optional[str]],

):
    LOGGER.info(
        f"SSTP sources {sstp_sources} fields {sstp_source_fields} topks {topk} max_chars_per_doc {max_docs_chars}"
        f"aggregate_modes {aggregate_modes} max_total_chars {max_total_chars} precompute {precompute}"
        f"stop_fraction {stop_fraction} prefilter_max_docs_chars {prefilter_max_docs_chars} sstp_random_shuffle_tree {sstp_random_shuffle_tree}"
        f"sstp_topk_sample {sstp_topk_sample} sstp_reverse_tree {sstp_reverse_tree} sstp_take_char_prefix {sstp_take_char_prefix}"
        f"sstp_tokenizer_path {sstp_tokenizer_path}"
    )

    assert len(sstp_sources) == len(sstp_source_fields)
    assert len(sstp_sources) == len(topk)
    assert len(sstp_sources) == len(max_docs_chars)
    assert len(sstp_sources) == len(aggregate_modes)
    assert len(sstp_sources) == len(max_total_chars)
    assert len(sstp_sources) == len(prefilter_max_docs_chars)
    assert len(sstp_sources) == len(sstp_random_shuffle_tree)
    assert len(sstp_sources) == len(sstp_topk_sample)
    assert len(sstp_sources) == len(sstp_reverse_tree)
    assert len(sstp_sources) == len(sstp_take_char_prefix)
    assert len(sstp_sources) == len(sstp_tokenizer_path)

    def prepare_one_sstp(
        source: Union[str, Tuple[str, str, str]],
        source_field: str,
        topk: int,
        max_doc_chars: int,
        aggregate_mode: int,
        max_total_chars: int,
        pre_comp: bool,
        stop_frac: float,
        prefilter_max_doc_chars: int,
        random_shuffle_tree: bool,
        topk_sample: Optional[int],
        reverse_tree: bool,
        take_char_prefix: int,
        tokenizer_path: Optional[str]
    ):
        if isinstance(source, str):
            tmp_name = source.split("/")[-1]
            records = get_records_from_dataset(
                ds_mode="file",
                ds_path=source,
                ds_subset=None,
                ds_split=None,
                topk=topk,
                max_doc_chars=max_doc_chars,
                aggregate_mode=aggregate_mode,
                content_field=source_field,
                min_len_in_chars=min_len_in_chars,
                max_total_chars=max_total_chars,
                precompute=pre_comp,
                stop_fraction=stop_frac,
                prefilter_max_doc_chars=prefilter_max_doc_chars,
                random_shuffle_tree=random_shuffle_tree,
                topk_sample=topk_sample,
                reverse_tree=reverse_tree,
                take_char_prefix=take_char_prefix,
                tokenizer_path=tokenizer_path,
            )

        else:
            ds_path, ds_subset, ds_split = source
            tmp_name = ds_path.replace("/", "_")
            records = get_records_from_dataset(
                ds_mode="hf",
                ds_path=ds_path,
                ds_subset=ds_subset,
                ds_split=ds_split,
                topk=topk,
                max_doc_chars=max_doc_chars,
                aggregate_mode=aggregate_mode,
                content_field=source_field,
                min_len_in_chars=min_len_in_chars,
                max_total_chars=max_total_chars,
                precompute=pre_comp,
                stop_fraction=stop_frac,
                prefilter_max_doc_chars=prefilter_max_doc_chars,
                random_shuffle_tree=random_shuffle_tree,
                topk_sample=topk_sample,
                reverse_tree=reverse_tree,
                take_char_prefix=take_char_prefix,
                tokenizer_path=tokenizer_path
            )

        tmp_folder = "tmp"
        result_path = os.path.join(result_dir, tmp_folder)
        result_path = os.path.join(result_path, tmp_name)

        stamp = datetime.now().strftime("%Y%m%d-%H%M%f")
        source_hash = sha1(
            (str(source) + str(topk) + str(max_doc_chars) + str(aggregate_mode) + str(max_total_chars)).encode("utf8")
        ).hexdigest()
        result_path = os.path.join(result_path, stamp + "_" + source_hash + ".jsonl")
        save_jsonl_efficient(records, result_path)

        return result_path

    result_dirs = []
    for s, sf, tk, mdc, am, mtc, pre_comp, stop_frac, pmdc, sstp_rst, topk_sample, reverse_tree, take_char_prefix, tokenizer_path in zip(
        sstp_sources, sstp_source_fields, topk, max_docs_chars, aggregate_modes, max_total_chars, precompute, stop_fraction, prefilter_max_docs_chars, sstp_random_shuffle_tree, sstp_topk_sample, sstp_reverse_tree, sstp_take_char_prefix, sstp_tokenizer_path
    ):
        LOGGER.info(f"Processing {s} {sf} {tk} {mdc} {am} {mtc} {pmdc} {sstp_rst} {topk_sample} {reverse_tree} {take_char_prefix}")
        result_dirs.append(
            prepare_one_sstp(
                source=s, source_field=sf, topk=tk, max_doc_chars=mdc, aggregate_mode=am, max_total_chars=mtc, pre_comp=pre_comp, stop_frac=stop_frac, prefilter_max_doc_chars=pmdc, random_shuffle_tree=sstp_rst, topk_sample=topk_sample, reverse_tree=reverse_tree, take_char_prefix=take_char_prefix, tokenizer_path=tokenizer_path
            )
        )
        LOGGER.info(f"Saved to tmp {result_dirs[-1]}")

    return result_dirs





def main(_):
    assert abs(sum(FLAGS.all_data_proportions) - 1.0) <= 1e-4
    assert len(FLAGS.other_sources) == len(FLAGS.other_source_fields)
    assert len(FLAGS.sstp_sources) + len(FLAGS.other_sources) == len(FLAGS.all_data_origins)
    sstp_files = get_sstp_data_paths(
        sstp_sources=FLAGS.sstp_sources,
        sstp_source_fields=FLAGS.sstp_source_fields,
        result_dir=FLAGS.result_dir,
        topk=FLAGS.sstp_topks,
        max_docs_chars=FLAGS.sstp_max_doc_chars,
        aggregate_modes=FLAGS.sstp_modes,
        min_len_in_chars=FLAGS.min_len_in_chars,
        max_total_chars=FLAGS.sstp_max_total_chars,
        precompute=[p > 0 for p in FLAGS.sstp_precompute],
        sstp_random_shuffle_tree=[p > 0 for p in FLAGS.sstp_random_shuffle_tree],
        stop_fraction=FLAGS.sstp_stop_fraction,
        prefilter_max_docs_chars=FLAGS.sstp_prefilter_max_doc_chars,
        sstp_topk_sample=FLAGS.sstp_topk_sample,
        sstp_reverse_tree=[p > 0 for p in FLAGS.sstp_reverse_tree],
        sstp_take_char_prefix=FLAGS.sstp_take_char_prefix,
        sstp_tokenizer_path=FLAGS.sstp_tokenizer_path,
    )

    all_files = sstp_files + FLAGS.other_sources
    source_fields = ["text"] * len(sstp_files) + FLAGS.other_source_fields

    LOGGER.info(f"Source Files are {all_files}")

    files = [mlxu.open_file(fp, "r") for fp in all_files]

    record_list = load_proportionally(
        files=files,
        source_fields=source_fields,
        destination_field="text",
        char_proportions=FLAGS.all_data_proportions,
        max_chars=FLAGS.mixture_max_chars,
        origin_field_name="source_dataset",
        origin_to_add=FLAGS.all_data_origins,
    )

    LOGGER.info(f"Shuffling")
    random.shuffle(record_list)

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S%f")

    origin_tag = ""
    for o, p in zip(FLAGS.all_data_origins, FLAGS.all_data_proportions):
        origin_tag += str(o) + "-" + str(p)

    for sstpm, sstprst, sstptopk, sstptopks, sstprt, sstp_tok_path in zip(FLAGS.sstp_modes, FLAGS.sstp_random_shuffle_tree, FLAGS.sstp_topks, FLAGS.sstp_topk_sample, FLAGS.sstp_reverse_tree, FLAGS.sstp_tokenizer_path):
        origin_tag += f".{sstpm}_tk{sstptopk}s{sstptopks}_rst{sstprst}_revt{sstprt}_tok{sstp_tok_path}."

    destination_path = os.path.join(FLAGS.result_dir, f"mix_{origin_tag}.") + f"{timestamp}.jsonl"

    LOGGER.info(f"Saving to {destination_path}")
    print(f"Saving to {destination_path}")

    dest_f = BuffferedFile(destination_path)

    for r in record_list:
        dest_f.write(json.dumps(r) + "\n")

    dest_f.close()

    for f in files:
        f.close()

        

    LOGGER.info(f"Done")

