# inspect_minari.py
import os, csv, json, hashlib, re
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional

import h5py
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from omegaconf import OmegaConf

def ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def as_uint8_rgb(img: np.ndarray) -> np.ndarray:
    if img.dtype == np.uint8: return img
    if img.max() <= 1.0: img = (img * 255.0).clip(0, 255)
    return img.astype(np.uint8)

def upscale(img: np.ndarray, scale: int = 1) -> np.ndarray:
    return np.repeat(np.repeat(img, scale, axis=0), scale, axis=1) if scale > 1 else img

def save_png(img: np.ndarray, path: str, scale: int = 1):
    ensure_dir(os.path.dirname(path))
    plt.imsave(path, as_uint8_rgb(upscale(img, scale)))

def layout_hash(img: np.ndarray) -> str:
    h = hashlib.md5(); h.update(np.ascontiguousarray(img).tobytes())
    return h.hexdigest()[:12]

def try_decode(x):
    if isinstance(x, (bytes, np.bytes_)):
        try: return x.decode("utf-8")
        except Exception: return str(x)
    if isinstance(x, np.ndarray) and x.dtype.kind in ("S","O","U"):
        return try_decode(x[0]) if x.size>0 else ""
    return x

def _extract_id_from_jsonlike(s: str) -> Optional[str]:
    try:
        obj = json.loads(s)
        if isinstance(obj, dict):
            v = obj.get("id") or obj.get("name")
            return v if isinstance(v, str) else None
    except Exception:
        pass
    m = re.search(r'"id"\s*:\s*"([^"]+)"', s)
    return m.group(1) if m else None

def canonical_env_id(raw: Any) -> Optional[str]:
    if isinstance(raw, dict):
        v = raw.get("id") or raw.get("name"); return v if isinstance(v, str) else None
    if isinstance(raw, str):
        return _extract_id_from_jsonlike(raw) or raw
    return None

def ensure_versioned_id(eid: Optional[str]) -> Optional[str]:
    if not eid: return None
    return eid if re.search(r"-v\d+$", eid) else eid + "-v0"

def safe_env_tag(env_id: Any, fallback: str = "env", max_len: int = 48) -> str:
    tag = None
    if isinstance(env_id, dict): tag = env_id.get("id") or env_id.get("name")
    elif isinstance(env_id, str): tag = _extract_id_from_jsonlike(env_id) or env_id
    tag = (tag or fallback).replace("/", "_").replace(" ", "_")
    tag = re.sub(r"[^A-Za-z0-9._-]+", "_", tag)
    return tag[:max_len]

def resolve_minari_dir_from_id(dataset_id: str) -> str:
    root = os.environ.get("MINARI_DATASETS_PATH", os.path.expanduser("~/.minari/datasets"))
    return os.path.join(root, dataset_id)

def read_env_id_from_metadata(meta_path: str) -> Optional[str | Dict[str, Any]]:
    try:
        with open(meta_path, "r") as f: raw = f.read()
        meta = json.loads(raw)
        if isinstance(meta, str):
            try: meta = json.loads(meta)
            except Exception: return meta
        for k in ("env_id","environment_name","env_name","name"):
            v = meta.get(k)
            if isinstance(v, str) and v: return v
        env_spec = meta.get("env_spec")
        if isinstance(env_spec, dict): return env_spec.get("id") or env_spec.get("name") or env_spec
        if isinstance(env_spec, str):  return env_spec
    except Exception:
        pass
    return None

def find_episode_groups(h5: h5py.File) -> List[h5py.Group]:
    eps=[]
    def visit(_, obj):
        if isinstance(obj, h5py.Group):
            k=set(obj.keys())
            if "observations" in k and ("actions" in k or "rewards" in k):
                eps.append(obj)
    h5.visititems(visit)
    def has_image(g):
        if "observations" not in g: return False
        og=g["observations"]
        return ("image" in og) or ("observation" in og) or any("image" in k.lower() for k in og.keys())
    eps=[g for g in eps if has_image(g)]
    eps.sort(key=lambda g: g.name)
    return eps

