#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
coco_caption_val2014_eval.py
----------------------------
Generate captions for MS‑COCO val2014 with UnifiedPipeline and
directly compute CIDEr.

Example
-------
python coco_caption_val2014_eval.py \
       --ann      /data/coco2014/annotations/captions_val2014.json \
       --img-dir  /data/coco2014/val2014 \
       --save-pred ./predictions.json \
       --bs 4 --device cuda
"""
import argparse, json, math, re, string
from collections import Counter, defaultdict
from pathlib import Path

import numpy as np
import torch
from tqdm import tqdm

from src.pipeline import UnifiedPipeline
from src.transformer import SymmetricTransformer2DModel
from src.scheduler import Scheduler
from train.trainer_utils import load_images_to_tensor
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import VQModel

def build_pipe(model_path, ckpt_path, device):
    tok = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
    txt = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder")
    vq  = VQModel.from_pretrained(model_path, subfolder="vqvae")
    tr  = SymmetricTransformer2DModel.from_pretrained(
            ckpt_path or model_path, subfolder="transformer")
    sch = Scheduler.from_pretrained(model_path, subfolder="scheduler")
    return UnifiedPipeline(vqvae=vq, tokenizer=tok, text_encoder=txt,
                           transformer=tr, scheduler=sch).to(device)

try:
    from pycocoevalcap.cider.cider import Cider
    USE_COCOCAP = True
    print("[INFO] use official pycocoevalcap")
except ModuleNotFoundError:
    USE_COCOCAP = False
    print("[WARN] pycocoevalcap not found → use NumPy fallback")
    def _precook(s, n=4):
        cnt = Counter()
        words = s.split()
        for k in range(1, n+1):
            for i in range(len(words)-k+1):
                cnt[tuple(words[i:i+k])] += 1
        return cnt
    def _idf(refs):
        df = defaultdict(int)
        for r in refs:
            for ng in r: df[ng] += 1
        N = len(refs)
        return {ng: math.log(max(1.0, N)/dfv) for ng, dfv in df.items()}
    def _cider(h, rs, idf, n=4):
        hc = _precook(h)
        rc = [_precook(r) for r in rs]
        scr = 0.0
        for k in range(1, n+1):
            hv = np.array([hc.get(ng,0) for ng in idf if len(ng)==k])
            hv = hv/np.linalg.norm(hv) if hv.sum() else hv
            sim=[]
            for r in rc:
                rv = np.array([r.get(ng,0) for ng in idf if len(ng)==k])
                rv = rv/np.linalg.norm(rv) if rv.sum() else rv
                sim.append(np.minimum(hv, rv).sum())
            scr += np.mean(sim)
        return scr*10
    def cider_np(hyps, refs):
        idf = _idf([_precook(r) for rs in refs.values() for r in rs])
        return float(np.mean([_cider(hyps[i], refs[i], idf) for i in hyps]))

def compute_cider(refs, hyps):
    common = refs.keys() & hyps.keys()
    refs = {i: refs[i] for i in common}
    hyps = {i: hyps[i] for i in common}
    print(f"Scoring {len(common):,} images …")

    if USE_COCOCAP:                   
        gts = {i: refs[i]         for i in refs}  
        res = {i: [hyps[i]]       for i in hyps} 
        return Cider().compute_score(gts, res)[0]


@torch.no_grad()
def main(cfg):
    coco = json.load(open(cfg.ann))
    id2file = {img["id"]: img["file_name"] for img in coco["images"]}
    img_ids = sorted(id2file.keys())
    print(f"Total images: {len(img_ids):,}")

    pipe = build_pipe(cfg.model, cfg.ckpt, cfg.device)
    hyps, refs = {}, defaultdict(list)

    for ann in coco["annotations"]:
        refs[ann["image_id"]].append(ann["caption"])

    cnt = 0
    batch = []
    for img_id in tqdm(img_ids, desc="Generate"):
        img_path = Path(cfg.img_dir) / id2file[img_id]
        batch.append((img_id, img_path))
        if len(batch) == cfg.bs:
            gen_batch(batch, pipe, cfg, hyps)
            batch.clear()
        if cnt == 10:
            break
        cnt += 1
    if batch:
        gen_batch(batch, pipe, cfg, hyps)

    if cfg.save_pred:
        json.dump([{"image_id": int(i), "caption": c} for i, c in hyps.items()],
                  open(cfg.save_pred, "w"), indent=2)
        print(f"[✓] predictions saved to {cfg.save_pred}")

    cider = compute_cider(refs, hyps)
    print(f"\nCIDEr: {cider*100:.2f}")

def gen_batch(batch, pipe, cfg, hyps):
    paths = [p for _, p in batch]
    imgs  = load_images_to_tensor(paths,
                                  target_size=(cfg.res, cfg.res)).to(cfg.device)
    outs = pipe(prompt=["Describe this image."]*len(batch),
                image=imgs, height=cfg.res, width=cfg.res,
                guidance_scale=cfg.cfg, num_inference_steps=cfg.steps)
    for (img_id, _), cap in zip(batch, outs.prompts):
        if img_id not in hyps:
            hyps[img_id] = cap

# ========== 4. CLI ==========
if __name__ == "__main__":
    pa = argparse.ArgumentParser()
    pa.add_argument("--ann",      required=True, help="captions_val2014.json")
    pa.add_argument("--img-dir",  required=True, help="COCO val2014")
    pa.add_argument("--model",    default="MeissonFlow/Meissonic")
    pa.add_argument("--ckpt",     default="")
    pa.add_argument("--device",   default="cuda")
    pa.add_argument("--bs",   type=int, default=4)
    pa.add_argument("--res",  type=int, default=512)
    pa.add_argument("--steps",type=int, default=64)
    pa.add_argument("--cfg",  type=float, default=9)
    pa.add_argument("--save-pred", default=None,
                    help=" predictions.json")
    main(pa.parse_args())
