# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

"""
ALOHA Dynamics 评估脚本（多图像时间预测）
任务：[head, left, right][t] + (optional next_head) + action_sequence -> next_frame
1. Format 1: [head, left, right] + action -> next_head
2. Format 2a: [head, left, right, next_head] + action -> next_left_wrist
3. Format 2b: [head, left, right, next_head] + action -> next_right_wrist
"""

import os
import json
import argparse
import sys
sys.path.append('.')

import torch
import torch.distributed as dist
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse

from eval.aloha_eval_utils import setup_models, batch_pred_next_imgs_cfg, set_seeds
try:
    import lpips
    LPIPS_AVAILABLE = True
except ImportError:
    LPIPS_AVAILABLE = False
    print("Warning: lpips not available. LPIPS metric will be skipped.")

try:
    from torchmetrics.image.fid import FrechetInceptionDistance
    FID_AVAILABLE = True
except ImportError:
    FID_AVAILABLE = False
    print("Warning: torchmetrics not available. FID metric will be skipped.")


def compute_image_metrics(generated_image, gt_image, lpips_model=None, device='cuda'):
    """计算生成图像和GT图像之间的量化指标"""
    if isinstance(generated_image, np.ndarray):
        generated_image_pil = Image.fromarray(generated_image)
    else:
        generated_image_pil = generated_image

    if generated_image_pil.size != gt_image.size:
        gt_image = gt_image.resize(generated_image_pil.size, Image.LANCZOS)
    
    gen_np = np.array(generated_image).astype(np.float32) / 255.0
    gt_np = np.array(gt_image).astype(np.float32) / 255.0
    
    metrics = {}
    
    try:
        metrics['mse'] = float(mse(gt_np, gen_np))
        metrics['psnr'] = float(psnr(gt_np, gen_np, data_range=1.0))
        
        if gen_np.ndim == 3:
            metrics['ssim'] = float(ssim(gt_np, gen_np, data_range=1.0, channel_axis=2, multichannel=True))
        else:
            metrics['ssim'] = float(ssim(gt_np, gen_np, data_range=1.0))
        
        if lpips_model is not None:
            gen_tensor = torch.from_numpy(gen_np).permute(2, 0, 1).unsqueeze(0).to(device) * 2 - 1
            gt_tensor = torch.from_numpy(gt_np).permute(2, 0, 1).unsqueeze(0).to(device) * 2 - 1
            
            with torch.no_grad():
                lpips_value = lpips_model(gen_tensor, gt_tensor)
            metrics['lpips'] = float(lpips_value.item())
        
    except Exception as e:
        print(f"Error computing metrics: {e}")
        metrics['error'] = str(e)
    
    return metrics


def image_to_fid_tensor(img, device):
    """将图像转换为FID计算所需的Tensor格式 (1, C, H, W), float32, [0,1]"""
    if isinstance(img, Image.Image):
        arr = np.array(img)
    else:
        arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr] * 3, axis=-1)
    tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
    return tensor