def first_frame_from_group(g: h5py.Group) -> np.ndarray:
    og=g["observations"]
    if "image" in og: return np.array(og["image"][0])
    if "observation" in og: return np.array(og["observation"][0])
    for k in og.keys():
        if "image" in k.lower(): return np.array(og[k][0])
    raise KeyError(f"No image-like dataset under {og.name}")

def get_episode_seed(g: h5py.Group) -> Optional[int]:
    if "seed" in g.attrs: return int(np.array(g.attrs["seed"]).squeeze())
    for k in ("metadata","episode_metadata","info","infos"):
        if k in g:
            sub=g[k]
            if isinstance(sub,h5py.Group):
                if "seed" in sub.attrs: return int(np.array(sub.attrs["seed"]).squeeze())
                if "seed" in sub:
                    try: return int(np.array(sub["seed"][()]).squeeze())
                    except Exception: pass
            elif isinstance(sub,h5py.Dataset) and sub.name.endswith("/seed"):
                try: return int(np.array(sub[()]).squeeze())
                except Exception: pass
    if "seed" in g:
        try: return int(np.array(g["seed"][()]).squeeze())
        except Exception: pass
    return None

def get_episode_mission(g: h5py.Group) -> str:
    for path in [("observations","mission"),("infos","mission"),("metadata","mission")]:
        try:
            sub=g[path[0]]
            if path[1] in sub: return str(try_decode(sub[path[1]][0]))
        except Exception: continue
    if "mission" in g.attrs: return str(try_decode(g.attrs["mission"]))
    return ""

def build_minari_dataset_id_from_dir(dataset_dir: str, env_id: Optional[str | Dict[str, Any]]) -> Optional[str]:
    env_id_str = canonical_env_id(env_id) or safe_env_tag(env_id) if env_id is not None else None
    if env_id_str is None: return None
    env_family = re.sub(r"-v\d+$", "", env_id_str)
    dataset_name = Path(dataset_dir).name
    return f"minigrid/{env_family}/{dataset_name}"

def recover_env_via_minari_path(dataset_dir: str, seed: int):
    import minari
    from minari.dataset import MinariDataset
    ds = MinariDataset.from_path(os.path.join(dataset_dir, "data"))
    env = ds.recover_environment()
    env.reset(seed=int(seed))
    return env

def recover_env_via_minari_id(dataset_id: str, seed: int):
    import minari
    ds = minari.load_dataset(dataset_id)
    env = ds.recover_environment()
    env.reset(seed=int(seed))
    return env

