# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

import torch
from torch import Tensor, nn
import random
from .modules.layers_mask import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding, RefInject

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

@dataclass
class FluxParams:
    in_channels: int
    vec_in_dim: int
    context_in_dim: int
    hidden_size: int
    mlp_ratio: float
    num_heads: int
    depth: int
    depth_single_blocks: int
    axes_dim: list[int]
    theta: int
    qkv_bias: bool
    guidance_embed: bool


class Flux(nn.Module):
    """
    Transformer model for flow matching on sequences.
    """
    _supports_gradient_checkpointing = True

    def __init__(self, params: FluxParams):
        super().__init__()

        self.params = params
        self.in_channels = params.in_channels
        self.out_channels = self.in_channels
        if params.hidden_size % params.num_heads != 0:
            raise ValueError(
                f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
            )
        pe_dim = params.hidden_size // params.num_heads
        if sum(params.axes_dim) != pe_dim:
            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
        self.hidden_size = params.hidden_size
        self.num_heads = params.num_heads

        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)

        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
        self.guidance_in = (
            MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
        )

        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

        self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=params.mlp_ratio,
                    qkv_bias=params.qkv_bias,
                )
                for _ in range(params.depth)
            ]
        )

        self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
                for _ in range(params.depth_single_blocks)
            ]
        )

        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
        self.gradient_checkpointing = False


    def set_MLP(self,device):
        with torch.device(device):
            # tokenverse modulate
            self.concept_MLP_global = MLPEmbedder(in_dim=4096, hidden_dim=self.hidden_size)
            torch.nn.init.constant_(self.concept_MLP_global.out_layer.weight,0)
            torch.nn.init.constant_(self.concept_MLP_global.out_layer.bias,0)

            self.double_ref_injects = [] 
            for index_block, block in enumerate(self.double_blocks):
                if index_block % 5 == 0:
                    self.double_ref_injects.append(RefInject(num_heads=block.num_heads,hidden_size=block.hidden_size,mlp_ratio=self.params.mlp_ratio))
            self.double_ref_injects = nn.ModuleList(self.double_ref_injects)

            self.single_ref_injects = []
            for index_block, block in enumerate(self.single_blocks):
                if index_block % 5 == 0:
                    self.single_ref_injects.append(RefInject(num_heads=block.num_heads,hidden_size=block.hidden_size,mlp_ratio=self.params.mlp_ratio))
            self.single_ref_injects = nn.ModuleList(self.single_ref_injects)

            # print(len(self.double_blocks),len(self.single_blocks))
            # print('single len',len(self.single_ref_injects))

    def _set_gradient_checkpointing(self, module, value=False):
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    @property
    def attn_processors(self):
        # set recursively
        processors = {}  # type: dict[str, nn.Module]

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    def set_attn_processor(self, processor):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

    def forward(
        self,
        img: Tensor,
        img_ids: Tensor,
        txt: Tensor,
        txt_ids: Tensor,
        timesteps: Tensor,
        y: Tensor,
        guidance: Tensor | None = None,
        ref_img: Tensor | None = None, 
        ref_img_ids: Tensor | None = None, 
        concept_tokens_range=None,
        concept_tokens_range_list=None,
        _scale_global=1,
        mask_ref: bool = False,
        attn_mask: Tensor | None = None,
    ) -> Tensor:
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")
        img = self.img_in(img) 

        vec = self.time_in(timestep_embedding(timesteps, 256))

        if self.params.guidance_embed:
            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))

        vec = vec + self.vector_in(y) 
        

        t5_token = txt 
        txt = self.txt_in(txt) 

        ids = torch.cat((txt_ids, img_ids), dim=1) 
        txt_end = txt.shape[1] 
        img_end = img.shape[1] 
        

        
        if ref_img is not None:
            if isinstance(ref_img, tuple) or isinstance(ref_img, list):
                img_ids = [img_ids] + [ref_ids for ref_ids in ref_img_ids] 
                img_ids = torch.cat(img_ids,dim=1) 
                

                img_in = [img] + [self.img_in(ref) for ref in ref_img]
                concat_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]

                img = torch.cat(img_in, dim=1)  
                ids = torch.cat(concat_ids, dim=1)
            else:
                img = torch.cat((img, self.img_in(ref_img)), dim=1)  
                ids = torch.cat((ids, ref_img_ids), dim=1)

        
        vec_txt_ori = vec[:,None,:].repeat(1,txt_ids.size(1),1) 
        vec_img = vec[:,None,:].expand(-1,img_ids.size(1),-1) 

        pe = self.pe_embedder(ids)
        ref_pe = torch.cat([pe[:,:,:txt_end,:,:,:],pe[:,:,(txt_end+img_end):,:,:,:]],dim=2) # (1, 1, 2560, 64, 2, 2)

        if concept_tokens_range_list is None:
            concept_tokens_range_list = [concept_tokens_range]

        modulate_shift = _scale_global * self.concept_MLP_global(t5_token)

        # ====================double block====================
        ref_len = img.size(1) - img_end
        for index_block, block in enumerate(self.double_blocks):
            vec_txt_cur_block = vec_txt_ori
            i = index_block

            if index_block % 5 == 0:
                ref_img = img[:,img_end:,:] 
                ref_modulate_shift = self.double_ref_injects[index_block // 5](ref_img, modulate_shift, ref_pe)
            for j, concept_tokens_range in enumerate(concept_tokens_range_list):
                if concept_tokens_range == range(0,0):
                    pass
                else:
                    vec_txt_cur_block=torch.cat([
                        vec_txt_cur_block[:,:concept_tokens_range[0],:],
                        vec_txt_cur_block[:,concept_tokens_range,:] + ref_modulate_shift[:,concept_tokens_range, :],
                        vec_txt_cur_block[:,concept_tokens_range[-1] + 1:, :] 
                    ],dim=1)
            img, txt = block(img=img, txt=txt, vec=[vec_txt_cur_block, vec_img], pe=pe, mask_ref = mask_ref, attn_mask = attn_mask)
        img = torch.cat((txt, img), 1)
        # ====================double block====================

        # ====================single block====================
        ref_len = img.size(1) - img_end
        for i, block in enumerate(self.single_blocks):
            # mask_ref = True if random.random() > 0.8 else False
            vec_txt_cur_block = vec_txt_ori
            if i % 5 == 0:
                ref_img = img[:, (txt_end+img_end):,:]
                ref_modulate_shift = self.single_ref_injects[i // 5](ref_img, modulate_shift, ref_pe)
            for j, concept_tokens_range in enumerate(concept_tokens_range_list):
                if concept_tokens_range == range(0,0):
                    pass
                else:
                    vec_txt_cur_block=torch.cat([
                        vec_txt_cur_block[:,:concept_tokens_range[0], :],
                        vec_txt_cur_block[:,concept_tokens_range,:] + ref_modulate_shift[:,concept_tokens_range, :],
                        vec_txt_cur_block[:,concept_tokens_range[-1]+1:, :] 
                    ],dim=1)
            img = block(img, vec=[vec_txt_cur_block,vec_img], pe=pe, mask_ref = mask_ref, attn_mask = attn_mask)
        img = img[:, txt.shape[1] :, ...]
        # index img
        img = img[:, :img_end, ...]
        # ====================single block====================

        img = self.final_layer(img, vec) 
        return img
        
