import copy

from einops import repeat
from diffusers import __version__
from diffusers.models.modeling_utils import (
    _add_variant, _get_checkpoint_shard_files, _get_model_file,  # diffusers.utils
    _determine_device_map, _fetch_index_file,  # diffusers.models.model_loading_utils
)
from diffusers.models.modeling_utils import *
from diffusers.models.transformers.transformer_sd3 import *

from extensions.diffusers_diffsplat.models.mv_attention import JointMVTransformerBlock


if is_torch_version(">=", "1.9.0"):
    _LOW_CPU_MEM_USAGE_DEFAULT = True
else:
    _LOW_CPU_MEM_USAGE_DEFAULT = False


# Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel
# The only modifications: `JointTransformerBlock` -> `JointMVTransformerBlock`
class SD3TransformerMV2DModel(
    ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
    """
    The Transformer model introduced in Stable Diffusion 3.

    Reference: https://arxiv.org/abs/2403.03206

    Parameters:
        sample_size (`int`): The width of the latent images. This is fixed during training since
            it is used to learn a number of position embeddings.
        patch_size (`int`): Patch size to turn the input data into small patches.
        in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
        num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
        attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
        num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
        pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
        out_channels (`int`, defaults to 16): Number of output channels.

    """

    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        sample_size: int = 128,
        patch_size: int = 2,
        in_channels: int = 16,
        num_layers: int = 18,
        attention_head_dim: int = 64,
        num_attention_heads: int = 18,
        joint_attention_dim: int = 4096,
        caption_projection_dim: int = 1152,
        pooled_projection_dim: int = 2048,
        out_channels: int = 16,
        pos_embed_max_size: int = 96,
        dual_attention_layers: Tuple[
            int, ...
        ] = (),  # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
        qk_norm: Optional[str] = None,
    ):
        super().__init__()
        default_out_channels = in_channels
        self.out_channels = out_channels if out_channels is not None else default_out_channels
        self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

        # print("self.config.num_attention_heads", self.config.num_attention_heads) 
        # print("self.config.attention_head_dim", self.config.attention_head_dim)
        # self.config.num_attention_heads 24
        # self.config.attention_head_dim 64
        # exit()
        self.pos_embed = PatchEmbed(
            height=self.config.sample_size,
            width=self.config.sample_size,
            patch_size=self.config.patch_size,
            in_channels=self.config.in_channels,
            embed_dim=self.inner_dim,
            pos_embed_max_size=pos_embed_max_size,  # hard-code for now.
        )

        # print("self.pos_embed", self.pos_embed)
        # print("self.config.sample_size", self.config.sample_size)
        # print("self.config.patch_size", self.config.patch_size)
        # print("self.config.in_channels", self.config.in_channels)
        # print("self.inner_dim", self.inner_dim)
        # print("pos_embed_max_size", pos_embed_max_size)
        # self.pos_embed PatchEmbed(
        #     (proj): Conv2d(22, 1536, kernel_size=(2, 2), stride=(2, 2))
        # )
        # self.config.sample_size 32
        # self.config.patch_size 2
        # self.config.in_channels 22
        # self.inner_dim 1536
        # pos_embed_max_size 384
        
        # exit()

        self.time_text_embed = CombinedTimestepTextProjEmbeddings(
            embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
        )
        self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)

        # `attention_head_dim` is doubled to account for the mixing.
        # It needs to crafted when we get the actual checkpoints.
        self.transformer_blocks = nn.ModuleList(
            [
                JointMVTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                    context_pre_only=i == num_layers - 1,
                    qk_norm=qk_norm,
                    use_dual_attention=True if i in dual_attention_layers else False,
                )
                for i in range(self.config.num_layers)
            ]
        )

        self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

        self.gradient_checkpointing = False

    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        Sets the attention processor to use [feed forward
        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).

        Parameters:
            chunk_size (`int`, *optional*):
                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
                over each tensor of dim=`dim`.
            dim (`int`, *optional*, defaults to `0`):
                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
                or dim=1 (sequence length).
        """
        if dim not in [0, 1]:
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # By default chunk size is 1
        chunk_size = chunk_size or 1

        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)

    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
    def disable_forward_chunking(self):
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        for module in self.children():
            fn_recursive_feed_forward(module, None, 0)

    @property
    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
    def fuse_qkv_projections(self):
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>
        """
        self.original_attn_processors = None

        for _, attn_processor in self.attn_processors.items():
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        self.original_attn_processors = self.attn_processors

        for module in self.modules():
            if isinstance(module, Attention):
                module.fuse_projections(fuse=True)

        self.set_attn_processor(FusedJointAttnProcessor2_0())

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
    def unfuse_qkv_projections(self):
        """Disables the fused QKV projection if enabled.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        """
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        pooled_projections: torch.FloatTensor = None,
        timestep: torch.LongTensor = None,
        block_controlnet_hidden_states: List = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = True,
        skip_layers: Optional[List[int]] = None,
        mask_bg = None,
    ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
        """
        The [`SD3Transformer2DModel`] forward method.

        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
                Input `hidden_states`.
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
            pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
                Embeddings projected from the embeddings of input conditions.
            timestep (`torch.LongTensor`):
                Used to indicate denoising step.
            block_controlnet_hidden_states (`list` of `torch.Tensor`):
                A list of tensors that if specified are added to the residuals of transformer blocks.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
                tuple.
            skip_layers (`list` of `int`, *optional*):
                A list of layer indices to skip during the forward pass.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        """
        if joint_attention_kwargs is not None:
            joint_attention_kwargs = joint_attention_kwargs.copy()
            lora_scale = joint_attention_kwargs.pop("scale", 1.0)
        else:
            lora_scale = 1.0

        if USE_PEFT_BACKEND:
            # weight the lora layers by setting `lora_scale` for each PEFT layer
            scale_lora_layers(self, lora_scale)
        else:
            if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
                logger.warning(
                    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
                )

        height, width = hidden_states.shape[-2:]

        # print("hidden_states.shape", hidden_states.shape) # torch.Size([10, 23, 32, 32])
        # print(self.pos_embed) # PatchEmbed(Conv2d(23, 1536, kernel_size=(2, 2), stride=(2, 2)))

        hidden_states, latent = self.pos_embed(hidden_states)  # takes care of adding positional embeddings too. latent is without positional embeddings
        # print("hidden_states.shape", hidden_states.shape) # torch.Size([10, 256, 1536]) #N = 256 = 32/2*32/2 (patch size = 2) # BCHW -> BNC,  N is here and we will reduce it for efficiency
        # exit()
        
        ##### Mask the background tokens #####
        if mask_bg is not None and mask_bg.shape[0] == 256:
            # Reduce tokens from 256 to ...
            # mask_bg = torch.cat([mask_bg]*2, dim=0)
            assert mask_bg.shape[0] == hidden_states.shape[1], f"mask_bg shape {mask_bg.shape} does not match hidden_states shape {hidden_states.shape}"
            
            B,N,D = hidden_states.shape
            
            # no image comditioning
            # mask_bg # torch.Size([8, 256])  including True or False
            # hidden_states # torch.Size([8, 256, 1536])
            hidden_states_bg = hidden_states[:, mask_bg, :]  # keep the background tokens
            hidden_states_obj = hidden_states[:, ~mask_bg, :]  # remove the background tokens
            
            # print("hidden_states_bg.shape", hidden_states_bg.shape) # torch.Size([8, 55, 1536])
            # print("hidden_states_obj.shape", hidden_states_obj.shape) # torch.Size([8, 201, 1536])
            
            merged_bg_token = hidden_states_bg.mean(dim=1, keepdim=True)  # torch.Size([8, 1, 1536])
            # merged_bg_token = hidden_states_bg[:, 0, :].unsqueeze(1)  # torch.Size([8, 1, 1536])
            
            hidden_states = torch.cat([hidden_states_obj, merged_bg_token], dim=1)  # # torch.Size([8, 202, 1536])
            
            # print("hidden_states.shape", hidden_states.shape) # torch.Size([8, 202, 1536])
            
        
        if mask_bg is not None and mask_bg.shape[0] == 4:
            B,N,D = hidden_states.shape # torch.Size([8, 256, 1536])



            # Part 1: Mask Normalization
            # Get number of background tokens per view
            bg_counts = mask_bg.sum(dim=1)  # (4,)
            
            # Find minimum number of background tokens across views
            min_bg_count = bg_counts.min()
            # print("bg_counts", bg_counts)
            # print("min_bg_count", min_bg_count)
            
            # For each view that has more background tokens than the minimum,
            # remove excess background tokens to match min_bg_count
            for view_idx in range(len(mask_bg)):
                if bg_counts[view_idx] > min_bg_count:
                    # Get indices where mask is True
                    true_indices = torch.where(mask_bg[view_idx])[0]
                    # Randomly select indices to flip to False
                    num_to_remove = bg_counts[view_idx] - min_bg_count
                    indices_to_remove = true_indices[-num_to_remove:]
                    # Set selected indices to False
                    mask_bg[view_idx][indices_to_remove] = False

            # Part 2: Mask Slicing
            mask_bg_slicer = mask_bg.repeat(2, 1) # (4,256) -> (8,256)
            # print("mask_bg_slicer.shape", mask_bg_slicer.shape)
            # exit()

            hidden_states_bg = []
            hidden_states_obj = []
            
            for i in range(mask_bg_slicer.shape[0]):
                hidden_states_bg.append(hidden_states[i][mask_bg_slicer[i]].unsqueeze(0))
                hidden_states_obj.append(hidden_states[i][~mask_bg_slicer[i]].unsqueeze(0))

            hidden_states_bg = torch.cat(hidden_states_bg, dim=0) 
            hidden_states_obj = torch.cat(hidden_states_obj, dim=0) 
            
            # hidden_states_bg = hidden_states[mask_bg_slicer, :]
            # hidden_states_obj = hidden_states[~mask_bg_slicer, :]

            # print("hidden_states_bg.shape", hidden_states_bg.shape) # hidden_states_bg.shape torch.Size([8, 46, 1536])
            # print("hidden_states_obj.shape", hidden_states_obj.shape) # hidden_states_obj.shape torch.Size([8, 210, 1536])

            merged_bg_token = hidden_states_bg.mean(dim=1, keepdim=True)  # torch.Size([8, 1, 1536])

            hidden_states = torch.cat([hidden_states_obj, merged_bg_token], dim=1)  # # torch.Size([8, 211, 1536])


            
        temb = self.time_text_embed(timestep, pooled_projections)
        encoder_hidden_states = self.context_embedder(encoder_hidden_states)
        # print(encoder_hidden_states.shape) # torch.Size([2, 333, 1536])

        if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
            ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
            ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)

            joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)

        for index_block, block in enumerate(self.transformer_blocks):
            # Skip specified layers
            is_skip = True if skip_layers is not None and index_block in skip_layers else False

            if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:

                def create_custom_forward(module, return_dict=None):
                    def custom_forward(*inputs):
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    encoder_hidden_states,
                    temb,
                    joint_attention_kwargs,
                    **ckpt_kwargs,
                )
            elif not is_skip:
                encoder_hidden_states, hidden_states = block(
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    temb=temb,
                    joint_attention_kwargs=joint_attention_kwargs,
                )

            # controlnet residual
            if block_controlnet_hidden_states is not None and block.context_pre_only is False:
                interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
                hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]

        temb = repeat(temb, "b d -> (b v) d", v=joint_attention_kwargs.get("num_views", 1))
        

        ##### Recover the background tokens #####
        if mask_bg is not None and mask_bg.shape[0] == 256: # 1 view mask (256)
            device = hidden_states.device
            hidden_states_obj = hidden_states[:, :-1, :]      # [8, 201, 1536]
            merged_bg_token = hidden_states[:, -1:, :]        # [8, 1, 1536]
            
            # Step 1: Create empty output
            recovered = torch.zeros(B, N, D, device=device)  # [8, 256, 1536]

            # Step 2: Fill background positions with merged token
            recovered[:, mask_bg, :] = merged_bg_token  # broadcast fill

            # Step 3: Fill object tokens to non-background positions
            recovered[:, ~mask_bg, :] = hidden_states_obj  # match 201 positions
            
            # print("recovered.shape", recovered.shape) # torch.Size([8, 202, 1536])
            
            hidden_states = recovered
        

        if mask_bg is not None and mask_bg.shape[0] == 4: # 4 view mask (4,256)
            device = hidden_states.device
            hidden_states_obj = hidden_states[:, :-1, :]      # [8, 201, 1536]
            merged_bg_token = hidden_states[:, -1:, :]        # [8, 1, 1536]

            # Step 1: Create empty output
            recovered = torch.zeros(B, N, D, device=device)  # [8, 256, 1536]

            # Step 2: Fill background positions with merged token
            # recovered[mask_bg, :] = merged_bg_token  # broadcast fill

            for i in range(mask_bg_slicer.shape[0]):
                recovered[i][mask_bg_slicer[i]] = merged_bg_token[i]
                recovered[i][~mask_bg_slicer[i]] = hidden_states_obj[i]


            # Step 3: Fill object tokens to non-background positions
            # recovered[~mask_bg, :] = hidden_states_obj  # match 201 positions

            hidden_states = recovered
            
            
        
        hidden_states = self.norm_out(hidden_states, temb)
        # print(hidden_states.shape) # torch.Size([10, 256, 1536])
        hidden_states = self.proj_out(hidden_states)
        # print(hidden_states.shape) # torch.Size([10, 256, 64])
        
        # unpatchify
        patch_size = self.config.patch_size
        height = height // patch_size
        width = width // patch_size

        hidden_states = hidden_states.reshape(
            shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
        )
        # print(hidden_states.shape) # torch.Size([10, 16, 16, 2, 2, 16])
        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
        # print(hidden_states.shape) # torch.Size([10, 16, 16, 2, 16, 2])
        output = hidden_states.reshape(
            shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
        )
        # print(output.shape) # torch.Size([10, 16, 32, 32]) # [B, C, H, W]

        if USE_PEFT_BACKEND:
            # remove `lora_scale` from each PEFT layer
            unscale_lora_layers(self, lora_scale)

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

    # Copied from diffusers.models.modeling_utils.ModelingMixin.from_pretrained
    @classmethod
    @validate_hf_hub_args
    def from_pretrained_new(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],

        sample_size: int = 32,  # `input_res` / 8
        in_channels: int = 16,
        out_channels: int = 16,
        zero_init_conv_in: bool = True,
        view_concat_condition: bool = False,
        input_concat_plucker: bool = False,
        input_concat_binary_mask: bool = False,
        from_scratch: bool = False,  # do not load pretrained parameters

        **kwargs
    ):
        cache_dir = kwargs.pop("cache_dir", None)
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
        force_download = kwargs.pop("force_download", False)
        from_flax = kwargs.pop("from_flax", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", None)
        token = kwargs.pop("token", None)
        revision = kwargs.pop("revision", None)
        torch_dtype = kwargs.pop("torch_dtype", None)
        subfolder = kwargs.pop("subfolder", None)
        device_map = kwargs.pop("device_map", None)
        max_memory = kwargs.pop("max_memory", None)
        offload_folder = kwargs.pop("offload_folder", None)
        offload_state_dict = kwargs.pop("offload_state_dict", False)
        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
        variant = kwargs.pop("variant", None)
        use_safetensors = kwargs.pop("use_safetensors", None)

        allow_pickle = False
        if use_safetensors is None:
            use_safetensors = True
            allow_pickle = True

        if low_cpu_mem_usage and not is_accelerate_available():
            low_cpu_mem_usage = False
            logger.warning(
                "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                " install accelerate\n```\n."
            )

        if device_map is not None and not is_accelerate_available():
            raise NotImplementedError(
                "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
                " `device_map=None`. You can install accelerate with `pip install accelerate`."
            )

        # Check if we can handle device_map and dispatching the weights
        if device_map is not None and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `device_map=None`."
            )

        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )

        if low_cpu_mem_usage is False and device_map is not None:
            raise ValueError(
                f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
                " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
            )

        # change device_map into a map if we passed an int, a str or a torch.device
        if isinstance(device_map, torch.device):
            device_map = {"": device_map}
        elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
            try:
                device_map = {"": torch.device(device_map)}
            except RuntimeError:
                raise ValueError(
                    "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
                    f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
                )
        elif isinstance(device_map, int):
            if device_map < 0:
                raise ValueError(
                    "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
                )
            else:
                device_map = {"": device_map}

        if device_map is not None:
            if low_cpu_mem_usage is None:
                low_cpu_mem_usage = True
            elif not low_cpu_mem_usage:
                raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")

        if low_cpu_mem_usage:
            if device_map is not None and not is_torch_version(">=", "1.10"):
                # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
                raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")

        # Load config if we don't provide a configuration
        config_path = pretrained_model_name_or_path

        user_agent = {
            "diffusers": __version__,
            "file_type": "model",
            "framework": "pytorch",
        }

        # load config
        config, unused_kwargs, commit_hash = cls.load_config(
            config_path,
            cache_dir=cache_dir,
            return_unused_kwargs=True,
            return_commit_hash=True,
            force_download=force_download,
            proxies=proxies,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            subfolder=subfolder,
            user_agent=user_agent,
            **kwargs,
        )

        # Modify configs for the multi-view cross-domain diffusion model
        config["_class_name"] = cls.__name__
        config["sample_size"] = sample_size  # training resolution
        config["in_channels"] = in_channels
        config["out_channels"] = out_channels

        config["view_concat_condition"] = view_concat_condition
        config["input_concat_plucker"] = input_concat_plucker
        config["input_concat_binary_mask"] = input_concat_binary_mask

        # Determine if we're loading from a directory of sharded checkpoints.
        is_sharded = False
        index_file = None
        is_local = os.path.isdir(pretrained_model_name_or_path)
        index_file = _fetch_index_file(
            is_local=is_local,
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            subfolder=subfolder or "",
            use_safetensors=use_safetensors,
            cache_dir=cache_dir,
            variant=variant,
            force_download=force_download,
            proxies=proxies,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            user_agent=user_agent,
            commit_hash=commit_hash,
        )
        if index_file is not None and index_file.is_file():
            is_sharded = True

        if is_sharded and from_flax:
            raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")

        # load model
        model_file = None
        if from_flax:
            model_file = _get_model_file(
                pretrained_model_name_or_path,
                weights_name=FLAX_WEIGHTS_NAME,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
                subfolder=subfolder,
                user_agent=user_agent,
                commit_hash=commit_hash,
            )
            model = cls.from_config(config, **unused_kwargs)

            # Convert the weights
            from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model

            if not from_scratch:
                model = load_flax_checkpoint_in_pytorch_model(model, model_file)
        else:
            if is_sharded:
                sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
                    pretrained_model_name_or_path,
                    index_file,
                    cache_dir=cache_dir,
                    proxies=proxies,
                    local_files_only=local_files_only,
                    token=token,
                    user_agent=user_agent,
                    revision=revision,
                    subfolder=subfolder or "",
                )

            elif use_safetensors and not is_sharded:
                try:
                    model_file = _get_model_file(
                        pretrained_model_name_or_path,
                        weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
                        cache_dir=cache_dir,
                        force_download=force_download,
                        proxies=proxies,
                        local_files_only=local_files_only,
                        token=token,
                        revision=revision,
                        subfolder=subfolder,
                        user_agent=user_agent,
                        commit_hash=commit_hash,
                    )

                except IOError as e:
                    logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
                    if not allow_pickle:
                        raise
                    logger.warning(
                        "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
                    )

            if model_file is None and not is_sharded:
                model_file = _get_model_file(
                    pretrained_model_name_or_path,
                    weights_name=_add_variant(WEIGHTS_NAME, variant),
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    local_files_only=local_files_only,
                    token=token,
                    revision=revision,
                    subfolder=subfolder,
                    user_agent=user_agent,
                    commit_hash=commit_hash,
                )

            if low_cpu_mem_usage:
                # Instantiate model with empty weights
                with accelerate.init_empty_weights():
                    model = cls.from_config(config, **unused_kwargs)

                if not from_scratch:
                    # if device_map is None, load the state dict and move the params from meta device to the cpu
                    if device_map is None and not is_sharded:
                        param_device = "cpu"
                        state_dict = load_state_dict(model_file, variant=variant)
                        model._convert_deprecated_attention_blocks(state_dict)
                        # move the params from meta device to cpu
                        missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
                        if len(missing_keys) > 0:
                            raise ValueError(
                                f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
                                f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
                                " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
                                " those weights or else make sure your checkpoint file is correct."
                            )

                        unexpected_keys = load_model_dict_into_meta(
                            model,
                            state_dict,
                            device=param_device,
                            dtype=torch_dtype,
                            model_name_or_path=pretrained_model_name_or_path,
                        )

                        if cls._keys_to_ignore_on_load_unexpected is not None:
                            for pat in cls._keys_to_ignore_on_load_unexpected:
                                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

                        if len(unexpected_keys) > 0:
                            logger.warning(
                                f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
                            )

                    else:  # else let accelerate handle loading and dispatching.
                        # Load weights and dispatch according to the device_map
                        # by default the device_map is None and the weights are loaded on the CPU
                        force_hook = True
                        device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
                        if device_map is None and is_sharded:
                            # we load the parameters on the cpu
                            device_map = {"": "cpu"}
                            force_hook = False
                        try:
                            accelerate.load_checkpoint_and_dispatch(
                                model,
                                model_file if not is_sharded else index_file,
                                device_map,
                                max_memory=max_memory,
                                offload_folder=offload_folder,
                                offload_state_dict=offload_state_dict,
                                dtype=torch_dtype,
                                force_hooks=force_hook,
                                strict=True,
                            )
                        except AttributeError as e:
                            # When using accelerate loading, we do not have the ability to load the state
                            # dict and rename the weight names manually. Additionally, accelerate skips
                            # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
                            # (which look like they should be private variables?), so we can't use the standard hooks
                            # to rename parameters on load. We need to mimic the original weight names so the correct
                            # attributes are available. After we have loaded the weights, we convert the deprecated
                            # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
                            # the weights so we don't have to do this again.

                            if "'Attention' object has no attribute" in str(e):
                                logger.warning(
                                    f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
                                    " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
                                    " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
                                    " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
                                    " please also re-upload it or open a PR on the original repository."
                                )
                                model._temp_convert_self_to_deprecated_attention_blocks()
                                accelerate.load_checkpoint_and_dispatch(
                                    model,
                                    model_file if not is_sharded else index_file,
                                    device_map,
                                    max_memory=max_memory,
                                    offload_folder=offload_folder,
                                    offload_state_dict=offload_state_dict,
                                    dtype=torch_dtype,
                                    force_hooks=force_hook,
                                    strict=True,
                                )
                                model._undo_temp_convert_self_to_deprecated_attention_blocks()
                            else:
                                raise e

                loading_info = {
                    "missing_keys": [],
                    "unexpected_keys": [],
                    "mismatched_keys": [],
                    "error_msgs": [],
                }
            else:
                model = cls.from_config(config, **unused_kwargs)

                if not from_scratch:
                    state_dict = load_state_dict(model_file, variant=variant)
                    model._convert_deprecated_attention_blocks(state_dict)
                    state_dict_original = copy.deepcopy(state_dict)

                    model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
                        model,
                        state_dict,
                        model_file,
                        pretrained_model_name_or_path,
                        ignore_mismatched_sizes=ignore_mismatched_sizes,
                    )

                    loading_info = {
                        "missing_keys": missing_keys,
                        "unexpected_keys": unexpected_keys,
                        "mismatched_keys": mismatched_keys,
                        "error_msgs": error_msgs,
                    }
                else:
                    loading_info = {
                        "missing_keys": [],
                        "unexpected_keys": [],
                        "mismatched_keys": [],
                        "error_msgs": [],
                    }

        if not from_scratch:
            # Handle initilizations for some layers
            ## Patch embedding conv
            pos_embed_proj_weight = state_dict_original["pos_embed.proj.weight"]
            latent_channels = pos_embed_proj_weight.shape[1]
            if model.pos_embed.proj.weight.data.shape[1] != latent_channels:
                # Initialize from the original weights
                model.pos_embed.proj.weight.data[:, :latent_channels] = pos_embed_proj_weight
                # Whether to place all zero to new layers ?
                if zero_init_conv_in:
                    model.pos_embed.proj.weight.data[:, latent_channels:] = 0

        if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
            raise ValueError(
                f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
            )
        elif torch_dtype is not None:
            model = model.to(torch_dtype)

        model.register_to_config(_name_or_path=pretrained_model_name_or_path)

        # Set model in evaluation mode to deactivate DropOut modules by default
        model.eval()
        if output_loading_info:
            return model, loading_info

        return model
