import os
from PIL import Image, ImageFilter, ImageDraw
import itertools
import cv2
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as T
import json
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torchvision.transforms import ToPILImage, ToTensor
from torchvision.transforms import (
    Compose,
    ConvertImageDtype,
    Lambda,
    Normalize,
    ToTensor,
)

from decord import VideoReader
from torch.utils.data.dataset import Dataset
from packaging import version as pver

from transformers import pipeline
import depth_pro.depth_pro as depth_pro

# from src.flux.pipeline_tools import encode_images, encode_poses
from src.train.template import object_templates, prompt_templates

class Subject200KDateset(Dataset):
    def __init__(
        self,
        base_dataset,
        condition_size: int = 512,
        target_size: int = 512,
        image_size: int = 512,
        padding: int = 0,
        condition_type: str = "subject",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
    ):
        self.base_dataset = base_dataset
        self.condition_size = condition_size
        self.target_size = target_size
        self.image_size = image_size
        self.padding = padding
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image

        self.to_tensor = T.ToTensor()

    def __len__(self):
        return len(self.base_dataset) * 2

    def __getitem__(self, idx):
        # If target is 0, left image is target, right image is condition
        target = idx % 2
        item = self.base_dataset[idx // 2]

        # Crop the image to target and condition
        image = item["image"]
        left_img = image.crop(
            (
                self.padding,
                self.padding,
                self.image_size + self.padding,
                self.image_size + self.padding,
            )
        )
        right_img = image.crop(
            (
                self.image_size + self.padding * 2,
                self.padding,
                self.image_size * 2 + self.padding * 2,
                self.image_size + self.padding,
            )
        )

        # Get the target and condition image
        target_image, condition_img = (
            (left_img, right_img) if target == 0 else (right_img, left_img)
        )

        # Resize the image
        condition_img = condition_img.resize(
            (self.condition_size, self.condition_size)
        ).convert("RGB")
        target_image = target_image.resize(
            (self.target_size, self.target_size)
        ).convert("RGB")

        # Get the description
        description = item["description"][
            "description_0" if target == 0 else "description_1"
        ]

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_img = Image.new(
                "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
            )

        return {
            "image": self.to_tensor(target_image),
            "condition": self.to_tensor(condition_img),
            "condition_type": self.condition_type,
            "description": description,
            # 16 is the downscale factor of the image
            "position_delta": np.array([0, -self.condition_size // 16]),
            **({"pil_image": image} if self.return_pil_image else {}),
        }


class ImageConditionDataset(Dataset):
    def __init__(
        self,
        base_dataset,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "canny",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
    ):
        self.base_dataset = base_dataset
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image

        self.to_tensor = T.ToTensor()

    def __len__(self):
        return len(self.base_dataset)

    @property
    def depth_pipe(self):
        if not hasattr(self, "_depth_pipe"):
            from transformers import pipeline

            self._depth_pipe = pipeline(
                task="depth-estimation",
                model="LiheYoung/depth-anything-small-hf",
                device="cpu",
            )
        return self._depth_pipe

    def _get_canny_edge(self, img):
        resize_ratio = self.condition_size / max(img.size)
        img = img.resize(
            (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
        )
        img_np = np.array(img)
        img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(img_gray, 100, 200)
        return Image.fromarray(edges).convert("RGB")

    def __getitem__(self, idx):
        image = self.base_dataset[idx]["jpg"]
        image = image.resize((self.target_size, self.target_size)).convert("RGB")
        description = self.base_dataset[idx]["json"]["prompt"]

        # Get the condition image
        position_delta = np.array([0, 0])
        if self.condition_type == "canny":
            condition_img = self._get_canny_edge(image)
        elif self.condition_type == "coloring":
            condition_img = (
                image.resize((self.condition_size, self.condition_size))
                .convert("L")
                .convert("RGB")
            )
        elif self.condition_type == "deblurring":
            blur_radius = random.randint(1, 10)
            condition_img = (
                image.convert("RGB")
                .resize((self.condition_size, self.condition_size))
                .filter(ImageFilter.GaussianBlur(blur_radius))
                .convert("RGB")
            )
        elif self.condition_type == "depth":
            condition_img = self.depth_pipe(image)["depth"].convert("RGB")
        elif self.condition_type == "depth_pred":
            condition_img = image
            image = self.depth_pipe(condition_img)["depth"].convert("RGB")
            description = f"[depth] {description}"
        elif self.condition_type == "fill":
            condition_img = image.resize(
                (self.condition_size, self.condition_size)
            ).convert("RGB")
            w, h = image.size
            x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
            y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
            mask = Image.new("L", image.size, 0)
            draw = ImageDraw.Draw(mask)
            draw.rectangle([x1, y1, x2, y2], fill=255)
            if random.random() > 0.5:
                mask = Image.eval(mask, lambda a: 255 - a)
            condition_img = Image.composite(
                image, Image.new("RGB", image.size, (0, 0, 0)), mask
            )
        elif self.condition_type == "sr":
            condition_img = image.resize(
                (self.condition_size, self.condition_size)
            ).convert("RGB")
            position_delta = np.array([0, -self.condition_size // 16])

        else:
            raise ValueError(f"Condition type {self.condition_type} not implemented")

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_img = Image.new(
                "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
            )

        return {
            "image": self.to_tensor(image),
            "condition": self.to_tensor(condition_img),
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
            **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
        }

class LooseConditionDataset(Dataset):
    def __init__(
        self,
        data_path,
        json_path,
        depth_path=None,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "loose_condition",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        max_entity_len: int = 2,
        aug: bool = False,
    ):
        self.data_path = data_path
        self.depth_path = depth_path
        self.json_path = json_path
        self.base_dataset = []
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.max_entity_len = max_entity_len
        self.aug = aug

        self.to_tensor = T.ToTensor()
        self.load_dataset()

        self.depth_transform = Compose(
            [
                ToTensor(),
                Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
    
    def load_dataset(self):
        with open(self.json_path, 'r', encoding='utf-8') as file:
            for line in file:
                obj = json.loads(line)
                # if os.path.exists(f"{self.depth_path}/{obj['image_id']}"):
                #     self.base_dataset.append(obj)
                self.base_dataset.append(obj)

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

    def get_batch(self, idx):
        flip_flag = random.random() > 0.5 and self.aug

        image_id = self.base_dataset[idx]['image_id']
        try:
            image = Image.open(f"{self.data_path}/{image_id}.png")
        except:
            image = Image.open(f"{self.data_path}/{image_id}/{image_id}.png")
        image = image.resize((self.target_size, self.target_size)).convert("RGB")
        if flip_flag: image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        description = self.base_dataset[idx]["caption"]
        entities = self.base_dataset[idx]["entities"]

        position_delta = np.array([0, 0])

        # Get the condition image
        condition_imgs = []
        for i in range(len(entities)):
            depth_path = f"{self.depth_path}/{image_id}/render_depth_{i}.png"
            # depth_path = f"{self.depth_path}/{image_id}/{entities[i].get('depth', f'render_depth_{i}.png')}"
            condition_img = Image.open(depth_path).resize((self.condition_size, self.condition_size)).convert("RGB")
            if flip_flag: condition_img = condition_img.transpose(Image.FLIP_LEFT_RIGHT)
            condition_imgs.append(condition_img)

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_imgs = [Image.new(
                "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
            )] * len(entities)

        condition_imgs = torch.stack([self.to_tensor(condition_img) for condition_img in condition_imgs])
        # 创建非零掩码并计算均值
        mask = condition_imgs != 0
        sum_per_img = (condition_imgs * mask).sum(dim=(1, 2, 3))  # 各图像非零元素总和
        count_per_img = mask.sum(dim=(1, 2, 3)).float()           # 各图像非零元素数量
        means = sum_per_img / count_per_img                       # 计算非零均值

        # 处理全零图像（将NaN转为0）
        means = torch.nan_to_num(means, nan=0.0)

        # 获取降序排列索引（均值大的在前）
        sorted_indices = torch.argsort(means, descending=True).tolist()
        if self.max_entity_len > -1 and len(sorted_indices) > self.max_entity_len:
            sorted_indices = sorted_indices[:self.max_entity_len]

        return {
            "image": self.to_tensor(image),
            "condition": condition_imgs[sorted_indices],
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
        }

class EligenDepthDataset(Dataset):
    def __init__(
        self,
        data_path,
        json_path,
        depth_path=None,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "eligen_depth",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        max_entity_len: int = 2,
    ):
        self.data_path = data_path
        self.depth_path = depth_path
        self.json_path = json_path
        self.base_dataset = []
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.max_entity_len = max_entity_len

        self.to_tensor = T.ToTensor()
        self.load_dataset()
        self.depth_transform = Compose(
            [
                ToTensor(),
                Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ConvertImageDtype(torch.bfloat16),
            ]
        )
    
    def load_dataset(self):
        with open(self.json_path, 'r', encoding='utf-8') as file:
            for line in file:
                obj = json.loads(line)
                # if os.path.exists(f"{self.depth_path}/{obj['image_id']}"):
                #     self.base_dataset.append(obj)
                self.base_dataset.append(obj)

    def __len__(self):
        return len(self.base_dataset)

    @property
    def depth_pipe(self):
        if not hasattr(self, "_depth_pipe"):
            self._depth_pipe = pipeline(
                task="depth-estimation",
                model="LiheYoung/depth-anything-small-hf",
                device="cpu",
            )

        return self._depth_pipe

    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

    def get_batch(self, idx):
        image_id = self.base_dataset[idx]['image_id']
        try:
            image = Image.open(f"{self.data_path}/{image_id}.png")
        except:
            image = Image.open(f"{self.data_path}/{image_id}/{image_id}.png")
        image = image.resize((self.target_size, self.target_size)).convert("RGB")
        description = self.base_dataset[idx]["caption"]
        entities = self.base_dataset[idx]["entities"]

        # Get the condition image
        position_delta = np.array([0, 0])

        # condition_imgs = [self.depth_pipe(image)["depth"].convert("RGB")]
        condition_imgs = self.depth_transform(np.array(image))

        # Get the eligen entity prompts and masks
        eligen_entity_prompts = []
        eligen_entity_masks = []
        for i in range(len(entities)):
            eligen_entity_prompts.append(entities[i]["entity"])
            coordinates = entities[i]["bbox"]
            # Convert percentages to pixel coordinates
            x_min = int(self.condition_size * coordinates[0])
            y_min = int(self.condition_size * coordinates[1])
            x_max = int(self.condition_size * coordinates[2])
            y_max = int(self.condition_size * coordinates[3])
            
            # Create binary mask
            mask = Image.new("L", (self.condition_size, self.condition_size), 0)
            draw = ImageDraw.Draw(mask)
            draw.rectangle([x_min, y_min, x_max, y_max], fill=255)
            mask = mask.convert("RGB")

            mask = np.array(mask.resize((self.condition_size//8, self.condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks.append(mask)

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_imgs = [Image.new(
                "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
            )] * len(entities)

            condition_imgs = self.depth_transform(np.array(condition_imgs[0]))

        # condition_imgs = torch.stack([self.to_tensor(condition_img) for condition_img in condition_imgs])

        if len(eligen_entity_prompts)>self.max_entity_len:
            eligen_entity_prompts = eligen_entity_prompts[:self.max_entity_len]
            eligen_entity_masks = eligen_entity_masks[:self.max_entity_len]

        return {
            "image": self.to_tensor(image),
            "condition": condition_imgs,
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
        }

class EligenLooseDataset(Dataset):
    def __init__(
        self,
        data_path,
        json_path,
        depth_path=None,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "eligen_loose",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        max_entity_len: int = 2,
        aug: bool = False,
    ):
        self.data_path = data_path
        self.depth_path = depth_path
        self.json_path = json_path
        self.base_dataset = []
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.max_entity_len = max_entity_len
        self.aug = aug

        self.to_tensor = T.ToTensor()
        self.load_dataset()
    
    def load_dataset(self):
        with open(self.json_path, 'r', encoding='utf-8') as file:
            for line in file:
                obj = json.loads(line)
                # if os.path.exists(f"{self.depth_path}/{obj['image_id']}"):
                #     self.base_dataset.append(obj)
                self.base_dataset.append(obj)

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

    def get_batch(self, idx):
        flip_flag = random.random() > 0.5 and self.aug

        image_id = self.base_dataset[idx]['image_id']
        try:
            image = Image.open(f"{self.data_path}/{image_id}.png")
        except:
            image = Image.open(f"{self.data_path}/{image_id}/{image_id}.png")
        image = image.resize((self.target_size, self.target_size)).convert("RGB")
        if flip_flag: image = image.transpose(Image.FLIP_LEFT_RIGHT)

        description = self.base_dataset[idx]["caption"]
        entities = self.base_dataset[idx]["entities"]

        position_delta = np.array([0, 0])

        # Get the condition image
        # Get the eligen entity prompts and masks
        condition_imgs = []
        eligen_entity_prompts = []
        eligen_entity_masks = []
        for i in range(len(entities)):
            depth_path = f"{self.depth_path}/{image_id}/render_depth_{i}.png"
            # depth_path = f"{self.depth_path}/{image_id}/{entities[i].get('depth', f'render_depth_{i}.png')}"
            condition_img = Image.open(depth_path).resize((self.condition_size, self.condition_size)).convert("RGB")
            if flip_flag: condition_img = condition_img.transpose(Image.FLIP_LEFT_RIGHT)
            condition_imgs.append(condition_img)

            eligen_entity_prompts.append(entities[i]["entity"])
            mask = np.array(condition_img.resize((self.condition_size//8, self.condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks.append(mask)

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_imgs = [Image.new(
                "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
            )] * len(entities)

        condition_imgs = torch.stack([self.to_tensor(condition_img) for condition_img in condition_imgs])
        # 创建非零掩码并计算均值
        mask = condition_imgs != 0
        sum_per_img = (condition_imgs * mask).sum(dim=(1, 2, 3))  # 各图像非零元素总和
        count_per_img = mask.sum(dim=(1, 2, 3)).float()           # 各图像非零元素数量
        means = sum_per_img / count_per_img                       # 计算非零均值

        # 处理全零图像（将NaN转为0）
        means = torch.nan_to_num(means, nan=0.0)

        # 获取降序排列索引（均值大的在前）
        sorted_indices = torch.argsort(means, descending=True).tolist()
        if self.max_entity_len > -1 and len(sorted_indices) > self.max_entity_len:
            sorted_indices = sorted_indices[:self.max_entity_len]

        return {
            "image": self.to_tensor(image),
            "condition": condition_imgs[sorted_indices],
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
            "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
            "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
        }

class EligenLoose2DDataset(Dataset):
    def __init__(
        self,
        data_path,
        json_path,
        depth_path=None,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "eligen_depth",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        max_entity_len: int = 2,
        aug: bool = False,
    ):
        self.data_path = data_path
        self.depth_path = depth_path
        self.json_path = json_path
        self.base_dataset = []
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.max_entity_len = max_entity_len
        self.aug = aug

        self.to_tensor = T.ToTensor()
        self.load_dataset()
    
    def load_dataset(self):
        with open(self.json_path, 'r', encoding='utf-8') as file:
            for line in file:
                obj = json.loads(line)
                self.base_dataset.append(obj)

    def __len__(self):
        return len(self.base_dataset)

    @property
    def depth_pipe(self):
        if not hasattr(self, "_depth_pipe"):
            self._depth_pipe = pipeline(
                task="depth-estimation",
                model="LiheYoung/depth-anything-small-hf",
                device="cpu",
            )

        return self._depth_pipe

    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

    def get_batch(self, idx):
        flip_flag = random.random() > 0.5 and self.aug

        image_id = self.base_dataset[idx]['image_id']
        try:
            image = Image.open(f"{self.data_path}/{image_id}.png")
        except:
            image = Image.open(f"{self.data_path}/{image_id}/{image_id}.png")
        image = image.resize((self.target_size, self.target_size)).convert("RGB")
        if flip_flag: image = image.transpose(Image.FLIP_LEFT_RIGHT)

        description = self.base_dataset[idx]["caption"]
        entities = self.base_dataset[idx]["entities"]

        position_delta = np.array([0, 0])

        # Get the condition image
        # Get the eligen entity prompts and masks
        condition_imgs = []
        eligen_entity_prompts = []
        eligen_entity_masks = []
        for i in range(len(entities)):
            depth_path = f"{self.depth_path}/{image_id}/render_depth_{i}.png"
            # depth_path = f"{self.depth_path}/{image_id}/{entities[i].get('depth', f'render_depth_{i}.png')}"
            condition_img = Image.open(depth_path).resize((self.condition_size, self.condition_size)).convert("RGB")
            if flip_flag: condition_img = condition_img.transpose(Image.FLIP_LEFT_RIGHT)
            condition_imgs.append(condition_img)

            eligen_entity_prompts.append(entities[i]["entity"])
            mask = np.array(condition_img.resize((self.condition_size//8, self.condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks.append(mask)

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_imgs = [Image.new(
                "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
            )] * len(entities)

        condition_imgs = torch.stack([self.to_tensor(condition_img) for condition_img in condition_imgs])
        return {
            "image": self.to_tensor(image),
            # "condition": condition_imgs,
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
            "eligen_entity_prompts": eligen_entity_prompts[:self.max_entity_len],
            "eligen_entity_masks": eligen_entity_masks[:self.max_entity_len],
        }
        # # 创建非零掩码并计算均值
        # mask = condition_imgs != 0
        # sum_per_img = (condition_imgs * mask).sum(dim=(1, 2, 3))  # 各图像非零元素总和
        # count_per_img = mask.sum(dim=(1, 2, 3)).float()           # 各图像非零元素数量
        # means = sum_per_img / count_per_img                       # 计算非零均值

        # # 处理全零图像（将NaN转为0）
        # means = torch.nan_to_num(means, nan=0.0)

        # # 获取降序排列索引（均值大的在前）
        # sorted_indices = torch.argsort(means, descending=True).tolist()
        # if self.max_entity_len > -1 and len(sorted_indices) > self.max_entity_len:
        #     sorted_indices = sorted_indices[:self.max_entity_len]

        # return {
        #     "image": self.to_tensor(image),
        #     "condition": condition_imgs[sorted_indices],
        #     "condition_type": self.condition_type,
        #     "description": description,
        #     "position_delta": position_delta,
        #     "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
        #     "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
        # }

class EligenPoseDataset(Dataset):
    def __init__(
        self,
        data_path,
        json_path,
        depth_path=None,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "eligen_loose",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        max_entity_len: int = 2,
        aug: bool = False,
    ):
        self.data_path = data_path
        self.depth_path = depth_path
        self.json_path = json_path
        self.base_dataset = []
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.max_entity_len = max_entity_len
        self.aug = aug

        self.to_tensor = T.ToTensor()
        self.load_dataset()
    
    def load_dataset(self):
        with open(self.json_path, 'r', encoding='utf-8') as file:
            for line in file:
                obj = json.loads(line)
                self.base_dataset.append(obj)

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

    def get_batch(self, idx):
        flip_flag = random.random() > 0.5 and self.aug

        image_id = self.base_dataset[idx]['image_id']
        try:
            image = Image.open(f"{self.data_path}/{image_id}.png")
        except:
            image = Image.open(f"{self.data_path}/{image_id}/{image_id}.png")
        image = image.resize((self.target_size, self.target_size)).convert("RGB")
        if flip_flag: image = image.transpose(Image.FLIP_LEFT_RIGHT)

        description = self.base_dataset[idx]["caption"]
        entities = self.base_dataset[idx]["entities"]
        cam_entity_idxs = self.base_dataset[idx]["cam_entity_idx"]

        position_delta = np.array([0, 0])

        # Get the condition image
        # Get the eligen entity prompts and masks
        condition_imgs = []
        cam_entity_idx = []
        orient = []
        orient_bboxs = []
        eligen_entity_prompts = []
        eligen_entity_masks = []
        for i in range(len(entities)):
            if self.max_entity_len > -1 and i == self.max_entity_len:
                break
            
            if i in cam_entity_idxs:
                eligen_entity_prompts.append('<extra_id_0>' + entities[i]["entity"])
            else:
                eligen_entity_prompts.append(entities[i]["entity"])
            coordinates = entities[i]["bbox"]
            # Convert percentages to pixel coordinates
            x_min = int(self.condition_size * coordinates[0])
            y_min = int(self.condition_size * coordinates[1])
            x_max = int(self.condition_size * coordinates[2])
            y_max = int(self.condition_size * coordinates[3])
            if flip_flag: x_min, x_max = self.condition_size - x_max, self.condition_size - x_min # 水平翻转2d bbox
            
            # Create binary mask
            mask = Image.new("L", (self.condition_size, self.condition_size), 0)
            draw = ImageDraw.Draw(mask)
            draw.rectangle([x_min, y_min, x_max, y_max], fill=255)
            mask = mask.convert("RGB")

            mask = np.array(mask.resize((self.condition_size//8, self.condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks.append(mask)
            
            if i in cam_entity_idxs:
                phi, theta, delta, confidence = entities[i]["pose"]
                if flip_flag: phi = 360 - phi
                condition = torch.tensor([phi, theta, delta])
                condition_imgs.append(condition)
                cam_entity_idx.append(i)
                orient.append([phi, theta, delta])
                orient_bboxs.append([x_min, y_min, x_max, y_max])

        condition_imgs = torch.stack(condition_imgs)

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_imgs = torch.zeros_like(condition_imgs)
        
        if self.max_entity_len > -1 and len(eligen_entity_prompts)>self.max_entity_len:
            eligen_entity_prompts = eligen_entity_prompts[:self.max_entity_len]
            eligen_entity_masks = eligen_entity_masks[:self.max_entity_len]

        return {
            "image": self.to_tensor(image),
            # "condition": condition_imgs,
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            "cam_entity_idx": cam_entity_idx,
            "orient": orient,
            "orient_bboxs": orient_bboxs,
        }

class EligenCameraDataset(Dataset):
    def __init__(
        self,
        data_path,
        json_path,
        depth_path=None,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "eligen_loose",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        max_entity_len: int = 2,
        aug: bool = False,
    ):
        self.data_path = data_path
        self.depth_path = depth_path
        self.json_path = json_path
        self.base_dataset = []
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.max_entity_len = max_entity_len
        self.aug = aug

        self.to_tensor = T.ToTensor()
        self.load_dataset()
    
    def load_dataset(self):
        with open(self.json_path, 'r', encoding='utf-8') as file:
            for line in file:
                obj = json.loads(line)
                self.base_dataset.append(obj)

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

    def get_batch(self, idx):
        def pose_latent(phi, theta, delta):
            def rotation_matrix_azimuth(phi):
                """Compute rotation matrix for azimuth angle (around the z-axis)."""
                return torch.tensor([
                    [torch.cos(phi), -torch.sin(phi), 0],
                    [torch.sin(phi), torch.cos(phi), 0],
                    [0, 0, 1]
                ])

            def rotation_matrix_polar(theta):
                """Compute rotation matrix for polar angle (around the y-axis)."""
                return torch.tensor([
                    [torch.cos(theta), 0, torch.sin(theta)],
                    [0, 1, 0],
                    [-torch.sin(theta), 0, torch.cos(theta)]
                ])

            def rotation_matrix_camera(delta):
                """Compute rotation matrix for camera rotation (around the x-axis)."""
                return torch.tensor([
                    [1, 0, 0],
                    [0, torch.cos(delta), -torch.sin(delta)],
                    [0, torch.sin(delta), torch.cos(delta)]
                ])

            """Combine the three rotation matrices (camera rotation, azimuth, polar)."""
            # Compute individual rotation matrices
            R_azimuth = rotation_matrix_azimuth(torch.deg2rad(torch.tensor(phi)))
            R_polar = rotation_matrix_polar(torch.deg2rad(torch.tensor(theta)))
            R_camera = rotation_matrix_camera(torch.deg2rad(torch.tensor(delta)))
            
            # Combine the matrices (R_camera * R_polar * R_azimuth)
            R_combined = torch.mm(R_camera, torch.mm(R_polar, R_azimuth))
            
            return R_combined.flatten()

        flip_flag = random.random() > 0.5 and self.aug

        image_id = self.base_dataset[idx]['image_id']
        try:
            image = Image.open(f"{self.data_path}/{image_id}.png")
        except:
            image = Image.open(f"{self.data_path}/{image_id}/{image_id}.png")
        image = image.resize((self.target_size, self.target_size)).convert("RGB")
        if flip_flag: image = image.transpose(Image.FLIP_LEFT_RIGHT)

        description = self.base_dataset[idx]["caption"]
        entities = self.base_dataset[idx]["entities"]
        cam_entity_idxs = self.base_dataset[idx]["cam_entity_idx"]

        position_delta = np.array([0, 0])

        # Get the condition image
        # Get the eligen entity prompts and masks
        condition_imgs = []
        cam_entity_idx = []
        orient = []
        orient_bboxs = []
        eligen_entity_prompts = []
        eligen_entity_masks = []
        for i in range(len(entities)):
            if self.max_entity_len > -1 and i == self.max_entity_len:
                break
            
            # depth_path = f"{self.depth_path}/{image_id}/render_depth_{i}.png"
            # # depth_path = f"{self.depth_path}/{image_id}/{entities[i].get('depth', f'render_depth_{i}.png')}"
            # condition_img = Image.open(depth_path).resize((self.condition_size, self.condition_size)).convert("RGB")
            # if flip_flag: condition_img = condition_img.transpose(Image.FLIP_LEFT_RIGHT)

            # eligen_entity_prompts.append(entities[i]["entity"])
            # mask = np.array(condition_img.resize((self.condition_size//8, self.condition_size//8)))
            # mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            # eligen_entity_masks.append(mask)
            
            eligen_entity_prompts.append(entities[i]["entity"])
            coordinates = entities[i]["bbox"]
            # Convert percentages to pixel coordinates
            x_min = int(self.condition_size * coordinates[0])
            y_min = int(self.condition_size * coordinates[1])
            x_max = int(self.condition_size * coordinates[2])
            y_max = int(self.condition_size * coordinates[3])
            if flip_flag: x_min, x_max = self.condition_size - x_max, self.condition_size - x_min # 水平翻转2d bbox
            
            # Create binary mask
            mask = Image.new("L", (self.condition_size, self.condition_size), 0)
            draw = ImageDraw.Draw(mask)
            draw.rectangle([x_min, y_min, x_max, y_max], fill=255)
            mask = mask.convert("RGB")

            mask = np.array(mask.resize((self.condition_size//8, self.condition_size//8)))
            mask = np.where(mask > 0, 1, 0).astype(np.uint8)
            eligen_entity_masks.append(mask)
            
            if i in cam_entity_idxs:
                phi, theta, delta, confidence = entities[i]["pose"]
                if flip_flag: phi = 360 - phi
                condition = pose_latent(phi, theta, delta)
                condition_imgs.append(condition)
                cam_entity_idx.append(i)
                orient.append([phi, theta, delta])
                orient_bboxs.append([x_min, y_min, x_max, y_max])

        condition_imgs = torch.stack(condition_imgs)

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""
        if drop_image:
            condition_imgs = torch.zeros_like(condition_imgs)

        if self.max_entity_len > -1 and len(eligen_entity_prompts)>self.max_entity_len:
            eligen_entity_prompts = eligen_entity_prompts[:self.max_entity_len]
            eligen_entity_masks = eligen_entity_masks[:self.max_entity_len]

        return {
            "image": self.to_tensor(image),
            "condition": condition_imgs,
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
            "eligen_entity_prompts": eligen_entity_prompts,
            "eligen_entity_masks": eligen_entity_masks,
            "cam_entity_idx": cam_entity_idx,
            "orient": orient,
            "orient_bboxs": orient_bboxs,
        }

class PoseDataset(Dataset):
    def __init__(
        self,
        data_path,
        json_path,
        depth_path=None,
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "eligen_loose",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        max_entity_len: int = 2,
        aug: bool = False,
    ):
        self.data_path = data_path
        self.depth_path = depth_path
        self.json_path = json_path
        self.base_dataset = []
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.max_entity_len = max_entity_len
        self.aug = aug

        self.load_dataset()

    def load_dataset(self):
        # 生成自然语言描述
        def format_objects(obj_list, dir_list=None):
            if dir_list is None:
                if len(obj_list) == 1:
                    return f"a {obj_list[0]}"
                return "".join([f"a {obj}" for obj in obj_list[:-1]]) + f" and a {obj_list[-1]}"
            else:
                if len(obj_list) == 1:
                    return f"a {obj_list[0]} {dir_list[0]}"
                return "".join([f"a {obj} {dir}" for obj, dir in zip(obj_list[:-1], dir_list[:-1])]) + f" and a {obj_list[-1]} {dir_list[-1]}"
        
        directions = [
            "facing front", "facing back", "facing left", "facing right", 
            "facing left-front", "facing left-back", 
            "facing right-front", "facing right-back",
        ]

        train_obj = ["helicopter", "cat", "dog", "jeep", "teddy bear", "lion", "sedan", "horse", "motorbike", "sofa", "person"]
        # train_obj = ["ostrich", "helicopter", "shoe", "jeep", "teddy bear", "lion", "sedan", "horse", "motorbike", "sofa"]

        self.base_dataset = []
        scenes = prompt_templates.keys()
        for scene in scenes:
            prompt_list = prompt_templates[scene]
            # object_list = object_templates[scene]
            object_list = train_obj
            all_permutations = [
                p for r in range(1, self.max_entity_len+1) 
                for p in itertools.permutations(object_list, r)
            ]
            for prompt in prompt_list:
                for objs in all_permutations:
                    dirs_list = [
                        p for p in itertools.permutations(directions, len(objs))
                    ]
                    for dirs in dirs_list:
                        for i in range(10):
                            self.base_dataset.append((prompt.replace("<subject>", format_objects(objs)), prompt.replace("<subject>", format_objects(objs, dirs)), objs, i))

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

    def get_batch(self, idx):
        data = self.base_dataset[idx]
        description = data[0]
        description_dir = data[1]
        eligen_entity_prompts = [d for d in data[2]]
        idx = data[3]

        position_delta = np.array([0, 0])

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        if drop_text:
            description = ""
        return {
            "condition_type": self.condition_type,
            "description": description,
            "description_dir": description_dir,
            "position_delta": position_delta,
            "eligen_entity_prompts": eligen_entity_prompts,
            "idx": str(idx),
        }

class Camera(object):
    def __init__(self, entry):
        fx, fy, cx, cy = entry[:4]
        self.fx = fx
        self.fy = fy
        self.cx = cx
        self.cy = cy
        w2c_mat = np.array(entry[6:]).reshape(3, 4)
        w2c_mat_4x4 = np.eye(4)
        w2c_mat_4x4[:3, :] = w2c_mat
        self.w2c_mat = w2c_mat_4x4
        self.c2w_mat = np.linalg.inv(w2c_mat_4x4)

class RealEstate10KPose_image(Dataset):
    def __init__(
            self,
            root_path,
            data_path,
            caption_path,
            # dataset config
            condition_size: int = 512,
            target_size: int = 512,
            condition_type: str = "camera",
            drop_image_prob: float = 0.1,
            drop_text_prob: float = 0.1,
            mix_data_rate: float = 0.0,
            mix_data_path = None,
            # video config
            sample_stride=4,
            minimum_sample_stride=1,
            sample_n_frames=16,
            flip_rate=0.,
    ):
        self.root_path = root_path
        self.data_path = data_path
        self.caption_path = caption_path

        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob

        self.sample_stride = sample_stride
        self.minimum_sample_stride = minimum_sample_stride
        self.sample_n_frames = sample_n_frames
        self.flip_rate = flip_rate

        self.mix_data_rate = mix_data_rate
        self.mix_data_path = mix_data_path

        self.init_dataset()
        self.to_tensor = T.ToTensor()
    
    def __len__(self):
        return len(self.dataset)

    def init_dataset(self):
        self.dataset = []

        with open(self.caption_path, 'r') as f:
            captions = json.load(f)

        for data_file in os.listdir(os.path.join(self.data_path, 'train')):
            caption = captions.get(data_file.replace('.txt', '.mp4'), None)
            video_path = os.path.join(self.root_path, 'train', data_file[:-len('.txt')])
            if not caption or not os.path.exists(video_path):
                continue

            with open(os.path.join(self.data_path, 'train', data_file), 'r') as f:
                video_datas = f.readlines()
            video_datas = [data.strip().split(' ') for data in video_datas[1:]]

            total_frames = len(video_datas)
            current_sample_stride = self.sample_stride
            if total_frames < self.sample_n_frames * current_sample_stride:
                maximum_sample_stride = max(int(total_frames // self.sample_n_frames), 1)
                current_sample_stride = random.randint(self.minimum_sample_stride, maximum_sample_stride)
            cropped_length = self.sample_n_frames * current_sample_stride
            start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
            datas = video_datas[start_frame_ind:start_frame_ind + cropped_length:current_sample_stride]

            for data in datas:
                image, pose = data[0], data[1:]
                image_path = os.path.join(self.root_path, 'train', data_file[:-len('.txt')], image+'.png')
                cam_params = Camera([float(x) for x in pose])

                self.dataset.append((image_path, caption[0], cam_params))
            # if len(self.dataset) > 100:
            #     break

        if self.mix_data_rate != 0:
            self.mix_dataset = []
            if isinstance(self.mix_data_path, list):
                for json_file in self.mix_data_path:
                    with open(json_file, 'r') as f:
                        new_data = json.load(f)
                        self.mix_dataset.extend(new_data)
            else:
                with open(self.mix_data_path, 'r') as f:
                    self.mix_dataset = json.load(f)
            self.mix_dataset = random.sample(self.mix_dataset, int(len(self.dataset) * self.mix_data_rate))
            for data in self.mix_dataset:
                self.dataset.append((data["image"], data["prompt"]))

    def get_batch(self, idx):
        data = self.dataset[idx]
        if len(data) == 3:
            image, caption, cam_param = data
        else:
            image, caption = data
            cam_param = None
        image = Image.open(image).convert("RGB")
        image = resize_and_crop(image, (self.target_size, self.target_size))

        position_delta = np.array([0, 0])

        if cam_param is None:
            return {
                "image": self.to_tensor(image),
                "condition_type": self.condition_type,
                "description": caption,
                "position_delta": position_delta,
            }

        flip_flag = torch.rand(1) < self.flip_rate
        if flip_flag:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)

        intrinsics = np.asarray(
            [
                cam_param.fx * self.target_size,
                cam_param.fy * self.target_size,
                cam_param.cx * self.target_size,
                cam_param.cy * self.target_size
            ], 
            dtype=np.float32
        )
        
        intrinsics = torch.as_tensor(intrinsics)[None, None]                  # [1, 4]
        c2w_poses = np.array(cam_param.c2w_mat, dtype=np.float32)
        c2w = torch.as_tensor(c2w_poses)[None, None]                          # [1, 4, 4]

        plucker_embedding = encode_poses(
            intrinsics, 
            c2w, 
            self.target_size, 
            self.target_size, 
            device='cpu',
            flip_flag=flip_flag
        )[0].permute(0, 3, 1, 2).contiguous().squeeze(0).squeeze(0)

        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        drop_image = random.random() < self.drop_image_prob
        if drop_text:
            caption = ""
        if drop_image:
            plucker_embedding = torch.zeros_like(plucker_embedding).to(device=plucker_embedding.device, dtype=plucker_embedding.dtype)

        return {
            "image": self.to_tensor(image),
            "condition": plucker_embedding,
            "condition_type": self.condition_type,
            "description": caption,
            "position_delta": position_delta,
        }
    
    def __getitem__(self, idx):
        while True:
            try:
                data = self.get_batch(idx)
                break

            except Exception as e:
                print(f'sample idx:{idx} failed:', e)
                idx = random.randint(0, len(self) - 1)
        # data = self.get_batch(idx)
        
        return data

def resize_and_crop(image, output_size):
    """
    使用 PIL 实现短边 resize 和中心裁剪。
    
    参数:
        image_path (str): 输入图像的路径。
        output_size (tuple): 最终裁剪后图像的大小 (width, height)。
        save_path (str): 保存处理后图像的路径。如果为 None，则不保存。
        
    返回:
        PIL.Image.Image: 处理后的图像对象。
    """
    # 获取原始图像宽高
    original_width, original_height = image.size
    
    # 短边调整到目标大小，同时保持长宽比
    target_short_side = min(output_size)
    if original_width < original_height:
        new_width = target_short_side
        new_height = int(original_height * (target_short_side / original_width))
    else:
        new_height = target_short_side
        new_width = int(original_width * (target_short_side / original_height))
    
    # Resize 图像
    img_resized = image.resize((new_width, new_height), Image.LANCZOS)
    
    # 中心裁剪
    left = (new_width - output_size[0]) / 2
    top = (new_height - output_size[1]) / 2
    right = left + output_size[0]
    bottom = top + output_size[1]
    
    img_cropped = img_resized.crop((left, top, right, bottom))
    
    return img_cropped

def custom_meshgrid(*args):
    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
    if pver.parse(torch.__version__) < pver.parse('1.10'):
        return torch.meshgrid(*args)
    else:
        return torch.meshgrid(*args, indexing='ij')

def encode_poses(K, c2w, H, W, device, dtype=None, flip_flag=None):
    # c2w: B, V, 4, 4
    # K: B, V, 4
    dtype = dtype if dtype is not None else c2w.dtype

    B, V = K.shape[:2]

    j, i = custom_meshgrid(
        torch.linspace(0, H - 1, H, device=device, dtype=dtype),
        torch.linspace(0, W - 1, W, device=device, dtype=dtype),
    )
    i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5          # [B, V, HxW]
    j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5          # [B, V, HxW]

    n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
    if n_flip > 0:
        j_flip, i_flip = custom_meshgrid(
            torch.linspace(0, H - 1, H, device=device, dtype=dtype),
            torch.linspace(W - 1, 0, W, device=device, dtype=dtype)
        )
        i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        i[:, flip_flag, ...] = i_flip
        j[:, flip_flag, ...] = j_flip

    fx, fy, cx, cy = K.chunk(4, dim=-1)     # B,V, 1

    zs = torch.ones_like(i)                 # [B, V, HxW]
    xs = (i - cx) / fx * zs
    ys = (j - cy) / fy * zs
    zs = zs.expand_as(ys)

    directions = torch.stack((xs, ys, zs), dim=-1)              # B, V, HW, 3
    directions = directions / directions.norm(dim=-1, keepdim=True)             # B, V, HW, 3

    rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)        # B, V, HW, 3
    rays_o = c2w[..., :3, 3]                                        # B, V, 3
    rays_o = rays_o[:, :, None].expand_as(rays_d)                   # B, V, HW, 3
    # c2w @ dirctions
    # rays_dxo = torch.cross(rays_o, rays_d)                          # B, V, HW, 3
    rays_dxo = torch.linalg.cross(rays_o, rays_d)                          # B, V, HW, 3
    plucker = torch.cat([rays_dxo, rays_d], dim=-1)
    plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)             # B, V, H, W, 6
    # plucker = plucker.permute(0, 1, 4, 2, 3)
    return plucker

if __name__ == "__main__":
    root_path = "/mnt/nas_jianchong/datasets/RealEstate10K/copy1/dataset"
    data_path = "/mnt/nas_jianchong/datasets/RealEstate10K/copy1/RealEstate10K"
    caption_path = "RealEstate10K/train_captions.json"
    dataset = RealEstate10KPose_image(
        root_path,
        data_path,
        caption_path,
    )

    a = dataset[0]
    print("Dataset length:", len(dataset))