# import lightning as L
import pytorch_lightning as L
from PIL import Image, ImageFilter, ImageDraw
import numpy as np
from transformers import pipeline
import cv2
import numpy as np
import torch
import torchvision.transforms as T
import os
from tqdm import tqdm
import json
from torchvision.transforms import ToPILImage, ToTensor
import depth_pro.depth_pro as depth_pro

try:
    import wandb
except ImportError:
    wandb = None

from src.flux.condition import Condition
from src.flux.generate import generate, seed_everything
from src.flux.pipeline_tools import Camera
from src.flux.pipeline_tools import visualize_masks

def sort_depth(condition_imgs):
    """Sort images by depth (non-zero pixel mean values in descending order)"""
    mask = condition_imgs != 0
    sum_per_img = (condition_imgs * mask).sum(dim=(1, 2, 3))
    count_per_img = mask.sum(dim=(1, 2, 3)).float()
    means = sum_per_img / count_per_img
    means = torch.nan_to_num(means, nan=0.0)
    return torch.argsort(means, descending=True).tolist()

def loose_condition_data(data_path, image_id, entities, device, dtype, condition_size=512):
    # Load and process condition images
    condition_paths = []
    for i in range(len(entities)):
        path = f'{data_path}/{image_id}/render_depth_{i}.png'
        condition_paths.append(path)

    condition_imgs = [
        Image.open(condition_paths[i])
            .resize((condition_size, condition_size))
            .convert("RGB") 
        for i in range(len(entities))
    ]
    eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

    # Process masks
    eligen_entity_masks = []
    eligen_entity_masks_pil = []
    for img in condition_imgs:
        # Create downsampled mask for model input
        mask = np.array(img.resize((condition_size//8, condition_size//8)))
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
        eligen_entity_masks.append(mask_tensor.unsqueeze(0))
        
        # Create full resolution mask for visualization
        mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
        eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
    
    # Convert images to tensors and sort by depth
    condition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])
    sorted_indices = sort_depth(condition_imgs)
    
    # Create final condition object
    condition_data = {
        "condition": condition_imgs[sorted_indices],
        "eligen_entity_prompts": [eligen_entity_prompts[idx] for idx in sorted_indices],
        "eligen_entity_masks": [eligen_entity_masks[idx] for idx in sorted_indices],
        'eligen_entity_masks_pil': [eligen_entity_masks_pil[idx] for idx in sorted_indices],
    }

    return condition_data

def loose_2d_condition_data(data_path, image_id, entities, device, dtype, condition_size=512):
    # Load and process condition images
    condition_paths = []
    for i in range(len(entities)):
        path = f'{data_path}/{image_id}/render_depth_{i}.png'
        condition_paths.append(path)

    condition_imgs = [
        Image.open(condition_paths[i])
            .resize((condition_size, condition_size))
            .convert("RGB") 
        for i in range(len(entities))
    ]
    eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

    # Process masks
    eligen_entity_masks = []
    eligen_entity_masks_pil = []
    for img in condition_imgs:
        # Create downsampled mask for model input
        mask = np.array(img.resize((condition_size//8, condition_size//8)))
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
        eligen_entity_masks.append(mask_tensor.unsqueeze(0))
        
        # Create full resolution mask for visualization
        mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
        eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))
    
    # Convert images to tensors and sort by depth
    condition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])
    condition_imgs = condition_imgs.max(dim=0, keepdim=True)[0]
    
    # Create final condition object
    condition_data = {
        # "condition": condition_imgs,
        "condition": None,
        "eligen_entity_prompts": eligen_entity_prompts,
        "eligen_entity_masks": eligen_entity_masks,
        'eligen_entity_masks_pil': eligen_entity_masks_pil,
    }

    return condition_data

