import torch
from diffusers.pipelines import FluxPipeline
from typing import List, Union, Optional, Dict, Any, Callable
from .block import block_forward, single_block_forward
from .lora_controller import enable_lora, module_active_adapters
from .pipeline_tools import process_entity_masks, construct_mask_camera
from diffusers.models.transformers.transformer_flux import (
    FluxTransformer2DModel,
    Transformer2DModelOutput,
    USE_PEFT_BACKEND,
    # is_torch_version,
    scale_lora_layers,
    unscale_lora_layers,
    logger,
)
import numpy as np
from einops import rearrange, repeat


def prepare_params(
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor = None,
    pooled_projections: torch.Tensor = None,
    timestep: torch.LongTensor = None,
    img_ids: torch.Tensor = None,
    txt_ids: torch.Tensor = None,
    guidance: torch.Tensor = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    controlnet_block_samples=None,
    controlnet_single_block_samples=None,
    return_dict: bool = True,
    **kwargs: dict,
):
    return (
        hidden_states,
        encoder_hidden_states,
        pooled_projections,
        timestep,
        img_ids,
        txt_ids,
        guidance,
        joint_attention_kwargs,
        controlnet_block_samples,
        controlnet_single_block_samples,
        return_dict,
        kwargs,
    )


def tranformer_forward(
    transformer: FluxTransformer2DModel,
    condition_latents: torch.Tensor,
    condition_ids: torch.Tensor,
    condition_type_ids: torch.Tensor,
    condition_types=None,
    model_config: Optional[Dict[str, Any]] = {},
    c_t=0,
    return_x_embedder_output: bool = False,
    **params: dict,
):
    self = transformer
    use_condition = condition_latents is not None

    (
        hidden_states,
        encoder_hidden_states,
        pooled_projections,
        timestep,
        img_ids,
        txt_ids,
        guidance,
        joint_attention_kwargs,
        controlnet_block_samples,
        controlnet_single_block_samples,
        return_dict,
        kwargs,
    ) = prepare_params(**params)

    if joint_attention_kwargs is not None:
        joint_attention_kwargs = joint_attention_kwargs.copy()
        lora_scale = joint_attention_kwargs.pop("scale", 1.0)
    else:
        lora_scale = 1.0

    if USE_PEFT_BACKEND:
        # weight the lora layers by setting `lora_scale` for each PEFT layer
        scale_lora_layers(self, lora_scale)
    else:
        if (
            joint_attention_kwargs is not None
            and joint_attention_kwargs.get("scale", None) is not None
        ):
            logger.warning(
                "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
            )

    # with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
    with enable_lora([self.x_embedder], model_config.get("latent_lora", [])):
    # with enable_lora([self.x_embedder], ['eligen', 'adapter']):
        hidden_states = self.x_embedder(hidden_states)
        
    if use_condition:
        x_embedder_output = None  # 用于保存中间结果
        if torch.all(condition_type_ids == 3):
            def patchify(hidden_states):
                hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
                return hidden_states
            cam_entity_masks = kwargs['eligen_kwargs'].get("entity_masks", None)[kwargs['cam_entity_idx']]
            cam_entity_masks = patchify(cam_entity_masks)
            cam_entity_masks = torch.sum(cam_entity_masks, dim=-1) > 0
            instance_idxs = [cam_entity_mask.nonzero(as_tuple=True)[0] for cam_entity_mask in cam_entity_masks]
            cam_seq_len = [instance_idx.numel() for instance_idx in instance_idxs]
            cam_seq_len = [sum(cam_seq_len[:i]) for i in range(len(cam_seq_len))]

            image_latents = []
            for condition_latent, instance_idx in zip(condition_latents, instance_idxs):
                image_latent = hidden_states[:, instance_idx.tolist(), :].clone()
                image_latent += self.cam_embedder(condition_latent)
                # image_latent = self.cam_embedder(condition_latent)[None,None].repeat(hidden_states.shape[0], instance_idx.shape[0], 1)
                image_latents.append(image_latent)
            condition_latents = torch.cat(image_latents, dim=1)

            image_ids = []
            for instance_idx in instance_idxs:
                image_id = img_ids[instance_idx.tolist(), :].clone()
                image_ids.append(image_id)
            condition_ids = torch.cat(image_ids, dim=0)
        elif torch.all(condition_type_ids == 2):
            with enable_lora([self.x_embedder], model_config.get("condition_lora", ['default'])):
                condition_latents = self.x_embedder(condition_latents) # [n_mask, N, 64] -> [n_mask, N, 3072]
                cond_rotary_emb = self.pos_embed(condition_ids)

                if hasattr(self, 'inter_controller'):
                    condition_latents = self.inter_controller(
                        condition_latents.unsqueeze(0),
                        cond_rotary_emb,
                    ).squeeze(0)
                    
                condition_latents = condition_latents.sum(dim=0, keepdim=True)# .unsqueeze(0)

                if hasattr(self, 'loose_embedder'):
                    hidden_states = self.loose_embedder(torch.cat([hidden_states, condition_latents], dim=-1))
                    condition_latents = None
        else:
            condition_latents = self.x_embedder(condition_latents)

        if return_x_embedder_output:
            x_embedder_output = condition_latents.clone()  # 克隆以保留梯度
    else:
        condition_latents = None

    timestep = timestep.to(hidden_states.dtype) * 1000

    if guidance is not None:
        guidance = guidance.to(hidden_states.dtype) * 1000
    else:
        guidance = None

    temb = (
        self.time_text_embed(timestep, pooled_projections)
        if guidance is None
        else self.time_text_embed(timestep, guidance, pooled_projections)
    )

    cond_temb = (
        self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
        if guidance is None
        else self.time_text_embed(
            torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
        )
    )
    if 'eligen_kwargs' in kwargs:
        encoder_hidden_states, txt_ids, attention_mask = process_entity_masks(
            self, 
            hidden_states, 
            encoder_hidden_states, 
            kwargs['eligen_kwargs'], 
            txt_ids, 
            use_condition=use_condition and (not hasattr(self, 'loose_embedder')), 
            condition_length=condition_latents.shape[1] if use_condition else hidden_states.shape[1],
            eligen_depth_attn=model_config.get("eligen_depth_attn", True)
        )
        # print(attention_mask.shape)
    else:
        encoder_hidden_states = self.context_embedder(encoder_hidden_states)
        attention_mask = None
    
    if 'cam_entity_idx' in kwargs and use_condition:
        cam_entity_masks = kwargs['eligen_kwargs'].get("entity_masks", None)[kwargs['cam_entity_idx']]
        attention_mask = construct_mask_camera(
            attention_mask, 
            cam_entity_masks, 
            512, 
            hidden_states.shape[1], 
            cam_seq_len, 
            eligen_camera_attn=model_config.get("eligen_camera_attn", True)
        )
        attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)

    if txt_ids.ndim == 3:
        logger.warning(
            "Passing `txt_ids` 3d torch.Tensor is deprecated."
            "Please remove the batch dimension and pass it as a 2d torch Tensor"
        )
        txt_ids = txt_ids[0]
    if img_ids.ndim == 3:
        logger.warning(
            "Passing `img_ids` 3d torch.Tensor is deprecated."
            "Please remove the batch dimension and pass it as a 2d torch Tensor"
        )
        img_ids = img_ids[0]

    ids = torch.cat((txt_ids, img_ids), dim=0)
    image_rotary_emb = self.pos_embed(ids)
    if use_condition:
        # condition_ids[:, :1] = condition_type_ids
        cond_rotary_emb = self.pos_embed(condition_ids)

    # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)

    for index_block, block in enumerate(self.transformer_blocks):
        if self.training and self.gradient_checkpointing:
            ckpt_kwargs: Dict[str, Any] = (
                {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            )
            encoder_hidden_states, hidden_states, condition_latents = (
                torch.utils.checkpoint.checkpoint(
                    block_forward,
                    self=block,
                    model_config=model_config,
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    condition_latents=condition_latents if use_condition else None,
                    condition_types = condition_types if use_condition else None,
                    temb=temb,
                    cond_temb=cond_temb if use_condition else None,
                    cond_rotary_emb=cond_rotary_emb if use_condition else None,
                    image_rotary_emb=image_rotary_emb,
                    attention_mask=attention_mask,
                    **ckpt_kwargs,
                )
            )

        else:
            encoder_hidden_states, hidden_states, condition_latents = block_forward(
                block,
                model_config=model_config,
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                condition_latents=condition_latents if use_condition else None,
                condition_types = condition_types if use_condition else None,
                temb=temb,
                cond_temb=cond_temb if use_condition else None,
                cond_rotary_emb=cond_rotary_emb if use_condition else None,
                image_rotary_emb=image_rotary_emb,
                attention_mask=attention_mask,
            )

        # controlnet residual
        if controlnet_block_samples is not None:
            interval_control = len(self.transformer_blocks) / len(
                controlnet_block_samples
            )
            interval_control = int(np.ceil(interval_control))
            hidden_states = (
                hidden_states
                + controlnet_block_samples[index_block // interval_control]
            )
    hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

    for index_block, block in enumerate(self.single_transformer_blocks):
        if self.training and self.gradient_checkpointing:
            ckpt_kwargs: Dict[str, Any] = (
                {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            )
            result = torch.utils.checkpoint.checkpoint(
                single_block_forward,
                self=block,
                model_config=model_config,
                hidden_states=hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                attention_mask=attention_mask,
                **(
                    {
                        "condition_latents": condition_latents,
                        "cond_temb": cond_temb,
                        "cond_rotary_emb": cond_rotary_emb,
                        "condition_types": condition_types,

                    }
                    if use_condition
                    else {}
                ),
                **ckpt_kwargs,
            )

        else:
            result = single_block_forward(
                block,
                model_config=model_config,
                hidden_states=hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                attention_mask=attention_mask,
                **(
                    {
                        "condition_latents": condition_latents,
                        "cond_temb": cond_temb,
                        "cond_rotary_emb": cond_rotary_emb,
                        "condition_types": condition_types,
                    }
                    if use_condition
                    else {}
                ),
            )
        if use_condition and (not hasattr(self, 'loose_embedder')):
            hidden_states, condition_latents = result
        else:
            hidden_states = result

        # controlnet residual
        if controlnet_single_block_samples is not None:
            interval_control = len(self.single_transformer_blocks) / len(
                controlnet_single_block_samples
            )
            interval_control = int(np.ceil(interval_control))
            hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
                hidden_states[:, encoder_hidden_states.shape[1] :, ...]
                + controlnet_single_block_samples[index_block // interval_control]
            )

    hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

    hidden_states = self.norm_out(hidden_states, temb)
    output = self.proj_out(hidden_states)

    if USE_PEFT_BACKEND:
        # remove `lora_scale` from each PEFT layer
        unscale_lora_layers(self, lora_scale)
    
    # if return_x_embedder_output:
    #     x_embedder_output = condition_latents

    if not return_dict:
        if return_x_embedder_output:
            return (output, x_embedder_output)  # 返回额外结果
        return (output,)
    else:
        if return_x_embedder_output:
            return Transformer2DModelOutput(sample=output), x_embedder_output
        return Transformer2DModelOutput(sample=output)
