from absl import flags, app
import logging
import sys
import mlxu
import json
from .stat_utils import basic_stats_from_numeric_list
from .file_utils import DatasetWriter
from .record_loader import load_records
from ..utils import MultiLogger
from typing import Optional
import random
from ..sspt_processing.retrievers import RepoRetrieverCatalog, IndexedData
from typing import List, Dict, Any
import functools

LOGGER = logging.Logger("Data Extractor", level=logging.INFO)
LOGGER_HANDLER = logging.StreamHandler(sys.stderr)
LOGGER_HANDLER.setFormatter(logging.Formatter("[%(asctime)s] DE [%(levelname)s] : %(message)s"))
LOGGER.addHandler(LOGGER_HANDLER)

LOGGER = MultiLogger(basic_loggers=[print], advanced_loggers=[LOGGER])

if __name__ == "__main__":
    flags.DEFINE_string("source_jsonl", None, "REQjson with dataset")
    flags.DEFINE_string("dest_jsonl", None, "REQ json with dataset")
    flags.DEFINE_string("text_field", None, "REQ")
    flags.DEFINE_string("mode", None, "REQ repo|plain - whether to take repos into account")

    flags.DEFINE_integer("max_chars_load", None, "OPT")
    flags.DEFINE_integer("max_chars_doc", None, "OPT")
    flags.DEFINE_integer("max_chars_save", None, "OPT")
    flags.DEFINE_integer("max_chars_repo", None, "OPT")
    flags.DEFINE_integer("seed", 42, "")
    flags.DEFINE_boolean("preshuffle", True, "shuffle dataset after loading")
    flags.DEFINE_boolean("postshuffle", True, "shuffle dataset before saving")
    FLAGS = flags.FLAGS


def prepare_repo_aware(records: List[Dict[str, Any]], text_field: str, max_chars_repo: Optional[int]):
    LOGGER.info(f"Orgainzing records in repos (files from one repo one after another)." f"Text field is {text_field}")
    indexed_records = IndexedData(
        records,
        text_field_src_alias=text_field,
        repo_name_src_alias="max_stars_repo_name",
        repo_path_src_alias="max_stars_repo_path",
    )
    repo_retriever = RepoRetrieverCatalog(indexed_data=indexed_records)
    num_repo_frags = 0
    docs_per_repo_frag = []

    result_records_ids = []
    for ir in indexed_records.iter_unmarked():
        whole_repo = repo_retriever.search(ir, None)
        num_repo_frags += 1
        
        repo_chars = 0
        taken_files = 0
        for r in whole_repo:
            indexed_records.mark_data(r)
            repo_chars += len(records[r["id"]][text_field])
            result_records_ids.append(r["id"])
            
            taken_files += 1
            if max_chars_repo is not None and repo_chars >= max_chars_repo:
                break
        docs_per_repo_frag.append(taken_files)
                
    del indexed_records
    result_records = []
    for rid in result_records_ids:
        result_records.append(records[rid])

    stats = basic_stats_from_numeric_list(docs_per_repo_frag)
    LOGGER.info("Repo aware preparation finished\n" f"The stats are:\n {json.dumps(stats, indent=2)}")
    return result_records


def filter_length(records, text_field, max_doc_char_length: Optional[int]):
    LOGGER.info(
        f"Length filtering using: "
        f"text_field: {text_field}"
        f"max_doc_char_length: {max_doc_char_length}"
    )

    result = []
    for r in records:
        if max_doc_char_length is not None and len(r[text_field]) > max_doc_char_length:
            continue
        result.append(r)

    return result



def filter_total_chars(records, text_field, max_total_char_length: Optional[int]):
    LOGGER.info(
        f"Filtering using: "
        f"text_field: {text_field}"
        f"max_total_char_length: {max_total_char_length}"
    )
    total_chars = 0
    result = []
    for r in records:
        total_chars += len(r[text_field])
        result.append(r)
        if max_total_char_length is not None and total_chars >= max_total_char_length:
            LOGGER.info(f"Reached char limit {total_chars}/{max_total_char_length}")
            break

    return result, total_chars


def extract_data(
    source_jsonl: str,
    dest_jsonl: str,
    text_field: str,
    mode: str,
    max_chars_load: Optional[int],
    max_chars_doc: Optional[int],
    max_chars_save: Optional[int],
    max_chars_repo: Optional[int],
    seed: str,
    preshuffle: bool,
    postshuffle: bool,
):
    assert isinstance(source_jsonl, str)
    assert isinstance(dest_jsonl, str)
    assert isinstance(text_field, str)
    assert isinstance(mode, str)

    records = load_records(jsonl_path=source_jsonl, text_field=text_field, max_chars=max_chars_load)
    LOGGER.info(f"Loaded {len(records)} records")

    records = filter_length(records=records,
                            text_field=text_field,
                            max_doc_char_length=max_chars_doc)
    
    LOGGER.info(f"Filtered out longer than {max_chars_doc} and have {len(records)} records")

    random.seed(seed)
    if preshuffle:
        LOGGER.info(f"Preshuffling")
        random.shuffle(records)

    if mode == "repo":
        records = prepare_repo_aware(records=records, text_field=text_field, max_chars_repo=max_chars_repo)
    else:
        pass

    records, num_chars = filter_total_chars(
        records=records, text_field=text_field, max_total_char_length=max_chars_save
    )

    if postshuffle:
        LOGGER.info(f"Postshuffling")
        random.shuffle(records)

    with DatasetWriter(destination=dest_jsonl, expected_chars=num_chars, text_field=text_field) as dsw:
        for r in records:
            dsw.add(r)


def main(_):
    extract_data(
        source_jsonl=FLAGS.source_jsonl,
        dest_jsonl=FLAGS.dest_jsonl,
        text_field=FLAGS.text_field,
        mode=FLAGS.mode,
        max_chars_load=FLAGS.max_chars_load,
        max_chars_doc=FLAGS.max_chars_doc,
        max_chars_save=FLAGS.max_chars_save,
        max_chars_repo=FLAGS.max_chars_repo,
        seed=FLAGS.seed,
        preshuffle=FLAGS.preshuffle,
        postshuffle=FLAGS.postshuffle,
    )

