#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Offline generator for wrist training tensors:
- Input dataset root contains four subdirectories with one-to-one filenames: ext1/, ext2/, condition/, wrist_rgb/ (93 frames, 1280x720)
- Output: for each sample, produce a .tensors.pth containing keys:
  - latents: VAE-encoded wrist_rgb latent (used as training x)
  - prompt_emb: {"context": T5-encoded fixed text ("robotic manipulation scene")}
  - image_emb: atomic components assembled online during training:
      - clip_feature: first-frame guidance image feature (1, 257, 1280)
      - y_wrist16: first-frame guidance wrist 16-channel latent aligned with VAE latent
      - control_latents: 16-channel latent of condition video (concatenate with y_wrist16 along channel to get y=32)
  - ext_frame_feats: per-frame image encoder CLS features ([2, 93, 1280], 93 frames for ext1/ext2). Projection to 4096 and concatenation to text happens in training.

Key points:
- Reuse diffsynth's WanVideoPipeline and ModelManager to ensure identical preprocessing/encoding with training side
- First-frame guidance: always encoded offline (no random drop); in training, set to zero with 80% probability online
- Condition: encode_control_video to get 16-channel latent saved offline; concatenated with y_wrist16 along channels during training
- Text: fixed string encoded by T5 (4096-dim), stored in prompt_emb["context"] offline
- ext1/ext2: per-frame image_encoder.encode_image to get tokens (1,257,1280), take CLS (index 0) as each frame feature (1280-dim). Position/view embeddings and projection are applied during training
"""

import os
import argparse
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import torch
from torchvision.transforms import v2
import imageio
from PIL import Image
import multiprocessing as mp
import inspect

from diffsynth import WanVideoPipeline, ModelManager


def list_samples(dataset_root: Path) -> List[Tuple[Path, Path, Path, Path, str]]:
    ext1_dir = dataset_root / "ext1"
    ext2_dir = dataset_root / "ext2"
    cond_dir = dataset_root / "condition"
    wrist_dir = dataset_root / "wrist_rgb"
    assert ext1_dir.is_dir() and ext2_dir.is_dir() and cond_dir.is_dir() and wrist_dir.is_dir(), \
        "Dataset must contain ext1/, ext2/, condition/, wrist_rgb/ subdirectories"

    samples: List[Tuple[Path, Path, Path, Path, str]] = []
    for pw in sorted(wrist_dir.glob("*.mp4")):
        name = pw.name
        p1 = ext1_dir / name
        p2 = ext2_dir / name
        pc = cond_dir / name
        if p1.exists() and p2.exists() and pc.exists():
            base = name[:-4]
            samples.append((p1, p2, pc, pw, base))
    return samples


def read_video_frames(path: Path) -> Optional[List[Image.Image]]:
    """Read frames, truncate to 81, resize to 832x480, with error handling"""
    try:
        reader = imageio.get_reader(str(path))
        frames: List[Image.Image] = []
        num_frames = min(reader.count_frames(), 81)
        for i in range(num_frames):
            arr = reader.get_data(i)
            img = Image.fromarray(arr).convert("RGB")
            img = img.resize((832, 480), resample=Image.BICUBIC)
            frames.append(img)
        reader.close()
        if len(frames) == 0:
            raise ValueError(f"Empty video: {path}")
        return frames
    except Exception as e:
        print(f"Warning: failed to read video {path.name}: {e}")
        return None


def frames_to_tensor(frames: List[Image.Image], height: int, width: int) -> torch.Tensor:
    tfm = v2.Compose([
        v2.CenterCrop(size=(height, width)),
        v2.Resize(size=(height, width), antialias=True),
        v2.ToTensor(),
        v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    tensors = [tfm(img) for img in frames]
    video = torch.stack(tensors, dim=0)
    video = video.permute(1, 0, 2, 3).contiguous()
    return video


def encode_ext_frame_feats(pipe: WanVideoPipeline, frames: List[Image.Image]) -> torch.Tensor:
    feats: List[torch.Tensor] = []
    for img in frames:
        img_t = pipe.preprocess_image(img).to(pipe.device)
        with torch.no_grad():
            ctx = pipe.image_encoder.encode_image([img_t])
        feats.append(ctx[:, 0, :].to(dtype=pipe.torch_dtype, device=pipe.device))
    feats_t = torch.cat(feats, dim=0)
    return feats_t


def process_one_sample(
    ext1_path: Path,
    ext2_path: Path,
    condition_path: Path,
    wrist_path: Path,
    sample_id: str,
    output_dir: Path,
    pipe: WanVideoPipeline,
    height: int,
    width: int,
    num_frames: int,
    tiled: bool = False,
    tile_size: Tuple[int, int] = (34, 34),
    tile_stride: Tuple[int, int] = (18, 16),
) -> Optional[Dict]:
    """Process a single sample with error handling"""
    try:
        # Check if output file already exists
        out_path = output_dir / f"{sample_id}.tensors.pth"
        if out_path.exists():
            print(f"Skip existing file: {sample_id}")
            return {"out": str(out_path), "skipped": True}
            
        # Read video frames
        ext1_frames = read_video_frames(ext1_path)
        ext2_frames = read_video_frames(ext2_path)
        condition_frames = read_video_frames(condition_path)
        wrist_frames = read_video_frames(wrist_path)
        # Check for corrupted inputs
        if any(frames is None for frames in [ext1_frames, ext2_frames, condition_frames, wrist_frames]):
            print(f"Skip corrupted sample: {sample_id}")
            return None
            
        # Validate frame counts
        for name, frames in [("ext1", ext1_frames), ("ext2", ext2_frames), 
                           ("condition", condition_frames), ("wrist", wrist_frames)]:
            if len(frames) != num_frames:
                print(f"Warning: {name} frame count mismatch; expect {num_frames}, got {len(frames)}")
                return None

        height, width = wrist_frames[0].height, wrist_frames[0].width
        assert height > 0 and width > 0, "Invalid resolution"
        assert height % 16 == 0 and width % 16 == 0, f"Resolution must be divisible by 16, got {height}x{width}"

        video_tensor = frames_to_tensor(wrist_frames, height, width)
        video_tensor = video_tensor.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
        
        # Expect (1,3,T,H,W)
        assert video_tensor.ndim == 5 and video_tensor.shape[0] == 1 and video_tensor.shape[1] == 3 \
            and video_tensor.shape[2] == num_frames and video_tensor.shape[3] == height and video_tensor.shape[4] == width, \
            f"Unexpected wrist tensor shape: {tuple(video_tensor.shape)}"
        with torch.no_grad():
            latents = pipe.encode_video(video_tensor, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
            if isinstance(latents, (list, tuple)):
                latents = latents[0]
        # Expect (1,16,(T-1)//4+1,H/8,W/8)
        latent_T = (num_frames - 1) // 4 + 1
        assert latents.ndim == 5 and latents.shape[0] == 1 and latents.shape[1] == 16 \
            and latents.shape[2] == latent_T and latents.shape[3] == height // 8 and latents.shape[4] == width // 8, \
            f"Unexpected wrist latents shape: {tuple(latents.shape)}, expect: (1,16,{latent_T},{height//8},{width//8})"
        assert latents.dtype == pipe.torch_dtype, f"latents dtype mismatch: {latents.dtype} vs {pipe.torch_dtype}"

        prompt = "robotic manipulation scene"
        prompt_emb = pipe.encode_prompt(prompt)
        assert isinstance(prompt_emb, dict) and "context" in prompt_emb, "prompt_emb missing context"

        with torch.no_grad():
            first_frame = wrist_frames[0]
            image_emb_seed = pipe.encode_image(first_frame, None, num_frames, height, width,
                                               tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
            y_seed = image_emb_seed.get("y")
            if y_seed is None:
                raise RuntimeError("encode_image returned no 'y'")
            y_wrist16 = y_seed[:, -16:]
        # Validate first-frame guidance outputs
        assert isinstance(image_emb_seed, dict) and "clip_feature" in image_emb_seed, "image_emb_seed missing clip_feature"
        cf = image_emb_seed["clip_feature"]
        assert cf.ndim == 3 and cf.shape[0] == 1 and cf.shape[1] == 257 and cf.shape[2] == 1280, \
            f"Unexpected clip_feature shape: {tuple(cf.shape)} expect (1,257,1280)"
        assert y_seed.ndim == 5 and y_seed.shape[0] == 1 and y_seed.shape[2] == latent_T \
            and y_seed.shape[3] == height // 8 and y_seed.shape[4] == width // 8, \
            f"Unexpected y_seed shape: {tuple(y_seed.shape)}"
        assert y_wrist16.shape == (1, 16, latent_T, height // 8, width // 8), \
            f"Unexpected y_wrist16 shape: {tuple(y_wrist16.shape)}"

        # Build condition tensor as 5D (B,C,T,H,W) and encode via VAE
        cond_tensor = frames_to_tensor(condition_frames, height, width)
        cond_tensor = cond_tensor.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
        assert cond_tensor.shape == (1, 3, num_frames, height, width), f"Unexpected condition tensor shape: {tuple(cond_tensor.shape)}"
        with torch.no_grad():
            control_latents = pipe.encode_video(cond_tensor, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
        assert control_latents.shape == (1, 16, latent_T, height // 8, width // 8), \
            f"Unexpected control_latents shape: {tuple(control_latents.shape)}"
        assert control_latents.dtype == pipe.torch_dtype, "control_latents dtype mismatch"

        with torch.no_grad():
            ext1_feats = encode_ext_frame_feats(pipe, ext1_frames)
            ext2_feats = encode_ext_frame_feats(pipe, ext2_frames)
            ext_feats = torch.stack([ext1_feats, ext2_feats], dim=0)
        # Validate ext feats
        assert ext1_feats.shape == (num_frames, 1280), f"Unexpected ext1_feats shape: {tuple(ext1_feats.shape)}"
        assert ext2_feats.shape == (num_frames, 1280), f"Unexpected ext2_feats shape: {tuple(ext2_feats.shape)}"
        assert ext_feats.shape == (2, num_frames, 1280), f"Unexpected ext_frame_feats shape: {tuple(ext_feats.shape)}"
        # print(latents.shape,image_emb_seed["clip_feature"].shape,y_wrist16.shape,control_latents.shape,ext_feats.shape)
        # exit(0)
        data = {
            "latents": latents.to("cpu").detach(),
            "prompt_emb": {"context": prompt_emb["context"].to("cpu").detach()},
            "image_emb": {
                "clip_feature": image_emb_seed["clip_feature"].to("cpu").detach(),
                "y_wrist16": y_wrist16.to("cpu").detach(),
                "control_latents": control_latents.to("cpu").detach(),
            },
            "ext_frame_feats": ext_feats.to("cpu").detach(),
            "meta": {
                "paths": {
                    "ext1": str(ext1_path),
                    "ext2": str(ext2_path),
                    "condition": str(condition_path),
                    "wrist_rgb": str(wrist_path),
                },
                "num_frames": int(num_frames),
                "height": int(height),
                "width": int(width),
            }
        }
        # Final consistency check (training-side concatenation constraints)
        assert data["latents"].shape[2] == data["image_emb"]["y_wrist16"].shape[2] == data["image_emb"]["control_latents"].shape[2], \
            "latent temporal dimension mismatch"

        # Ensure output directory exists
        output_dir.mkdir(parents=True, exist_ok=True)
        torch.save(data, str(out_path))
        return {"out": str(out_path)}
    except Exception as e:
        import traceback
        print(f"[Error] Failed to process sample {wrist_path.name}: {e}\n{traceback.format_exc()}")
        return None


def worker(proc_id: int,
           shard: List[Tuple[Path, Path, Path, Path, str]],
           args: argparse.Namespace,
           result_queue: mp.Queue) -> None:
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    device = f"cuda:{proc_id}" if torch.cuda.is_available() else "cpu"
    try:
        model_paths = [args.text_encoder_path, args.vae_path, args.image_encoder_path]
        model_manager = ModelManager(torch_dtype=torch.bfloat16, device=device)
        model_manager.load_models(model_paths)
        pipe = WanVideoPipeline.from_model_manager(model_manager)

        out_dir = Path(args.output_root) / "wrist_rgb"
        out_dir.mkdir(parents=True, exist_ok=True)

        results = []
        for (p1, p2, pc, pw, base) in shard:
            print(f"[GPU{proc_id}] processing: {pw.name}")
            rec = process_one_sample(
                p1, p2, pc, pw,
                base,
                out_dir,
                pipe,
                height=480,  # fixed height
                width=832,  # fixed width
                num_frames=81,  # fixed frame count
                tiled=args.tiled,
                tile_size=(args.tile_size_height, args.tile_size_width),
                tile_stride=(args.tile_stride_height, args.tile_stride_width),
            )
            if rec: # only append successful samples
                results.append(rec)
        result_queue.put(results)
    except Exception as e:
        import traceback
        result_queue.put({"error": str(e), "trace": traceback.format_exc()})


def main():
    parser = argparse.ArgumentParser(description="Offline generator for wrist training .tensors.pth")
    parser.add_argument("--dataset_root", required=True, help="Dataset root containing ext1/ext2/condition/wrist_rgb")
    parser.add_argument("--output_root", required=True, help="Output root to save .tensors.pth")
    parser.add_argument("--text_encoder_path", required=True, help="Text encoder path")
    parser.add_argument("--vae_path", required=True, help="VAE path")
    parser.add_argument("--image_encoder_path", required=True, help="Image encoder path")
    parser.add_argument("--tiled", action="store_true", help="Enable tiled VAE encoding to save VRAM")
    parser.add_argument("--tile_size_height", type=int, default=34)
    parser.add_argument("--tile_size_width", type=int, default=34)
    parser.add_argument("--tile_stride_height", type=int, default=18)
    parser.add_argument("--tile_stride_width", type=int, default=16)
    parser.add_argument("--num_gpus", type=int, default=1, help="Number of parallel processes (one GPU per process)")
    args = parser.parse_args()

    dataset_root = Path(args.dataset_root)
    output_root = Path(args.output_root)

    # Pre-enumerate samples
    samples = list_samples(dataset_root)
    if len(samples) == 0:
        raise RuntimeError("No samples found (four-way matching .mp4)")
    print(len(samples))
    # Basic sanity checks
    assert args.num_gpus >= 1, "--num_gpus must be >=1"
    assert args.tile_size_height > 0 and args.tile_size_width > 0, "tile size must be positive"
    assert args.tile_stride_height > 0 and args.tile_stride_width > 0, "tile stride must be positive"

    # Multi-GPU parallel processing
    if args.num_gpus > 1 and torch.cuda.is_available():
        nproc = min(int(args.num_gpus), torch.cuda.device_count())
        shards: List[List[Tuple[Path, Path, Path, Path, str]]] = [samples[i::nproc] for i in range(nproc)]
        ctx = mp.get_context("spawn")
        result_queue: mp.Queue = ctx.Queue()
        procs: List[mp.Process] = []
        for pid in range(nproc):
            assert len(shards[pid]) > 0, f"GPU{pid} shard is empty; check sample count and --num_gpus"
            p = ctx.Process(target=worker, args=(pid, shards[pid], args, result_queue))
            p.start()
            procs.append(p)
        stats = []
        for _ in range(nproc):
            res = result_queue.get()
            if isinstance(res, dict) and "error" in res:
                print(f"[子进程错误] {res['error']}\n{res['trace']}")
            else:
                stats.extend(res)
        for p in procs:
            p.join()
    else:
        # Single-GPU sequential path
        model_paths = [args.text_encoder_path, args.vae_path, args.image_encoder_path]
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model_manager = ModelManager(torch_dtype=torch.bfloat16, device=device)
        model_manager.load_models(model_paths)
        pipe = WanVideoPipeline.from_model_manager(model_manager)

        stats = []
        out_dir = output_root / "wrist_rgb"
        for idx, (p1, p2, pc, pw, base) in enumerate(samples):
            print(f"[{idx+1}/{len(samples)}] processing: {pw.name}")
            rec = process_one_sample(
                p1, p2, pc, pw,
                base,
                out_dir,
                pipe,
                height=480,  # fixed height
                width=832,  # fixed width
                num_frames=81,  # fixed frame count
                tiled=args.tiled,
                tile_size=(args.tile_size_height, args.tile_size_width),
                tile_stride=(args.tile_stride_height, args.tile_stride_width),
            )
            if rec: # only append successful samples
                print(f"  -> saved: {rec['out']}")
                stats.append(rec)

    print(f"Done: {len(stats)} samples generated .tensors.pth")


if __name__ == "__main__":
    main()
"""
TOKENIZERS_PARALLELISM=false python prepare_wrist_condition_tensors.py   --dataset_root XXX/condition_dataset_out \
      --output_root XXX/condition_dataset_out_tensors \
      --text_encoder_path XXX/models--Wan-AI--Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
      --vae_path XXX/models--Wan-AI--Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth \
      --image_encoder_path XXX/models--Wan-AI--Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
      --tiled \
      --tile_size_height 34 \
      --tile_size_width 34 \
      --tile_stride_height 18 \
      --tile_stride_width 16
"""