def create_comparison_image(input_images, generated_image, gt_image, prediction_type="head_camera", use_generated_head=False):
    """创建横向拼接的对比图像（支持多输入图像）
    
    Args:
        input_images: 输入图像列表
        generated_image: 生成的图像
        gt_image: 真实图像（Ground Truth）
        prediction_type: 预测类型（head_camera, wrist_camera, etc.）
        use_generated_head: 对于 wrist，是否使用了生成的 head（而不是 GT head）
    """
    if isinstance(generated_image, np.ndarray):
        generated_image = Image.fromarray(generated_image)
    
    target_width, target_height = 256, 192
    
    input_images_resized = [img.resize((target_width, target_height), Image.LANCZOS) for img in input_images]
    generated_image_resized = generated_image.resize((target_width, target_height), Image.LANCZOS)
    if gt_image is not None:
        gt_image_resized = gt_image.resize((target_width, target_height), Image.LANCZOS)
    
    label_height = 40
    font_size = 16
    
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
    except:
        font = ImageFont.load_default()
    
    num_images = len(input_images) + 2
    total_width = target_width * num_images
    total_height = target_height + label_height
    comparison = Image.new('RGB', (total_width, total_height), (255, 255, 255))
    
    for i, img in enumerate(input_images_resized):
        comparison.paste(img, (i * target_width, label_height))
    
    comparison.paste(generated_image_resized, (len(input_images) * target_width, label_height))
    
    if gt_image is not None:
        comparison.paste(gt_image_resized, ((len(input_images) + 1) * target_width, label_height))
    
    input_labels = []
    if len(input_images) == 2:
        # LIBERO: head + wrist
        if use_generated_head and prediction_type in ['wrist_camera']:
            input_labels = ["Gen Head[t+k]", "Wrist[t]"]
        else:
            input_labels = ["Head[t]", "Wrist[t]"]
    elif len(input_images) == 3:
        # ALOHA: head + left + right
        if use_generated_head and prediction_type in ['left_wrist_camera', 'right_wrist_camera']:
            input_labels = ["Gen Head[t+k]", "Left[t]", "Right[t]"]
        else:
            input_labels = ["Head[t]", "Left[t]", "Right[t]"]
    elif len(input_images) == 4:
        # ALOHA with next_head: head + left + right + next_head
        input_labels = ["Head[t]", "Left[t]", "Right[t]", "Head[t+k]"]
    else:
        input_labels = [f"Input {i+1}" for i in range(len(input_images))]
    
    if prediction_type == "head_camera":
        pred_label = "Generated Head[t+k]"
        gt_label = "GT Head[t+k]"
    elif prediction_type == "wrist_camera":
        pred_label = "Generated Wrist[t+k]"
        gt_label = "GT Wrist[t+k]"
    elif prediction_type == "left_wrist_camera":
        pred_label = "Generated Left[t+k]"
        gt_label = "GT Left[t+k]"
    elif prediction_type == "right_wrist_camera":
        pred_label = "Generated Right[t+k]"
        gt_label = "GT Right[t+k]"
    else:
        pred_label = "Generated"
        gt_label = "Ground Truth"
    
    all_labels = input_labels + [pred_label, gt_label]
    
    draw = ImageDraw.Draw(comparison)
    for i, label in enumerate(all_labels):
        bbox = draw.textbbox((0, 0), label, font=font)
        text_width = bbox[2] - bbox[0]
        text_x = target_width * i + (target_width - text_width) // 2
        draw.text((text_x, 10), label, fill=(0, 0, 0), font=font)
    
    return comparison


