# inference_seg_7b_coco_fused.py
# -*- coding: utf-8 -*-

import os
import sys
import glob
import json
from pathlib import Path
from datetime import datetime
import argparse

import numpy as np
from PIL import Image

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import einops

# project modules
from inference import LocalInferenceModel
from utils import encode_transform  # no need for codebook_similarity here

# silence unrelated TF / TF-TRT logs
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")


# ---------------- JS-divergence fusion helpers ----------------
def calculate_js_divergence(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    """
    Jensen–Shannon divergence for probabilities.
    p, q: [..., C] with sum=1 on last dim.
    return: [...] (per-element distance)
    """
    m = 0.5 * (p + q)
    kl_pm = F.kl_div(p.clamp_min(1e-12).log(), m, reduction='none').sum(-1)
    kl_qm = F.kl_div(q.clamp_min(1e-12).log(), m, reduction='none').sum(-1)
    return 0.5 * (kl_pm + kl_qm)


def get_weights(distances: torch.Tensor, method='softmax', temp=1.0, sigma=1.0) -> torch.Tensor:
    """
    Turn distances (small=good) into normalized weights.
    distances: [K] vector
    """
    if method == 'softmax':
        w = torch.softmax(-distances / temp, dim=-1)
    elif method == 'gaussian':
        w = torch.exp(-distances**2 / (2 * sigma**2))
    elif method == 'inv':
        w = 1.0 / (distances + 1e-8)
    elif method == 'epanechnikov':
        w = torch.clamp(1 - (distances / sigma) ** 2, min=0.0)
    elif method == 'triangular':
        w = torch.clamp(1 - distances / sigma, min=0.0)
    elif method == 'uniform':
        w = torch.ones_like(distances)
    else:
        raise ValueError(f"Unknown method: {method}")
    return w / (w.sum() + 1e-8)


def fuse_patchwise_probs_js(prob_list, temp=1.0, alpha=0.5, weight_method='softmax'):
    """
    Fuse K sequences' codebook probabilities patch-wise with JS-divergence weights.

    prob_list: List[Tensor [256, 8192]] — each is per-patch code distribution (sum=1).
    return:    Tensor [256, 8192] — fused probabilities.

    Strategy:
      For each patch (256 patches):
        - choose the base distribution with the lowest entropy
        - compute JS distance to others
        - weights = softmax(-JS/temp)
        - fused = (1-alpha)*base + alpha*sum(w_i * neighbor_i)
        - renormalize
    """
    assert len(prob_list) >= 1
    if len(prob_list) == 1:
        return prob_list[0]

    probs = [p.to(torch.float32) for p in prob_list]
    if temp != 1.0:
        probs = [torch.softmax(p.clamp_min(1e-12).log() / temp, dim=-1) for p in probs]

    stack = torch.stack(probs, dim=0)  # [K, 256, 8192]
    K, P, C = stack.shape

    entropy = -(stack * stack.clamp_min(1e-12).log()).sum(-1)  # [K, 256]
    base_idx = entropy.argmin(dim=0)                            # [256]

    fused = torch.empty(P, C, dtype=stack.dtype)
    for t in range(P):
        b = base_idx[t].item()
        base = stack[b, t]  # [8192]

        neighbors = [stack[j, t] for j in range(K) if j != b]
        if not neighbors:
            fused[t] = base
            continue

        dists = torch.stack([calculate_js_divergence(base.unsqueeze(0), nb.unsqueeze(0)).squeeze(0)
                             for nb in neighbors])  # [K-1]
        w = get_weights(dists, method=weight_method, temp=temp, sigma=0.5)  # [K-1]

        weighted = torch.zeros_like(base)
        for ww, nb in zip(w, neighbors):
            weighted = weighted + ww * nb

        f = (1.0 - alpha) * base + alpha * weighted
        fused[t] = f / f.sum()

    return fused


# ---------------- Dataset ----------------
class Test_Input(Dataset):
    def __init__(self, root, transform=None, validation=''):
        self.root = root
        self.transform = transform
        imgs_path_tmp = glob.glob(os.path.join(root, '*.jpg'))

        if validation:
            valid_imgs_path = []
            for img_path in imgs_path_tmp:
                img_name = os.path.basename(img_path)
                p_name, p_ext = os.path.splitext(img_name)
                img_prompts_name = p_name + '_' + str(0) + p_ext
                prompts_img_path = os.path.join(validation, img_prompts_name)
                if os.path.exists(prompts_img_path):
                    valid_imgs_path.append(img_path)
            self.imgs_path = valid_imgs_path
        else:
            self.imgs_path = imgs_path_tmp

    def __getitem__(self, index):
        path = self.imgs_path[index]
        img_name = os.path.basename(path)
        img = Image.open(path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, img_name

    def __len__(self):
        return len(self.imgs_path)


# ---------------- Utilities ----------------
def resolve_annotation_path(prompt_root, stem, prefer_exts=(".png", ".jpg")):
    candidates = [os.path.join(prompt_root, f"{stem}{ext}") for ext in prefer_exts]
    for p in candidates:
        if os.path.exists(p):
            return p
    raise FileNotFoundError(f"Annotation not found for {stem} (tried: {candidates})")


def load_prompt_candidates_from_json(json_path):
    with open(json_path, "r") as f:
        mp = json.load(f)
    stem2prompts = {}
    for key, lst in mp.items():
        test_stem = key.split("__", 1)[0]
        cand_files = []
        for token in lst:
            cand_stem = token.split("__", 1)[0]
            cand_files.append(f"{cand_stem}.jpg")
        stem2prompts[test_stem] = cand_files
    return stem2prompts


def _ddp_env():
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    return rank, world_size, local_rank


# ---------------- Main ----------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--k", type=int, default=2,
                        help="Number of prompt sequences fused per test image.")
    parser.add_argument("--prompts-per-seq", type=int, default=7,
                        help="How many (image,annotation) pairs per sequence.")
    parser.add_argument("--sim-json", type=str,
                        default="/path/to/dataset/top_50-similarity.json",
                        help="Similarity JSON path used when --random is NOT set.")
    parser.add_argument("--save-root", type=str,
                        default="/path/to/results",
                        help="Where to save results.")
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--fuse-temp", type=float, default=1.0,
                        help="Temperature in distance->weight (softmax).")
    parser.add_argument("--fuse-alpha", type=float, default=0.5,
                        help="Fusion strength alpha.")
    parser.add_argument("--fuse-weight", type=str, default="softmax",
                        choices=["softmax", "gaussian", "inv", "epanechnikov", "triangular", "uniform"],
                        help="Weighting method for distances.")
    args = parser.parse_args()

    # Dataset paths
    TEST_IMG_DIR = "/path/to/dataset/images/validation"
    TRAIN_IMG_DIR = "/path/to/dataset/images/training"
    TRAIN_ANN_DIR = "/path/to/dataset/annotations/training"

    # Save path
    if args.random:
        save_path = f"/path/to/results/random_k{args.k}"
    else:
        save_path = f"/path/to/results/plr_k{args.k}"

    rank, world_size, local_rank = _ddp_env()
    use_ddp = world_size > 1
    if use_ddp:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl")

    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    if rank == 0:
        print(f"[MAIN] world_size={world_size} | device={device} | "
              f"k={args.k} | prompts_per_seq={args.prompts_per_seq}")

    # Model checkpoint path
    lvm_path = "/path/to/weights/lvm"
    model = LocalInferenceModel(
        checkpoint=lvm_path,
        torch_device=device,
        dtype='float16',
        context_frames=16,
        use_lock=False,
    )

    gen_length = 1
    n_candidates = 1

    test_data = Test_Input(TEST_IMG_DIR, transform=encode_transform)
    if len(test_data) == 0:
        if rank == 0:
            print(f"[ERROR] No test images in {TEST_IMG_DIR}")
        sys.exit(1)

    if not os.path.exists(args.sim_json):
        if rank == 0:
            print(f"[ERROR] sim_json not found: {args.sim_json}")
        sys.exit(1)
    if rank == 0:
        print(f"[INFO] Loading prompt mapping from: {args.sim_json}")
    stem2prompts = load_prompt_candidates_from_json(args.sim_json)

    sampler = torch.utils.data.distributed.DistributedSampler(
        test_data, num_replicas=world_size, rank=rank, shuffle=False
    ) if use_ddp else None

    data_loader = DataLoader(
        test_data,
        batch_size=1,
        sampler=sampler,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    map_file = os.path.join(save_path, f"prompt_map_rank{rank}.tsv")
    with open(map_file, "a") as mf:
        mf.write(f"# run_at\t{datetime.now().isoformat()}\n")
        mf.write("test_image\tseq_id\tprompt_image\tprompt_annotation\n")

    if rank == 0:
        print(f"[INFO] #test_total={len(test_data)} | save_path={save_path}")

    processed = 0
    for batch in data_loader:
        _, img_name = batch
        img_name = img_name[0]
        test_stem = os.path.splitext(img_name)[0]
        test_path = os.path.join(TEST_IMG_DIR, img_name)

        # ---- NEW: skip if output already exists ----
        out_path = os.path.join(save_path, f"{Path(img_name).stem}.png")
        if os.path.exists(out_path):
            print(f"[rank{rank}] skip existing: {out_path}")
            processed += 1
            continue

        # 1) candidates in order
        cand_files = stem2prompts.get(test_stem, [])
        if len(cand_files) == 0:
            print(f"[rank{rank}] [WARN] No JSON prompts for {test_stem}, skip.")
            continue

        need_total = args.k * args.prompts_per_seq
        cand_files = cand_files[:need_total]

        # 2) resolve to absolute prompt images & annotations
        prompt_pairs = []
        for fname in cand_files:
            abs_img = os.path.join(TRAIN_IMG_DIR, fname)
            if not os.path.exists(abs_img):
                print(f"[rank{rank}] [WARN] Missing prompt image: {abs_img}")
                continue
            p_stem, _ = os.path.splitext(os.path.basename(abs_img))
            try:
                abs_ann = resolve_annotation_path(TRAIN_ANN_DIR, p_stem, prefer_exts=(".png", ".jpg"))
                prompt_pairs.append((abs_img, abs_ann))
            except FileNotFoundError as e:
                print(f"[rank{rank}] [WARN] {e} -> drop prompt {p_stem}")

        if len(prompt_pairs) == 0:
            print(f"[rank{rank}] [WARN] No valid prompts (with annotations) for {img_name}, skip.")
            continue

        max_seq = len(prompt_pairs) // args.prompts_per_seq
        if max_seq == 0:
            print(f"[rank{rank}] [WARN] Not enough prompts for even 1 sequence for {img_name}, skip.")
            continue
        num_seqs = min(args.k, max_seq)

        # 3) run each sequence, collect codebook probs [256,8192]
        seq_prob_list = []
        for s_idx in range(num_seqs):
            beg = s_idx * args.prompts_per_seq
            end = beg + args.prompts_per_seq
            group = prompt_pairs[beg:end]

            seq_prompt = []
            for (p_img, p_ann) in group:
                with Image.open(p_img) as im:
                    seq_prompt.append(im.convert('RGB'))
                with Image.open(p_ann) as im:
                    seq_prompt.append(im.convert('RGB'))
            with Image.open(test_path) as im:
                seq_prompt.append(im.convert('RGB'))

            np_seq = np.stack(
                [np.array(im.resize((256, 256)), dtype=np.float32) / 255.0 for im in seq_prompt],
                axis=0
            )

            with torch.no_grad():
                _, code_probs = model.generate_once_with_probs(
                    np_seq,
                    n_new_frames=gen_length,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    codebook_size=8192,
                    store_dtype=torch.float16,
                    renorm_to_codebook=True,
                )
            seq_prob_list.append(code_probs[0].to(torch.float32))

            with open(map_file, "a") as mf:
                for (p_img, p_ann) in group:
                    mf.write(
                        f"{img_name}\t{s_idx+1}\t"
                        f"{os.path.relpath(p_img, start=TRAIN_IMG_DIR)}\t"
                        f"{os.path.relpath(p_ann, start=TRAIN_ANN_DIR)}\n"
                    )

        # 4) fuse probabilities across sequences
        if len(seq_prob_list) == 1:
            fused_prob = seq_prob_list[0]
        else:
            fused_prob = fuse_patchwise_probs_js(
                seq_prob_list,
                temp=args.fuse_temp,
                alpha=args.fuse_alpha,
                weight_method=args.fuse_weight
            )  # [256, 8192]

        # 5) decode: argmax -> tokens -> VQ tokenizer decode
        tokens = fused_prob.argmax(dim=-1).to(torch.int64).view(1, 256)  # [1,256]
        fused_img = einops.rearrange(
            torch.clamp(model.tokenizer.decode_code(tokens.to(device)), 0.0, 1.0).detach().cpu().numpy(),
            'b c h w -> b h w c'
        )[0]

        # save (reusing the precomputed out_path)
        Image.fromarray((fused_img * 255).astype(np.uint8)).save(out_path)

        processed += 1
        if processed % 20 == 0:
            print(f"[rank{rank}] processed {processed} (num_seqs={num_seqs}, pps={args.prompts_per_seq})", flush=True)

        del fused_img
        torch.cuda.empty_cache()

    if use_ddp:
        dist.barrier()
        dist.destroy_process_group()
    if rank == 0:
        print("[DONE] Results saved to:", save_path)


if __name__ == "__main__":
    main()
