# tools/metrics.py
import os, re, argparse, ast
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import torch

try:
    import lpips
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lpips_fn = lpips.LPIPS(net='alex').to(device).eval()
    use_lpips = True
except Exception:
    # keep original message style
    print("⚠️ LPIPS fail to install LPIPS")
    lpips_fn, use_lpips = None, False


def parse_exps_arg(exps_arg: str):
    """Accept 'a,b,c' or '["a","b","c"]' / "['a','b']"."""
    s = exps_arg.strip()
    try:
        if s.startswith('[') and s.endswith(']'):
            val = ast.literal_eval(s)
            if isinstance(val, (list, tuple)):
                return [str(x) for x in val]
    except Exception:
        pass
    # fallback: comma-separated
    return [e.strip() for e in s.split(',') if e.strip()]


def to_lpips_tensor(img_np01):
    t = torch.from_numpy(img_np01).permute(2, 0, 1).unsqueeze(0).float()
    return t * 2.0 - 1.0

def safe_ssim(a01, b01):
    try:
        return ssim(a01, b01, data_range=1.0, channel_axis=2)
    except TypeError:
        return ssim(a01, b01, data_range=1.0, multichannel=True)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_dir", required=True, help="root logs directory")
    parser.add_argument("--exps", required=True, help="experiment names, e.g. a,b or ['a','b']")
    args = parser.parse_args()

    root = args.log_dir
    exps = parse_exps_arg(args.exps)

    exp_dict = {exp: ["00"] for exp in exps}
    print(exp_dict)

    for exp in exp_dict:
        EXP = exp
        cam_future = exp_dict[exp][0]

        fid_future_range = range(240, 270)

        future_exp = f"{EXP}_future"
        future_dir = f"{root}/{future_exp}/training_render"

        metrics_txt     = f"{root}/{EXP}_future_only.txt"
        output_gif_fut  = f"{root}/{future_exp}_future.gif"

        os.makedirs(os.path.dirname(metrics_txt), exist_ok=True)

        def future_name(cam, fid): return f"future_cam{cam}_frame_{fid:03d}.png"
        def gt_name(cam, fid):     return f"gt_cam{cam}_frame_{fid:03d}.png"

        eval_list = []
        for fid in fid_future_range:
            rpath = os.path.join(future_dir, future_name(cam_future, fid))
            gpath = os.path.join(future_dir, gt_name(cam_future, fid))
            ok = True
            if not os.path.exists(rpath):
                print(f"❌ no: {rpath}")
                ok = False
            if not os.path.exists(gpath):
                print(f"❌ no:  {gpath}")
                ok = False
            if ok:
                eval_list.append(("future", fid, rpath, gpath))

        eval_list.sort(key=lambda x: x[1])

        per_frame_rows = []  # (fid, psnr, ssim, lpips)
        gif_fut = []

        for _, fid, rpath, gpath in eval_list:
            render_img = Image.open(rpath).convert("RGB")
            gt_img     = Image.open(gpath).convert("RGB")
            if render_img.size != gt_img.size:
                render_img = render_img.resize(gt_img.size, Image.BILINEAR)

            r_np = np.array(render_img).astype(np.float32) / 255.0
            g_np = np.array(gt_img).astype(np.float32) / 255.0

            p = psnr(g_np, r_np, data_range=1.0)
            s = safe_ssim(g_np, r_np)
            if use_lpips:
                with torch.no_grad():
                    r_t = to_lpips_tensor(r_np).to(device)
                    g_t = to_lpips_tensor(g_np).to(device)
                    l = lpips_fn(g_t, r_t).item()
            else:
                l = np.nan

            per_frame_rows.append((fid, p, s, l))

            w, h = render_img.size
            small = render_img.resize((max(1, w // 4), max(1, h // 4)), Image.BILINEAR)
            gif_fut.append(small)

        if gif_fut:
            gif_fut[0].save(output_gif_fut, save_all=True, append_images=gif_fut[1:], duration=100, loop=0)
            print(f"✅ GIF saved to: {output_gif_fut}")

        def group_stats(rows, name, fid_range=None):
            vals = [(p, s, l) for (fid, p, s, l) in rows
                    if (fid_range is None or fid in fid_range)]
            if not vals:
                return None
            ps = np.array([x[0] for x in vals], float)
            ss = np.array([x[1] for x in vals], float)
            ls = np.array([x[2] for x in vals], float)
            return {
                "PSNR_mean": float(np.mean(ps)),
                "SSIM_mean": float(np.mean(ss)),
                "LPIPS_mean": float(np.mean(ls)) if not np.isnan(ls).all() else np.nan
            }

        with open(metrics_txt, "w", encoding="utf-8") as f:
            f.write(f"# Future dir : {future_dir}\n")
            f.write("# Columns: group, frame_id, PSNR(dB), SSIM, LPIPS(alex)\n")

            for (fid, p, s, l) in sorted(per_frame_rows, key=lambda r: (r[0], r[1])):
                lp_str = "NaN" if np.isnan(l) else f"{l:.6f}"
                f.write(f"{fid:06d}, {p:.6f}, {s:.6f}, {lp_str}\n")
            f.write("\n# Group Averages\n")
            for gname in ["future"]:
                st_all = group_stats(per_frame_rows, gname)
                if st_all is None:
                    f.write(f"{gname}: no data\n")
                else:
                    lp_str = "NaN" if np.isnan(st_all["LPIPS_mean"]) else f"{st_all['LPIPS_mean']:.6f}"
                    f.write(f"{gname} (all): PSNR_mean={st_all['PSNR_mean']:.6f}, "
                            f"SSIM_mean={st_all['SSIM_mean']:.6f}, "
                            f"LPIPS_mean={lp_str}\n")

                st_30 = group_stats(per_frame_rows, gname, fid_range=range(240, 270))
                if st_30 is not None:
                    lp_str = "NaN" if np.isnan(st_30["LPIPS_mean"]) else f"{st_30['LPIPS_mean']:.6f}"
                    f.write(f"{gname} (first30): PSNR_mean={st_30['PSNR_mean']:.6f}, "
                            f"SSIM_mean={st_30['SSIM_mean']:.6f}, "
                            f"LPIPS_mean={lp_str}\n")

        print(f"✅ metrics written in: {metrics_txt}")

if __name__ == "__main__":
    main()