import mlxu

from typing import Optional, Dict, Any
import logging
import sys
import json
from ..utils import MultiLogger
import numpy as np
from .stat_utils import basic_stats_from_numeric_list

LOGGER = logging.Logger("Data Extractor", level=logging.INFO)
LOGGER_HANDLER = logging.StreamHandler(sys.stderr)
LOGGER_HANDLER.setFormatter(logging.Formatter("[%(asctime)s] DE [%(levelname)s] : %(message)s"))
LOGGER.addHandler(LOGGER_HANDLER)
LOGGER = MultiLogger(basic_loggers=[print], advanced_loggers=[LOGGER])

def read_one_record(reader) -> Optional[Dict[str, Any]]:
    raw_record = reader.readline()

    if raw_record is None:
        return None
    elif len(raw_record.strip()) == 0:
        raw_record = reader.readline()
        if raw_record is None or len(raw_record.strip()) == 0:
            return None
        else:
            raise ValueError("Empty line inside jsonl")
    else:
        return json.loads(raw_record)


def load_records(jsonl_path: str, text_field: str, max_chars: Optional[int]):
    gathered_records = 0
    gathered_chars = 0
    gathered_bytes = 0
    recently_gathered_chars = 0
    char_lens = []
    records = []
    with mlxu.open_file(jsonl_path, "r") as f:
        if max_chars is None:
            raw_records = f.readlines()
            raw_records = raw_records[::-1]
            while len(raw_records) > 0:
                rr = raw_records.pop()
                if len(rr.strip()) == 0:
                    continue
                record = json.loads(rr)
                gathered_bytes += len(json.dumps(record))
                gathered_chars += len(record[text_field])
                char_lens.append(len(record[text_field]))
                gathered_records += 1
                records.append(record)

        else:
            while True:
                record = read_one_record(f)
                if record is None:
                    LOGGER.info("Finished due to end of file")
                    break
                else:
                    gathered_bytes += len(json.dumps(record))
                    gathered_chars += len(record[text_field])
                    recently_gathered_chars += len(record[text_field])
                    char_lens.append(len(record[text_field]))
                    gathered_records += 1

                    records.append(record)
                    if max_chars is not None and recently_gathered_chars >= 10_000_000:
                        recently_gathered_chars = 0
                        LOGGER.info(
                            f"Loaded {gathered_chars}/{max_chars} chars, that is {gathered_chars/max_chars * 100.0}%."
                        )
                    if max_chars is not None and gathered_chars >= max_chars:
                        LOGGER.info(f"Char limit reached {gathered_chars}/{max_chars}")
                        break

    len_stats = basic_stats_from_numeric_list(char_lens)

    stats = {
        "gathered_records": gathered_records,
        "gathered_chars": gathered_chars,
        "gathered_bytes": gathered_bytes,
        "len_stats": len_stats,
    }
    LOGGER.info(f"load_records: Stats\n {json.dumps(stats, indent=2)}")

    return records
