# train/data_utils/data_cache.py
from __future__ import annotations

import os, json, hashlib, time
from typing import Optional, Dict, Any

import torch
import torch.distributed as dist
from datasets import load_from_disk, Dataset as HFDataset

MTP_DATASET_BUILD_VERSION = "v1"  # if build logic is changed, only increase this version to invalidate the cache

def _is_dist_main(pg=None) -> bool:
    if not dist.is_initialized(): return True
    return dist.get_rank(pg) == 0

def _barrier(pg=None):
    if dist.is_initialized():
        dist.barrier(pg)

def _file_fingerprint(path: str) -> Dict[str, Any]:
    try:
        st = os.stat(path)
        return {"path": os.path.abspath(path), "mtime": int(st.st_mtime), "size": int(st.st_size)}
    except FileNotFoundError:
        return {"path": os.path.abspath(path), "mtime": 0, "size": 0}

def _tokenizer_fingerprint(tokenizer) -> Dict[str, Any]:
    # don't make it too heavy, only important things for reproducibility
    return {
        "name_or_path": getattr(tokenizer, "name_or_path", "<unknown>"),
        "vocab_size": tokenizer.vocab_size,
        "added_tokens": len(getattr(tokenizer, "added_tokens_decoder", {})),
        "mask_id": int(tokenizer.convert_tokens_to_ids("<mask>")),
    }

def _cache_key(meta: Dict[str, Any]) -> str:
    raw = json.dumps(meta, sort_keys=True).encode("utf-8")
    return hashlib.sha1(raw).hexdigest()[:16]

def _cache_dir_for(train_data_path: str, tokenizer, draft_length: int, shuffle_seed: int, cache_root: str) -> str:
    meta = {
        "version": MTP_DATASET_BUILD_VERSION,
        "train_file": _file_fingerprint(train_data_path),
        "tokenizer": _tokenizer_fingerprint(tokenizer),
        "draft_length": int(draft_length),
        "shuffle_seed": int(shuffle_seed),
    }
    key = _cache_key(meta)
    return os.path.join(cache_root, f"mtp_ds_{key}")

def load_cached_dataset(cache_dir: str) -> Optional[HFDataset]:
    if not os.path.isdir(cache_dir): return None
    ok_flag = os.path.join(cache_dir, "_SUCCESS")
    if not os.path.exists(ok_flag): return None
    ds = load_from_disk(cache_dir)
    return ds

def save_dataset_to_cache(ds: HFDataset, cache_dir: str, meta: Dict[str, Any]):
    os.makedirs(cache_dir, exist_ok=True)
    # when saving, release torch format to make it more neatly recorded in Arrow
    original_format = ds.format
    try:
        ds.reset_format()
        ds.save_to_disk(cache_dir)
        with open(os.path.join(cache_dir, "meta.json"), "w", encoding="utf-8") as f:
            json.dump(meta, f, indent=2, ensure_ascii=False)
        with open(os.path.join(cache_dir, "_SUCCESS"), "w") as f:
            f.write(f"{time.time()}\n")
    finally:
        # format recovery is done again at load time, so it can be skipped
        if original_format is not None:
            pass