# inference_seg_7b_coco.py
# -*- coding: utf-8 -*-

# ========== imports ==========
import os
import sys
import glob
import json
import hashlib
import argparse
from pathlib import Path
from datetime import datetime

import torch
import torch.distributed as dist
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# project modules
from inference import LocalInferenceModel
from utils import encode_transform  # no need for codebook_similarity here

# silence unrelated TF / TF-TRT logs
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")


# ========== Dataset ==========
class Test_Input(Dataset):
    """
    Simple image list dataset. Keeps your original behavior and optional 'validation' filter.
    """
    def __init__(self, root, transform=None, validation=''):
        self.root = root
        self.transform = transform
        imgs_path_tmp = glob.glob(os.path.join(root, '*.jpg'))

        if validation:
            valid_imgs_path = []
            for img_path in imgs_path_tmp:
                img_name = os.path.basename(img_path)
                p_name, p_ext = os.path.splitext(img_name)
                img_prompts_name = p_name + '_' + str(0) + p_ext
                prompts_img_path = os.path.join(validation, img_prompts_name)
                if os.path.exists(prompts_img_path):
                    valid_imgs_path.append(img_path)
            self.imgs_path = valid_imgs_path
        else:
            self.imgs_path = imgs_path_tmp

    def __getitem__(self, index):
        path = self.imgs_path[index]
        img_name = os.path.basename(path)
        img = Image.open(path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, img_name

    def __len__(self):
        return len(self.imgs_path)


# ========== Utilities ==========
def generate_images(input_images, n_new_frames, n_candidates, temperature=1.0, top_p=0.9):
    """
    Resize input PIL images to 256x256, stack into a sequence, call LocalInferenceModel,
    and horizontally concatenate the generated frames for saving.
    """
    assert len(input_images) > 0
    input_images = [
        np.array(img.convert('RGB').resize((256, 256)), dtype=np.float32) / 255.0
        for img in input_images
    ]
    input_images = np.stack(input_images, axis=0)
    output_images = model([input_images], n_new_frames, n_candidates, temperature, top_p)[0]

    generated_images = []
    for candidate in output_images:
        concatenated_image = []
        for img in candidate:
            concatenated_image.append(img)
        generated_images.append(
            Image.fromarray(
                (np.concatenate(concatenated_image, axis=1) * 255).astype(np.uint8)
            )
        )
    return generated_images


def resolve_annotation_path(prompt_root, stem, prefer_exts=(".png", ".jpg")):
    """
    Find annotation file by stem in prompt_root, trying multiple extensions.
    """
    candidates = [os.path.join(prompt_root, f"{stem}{ext}") for ext in prefer_exts]
    for p in candidates:
        if os.path.exists(p):
            return p
    raise FileNotFoundError(f"Annotation not found for {stem} (tried: {candidates})")


def deterministic_choice(candidates, key_str):
    """
    Stable 'random' choice: md5(key) -> single index.
    """
    h = hashlib.md5(key_str.encode('utf-8')).hexdigest()
    idx = int(h, 16) % len(candidates)
    return candidates[idx]


def deterministic_k_choices(candidates, key_str, k):
    """
    Stable 'random' k choices without replacement using md5(key) as RNG seed.
    """
    import numpy as _np
    rs = _np.random.RandomState(int(hashlib.md5(key_str.encode('utf-8')).hexdigest(), 16) % (2**32))
    idx = rs.choice(len(candidates), size=min(k, len(candidates)), replace=False)
    return [candidates[i] for i in idx]


def load_prompt_candidates_from_json(json_path):
    """
    Read similarity JSON and map to { test_stem: [prompt_img_filename, ...] }.

    JSON example:
      {
        "ADE_val_00000001__abbey": ["ADE_train_00000984__abbey", ...],
        ...
      }
    We ignore anything after '__' and produce:
      { "ADE_val_00000001": ["ADE_train_00000984.jpg", ...], ... }
    """
    with open(json_path, "r") as f:
        mp = json.load(f)
    stem2prompts = {}
    for key, lst in mp.items():
        test_stem = key.split("__", 1)[0]
        cand_files = []
        for token in lst:
            cand_stem = token.split("__", 1)[0]
            cand_files.append(f"{cand_stem}.jpg")
        stem2prompts[test_stem] = cand_files
    return stem2prompts


def _ddp_env():
    """Read torchrun-provided DDP env variables."""
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    return rank, world_size, local_rank


# ========== Main ==========
if __name__ == '__main__':
    # --------- CLI args ---------
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--random",
        action="store_true",
        help="If set, choose prompts via deterministic hash; if not set, load prompts from JSON."
    )
    parser.add_argument(
        "--k",
        type=int,
        default=1,
        help="Number of prompts to use per test sample (take first-k in order or k hashed choices)."
    )
    parser.add_argument(
        "--sim_json",
        type=str,
        default="/path/to/dataset/top_50-similarity.json",
        help="Similarity JSON path used when --random is NOT set."
    )
    args, _ = parser.parse_known_args()

    # --------- Paths ---------
    TEST_IMG_DIR = "/path/to/dataset/images/validation"
    TRAIN_IMG_DIR = "/path/to/dataset/images/training"
    TRAIN_ANN_DIR = "/path/to/dataset/annotations/training"

    if args.random:
        save_path = f"/path/to/results/random_k{args.k}"
    else:
        save_path = f"/path/to/results/plr_k{args.k}"
    os.makedirs(save_path, exist_ok=True)

    # --------- DDP init ---------
    rank, world_size, local_rank = _ddp_env()
    use_ddp = world_size > 1

    if use_ddp:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl")

    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    if rank == 0:
        print(f"[MAIN] world_size={world_size} | device={device} | random={args.random} | k={args.k}")

    # --------- Model ---------
    lvm_path = "/path/to/weights/lvm"
    model = LocalInferenceModel(
        checkpoint=lvm_path,
        torch_device=device,
        dtype='float16',
        context_frames=16,
        use_lock=False,
    )

    # generation params
    gen_length = 1
    n_candidates = 1
    temperature = 1.0
    top_p = 1.0

    # --------- Data ---------
    test_data = Test_Input(TEST_IMG_DIR, transform=encode_transform)
    if len(test_data) == 0:
        if rank == 0:
            print(f"[ERROR] No test images in {TEST_IMG_DIR}")
        sys.exit(1)

    train_imgs = sorted(glob.glob(os.path.join(TRAIN_IMG_DIR, "*.jpg")))
    if len(train_imgs) == 0:
        if rank == 0:
            print(f"[ERROR] No training images in {TRAIN_IMG_DIR}")
        sys.exit(1)

    # Load mapping from JSON only if random=False
    stem2prompts = None
    if not args.random:
        if not os.path.exists(args.sim_json):
            if rank == 0:
                print(f"[WARN] sim_json not found: {args.sim_json}. Falling back to --random behavior.")
            args.random = True
        else:
            if rank == 0:
                print(f"[INFO] Loading prompt mapping from: {args.sim_json}")
            stem2prompts = load_prompt_candidates_from_json(args.sim_json)

    sampler = torch.utils.data.distributed.DistributedSampler(
        test_data, num_replicas=world_size, rank=rank, shuffle=False
    ) if use_ddp else None

    data_loader = DataLoader(
        test_data,
        batch_size=1,
        sampler=sampler,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # per-rank mapping log (optional)
    map_file = os.path.join(save_path, f"prompt_map_rank{rank}.tsv")
    with open(map_file, "a") as mf:
        mf.write(f"# run_at\t{datetime.now().isoformat()}\n")
        mf.write("test_image\tprompt_image\tprompt_annotation\n")

    # --------- Inference loop ---------
    if rank == 0:
        print(f"[INFO] #test_total={len(test_data)} | #train_total={len(train_imgs)} | save_path={save_path}")

    processed = 0
    for batch in data_loader:
        _, img_name = batch
        img_name = img_name[0]
        test_stem = os.path.splitext(img_name)[0]
        test_path = os.path.join(TEST_IMG_DIR, img_name)

        # Select prompts:
        # - If --random is set: use deterministic_k_choices (stable hash).
        # - Else: read from JSON list by order and take first-k.
        prompt_img_paths = []
        if args.random:
            prompt_img_paths = deterministic_k_choices(train_imgs, test_stem, args.k)
        else:
            # JSON path
            cand_files = stem2prompts.get(test_stem, []) if stem2prompts is not None else []
            for fname in cand_files[:args.k]:
                abs_p = os.path.join(TRAIN_IMG_DIR, fname)
                if os.path.exists(abs_p):
                    prompt_img_paths.append(abs_p)
                else:
                    if rank == 0:
                        print(f"[WARN] Missing prompt image from JSON: {abs_p}")
            # Fallback if JSON has no entry or all missing
            if len(prompt_img_paths) == 0:
                if rank == 0:
                    print(f"[WARN] No JSON prompts for {test_stem}. Falling back to --random.")
                prompt_img_paths = deterministic_k_choices(train_imgs, test_stem, args.k)

        # Resolve annotations; drop any prompt without annotation
        prompt_pairs = []
        # print(f"[rank{rank}] processing_id:{processed+1}, img_name:{img_name}, #prompts_found:{len(prompt_img_paths)}")
        for p_img in prompt_img_paths:
            p_stem, _ = os.path.splitext(os.path.basename(p_img))
            try:
                p_ann = resolve_annotation_path(TRAIN_ANN_DIR, p_stem, prefer_exts=(".png", ".jpg"))
                prompt_pairs.append((p_img, p_ann))
            except FileNotFoundError as e:
                print(f"[rank{rank}] [WARN] {e} -> drop prompt {p_stem}")

        if len(prompt_pairs) == 0:
            print(f"[rank{rank}] [WARN] No valid prompts for {img_name}, skip.")
            continue

        print("prompt_pairs:", prompt_pairs)
        seq_prompt = []
        for (prompt_img_path, prompt_ann_path) in prompt_pairs:
            # keep the original order in prompt_pairs
            with Image.open(prompt_img_path) as im:
                seq_prompt.append(im.convert('RGB'))
            with Image.open(prompt_ann_path) as im:
                seq_prompt.append(im.convert('RGB'))
        # append the query at the end
        with Image.open(test_path) as im:
            seq_prompt.append(im.convert('RGB'))
        # print("seq_prompt:", seq_prompt )

        # Inference once for this (k prompts + query) sequence
        with torch.no_grad():
            gen_img = generate_images(
                seq_prompt,
                n_new_frames=gen_length,
                n_candidates=n_candidates,
                temperature=temperature,
                top_p=top_p,
            )[0]

        # Save one result per test image (regardless of k)
        out_path = os.path.join(save_path, f"{img_name}")
        gen_img.save(out_path)

        # Log all prompts used for this test image (one line per prompt pair)
        with open(map_file, "a") as mf:
            for (prompt_img_path, prompt_ann_path) in prompt_pairs:
                mf.write(
                    f"{img_name}\t"
                    f"{os.path.relpath(prompt_img_path, start=TRAIN_IMG_DIR)}\t"
                    f"{os.path.relpath(prompt_ann_path, start=TRAIN_ANN_DIR)}\n"
                )

        # optional VRAM cleanup
        del gen_img
        torch.cuda.empty_cache()

        processed += 1
        if processed % 20 == 0:
            print(f"[rank{rank}] processed {processed} (k={args.k}, random={args.random})", flush=True)

    # --------- Finalize ---------
    if use_ddp:
        dist.barrier()
        dist.destroy_process_group()

    if rank == 0:
        print("[DONE] Results saved to:", save_path)
