
from __future__ import annotations
import os
import re
import sys
from pathlib import Path
from typing import Optional

import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


sys.path.insert(0, "")



MODEL_PATH = ""
SKELETON_ROOT = Path("")
OUTPUT_FILE = Path("")
SKELETON_ENABLE_NFM = True

RENDER_FRAMES = 12   
RENDER_SIZE = 448   
DTYPE = torch.bfloat16
SIMPLE_SINGLE_GPU = False

# Optional: save rendered frames to disk (one folder per sample)
SAVE_RENDERED_FRAMES = False  
RENDER_SAVE_ROOT = Path("")
SAVE_FORMAT = "png"  # png/jpg

# Cross-subject test split (exclude these subjects for testing as in the old script)
NTU60_CROSS_SUBJECT_EXCLUDE = [
    1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38
]

# ACTION_IDS = [ 1,  2,  3,  7,  8,  9, 11, 13, 14, 16, 17, 19, 21, 22, 26, 27, 28, 32, 33, 34, 40, 43, 46, 48, 49, 52, 53, 56, 59, 60]
# ACTION_IDS = [1, 13, 14, 15, 16, 17, 18, 23, 24, 27, 30, 31, 32, 36, 37, 43, 44, 49, 57, 58]
ACTION_IDS = [4, 6, 10, 13, 16, 41, 43, 48, 52, 57, 59, 60]
# ACTION_IDS = [11, 12, 20, 27, 57]


NTU60_CLASSES_FULL = [
    "drink water", "eat food", "brush teeth", "brush hair", "drop something",
    "pick up", "throw", "sit down", "stand up", "clapping",
    "reading", "writing", "tear up paper", "put on jacket", "take off jacket",
    "put on shoe", "take off shoe", "put on glasses", "take off glasses", "put on hat/cap",
    "take off hat/cap", "cheer (raise arms)", "hand waving", "kicking the air/something", "reach into pocket",
    "hopping", "jump up", "phone call", "play with phone/tablet", "type on keyboard",
    "point at something/directions", "take a selfie", "check time (from watch)", "rub two hands", "bow (bend forward)",
    "shake head", "wipe face", "salute", "pray with hands together", "cross hands in front",
    "sneezing/cough", "staggering", "falling down", "headache", "chest pain",
    "back pain", "neck pain", "nausea/vomiting", "fan self", "punch/slap someone",
    "kicking someone", "pushing someone", "pat someone on the back", "point at someone", "hug each other",
    "giving object", "touch someone's pocket", "shaking hands", "walking towards each other", "walking apart from each other",
]


ACTION_LABELS = [NTU60_CLASSES_FULL[idx - 1] for idx in ACTION_IDS]


# Legacy subsets kept for quick toggles:


# Generation config: deterministic, short answer
GEN_CONFIG = dict(
    max_new_tokens=16,
    do_sample=False,
    temperature=0.0,
    top_p=None,
    repetition_penalty=1.05,
)


# --------------------------- Helpers ---------------------------
_NAME_RE = re.compile(r"S(\d{3})C(\d{3})P(\d{3})R(\d{3})A(\d{3})")


def parse_ids(name: str):
    m = _NAME_RE.match(name)
    if not m:
        return None
    s = int(m.group(1))
    c = int(m.group(2))
    p = int(m.group(3))
    r = int(m.group(4))
    a = int(m.group(5))
    return s, c, p, r, a


# No path resolver needed; we iterate .skeleton files under SKELETON_ROOT directly
def should_use_for_test(sample_id: str) -> bool:
    ids = parse_ids(sample_id)
    if ids is None:
        return False
    setup_id, subject_id, action_id = ids[0], ids[2], ids[4]
    if setup_id >= 18:
        return False
    if subject_id in NTU60_CROSS_SUBJECT_EXCLUDE:
        return False
    if action_id not in ACTION_IDS:
        return False
    return True


def build_question(num_frames: int) -> str:
    frame_prefix = "".join(f"Frame{i+1}: <image>\n" for i in range(num_frames))
    action_list = "\n".join(ACTION_LABELS)
    prompt_body = (
        "The clip you see is rendered from 3D skeleton data into abstract pseudo-images. Each frame encodes joint positions, bone connections, depth, and motion cues, and may look different from natural RGB videos. Treat these frames as a faithful visualization of human body kinematics."
        "You need to identify the action in the skeleton video and select one from the action labels below as the classification result. You are required to output your thought process, and finally, provide the specific classification label on a separate line in the format 'Label: <action>'. \n"
        f"{action_list}"
    )
    return frame_prefix + prompt_body