def create_rollout_comparison_image(gt_frames, generated_frames, prediction_type="head_camera", sequence_info=None):
    """创建rollout过程的对比图像
    
    Args:
        gt_frames: GT帧列表，包含初始帧和所有目标帧 [初始帧, chunk1目标帧, chunk2目标帧, ...]
        generated_frames: 生成的帧列表，包含初始帧和所有生成的帧 [初始帧, chunk1生成帧, chunk2生成帧, ...]
        prediction_type: 预测类型（head_camera, wrist_camera, etc.）
        sequence_info: 序列信息字典，包含episode_id, start_frame等（用于标签）
    
    Returns:
        comparison: PIL Image对象
    """
    gt_frames_pil = []
    for img in gt_frames:
        if isinstance(img, np.ndarray):
            gt_frames_pil.append(Image.fromarray(img))
        else:
            gt_frames_pil.append(img)
    
    generated_frames_pil = []
    for img in generated_frames:
        if isinstance(img, np.ndarray):
            generated_frames_pil.append(Image.fromarray(img))
        else:
            generated_frames_pil.append(img)
    
    target_width, target_height = 256, 192
    
    gt_frames_resized = [img.resize((target_width, target_height), Image.LANCZOS) for img in gt_frames_pil]
    generated_frames_resized = [img.resize((target_width, target_height), Image.LANCZOS) for img in generated_frames_pil]
    
    label_height = 50
    row_label_width = 100
    font_size = 16
    
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
    except:
        font = ImageFont.load_default()
    
    num_frames = len(gt_frames_resized)
    total_width = row_label_width + target_width * num_frames
    total_height = label_height + target_height * 2
    
    comparison = Image.new('RGB', (total_width, total_height), (255, 255, 255))
    draw = ImageDraw.Draw(comparison)
    
    draw.text((10, label_height // 2 - 8), "GT", fill=(0, 0, 0), font=font)
    for i, img in enumerate(gt_frames_resized):
        x = row_label_width + i * target_width
        y = label_height
        comparison.paste(img, (x, y))
        
        if sequence_info and i == 0:
            frame_label = f"t={sequence_info.get('start_frame', 0)}"
        elif sequence_info and i > 0:
            chunk_size = sequence_info.get('chunk_size', 10)
            frame_label = f"t={sequence_info.get('start_frame', 0) + i * chunk_size}"
        else:
            frame_label = f"Frame {i}"
        
        bbox = draw.textbbox((0, 0), frame_label, font=font)
        text_width = bbox[2] - bbox[0]
        text_x = row_label_width + i * target_width + (target_width - text_width) // 2
        draw.text((text_x, 10), frame_label, fill=(0, 0, 0), font=font)
    
    draw.text((10, label_height + target_height + label_height // 2 - 8), "Gen", fill=(0, 0, 0), font=font)
    for i, img in enumerate(generated_frames_resized):
        x = row_label_width + i * target_width
        y = label_height + target_height
        comparison.paste(img, (x, y))
    
    type_labels = {
        "head_camera": "Head Camera Rollout",
        "wrist_camera": "Wrist Camera Rollout",
        "left_wrist_camera": "Left Wrist Camera Rollout",
        "right_wrist_camera": "Right Wrist Camera Rollout"
    }
    type_label = type_labels.get(prediction_type, "Rollout Comparison")
    
    bbox = draw.textbbox((0, 0), type_label, font=font)
    text_width = bbox[2] - bbox[0]
    text_x = (total_width - text_width) // 2
    draw.text((text_x, 5), type_label, fill=(0, 0, 0), font=font)
    
    return comparison


def organize_rollout_sequences(data, rollout_chunks):
    """
    组织数据为rollout序列
    
    Args:
        data: JSONL数据列表
        rollout_chunks: 每个rollout序列包含的chunk数量
    
    Returns:
        rollout_sequences: 列表，每个元素是一个rollout序列（包含rollout_chunks个连续样本）
    """
    grouped = {}
    for item in data:
        key = (item['episode_id'], item['prediction_type'])
        if key not in grouped:
            grouped[key] = []
        grouped[key].append(item)
    
    for key in grouped:
        grouped[key].sort(key=lambda x: x['start_frame'])

    chunk_size = grouped[(0, 'head_camera')][0]['end_frame'] - grouped[(0, 'head_camera')][0]['start_frame']
    
    rollout_sequences = []
    for key, items in grouped.items():
        for i in range(len(items) - rollout_chunks * chunk_size + 1):
            sequence = items[i:i + rollout_chunks * chunk_size:chunk_size]
            
            is_continuous = True
            for j in range(len(sequence) - 1):
                if sequence[j]['end_frame'] != sequence[j+1]['start_frame']:
                    is_continuous = False
                    break
            
            if is_continuous:
                rollout_sequences.append(sequence)
    
    return rollout_sequences


def gather_image_list(local_list):
    """在分布式环境中收集各rank的图像列表，仅返回合并后的列表（用于rank0计算FID）"""
    world_size = dist.get_world_size()
    gathered = [None for _ in range(world_size)]
    dist.all_gather_object(gathered, local_list)
    merged = []
    for lst in gathered:
        if lst:
            merged.extend(lst)
    return merged




def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', type=str, required=True)
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--jsonl_path', type=str, required=True)
    parser.add_argument('--prompt_path', type=str, required=True)
    parser.add_argument('--image_dir', type=str, required=True)
    parser.add_argument('--cfg_text_scale', type=float, default=4.0)
    parser.add_argument('--cfg_img_scale', type=float, default=2.0)
    parser.add_argument('--num_samples', type=int, default=50)
    parser.add_argument('--max_mem_per_gpu', type=str, default="80GiB")
    parser.add_argument('--rollout_chunks', type=int, default=1, help='跨chunk rollout的步数，1表示不rollout，2表示rollout两个chunk')
    parser.add_argument('--save_intermediate', action='store_true', help='保存rollout过程中的中间帧')
    args = parser.parse_args()
    
    dist.init_process_group(backend="nccl")
    local_rank = dist.get_rank()
    device = f"cuda:{local_rank}"
    torch.cuda.set_device(device)
    
    set_seeds(0)
    
    if local_rank == 0:
        print("Loading models...")
        if args.rollout_chunks > 1:
            print(f"   Rollout模式已启用: {args.rollout_chunks} chunks")
            print(f"   评估将只在最后一帧进行")
            if args.save_intermediate:
                print(f"   将保存中间帧")
        else:
            print("单步预测模式")
    model, vae_model, tokenizer, new_token_ids, vae_transform, vit_transform = setup_models(
        args.base_dir, args.model_path, device
    )
    
    with open(args.prompt_path, 'r') as f:
        prompt_text = f.read().strip()
    
    with open(args.jsonl_path, 'r') as f:
        data = [json.loads(line) for line in f]
    
    if local_rank == 0:
        print(f"Total samples in dataset: {len(data)}")
    
    lpips_model = None
    if LPIPS_AVAILABLE and local_rank == 0:
        try:
            lpips_model = lpips.LPIPS(net='alex').to(device)
            print("LPIPS model initialized successfully.")
        except Exception as e:
            print(f"Warning: Failed to initialize LPIPS model: {e}")
            print("LPIPS metric will be skipped.")
            lpips_model = None
    
    rollout_sequences = organize_rollout_sequences(data, args.rollout_chunks)
    
    if local_rank == 0:
        print(f"总共创建了 {len(rollout_sequences)} 个rollout序列")
    
    head_sequences = [seq for seq in rollout_sequences if seq[0].get('prediction_type') == 'head_camera']
    wrist_sequences = [seq for seq in rollout_sequences if seq[0].get('prediction_type') in ['wrist_camera', 'left_wrist_camera', 'right_wrist_camera']]
    
    if local_rank == 0:
        print(f"Head序列: {len(head_sequences)}, Wrist序列: {len(wrist_sequences)}")
    
    head_seq_dict = {}
    for seq in head_sequences:
        key = (seq[0].get('episode_id', 0), seq[0].get('start_frame', 0))
        head_seq_dict[key] = seq
    
    wrist_seq_dict = {}
    for seq in wrist_sequences:
        key = (seq[0].get('episode_id', 0), seq[0].get('start_frame', 0))
        if key not in wrist_seq_dict:
            wrist_seq_dict[key] = []
        wrist_seq_dict[key].append(seq)
    
    import random
    random.seed(42)
    
    paired_keys = [key for key in head_seq_dict.keys() if key in wrist_seq_dict]
    num_samples = min(args.num_samples, len(paired_keys))
    
    if local_rank == 0:
        print(f"从 {len(paired_keys)} 个配对序列中采样 {num_samples} 个...")
    
    sampled_keys = random.sample(paired_keys, num_samples)
    
    head_samples = [head_seq_dict[key] for key in sampled_keys]
    wrist_samples = []
    for key in sampled_keys:
        wrist_samples.extend(wrist_seq_dict[key])
    
    if local_rank == 0:
        print(f"Sampled {len(head_samples)} head samples and {len(wrist_samples)} wrist samples")
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    temp_head_dir = os.path.join(args.output_dir, 'temp_generated_heads')
    os.makedirs(temp_head_dir, exist_ok=True)
    
    if local_rank == 0:
        print("\n【阶段1】生成 Head 图像...")
    
    head_samples_per_gpu = len(head_samples) // dist.get_world_size()
    head_start_idx = local_rank * head_samples_per_gpu
    head_end_idx = head_start_idx + head_samples_per_gpu if local_rank < dist.get_world_size() - 1 else len(head_samples)
    local_head_data = head_samples[head_start_idx:head_end_idx]
    
    generated_head_paths = {}
    
    head_final_frames = {}
    
    local_generated_images_for_fid = []
    local_gt_images_for_fid = []
    
    for seq_idx, sequence in enumerate(local_head_data):
        current_generated = None
        intermediate_frames = []
        
        rollout_gt_frames = []
        rollout_generated_frames = []
        
        first_item = sequence[0]
        first_images_field = first_item['images']
        if isinstance(first_images_field[0], list):
            first_input_filenames = first_images_field[0]
        else:
            first_input_filenames = first_images_field[:-1]
        
        initial_gt_frame = Image.open(os.path.join(args.image_dir, first_input_filenames[0])).convert('RGB')
        rollout_gt_frames.append(initial_gt_frame)
        rollout_generated_frames.append(initial_gt_frame.copy())
        
        for chunk_idx, item in enumerate(sequence):
            images_field = item['images']
            if isinstance(images_field[0], list):
                input_image_filenames = images_field[0]
                output_image_filename = images_field[1][0]
            else:
                input_image_filenames = images_field[:-1]
                output_image_filename = images_field[-1]
            
            target_gt_frame = Image.open(os.path.join(args.image_dir, output_image_filename)).convert('RGB')
            rollout_gt_frames.append(target_gt_frame)
            
            if chunk_idx == 0:
                input_images_list = [
                    Image.open(os.path.join(args.image_dir, img_file)).convert('RGB')
                    for img_file in input_image_filenames
                ]
            else:
                input_images_list = [current_generated] + [
                    Image.open(os.path.join(args.image_dir, img_file)).convert('RGB')
                    for img_file in input_image_filenames[1:]
                ]
                # input_images_list = [current_generated]
            
            action_text = item['action_sequence'][0]
            
            generated_images = batch_pred_next_imgs_cfg(
                model, vae_model, tokenizer, new_token_ids, vae_transform, vit_transform,
                prompt_text, [input_images_list], [action_text], 
                num_timesteps=50,
                cfg_text_scale=args.cfg_text_scale,
                cfg_img_scale=args.cfg_img_scale,
                cfg_type="parallel",
                cfg_interval=[0.4, 1.0],
                cfg_renorm_min=0.0,
                cfg_renorm_type="text_channel",
                timestep_shift=4.0,
                enable_taylorseer=True,
                device=device
            )
            current_generated = generated_images[0]
            
            if isinstance(current_generated, np.ndarray):
                current_generated = Image.fromarray(current_generated)
            
            rollout_generated_frames.append(current_generated.copy())
            
            if args.save_intermediate and chunk_idx < len(sequence) - 1:
                intermediate_frames.append(current_generated.copy())
                if local_rank == 0:
                    inter_path = os.path.join(args.output_dir, f"head_seq{seq_idx:04d}_chunk{chunk_idx:02d}.png")
                    current_generated.save(inter_path)
            
            key = (item.get('episode_id', 0), item.get('start_frame', 0), item.get('end_frame', 0))
            head_filename = f"head_ep{item.get('episode_id', 0):06d}_s{item.get('start_frame', 0):06d}_e{item.get('end_frame', 0):06d}.png"
            head_path = os.path.join(temp_head_dir, head_filename)
            current_generated.save(head_path)
            generated_head_paths[key] = head_path
        
        if local_rank == 0 and args.rollout_chunks > 1:
            sequence_info = {
                'episode_id': sequence[0].get('episode_id', 0),
                'start_frame': sequence[0].get('start_frame', 0),
                'chunk_size': sequence[0].get('end_frame', 0) - sequence[0].get('start_frame', 0) if len(sequence) > 0 else 10
            }
            rollout_comparison = create_rollout_comparison_image(
                rollout_gt_frames, rollout_generated_frames, 
                prediction_type='head_camera',
                sequence_info=sequence_info
            )
            rollout_comparison.save(os.path.join(args.output_dir, f"head_seq{seq_idx:04d}_rollout_comparison.png"))
        
        final_item = sequence[-1]
        images_field = final_item['images']
        if isinstance(images_field[0], list):
            output_image_filename = images_field[1][0]
        else:
            output_image_filename = images_field[-1]
        
        head_final_frames[seq_idx] = {
            'generated': current_generated,
            'gt_filename': output_image_filename,
            'item': final_item,
            'input_images_first_chunk': [
                Image.open(os.path.join(args.image_dir, img_file)).convert('RGB')
                for img_file in (sequence[0]['images'][0] if isinstance(sequence[0]['images'][0], list) else sequence[0]['images'][:-1])
            ]
        }
        
        if FID_AVAILABLE:
            local_generated_images_for_fid.append(current_generated)
            gt_image_for_fid = Image.open(os.path.join(args.image_dir, output_image_filename)).convert('RGB')
            local_gt_images_for_fid.append(gt_image_for_fid)
        
        if local_rank == 0 and seq_idx % 10 == 0:
            print(f"Generated {seq_idx}/{len(local_head_data)} head rollout sequences ({args.rollout_chunks} chunks)")
    
    dist.barrier()
    
    if local_rank == 0:
        print(f"All head images saved to {temp_head_dir}")
    
    if local_rank == 0:
        print("\n【阶段2】生成 Wrist 图像（使用生成的 Head）...")
    
    wrist_samples_per_gpu = len(wrist_samples) // dist.get_world_size()
    wrist_start_idx = local_rank * wrist_samples_per_gpu
    wrist_end_idx = wrist_start_idx + wrist_samples_per_gpu if local_rank < dist.get_world_size() - 1 else len(wrist_samples)
    local_wrist_data = wrist_samples[wrist_start_idx:wrist_end_idx]
    
    all_metrics = []
    wrist_final_frames = {}
    
    local_generated_wrist_images_for_fid = []
    local_gt_wrist_images_for_fid = []
    
    for seq_idx, sequence in enumerate(local_wrist_data):
        current_generated = None
        prediction_type = sequence[0].get('prediction_type', 'wrist_camera')
        
        rollout_gt_frames = []
        rollout_generated_frames = []
        
        first_item = sequence[0]
        first_images_field = first_item['images']
        if isinstance(first_images_field[0], list):
            first_input_filenames = first_images_field[0]
        else:
            first_input_filenames = first_images_field[:-1]
        
        if len(first_input_filenames) > 1:
            initial_gt_frame = Image.open(os.path.join(args.image_dir, first_input_filenames[1])).convert('RGB')
        else:
            initial_gt_frame = Image.open(os.path.join(args.image_dir, first_input_filenames[0])).convert('RGB')
        rollout_gt_frames.append(initial_gt_frame)
        rollout_generated_frames.append(initial_gt_frame.copy())
        
        for chunk_idx, item in enumerate(sequence):
            images_field = item['images']
            if isinstance(images_field[0], list):
                input_image_filenames = images_field[0]
                output_image_filename = images_field[1][0]
            else:
                input_image_filenames = images_field[:-1]
                output_image_filename = images_field[-1]
            
            target_gt_frame = Image.open(os.path.join(args.image_dir, output_image_filename)).convert('RGB')
            rollout_gt_frames.append(target_gt_frame)
            
            key = (item.get('episode_id', 0), item.get('start_frame', 0), item.get('end_frame', 0))
            head_filename = f"head_ep{item.get('episode_id', 0):06d}_s{item.get('start_frame', 0):06d}_e{item.get('end_frame', 0):06d}.png"
            head_path = os.path.join(temp_head_dir, head_filename)
            
            if os.path.exists(head_path):
                generated_head = Image.open(head_path).convert('RGB')
            else:
                if local_rank == 0:
                    print(f"Warning: Generated head not found at {head_path}")
                generated_head = Image.open(os.path.join(args.image_dir, input_image_filenames[0])).convert('RGB')
            
            if chunk_idx == 0:
                wrist_input_images_list = [
                    Image.open(os.path.join(args.image_dir, img_file)).convert('RGB')
                    for img_file in input_image_filenames[1:]
                ]
                input_images_list = [generated_head] + wrist_input_images_list
                # wrist_input_images_list = [
                #     Image.open(os.path.join(args.image_dir, img_file)).convert('RGB')
                #     for img_file in input_image_filenames
                # ]
                # input_images_list = wrist_input_images_list
            else:
                wrist_input_images_list = [
                    Image.open(os.path.join(args.image_dir, img_file)).convert('RGB')
                    for img_file in input_image_filenames[1:]
                ]
                if len(wrist_input_images_list) > 0:
                    wrist_input_images_list[0] = current_generated
                input_images_list = [generated_head] + wrist_input_images_list
                # input_images_list = [current_generated]
            
            action_text = item['action_sequence'][0] if item['action_sequence'] else ""
            
            generated_images = batch_pred_next_imgs_cfg(
                model, vae_model, tokenizer, new_token_ids, vae_transform, vit_transform,
                prompt_text, [input_images_list], [action_text], 
                num_timesteps=20,
                cfg_text_scale=args.cfg_text_scale,
                cfg_img_scale=args.cfg_img_scale,
                cfg_type="parallel",
                cfg_interval=[0.4, 1.0],
                cfg_renorm_min=0.0,
                cfg_renorm_type="text_channel",
                timestep_shift=4.0,
                enable_taylorseer=True,
                device=device
            )
            current_generated = generated_images[0]
            
            if isinstance(current_generated, np.ndarray):
                current_generated = Image.fromarray(current_generated)
            
            rollout_generated_frames.append(current_generated.copy())
            
            if args.save_intermediate and chunk_idx < len(sequence) - 1 and local_rank == 0:
                inter_path = os.path.join(args.output_dir, f"wrist_seq{seq_idx:04d}_chunk{chunk_idx:02d}_{prediction_type}.png")
                current_generated.save(inter_path)
        
        if local_rank == 0 and args.rollout_chunks > 1:
            sequence_info = {
                'episode_id': sequence[0].get('episode_id', 0),
                'start_frame': sequence[0].get('start_frame', 0),
                'chunk_size': sequence[0].get('end_frame', 0) - sequence[0].get('start_frame', 0) if len(sequence) > 0 else 10
            }
            rollout_comparison = create_rollout_comparison_image(
                rollout_gt_frames, rollout_generated_frames, 
                prediction_type=prediction_type,
                sequence_info=sequence_info
            )
            rollout_comparison.save(os.path.join(args.output_dir, f"wrist_seq{seq_idx:04d}_rollout_comparison_{prediction_type}.png"))
        
        final_item = sequence[-1]
        images_field = final_item['images']
        if isinstance(images_field[0], list):
            output_image_filename = images_field[1][0]
        else:
            output_image_filename = images_field[-1]
        
        gt_image = Image.open(os.path.join(args.image_dir, output_image_filename)).convert('RGB')
        
        metrics = compute_image_metrics(current_generated, gt_image, lpips_model, device)
        metrics['sample_id'] = final_item['id']
        metrics['prediction_type'] = prediction_type
        metrics['rollout_chunks'] = args.rollout_chunks
        all_metrics.append(metrics)
        
        if FID_AVAILABLE:
            local_generated_wrist_images_for_fid.append(current_generated)
            local_gt_wrist_images_for_fid.append(gt_image)
        
        if local_rank == 0:
            first_item = sequence[0]
            first_images_field = first_item['images']
            if isinstance(first_images_field[0], list):
                first_input_filenames = first_images_field[0]
            else:
                first_input_filenames = first_images_field[:-1]
            
            first_input_images = [
                Image.open(os.path.join(args.image_dir, img_file)).convert('RGB')
                for img_file in first_input_filenames
            ]
            
            comparison = create_comparison_image(
                first_input_images, current_generated, gt_image, prediction_type, use_generated_head=True
            )
            comparison.save(os.path.join(args.output_dir, f"wrist_seq{seq_idx:04d}_comparison.png"))
        
        if local_rank == 0 and seq_idx % 10 == 0:
            print(f"Generated {seq_idx}/{len(local_wrist_data)} wrist rollout sequences ({args.rollout_chunks} chunks, type: {prediction_type})")
    
    head_metrics = []
    for seq_idx, frame_info in head_final_frames.items():
        generated_head = frame_info['generated']
        gt_filename = frame_info['gt_filename']
        final_item = frame_info['item']
        input_images_first = frame_info['input_images_first_chunk']
        
        gt_image = Image.open(os.path.join(args.image_dir, gt_filename)).convert('RGB')
        
        metrics = compute_image_metrics(generated_head, gt_image, lpips_model, device)
        metrics['sample_id'] = final_item['id']
        metrics['prediction_type'] = final_item.get('prediction_type', 'head_camera')
        metrics['rollout_chunks'] = args.rollout_chunks
        head_metrics.append(metrics)
        
        if local_rank == 0:
            comparison = create_comparison_image(
                input_images_first, generated_head, gt_image, 
                final_item.get('prediction_type', 'head_camera')
            )
            comparison.save(os.path.join(args.output_dir, f"head_seq{seq_idx:04d}_comparison.png"))
    
    all_metrics.extend(head_metrics)
    
    all_generated_head_images = []
    all_gt_head_images = []
    all_generated_wrist_images = []
    all_gt_wrist_images = []
    all_generated_images = []
    all_gt_images = []
    
    if FID_AVAILABLE:
        all_generated_head_images = gather_image_list(local_generated_images_for_fid)
        all_gt_head_images = gather_image_list(local_gt_images_for_fid)
        
        all_generated_wrist_images = gather_image_list(local_generated_wrist_images_for_fid)
        all_gt_wrist_images = gather_image_list(local_gt_wrist_images_for_fid)
        
        all_generated_images = all_generated_head_images + all_generated_wrist_images
        all_gt_images = all_gt_head_images + all_gt_wrist_images
    
    if local_rank == 0:
        import shutil
        shutil.rmtree(temp_head_dir)
        
    
    if local_rank == 0:
        avg_metrics = {}
        
        avg_metrics['rollout_chunks'] = args.rollout_chunks
        avg_metrics['evaluation_mode'] = 'rollout' if args.rollout_chunks > 1 else 'single_chunk'
        
        for key in ['mse', 'psnr', 'ssim', 'lpips']:
            values = [m[key] for m in all_metrics if key in m]
            if values:
                avg_metrics[f'avg_{key}'] = float(np.mean(values))
                avg_metrics[f'std_{key}'] = float(np.std(values))
        
        avg_metrics['num_samples'] = len(all_metrics)
        
        if FID_AVAILABLE and len(all_generated_images) > 0 and len(all_gt_images) > 0:
            print("\n计算 FID 指标...")
            fid_metric = FrechetInceptionDistance(normalize=True).to(device)
            
            print(f"  处理 {len(all_gt_images)} 个真实图像和 {len(all_generated_images)} 个生成图像...")
            for gt_img in all_gt_images:
                gt_tensor = image_to_fid_tensor(gt_img, device)
                fid_metric.update(gt_tensor, real=True)
            
            for gen_img in all_generated_images:
                gen_tensor = image_to_fid_tensor(gen_img, device)
                fid_metric.update(gen_tensor, real=False)
            
            fid_value = fid_metric.compute()
            avg_metrics['fid'] = float(fid_value.item())
            print(f"  FID: {avg_metrics['fid']:.4f}")
        else:
            if not FID_AVAILABLE:
                print("  FID 不可用（torchmetrics 未安装）")
            else:
                print(f"  FID 跳过：生成图像数={len(all_generated_images) if FID_AVAILABLE else 0}, GT图像数={len(all_gt_images) if FID_AVAILABLE else 0}")
            avg_metrics['fid'] = None
        
        pred_types = {}
        for m in all_metrics:
            pt = m.get('prediction_type', 'unknown')
            if pt not in pred_types:
                pred_types[pt] = []
            pred_types[pt].append(m)
        
        avg_metrics['by_prediction_type'] = {}
        for pt, metrics_list in pred_types.items():
            pt_metrics = {}
            for key in ['mse', 'psnr', 'ssim', 'lpips']:
                values = [m[key] for m in metrics_list if key in m]
                if values:
                    pt_metrics[f'avg_{key}'] = float(np.mean(values))
                    pt_metrics[f'std_{key}'] = float(np.std(values))
            pt_metrics['num_samples'] = len(metrics_list)
            avg_metrics['by_prediction_type'][pt] = pt_metrics
        
        if FID_AVAILABLE and len(all_generated_images) > 0:
            # Head camera FID
            if len(all_generated_head_images) > 0 and len(all_gt_head_images) > 0:
                fid_head = FrechetInceptionDistance(normalize=True).to(device)
                for gt_img in all_gt_head_images:
                    gt_tensor = image_to_fid_tensor(gt_img, device)
                    fid_head.update(gt_tensor, real=True)
                for gen_img in all_generated_head_images:
                    gen_tensor = image_to_fid_tensor(gen_img, device)
                    fid_head.update(gen_tensor, real=False)
                head_fid_value = float(fid_head.compute().item())
                if 'head_camera' not in avg_metrics['by_prediction_type']:
                    avg_metrics['by_prediction_type']['head_camera'] = {}
                avg_metrics['by_prediction_type']['head_camera']['fid'] = head_fid_value
                print(f"  Head FID: {head_fid_value:.4f}")
            
            # Wrist camera FID
            if len(all_generated_wrist_images) > 0 and len(all_gt_wrist_images) > 0:
                fid_wrist = FrechetInceptionDistance(normalize=True).to(device)
                for gt_img in all_gt_wrist_images:
                    gt_tensor = image_to_fid_tensor(gt_img, device)
                    fid_wrist.update(gt_tensor, real=True)
                for gen_img in all_generated_wrist_images:
                    gen_tensor = image_to_fid_tensor(gen_img, device)
                    fid_wrist.update(gen_tensor, real=False)
                wrist_fid_value = float(fid_wrist.compute().item())
                wrist_types = [pt for pt in avg_metrics['by_prediction_type'].keys() if 'wrist' in pt.lower()]
                if wrist_types:
                    wrist_type = wrist_types[0]
                else:
                    wrist_type = 'wrist_camera'
                if wrist_type not in avg_metrics['by_prediction_type']:
                    avg_metrics['by_prediction_type'][wrist_type] = {}
                avg_metrics['by_prediction_type'][wrist_type]['fid'] = wrist_fid_value
                print(f"  Wrist FID ({wrist_type}): {wrist_fid_value:.4f}")
        
        with open(os.path.join(args.output_dir, 'evaluation_results.json'), 'w') as f:
            json.dump(avg_metrics, f, indent=2)
        
        with open(os.path.join(args.output_dir, 'detailed_results.json'), 'w') as f:
            json.dump(all_metrics, f, indent=2)
        
        print("\n评估完成！")
        print(f"结果保存在: {args.output_dir}")
        print(json.dumps(avg_metrics, indent=2))
    
    dist.destroy_process_group()


if __name__ == "__main__":
    main()

