import os
import json
import numpy as np
from tqdm import tqdm

def summarize_depth_stat(depth_stat):
    """Return three dicts: min, max, len for each key (fast builtins)."""
    dmin, dmax, dlen = {}, {}, {}
    for k, v in depth_stat.items():
        # v can be list or np.array
        if isinstance(v, np.ndarray):
            if v.size == 0: 
                continue
            dmin[k] = float(v.min())
            dmax[k] = float(v.max())
            dlen[k] = int(v.size)
        else:
            if len(v) == 0:
                continue
            dmin[k] = float(min(v))
            dmax[k] = float(max(v))
            dlen[k] = int(len(v))
    return dmin, dmax, dlen

ref_dir  = '../bellman_explore/depth_stat/all_grids'

ref_files = sorted(os.listdir(ref_dir))
thresholds = (3.0, 5.0, 10.0)

eval_methods = ['DP', 'DSS', 'FME', 'Greedy', 'human']
for method in eval_methods:
    print(f'==== Evaluating method: {method} ====')
    eval_dir = f'../bellman_explore/depth_stat/{method}'
    all_res = {str(t): {} for t in thresholds}

    for ref_file in tqdm(ref_files, desc="Files"):
        ref_file_path  = os.path.join(ref_dir, ref_file)
        eval_file_path = os.path.join(eval_dir, ref_file)
        if not os.path.exists(eval_file_path):
            print(f"Warning: eval file {eval_file_path} does not exist, skipping.")
            continue

        ref_depth_stat  = np.load(ref_file_path,  allow_pickle=True).item()
        eval_depth_stat = np.load(eval_file_path, allow_pickle=True).item()

        # Compute per-file summaries ONCE
        ref_min_d, ref_max_d, ref_len_d = summarize_depth_stat(ref_depth_stat)
        eval_min_d, eval_max_d, _       = summarize_depth_stat(eval_depth_stat)

        # Align over ref keys only (behavior matches original)
        keys = list(ref_min_d.keys())
        n = len(keys)
        if n == 0:
            for th in thresholds:
                all_res[str(th)][ref_file] = np.nan
            continue

        ref_min = np.fromiter((ref_min_d[k] for k in keys), dtype=np.float64, count=n)
        ref_max = np.fromiter((ref_max_d[k] for k in keys), dtype=np.float64, count=n)
        ref_len = np.fromiter((ref_len_d[k] for k in keys), dtype=np.int32,   count=n)

        # Build eval arrays; missing keys -> mark as absent
        present = np.fromiter((k in eval_min_d for k in keys), dtype=bool, count=n)
        eval_min = np.empty(n, dtype=np.float64)
        eval_max = np.empty(n, dtype=np.float64)
        # Only fill where present; others left uninitialized but never used thanks to mask
        if present.any():
            idx = np.nonzero(present)[0]
            for i, k in zip(idx, (keys[j] for j in idx)):
                eval_min[i] = eval_min_d[k]
                eval_max[i] = eval_max_d[k]

        total_num = n  # same as len(ref_data) in your print

        for thres in thresholds:
            # Clip ref_max by threshold
            ref_max_clip = np.minimum(ref_max, thres)

            # Valid if enough samples and positive span after clipping
            denom = ref_max_clip - ref_min
            valid = (ref_len >= 3) & (denom > 0)

            # Intersection only where eval key exists
            numer = np.zeros(n, dtype=np.float64)
            if present.any():
                vp = valid & present
                if vp.any():
                    inter_hi = np.minimum(ref_max_clip[vp], eval_max[vp])
                    inter_lo = np.maximum(ref_min[vp],     eval_min[vp])
                    numer[vp] = np.maximum(0.0, inter_hi - inter_lo)

            percent = np.zeros(n, dtype=np.float64)
            vmask = valid
            percent[vmask] = numer[vmask] / denom[vmask]

            mean_percent = percent[vmask].mean() if vmask.any() else np.nan
            valid_num = int(vmask.sum())

            print(f"{ref_file}, thres = {thres}, total_num = {total_num}, "
                    f"valid_num = {valid_num}, mean percent = {mean_percent}")

            all_res[str(thres)][ref_file] = float(mean_percent)

    print('================ Summary ================')
    for thres in map(str, thresholds):
        vals = list(all_res[thres].values())
        mean_percent = np.nanmean(vals) if len(vals) else np.nan
        print(f"thres = {thres}, mean percent = {mean_percent}")
    
    with open(f'eval_depth_{method}.json', 'w') as f:
        json.dump(all_res, f, indent=4)