def main():
    print("[Info] Loading InternVL3 GS model …", file=sys.stderr)

    if SIMPLE_SINGLE_GPU:
        model = AutoModel.from_pretrained(
            MODEL_PATH,
            torch_dtype=DTYPE,
            low_cpu_mem_usage=False,
            use_flash_attn=True,
            trust_remote_code=True,
        ).eval().to("cuda:0")
    else:
        model = AutoModel.from_pretrained(
            MODEL_PATH,
            torch_dtype=DTYPE,
            low_cpu_mem_usage=False,
            use_flash_attn=True,
            trust_remote_code=True,
            device_map="auto",
        ).eval()

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, use_fast=False)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    
    if model.use_skeleton and hasattr(model, '_skeleton_renderer_module'):
        renderer_mod = model._skeleton_renderer_module
        if renderer_mod is not None and getattr(renderer_mod, '_renderer', None) is None:
            try:
                device = next(model.parameters()).device
                
                renderer_mod.enable_nfm = SKELETON_ENABLE_NFM
                
                if hasattr(renderer_mod, '_ensure_renderer'):
                    inner = renderer_mod._ensure_renderer(device)
                    
                    from pathlib import Path
                    ckpt_files = sorted(Path(MODEL_PATH).glob("pytorch_model*.bin"))
                    if not ckpt_files:
                        ckpt_files = sorted(Path(MODEL_PATH).glob("model*.safetensors"))
                    
                    renderer_state = {}
                    nfm_params_in_checkpoint = []
                    
                    for ckpt_file in ckpt_files:
                        if str(ckpt_file).endswith('.bin'):
                            state = torch.load(ckpt_file, map_location='cpu')
                        else:
                            from safetensors.torch import load_file
                            state = load_file(str(ckpt_file))
                        
                        for k, v in state.items():
                            if k.startswith('_skeleton_renderer_module._renderer.'):
                                param_name = k.replace('_skeleton_renderer_module._renderer.', '')
                                renderer_state[param_name] = v
                                if param_name.startswith('nfm.'):
                                    nfm_params_in_checkpoint.append(param_name)
                    

                    checkpoint_has_nfm = len(nfm_params_in_checkpoint) > 0
                    if SKELETON_ENABLE_NFM and not checkpoint_has_nfm:
                        print(f"  ❌ WARNING: enable_nfm=True but checkpoint has no NFM parameters. Using random init.", file=sys.stderr)
                    elif not SKELETON_ENABLE_NFM and checkpoint_has_nfm:
                        print(f"  ❌ WARNING: enable_nfm=False but checkpoint contains NFM parameters ({len(nfm_params_in_checkpoint)} params). NFM will NOT be used.", file=sys.stderr)

                        renderer_state = {k: v for k, v in renderer_state.items() if not k.startswith('nfm.')}
                    

                    if renderer_state:
                        inner.load_state_dict(renderer_state, strict=False)
                        renderer_mod._renderer = inner.to(device=device, dtype=torch.float32)
                        print("  ✅ Successfully loaded and initialized inner renderer!", file=sys.stderr)
                    else:
                        print("  ⚠️ No renderer weights found in checkpoint!", file=sys.stderr)
                        
            except Exception as e:
                print(f"  ❌ Failed to create/load inner renderer: {e}", file=sys.stderr)
                import traceback
                traceback.print_exc()

    # Ensure renderer uses desired frame count; do this BEFORE first chat/generate
    try:
        model.config.skeleton_target_num_frames = int(RENDER_FRAMES)
    except Exception:
        pass
    # Ensure renderer uses desired resolution (H=W)
    try:
        model.config.force_image_size = int(RENDER_SIZE)
    except Exception:
        pass

    # Collect skeleton files
    if not SKELETON_ROOT.is_dir():
        print(f"[Error] SKELETON_ROOT not found: {SKELETON_ROOT}")
        sys.exit(1)
    skeleton_files = sorted(SKELETON_ROOT.glob("*.skeleton"))
    if not skeleton_files:
        print(f"[Error] No .skeleton files under {SKELETON_ROOT}")
        sys.exit(1)

    print(f"[Info] Found {len(skeleton_files)} skeleton files. Starting inference …", file=sys.stderr)
    OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)

    question_template = build_question(RENDER_FRAMES)
    per_frame_counts = [1] * RENDER_FRAMES

    used = 0
    skipped = 0

    with OUTPUT_FILE.open("w", encoding="utf-8") as fout:
        for sk_file in tqdm(skeleton_files, desc="Evaluating", unit="video"):
            sample_id = sk_file.stem

            if not should_use_for_test(sample_id):
                skipped += 1
                continue

            try:
                pixel_values_input = str(sk_file)
                
                if SAVE_RENDERED_FRAMES:
                    try:
                        frames = model._maybe_render_skeleton_input(pixel_values_input)
                        if frames is not None and torch.is_tensor(frames):
                            # frames shape: (B, T, H, W, 3) in [0,1]
                            out_dir = RENDER_SAVE_ROOT / sample_id
                            out_dir.mkdir(parents=True, exist_ok=True)
                            frames_np = (
                                frames.squeeze(0)
                                .detach()
                                .clamp(0.0, 1.0)
                                .mul(255.0)
                                .round()
                                .byte()
                                .cpu()
                                .numpy()
                            )  # (T,H,W,3) uint8
                            T_len = frames_np.shape[0]
                            for t in range(T_len):
                                from PIL import Image
                                img = Image.fromarray(frames_np[t], mode="RGB")
                                img.save(out_dir / f"frame_{t+1:02d}.{SAVE_FORMAT}")
                            pixel_values_input = frames
                    except Exception as e:
                        print(f"[Warn] Failed to save rendered frames for {sample_id}: {e}", file=sys.stderr)
                        pixel_values_input = str(sk_file)
                
                answer = model.chat(
                    tokenizer,
                    pixel_values=pixel_values_input,
                    question=question_template,
                    generation_config=GEN_CONFIG,
                    num_patches_list=per_frame_counts,
                )
                fout.write(f"{sample_id}\t{answer}\n")
                fout.flush()
                used += 1
            except Exception as e:
                err = f"ERROR: {e}"
                print(f"[Warning] Failed on {sample_id}: {e}", file=sys.stderr)
                fout.write(f"{sample_id}\t{err}\n")
                fout.flush()

    print(
        f"[Info] Finished. Results saved to {OUTPUT_FILE}. "
        f"Used={used}, SkippedByRule={skipped}",
        file=sys.stderr,
    )


if __name__ == "__main__":
    main()