# scripts/inspect_indexed_dataset.py
# Inspect an indexed dataset built by megatron.data.indexed_dataset
# - Fixed config at the top (no CLI args)
# - Prints impl, dtype, num_samples, per-item lengths stats, and a quick peek

import os
import sys
import numpy as np

# ----------------------------
# Fixed configuration (edit me)
# ----------------------------
PREFIX = "/root/trainbin1/document1"   # dataset prefix WITHOUT extension (.bin/.idx)
IMPL = "infer"                  # "infer", "mmap", or "cached"
SKIP_WARMUP = True              # skip mmap warmup to load faster
PEEK = 10                       # print first N per-item lengths
SAMPLE_FOR_STATS = 200000       # sample count used for uniqueness/sum stats
PRINT_ONE_ITEM_DTYPE_FALLBACK = False  # if dtype cannot be inferred, try reading item 0
# ----------------------------

def _import_indexed():
    try:
        from megatron.data.indexed_dataset import (
            make_dataset,
            infer_dataset_impl,
        )
        return make_dataset, infer_dataset_impl
    except Exception as e:
        print("ERROR: cannot import megatron.data.indexed_dataset:", e)
        sys.exit(1)

def _resolve_impl(prefix, impl, infer_fn):
    if impl == "infer":
        impl2 = infer_fn(prefix)
        if impl2 is None:
            return "unknown"
        return impl2
    return impl

def _get_dtype(ds):
    # Try mmap path first (index dtype)
    idx = getattr(ds, "_index", None)
    if idx is not None:
        dt = getattr(idx, "dtype", None)
        if dt is not None:
            return dt
    # Try cached path (dataset dtype)
    dt = getattr(ds, "dtype", None)
    if dt is not None:
        return dt
    # Fallback: probe first item dtype (optional)
    if PRINT_ONE_ITEM_DTYPE_FALLBACK:
        try:
            arr0 = np.asarray(ds[0])
            return arr0.dtype
        except Exception:
            pass
    return None

def _get_per_item_lengths(ds, impl):
    """
    For mmap: ds.sizes is per-item lengths (np.int32)
    For cached: use data_offsets difference (element counts)
    Returns: (lengths_array or None, note string)
    """
    if impl == "mmap":
        lens = getattr(getattr(ds, "_index", ds), "sizes", None)
        if isinstance(lens, np.ndarray):
            return lens, "mmap:sizes"
        # fallback (unlikely)
        return None, "mmap:missing_sizes"
    else:
        # cached IndexedDataset stores element offsets in data_offsets
        data_offsets = getattr(ds, "data_offsets", None)
        if isinstance(data_offsets, np.ndarray) and data_offsets.ndim == 1 and data_offsets.size >= 2:
            lens = data_offsets[1:] - data_offsets[:-1]
            return lens, "cached:data_offsets"
        # fallback: try to compute first K item lengths by reading (slow; avoid in large sets)
        try:
            K = min(len(ds), min(1000, SAMPLE_FOR_STATS))
            probe = np.array([np.asarray(ds[i]).size for i in range(K)], dtype=np.int64)
            return probe, "cached:fallback_probe_first_{}(slow)".format(K)
        except Exception:
            return None, "cached:failed_to_compute_lengths"

def _np_dtype_str(dt):
    try:
        return str(np.dtype(dt))
    except Exception:
        return str(dt)

def _print_header(prefix, impl, ds, dtype):
    print("=== Indexed Dataset Inspect ===")
    print(f"prefix       : {prefix}")
    print(f"impl         : {impl}")
    print(f"num_samples  : {len(ds)}")
    print(f"dtype        : {_np_dtype_str(dtype) if dtype is not None else 'unknown'}")

def _print_length_stats(lens, note):
    if lens is None:
        print("per-item-len : unavailable ({})".format(note))
        return
    n = lens.shape[0]
    print(f"len_source   : {note}")
    print(f"len_count    : {n}")
    # limit stats to SAMPLE_FOR_STATS to avoid heavy ops
    k = min(n, SAMPLE_FOR_STATS)
    sample = lens[:k]
    uniq = np.unique(sample)
    is_fixed = (uniq.size == 1)
    print(f"is_fixed_len : {is_fixed}")
    if is_fixed:
        print(f"sample_len   : {int(uniq[0])}")
    else:
        # print a few unique lengths
        to_show = uniq[:min(10, uniq.size)]
        print(f"size_unique  : {to_show.tolist()}" + (" ..." if uniq.size > 10 else ""))
    print(f"peek_first_{min(PEEK, k)}: {sample[:min(PEEK, k)].astype(int).tolist()}")
    print(f"tokens_sum(sampled {k}) : {int(sample.sum())}")

def _print_misc(ds, impl):
    # element size (if available)
    if impl == "cached":
        esz = getattr(ds, "element_size", None)
        if esz is not None:
            print(f"element_size : {esz} bytes")
    # doc_idx length if available (not used in your pipeline but informative)
    doc_idx = getattr(ds, "doc_idx", None)
    try:
        if isinstance(doc_idx, np.ndarray):
            print(f"doc_idx_len  : {doc_idx.shape[0]}")
    except Exception:
        pass

def main():
    make_dataset, infer_dataset_impl = _import_indexed()

    if not PREFIX or PREFIX.strip() == "":
        print("ERROR: Please set PREFIX at the top of this script to your dataset prefix (no extension).")
        sys.exit(1)

    impl = _resolve_impl(PREFIX, IMPL, infer_dataset_impl)
    if impl == "unknown":
        print("ERROR: cannot infer impl; set IMPL to 'mmap' or 'cached' at the top.")
        sys.exit(1)

    ds = make_dataset(PREFIX, impl, skip_warmup=SKIP_WARMUP)
    if ds is None:
        print("ERROR: failed to load dataset. Check PREFIX/IMPL.")
        sys.exit(1)

    dtype = _get_dtype(ds)
    _print_header(PREFIX, impl, ds, dtype)

    lens, note = _get_per_item_lengths(ds, impl)
    _print_length_stats(lens, note)

    _print_misc(ds, impl)
    print("=== End ===")

if __name__ == "__main__":
    main()