def eligen_camera_data(data_path, image_id, entities, device, dtype, condition_size=512):
    def pose_latent(phi, theta, delta):
        def rotation_matrix_azimuth(phi):
            """Compute rotation matrix for azimuth angle (around the z-axis)."""
            return torch.tensor([
                [torch.cos(phi), -torch.sin(phi), 0],
                [torch.sin(phi), torch.cos(phi), 0],
                [0, 0, 1]
            ])

        def rotation_matrix_polar(theta):
            """Compute rotation matrix for polar angle (around the y-axis)."""
            return torch.tensor([
                [torch.cos(theta), 0, torch.sin(theta)],
                [0, 1, 0],
                [-torch.sin(theta), 0, torch.cos(theta)]
            ])

        def rotation_matrix_camera(delta):
            """Compute rotation matrix for camera rotation (around the x-axis)."""
            return torch.tensor([
                [1, 0, 0],
                [0, torch.cos(delta), -torch.sin(delta)],
                [0, torch.sin(delta), torch.cos(delta)]
            ])

        """Combine the three rotation matrices (camera rotation, azimuth, polar)."""
        # Compute individual rotation matrices
        R_azimuth = rotation_matrix_azimuth(torch.deg2rad(torch.tensor(phi)))
        R_polar = rotation_matrix_polar(torch.deg2rad(torch.tensor(theta)))
        R_camera = rotation_matrix_camera(torch.deg2rad(torch.tensor(delta)))
        
        # Combine the matrices (R_camera * R_polar * R_azimuth)
        R_combined = torch.mm(R_camera, torch.mm(R_polar, R_azimuth))
        
        return R_combined.flatten()

    # Load and process condition images
    condition_paths = []
    for i in range(len(entities)):
        path = f'{data_path}/{image_id}/render_depth_{i}.png'
        condition_paths.append(path)

    condition_imgs = [
        Image.open(condition_paths[i])
            .resize((condition_size, condition_size))
            .convert("RGB") 
        for i in range(len(entities))
    ]
    eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

    condition = []
    # Process masks
    eligen_entity_masks = []
    eligen_entity_masks_pil = []
    for i, img in enumerate(condition_imgs):
        # Create downsampled mask for model input
        mask = np.array(img.resize((condition_size//8, condition_size//8)))
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
        eligen_entity_masks.append(mask_tensor.unsqueeze(0))
        
        # Create full resolution mask for visualization
        mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
        eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))

        phi, theta, delta = entities[i].get('pose', [45, 90, 0])
        condition_pose = pose_latent(phi, theta, delta)
        condition.append(condition_pose)

    condition = torch.stack(condition).to(device=device, dtype=dtype)
    cam_entity_idx = list(range(len(entities)))
    
    # Create final condition object
    condition_data = {
        "condition": condition,
        "eligen_entity_prompts": eligen_entity_prompts,
        "eligen_entity_masks": eligen_entity_masks,
        'eligen_entity_masks_pil': eligen_entity_masks_pil,
        "cam_entity_idx": cam_entity_idx,
    }

    return condition_data

def eligen_pose_data(data_path, image_id, entities, device, dtype, condition_size=512):
    # Load and process condition images
    condition_paths = []
    for i in range(len(entities)):
        path = f'{data_path}/{image_id}/render_depth_{i}.png'
        condition_paths.append(path)
    
    condition_imgs = [
        Image.open(condition_paths[i])
            .resize((condition_size, condition_size))
            .convert("RGB") 
        for i in range(len(entities))
    ]
    eligen_entity_prompts = ['<extra_id_0>' + entities[i]['entity'] for i in range(len(entities))]

    condition = []
    orient = []
    # Process masks
    eligen_entity_masks = []
    eligen_entity_masks_pil = []
    for i, img in enumerate(condition_imgs):
        # Create downsampled mask for model input
        mask = np.array(img.resize((condition_size//8, condition_size//8)))
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
        eligen_entity_masks.append(mask_tensor.unsqueeze(0))
        
        # Create full resolution mask for visualization
        mask_pil = np.where(np.array(img) > 0, 1, 0).astype(np.uint8)
        eligen_entity_masks_pil.append(Image.fromarray(mask_pil*255))

        phi, theta, delta = entities[i].get('pose', [45, 0, 0])
        print(phi, theta, delta)
        orient.append([phi, theta, delta])

    cam_entity_idx = list(range(len(entities)))
    
    # Create final condition object
    condition_data = {
        "condition": None,
        "eligen_entity_prompts": eligen_entity_prompts[-1:],
        "eligen_entity_masks": eligen_entity_masks[-1:],
        'eligen_entity_masks_pil': eligen_entity_masks_pil[-1:],
        "cam_entity_idx": cam_entity_idx[-1:],
        "orient": orient[-1:],
    }

    return condition_data

def depth_condition_data(depth_model, depth_transform, data_path, image_id, entities, device, dtype, condition_size=512):
    # Load and process condition images
    condition_img = (
        Image.open(f"{data_path}/{image_id}/{image_id}.png")
        .resize((condition_size, condition_size))
        .convert("RGB")
    )
    condition_img = depth_transform(np.array(condition_img)).to(device=device, dtype=dtype)
    depth = depth_model.infer(condition_img, f_px=None)["depth"][None]
    depth_img = np.array(ToPILImage()(((depth.max()-depth[0])/(depth.max()-depth.min()))))
    condition_imgs = [Image.fromarray(depth_img).convert("RGB")]
    eligen_entity_prompts = [entities[i]['entity'] for i in range(len(entities))]

    # Process masks
    eligen_entity_masks = []
    eligen_entity_masks_pil = []
    for i in range(len(entities)):
        eligen_entity_prompts.append(entities[i]["entity"])
        coordinates = entities[i]["bbox"]
        # Convert percentages to pixel coordinates
        x_min = int(condition_size * coordinates[0])
        y_min = int(condition_size * coordinates[1])
        x_max = int(condition_size * coordinates[2])
        y_max = int(condition_size * coordinates[3])
        
        # Create binary mask
        mask = Image.new("L", (condition_size, condition_size), 0)
        draw = ImageDraw.Draw(mask)
        draw.rectangle([x_min, y_min, x_max, y_max], fill=255)
        mask = mask.convert("RGB")
        eligen_entity_masks_pil.append(mask)

        mask = np.array(mask.resize((condition_size//8, condition_size//8)))
        mask = np.where(mask > 0, 1, 0).astype(np.uint8)
        mask_tensor = torch.from_numpy(mask).to(device=device, dtype=dtype)
        eligen_entity_masks.append(mask_tensor.unsqueeze(0))
    
    # Convert images to tensors and sort by depth
    condition_imgs = torch.stack([T.ToTensor()(img) for img in condition_imgs])
    
    # Create final condition object
    condition_data = {
        "condition": condition_imgs,
        "eligen_entity_prompts": eligen_entity_prompts,
        "eligen_entity_masks": eligen_entity_masks,
        'eligen_entity_masks_pil': eligen_entity_masks_pil,
    }

    return condition_data

class TrainingCallback(L.Callback):
    def __init__(self, run_name, training_config: dict = {}):
        self.run_name, self.training_config = run_name, training_config

        self.print_every_n_steps = training_config.get("print_every_n_steps", 1)
        self.save_interval = training_config.get("save_interval", 1000)
        self.sample_interval = training_config.get("sample_interval", 1000)
        self.save_path = training_config.get("save_path", "./output")

        self.wandb_config = training_config.get("wandb", None)
        self.use_wandb = (
            wandb is not None and os.environ.get("WANDB_API_KEY") is not None
        )

        self.total_steps = 0

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        gradient_size = 0
        max_gradient_size = 0
        count = 0
        for _, param in pl_module.named_parameters():
            if param.grad is not None:
                gradient_size += param.grad.norm(2).item()
                max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
                count += 1
        # for param in pl_module.trainable_params:
        #     if param.grad is not None:
        #         print("Parameter gradient exists:", param.name)
        #     else:
        #         print("Parameter gradient is None:", param.name)
        if count > 0:
            gradient_size /= count

        self.total_steps += 1

        # Print training progress every n steps
        if self.use_wandb:
            report_dict = {
                "steps": batch_idx,
                "steps": self.total_steps,
                "epoch": trainer.current_epoch,
                "gradient_size": gradient_size,
            }
            loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
            report_dict["loss"] = loss_value
            report_dict["t"] = pl_module.last_t
            wandb.log(report_dict)

        if self.total_steps % self.print_every_n_steps == 0:
            loss_message = ''
            for k,v in pl_module.log_loss.items():
                loss_message += f"{k}: {v:.4f}, "
            print(
                # f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, {loss_message}Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
            )

        # Save LoRA weights at specified intervals
        if self.total_steps % self.save_interval == 0:
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
            )
            pl_module.save_lora(
                f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
            )

            condition_type = pl_module.condition_type
            if "camera" in condition_type or "pose" in condition_type:
                pl_module.save_camera(
                    f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}/cam_embedder.pth"
                )
            if "loose" in condition_type and hasattr(pl_module.flux_pipe.transformer, "inter_controller"):
                pl_module.save_loose_condition(
                    f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}/inter_controller.pth"
                )
            if "flux" in condition_type and hasattr(pl_module.flux_pipe.transformer, "loose_embedder"):
                pl_module.save_flux(
                    f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}/loose_embedder.pth"
                )

        # Generate and save a sample image at specified intervals
        if self.total_steps % self.sample_interval == 0:
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
            )
            self.generate_a_sample(
                trainer,
                pl_module,
                f"{self.save_path}/{self.run_name}/output",
                f"lora_{self.total_steps}",
                batch["condition_type"][
                    0
                ],  # Use the condition type from the current batch
            )

    @torch.no_grad()
    def generate_a_sample(
        self,
        trainer,
        pl_module,
        save_path,
        file_name,
        condition_type="super_resolution",
    ):
        # TODO: change this two variables to parameters
        condition_size = trainer.training_config["dataset"]["condition_size"]
        target_size = trainer.training_config["dataset"]["target_size"]

        # generator = torch.Generator(device=pl_module.device)
        # generator.manual_seed(42)

        test_list = []

        if condition_type == "subject":
            test_list.extend(
                [
                    (
                        Image.open("assets/test_in.jpg"),
                        [0, -32],
                        "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
                    ),
                    (
                        Image.open("assets/test_out.jpg"),
                        [0, -32],
                        "In a bright room. It is placed on a table.",
                    ),
                ]
            )
        elif condition_type == "canny":
            condition_img = Image.open("assets/vase_hq.jpg").resize(
                (condition_size, condition_size)
            )
            condition_img = np.array(condition_img)
            condition_img = cv2.Canny(condition_img, 100, 200)
            condition_img = Image.fromarray(condition_img).convert("RGB")
            test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
        elif condition_type == "coloring":
            condition_img = (
                Image.open("assets/vase_hq.jpg")
                .resize((condition_size, condition_size))
                .convert("L")
                .convert("RGB")
            )
            test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
        elif condition_type == "depth":
            if not hasattr(self, "deepth_pipe"):
                self.depth_pipe = pipeline(
                    task="depth-estimation",
                    model="LiheYoung/depth-anything-small-hf",
                    device="cpu",
                )
            condition_img = (
                Image.open("assets/vase_hq.jpg")
                .resize((condition_size, condition_size))
                .convert("RGB")
            )
            condition_img = self.depth_pipe(condition_img)["depth"].convert("RGB")
            test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
        elif condition_type == "depth_pred":
            condition_img = (
                Image.open("assets/vase_hq.jpg")
                .resize((condition_size, condition_size))
                .convert("RGB")
            )
            test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
        elif condition_type == "deblurring":
            blur_radius = 5
            image = Image.open("./assets/vase_hq.jpg")
            condition_img = (
                image.convert("RGB")
                .resize((condition_size, condition_size))
                .filter(ImageFilter.GaussianBlur(blur_radius))
                .convert("RGB")
            )
            test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
        elif condition_type == "fill":
            condition_img = (
                Image.open("./assets/vase_hq.jpg")
                .resize((condition_size, condition_size))
                .convert("RGB")
            )
            mask = Image.new("L", condition_img.size, 0)
            draw = ImageDraw.Draw(mask)
            a = condition_img.size[0] // 4
            b = a * 3
            draw.rectangle([a, a, b, b], fill=255)
            condition_img = Image.composite(
                condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
            )
            test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
        elif condition_type == "sr":
            condition_img = (
                Image.open("assets/vase_hq.jpg")
                .resize((condition_size, condition_size))
                .convert("RGB")
            )
            test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
        elif "camera" == condition_type:
            poses = [
                "0.532139961 0.946026558 0.500000000 0.500000000 0.000000000 0.000000000 0.978989959 -0.010294991 -0.203648433 -0.000762398 -0.007398812 0.996273518 -0.085932352 -0.031535059 0.203774214 0.085633665 0.975265563 -0.153683138",
                "0.591079453 1.050807901 0.500000000 0.500000000 0.000000000 0.000000000 0.999850392 0.008867561 -0.014851474 0.359474093 -0.009553785 0.998859048 -0.046790831 0.313312825 0.014419609 0.046925720 0.998794317 -0.689018860",
                "0.501289166 0.891180703 0.500000000 0.500000000 0.000000000 0.000000000 0.998422921 0.009509329 0.055328209 -0.088786468 -0.006474003 0.998477280 -0.054783192 -0.013369506 -0.055764910 0.054338600 0.996964216 -1.103852641",
            ]
            for pose in poses:
                pose = pose.strip().split(' ')
                cam_param = Camera([float(x) for x in pose])
                intrinsics = torch.tensor(
                    [
                        cam_param.fx * target_size,
                        cam_param.fy * target_size,
                        cam_param.cx * target_size,
                        cam_param.cy * target_size
                    ], 
                    device=pl_module.flux_pipe.device, dtype=pl_module.flux_pipe.dtype
                )[None, None]
                
                c2w = torch.tensor(
                    cam_param.c2w_mat, 
                    device=pl_module.flux_pipe.device, 
                    dtype=pl_module.flux_pipe.dtype
                )[None, None]
                condition_img = dict(
                    K=intrinsics, c2w=c2w, H=target_size, W=target_size, device=pl_module.flux_pipe.device, dtype=pl_module.flux_pipe.dtype
                )
                test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
        elif "loose_condition" in condition_type:
            device=pl_module.flux_pipe.device
            dtype = pl_module.flux_pipe.dtype

            # data_path = "test/human_data"
            # condition_path = 'src/train/test.json'
            # # condition_path = 'src/train/test_.json'
            data_path = "/mnt/workspace/workgroup/zheliu.lzy/vision_cot/3d_box/datasets_new"
            condition_path = 'src/evaluate/condition.json'
            with open(condition_path, 'r') as f:
                conditions = json.load(f)
                print(f"Load {len(conditions)} conditions successfully")
            for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
                if idx not in [0,1,6,7,8,9,10,11,14,15,19]:
                    continue
                image_id = condition.get('image_id', int(idx))
                # image_id = str(idx)
                caption = condition['caption']
                entities = condition['entities']

                condition_data = loose_condition_data(data_path, image_id, entities, device, dtype)
                test_list.append((condition_data, [0, 0], caption))

            data_path = "test/human_data_1"
            condition_path = 'src/train/test.json'
            with open(condition_path, 'r') as f:
                conditions = json.load(f)
                print(f"Load {len(conditions)} conditions successfully")
            for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
                image_id = condition.get('image_id', int(idx))
                # image_id = str(idx)
                caption = condition['caption']
                entities = condition['entities']

                condition_data = loose_condition_data(data_path, image_id, entities, device, dtype)

                test_list.append((condition_data, [0, 0], caption))

            if hasattr(pl_module.flux_pipe.transformer, "inter_controller"):
                for name, param in pl_module.flux_pipe.transformer.inter_controller.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "loose_embedder"):
                for name, param in pl_module.flux_pipe.transformer.loose_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
        elif "eligen_loose_2d" in condition_type:
            device=pl_module.flux_pipe.device
            dtype = pl_module.flux_pipe.dtype

            data_path = "test/human_data_3"
            condition_path = 'test/human_data_3/condition.json'
            with open(condition_path, 'r') as f:
                conditions = json.load(f)
                print(f"Load {len(conditions)} conditions successfully")
            for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
                image_id = condition.get('image_id', int(idx))
                # image_id = str(idx)
                if image_id not in [1,2,3,4,5,6,7,32,33,38,39,48,49,54,55]:
                    continue
                caption = condition['caption']
                entities = condition['entities']

                condition_data = loose_2d_condition_data(data_path, image_id, entities, device, dtype)
                test_list.append((condition_data, [0, 0], caption))

            if hasattr(pl_module.flux_pipe.transformer, "inter_controller"):
                for name, param in pl_module.flux_pipe.transformer.inter_controller.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "loose_embedder"):
                for name, param in pl_module.flux_pipe.transformer.loose_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
        elif "eligen_loose" in condition_type:
            device=pl_module.flux_pipe.device
            dtype = pl_module.flux_pipe.dtype

            # data_path = "test/human_data_1"
            # condition_path = 'src/train/test.json'
            data_path = "test/human_data_3"
            condition_path = 'test/human_data_3/condition.json'
            with open(condition_path, 'r') as f:
                conditions = json.load(f)
                print(f"Load {len(conditions)} conditions successfully")
            for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
                image_id = condition.get('image_id', int(idx))
                # image_id = str(idx)
                if image_id not in [1,2,3,4,5,6,7,32,33,38,39,48,49,54,55]:
                    continue
                caption = condition['caption']
                entities = condition['entities']

                condition_data = loose_condition_data(data_path, image_id, entities, device, dtype)
                test_list.append((condition_data, [0, 0], caption))

            if hasattr(pl_module.flux_pipe.transformer, "inter_controller"):
                for name, param in pl_module.flux_pipe.transformer.inter_controller.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "loose_embedder"):
                for name, param in pl_module.flux_pipe.transformer.loose_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
        elif "eligen_depth" in condition_type:
            device=pl_module.flux_pipe.device
            dtype = pl_module.flux_pipe.dtype

            data_path = "/mnt/workspace/workgroup/zheliu.lzy/vision_cot/3d_box/datasets"
            condition_path = 'src/evaluate/condition.json'
            with open(condition_path, 'r') as f:
                conditions = json.load(f)
                print(f"Load {len(conditions)} conditions successfully")
            for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
                if idx not in [0,1,6,7,8,9,10,11,14,15,19]:
                    continue
                image_id = condition.get('image_id', int(idx))
                # image_id = str(idx)
                caption = condition['caption']
                entities = condition['entities']

                condition_data = depth_condition_data(pl_module.depth_model, pl_module.depth_transform, data_path, image_id, entities, device, dtype)
                test_list.append((condition_data, [0, 0], caption))

            if hasattr(pl_module.flux_pipe.transformer, "inter_controller"):
                for name, param in pl_module.flux_pipe.transformer.inter_controller.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "loose_embedder"):
                for name, param in pl_module.flux_pipe.transformer.loose_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
        elif "eligen_camera" in condition_type:
            device=pl_module.flux_pipe.device
            dtype = pl_module.flux_pipe.dtype

            data_path = "test/human_data_3"
            condition_path = 'test/human_data_3/condition.json'
            with open(condition_path, 'r') as f:
                conditions = json.load(f)
                print(f"Load {len(conditions)} conditions successfully")
            for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
                image_id = condition.get('image_id', int(idx))
                # image_id = str(idx)
                if image_id not in [1,2,3,4,5,6,7]:
                    continue
                caption = condition['caption']
                entities = condition['entities']

                condition_data = eligen_camera_data(data_path, image_id, entities, device, dtype)
                test_list.append((condition_data, [0, 0], caption))

            if hasattr(pl_module.flux_pipe.transformer, "inter_controller"):
                for name, param in pl_module.flux_pipe.transformer.inter_controller.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "loose_embedder"):
                for name, param in pl_module.flux_pipe.transformer.loose_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "cam_embedder"):
                for name, param in pl_module.flux_pipe.transformer.cam_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
        elif "pose" in condition_type:
            device=pl_module.flux_pipe.device
            dtype = pl_module.flux_pipe.dtype

            data_path = "test/human_data_4"
            condition_path = 'test/human_data_4/condition.json'
            with open(condition_path, 'r') as f:
                conditions = json.load(f)
                print(f"Load {len(conditions)} conditions successfully")
            for idx, condition in tqdm(enumerate(conditions), desc="🚀 Loading conditions", total=len(conditions)):
                image_id = condition.get('image_id', int(idx))
                # image_id = str(idx)
                if image_id not in [1,2,3,4,5,6,7]:
                    continue
                directions = [
                    "facing front", 
                    "facing right-front", 
                    "facing right", 
                    "facing right-back",
                    "facing back", 
                    "facing left-back", 
                    "facing left", 
                    "facing left-front", 
                ]
                caption = condition['caption']
                caption = f'a photo of a teddy bear by a riverside with wildflowers blooming nearby'
                entities = condition['entities']

                condition_data = eligen_pose_data(data_path, image_id, entities, device, dtype)
                test_list.append((condition_data, [0, 0], caption))
                print(image_id)
                caption = f'a photo of a teddy bear {directions[image_id]} by a riverside with wildflowers blooming nearby'
                test_list.append((condition_data, [0, 0], caption))

            if hasattr(pl_module.flux_pipe.transformer, "inter_controller"):
                for name, param in pl_module.flux_pipe.transformer.inter_controller.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "loose_embedder"):
                for name, param in pl_module.flux_pipe.transformer.loose_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
            if hasattr(pl_module.flux_pipe.transformer, "cam_embedder"):
                for name, param in pl_module.flux_pipe.transformer.cam_embedder.named_parameters():
                    param.data = param.to(device=device, dtype=dtype)
        else:
            raise NotImplementedError

        # import ipdb; ipdb.set_trace() 
        # print(pl_module.flux_pipe.transformer.single_transformer_blocks)
        # print(pl_module.flux_pipe.get_list_adapters())
        # print(pl_module.flux_pipe.transformer.single_transformer_blocks[0].proj_out.scaling)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        for i, (condition_data, position_delta, prompt) in enumerate(test_list):
            seed_everything(42)
            condition = Condition(
                condition_type=condition_type,
                condition=condition_data.resize((condition_size, condition_size)).convert("RGB") if isinstance(condition_data, Image.Image) else condition_data,
                position_delta=position_delta,
            )
            kwargs = {}
            if condition_data.get("orient", None) is not None:
                kwargs["orient"] = condition_data.get("orient", None)
                print(condition_data.get("orient", None))
                # kwargs["cam_entity_idx"] = condition_data.get("cam_entity_idx", None)
            if condition_data.get("eligen_entity_prompts", None) is not None and condition_data.get("eligen_entity_masks", None) is not None:
                kwargs["eligen_entity_prompts"] = condition_data.get("eligen_entity_prompts", None)
                kwargs["eligen_entity_masks"] = condition_data.get("eligen_entity_masks", None)
            if "schnell" in pl_module.model_config.get("flux_path", "black-forest-labs/FLUX.1-dev"):
                kwargs["num_inference_steps"] = 4
                pl_module.model_config['latent_lora'] = [module for module in pl_module.model_config['latent_lora'] if module != 'adapter']

            res = generate(
                pl_module.flux_pipe,
                prompt=prompt,
                conditions=[condition] if condition_data.get("condition", None) is not None else None,
                height=target_size,
                width=target_size,
                # generator=generator,
                model_config=pl_module.model_config,
                default_lora=True,
                **kwargs,
            )
            res.images[0].save(
                os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
            )
            if "eligen" in condition_type or "loose" in condition_type:
                mask_path = os.path.join(save_path, f"{file_name}_{condition_type}_{i}_mask.png")
                visualize_masks(res.images[0], condition.condition["eligen_entity_masks_pil"],  condition.condition["eligen_entity_prompts"], mask_path)
        # pl_module.model_config['latent_lora'] = ['eligen', 'adapter', 'default']
        pl_module.model_config['latent_lora'].append('adapter')
        print(pl_module.model_config['latent_lora'])