def inspect_minari_folder(
    dataset_dir: str = "",
    out_dir: str = "viz/minari",
    limit_visuals: int = 0,
    scale: int = 4,
    train_frac: float = 1.0,
    split_seed: int = 123,
    split_csv_name: Optional[str] = None,
    game_code: Optional[str] = None,
    cfg_dataset_id: Optional[str] = None,
) -> str:
    if not dataset_dir:
        if not cfg_dataset_id:
            raise ValueError("Either dataset_dir must be provided or cfg_dataset_id must be set.")
        dataset_dir = resolve_minari_dir_from_id(cfg_dataset_id)

    h5_path = os.path.join(dataset_dir, "data", "main_data.hdf5")
    meta_path = os.path.join(dataset_dir, "data", "metadata.json")
    if not os.path.isfile(h5_path):
        raise FileNotFoundError(f"Couldn't find hdf5 at {h5_path}")

    raw_env_id = read_env_id_from_metadata(meta_path) if os.path.isfile(meta_path) else None
    env_tag    = safe_env_tag(raw_env_id)
    gym_id     = ensure_versioned_id(canonical_env_id(raw_env_id))
    minari_id_from_dir = build_minari_dataset_id_from_dir(dataset_dir, raw_env_id)

    ensure_dir(out_dir)
    ep_rows: List[Dict[str, Any]] = []
    unique_key_to_png: Dict[str, str] = {}

    with h5py.File(h5_path, "r") as h5:
        eps = find_episode_groups(h5)
        for idx, g in enumerate(eps):
            try:
                frame = first_frame_from_group(g)
            except Exception:
                continue

            seed = get_episode_seed(g)
            key  = str(seed) if seed is not None else layout_hash(frame)
            
            if key not in unique_key_to_png and len(unique_key_to_png) < limit_visuals:
                png_path = os.path.join(out_dir, f"{env_tag}_{key}.png")
                did_render = False
                if seed is not None:
                    try:
                        env = recover_env_via_minari_path(dataset_dir, int(seed)); img = env.render(); env.close()
                        save_png(img, png_path); did_render = True
                    except Exception: pass
                if (not did_render) and (seed is not None) and minari_id_from_dir:
                    try:
                        env = recover_env_via_minari_id(minari_id_from_dir, int(seed)); img = env.render(); env.close()
                        save_png(img, png_path); did_render = True
                    except Exception: pass
                if (not did_render) and (seed is not None) and gym_id:
                    try:
                        env = gym.make(gym_id, render_mode="rgb_array"); env.reset(seed=int(seed)); img = env.render(); env.close()
                        save_png(img, png_path); did_render = True
                    except Exception: pass
                if not did_render:
                    save_png(frame, png_path, scale=scale)
                unique_key_to_png[key] = png_path

            mission = get_episode_mission(g)
            ep_rows.append({
                "episode_idx": idx,
                "seed_or_layout": key,
                "mission": mission,
            })

    unique_keys = sorted(list(set(r["seed_or_layout"] for r in ep_rows)))
    rng = np.random.default_rng(split_seed)
    shuffled_keys = rng.permutation(unique_keys)
    
    split_idx = int(round(train_frac * len(shuffled_keys)))
    train_keys = set(shuffled_keys[:split_idx])
    
    for row in ep_rows:
        row["split"] = "train" if row["seed_or_layout"] in train_keys else "evaluate"
        
    prefix = f"{game_code}" if game_code else ""
    if split_csv_name is None:
        split_csv_name = f"{prefix}_episode_splits.csv"
        
    split_csv_path = os.path.join(out_dir, split_csv_name)
    with open(split_csv_path, "w", newline="") as f:
        fieldnames = ["episode_idx", "split", "seed_or_layout", "mission"]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(ep_rows)
    
    num_train = sum(1 for row in ep_rows if row["split"] == "train")
    num_eval = len(ep_rows) - num_train
    print(f"Consolidated Split CSV: {split_csv_path}")
    print(f"  Total episodes: {len(ep_rows)} (train={num_train}, evaluate={num_eval})")
    print(f"  Total unique keys: {len(unique_keys)}")
    
    return split_csv_path
    
if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--dataset_dir", default="", help="Folder containing data/main_data.hdf5 and data/metadata.json. If empty, we resolve from the YAML dataset_id.")
    p.add_argument("--config_dir", default="configs", help="Directory containing YAML config files")
    p.add_argument("--game_code", required=True, help="Game code (e.g., 'bosslevel')")
    p.add_argument("--out_dir", default="seeds")
    p.add_argument("--limit_visuals", type=int, default=30)
    p.add_argument("--scale", type=int, default=4)
    p.add_argument("--train_frac", type=float, default=1.0, help="Fraction of unique keys that go to the train split.")
    p.add_argument("--split_seed", type=int, default=123, help="Random seed for reproducible splitting.")
    p.add_argument("--split_csv_name", type=str, default=None, help="Optional custom filename for the split CSV.")
    args = p.parse_args()

    cfg_path = os.path.join(args.config_dir, f"{args.game_code}.yaml")
    if not os.path.isfile(cfg_path):
        raise FileNotFoundError(f"Config file not found: {cfg_path}")
    cfg = OmegaConf.load(cfg_path)

    cfg_dataset_id = cfg.data.dataset_id

    inspect_minari_folder(
        dataset_dir=args.dataset_dir,
        out_dir=args.out_dir,
        limit_visuals=args.limit_visuals,
        scale=args.scale,
        train_frac=args.train_frac,
        split_seed=args.split_seed,
        split_csv_name=args.split_csv_name,
        game_code=args.game_code,
        cfg_dataset_id=cfg_dataset_id,
    )