# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type


@dataclass
class AttentionProcessorMetadata:
    skip_processor_output_fn: Callable[[Any], Any]


@dataclass
class TransformerBlockMetadata:
    return_hidden_states_index: int = None
    return_encoder_hidden_states_index: int = None

    _cls: Type = None
    _cached_parameter_indices: Dict[str, int] = None

    def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
        kwargs = kwargs or {}
        if identifier in kwargs:
            return kwargs[identifier]
        if self._cached_parameter_indices is not None:
            return args[self._cached_parameter_indices[identifier]]
        if self._cls is None:
            raise ValueError("Model class is not set for metadata.")
        parameters = list(inspect.signature(self._cls.forward).parameters.keys())
        parameters = parameters[1:]  # skip `self`
        self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
        if identifier not in self._cached_parameter_indices:
            raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
        index = self._cached_parameter_indices[identifier]
        if index >= len(args):
            raise ValueError(f"Expected {index} arguments but got {len(args)}.")
        return args[index]


class AttentionProcessorRegistry:
    _registry = {}
    # TODO(aryan): this is only required for the time being because we need to do the registrations
    # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
    # import errors because of the models imported in this file.
    _is_registered = False

    @classmethod
    def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
        cls._register()
        cls._registry[model_class] = metadata

    @classmethod
    def get(cls, model_class: Type) -> AttentionProcessorMetadata:
        cls._register()
        if model_class not in cls._registry:
            raise ValueError(f"Model class {model_class} not registered.")
        return cls._registry[model_class]

    @classmethod
    def _register(cls):
        if cls._is_registered:
            return
        cls._is_registered = True
        _register_attention_processors_metadata()


class TransformerBlockRegistry:
    _registry = {}
    # TODO(aryan): this is only required for the time being because we need to do the registrations
    # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
    # import errors because of the models imported in this file.
    _is_registered = False

    @classmethod
    def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
        cls._register()
        metadata._cls = model_class
        cls._registry[model_class] = metadata

    @classmethod
    def get(cls, model_class: Type) -> TransformerBlockMetadata:
        cls._register()
        if model_class not in cls._registry:
            raise ValueError(f"Model class {model_class} not registered.")
        return cls._registry[model_class]

    @classmethod
    def _register(cls):
        if cls._is_registered:
            return
        cls._is_registered = True
        _register_transformer_blocks_metadata()


def _register_attention_processors_metadata():
    from ..models.attention_processor import AttnProcessor2_0
    from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
    from ..models.transformers.transformer_flux import FluxAttnProcessor
    from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
    from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
    from ..models.transformers.transformer_wan import WanAttnProcessor2_0
    from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor

    # AttnProcessor2_0
    AttentionProcessorRegistry.register(
        model_class=AttnProcessor2_0,
        metadata=AttentionProcessorMetadata(
            skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
        ),
    )

    # CogView4AttnProcessor
    AttentionProcessorRegistry.register(
        model_class=CogView4AttnProcessor,
        metadata=AttentionProcessorMetadata(
            skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
        ),
    )

    # WanAttnProcessor2_0
    AttentionProcessorRegistry.register(
        model_class=WanAttnProcessor2_0,
        metadata=AttentionProcessorMetadata(
            skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
        ),
    )

    # FluxAttnProcessor
    AttentionProcessorRegistry.register(
        model_class=FluxAttnProcessor,
        metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
    )

    # QwenDoubleStreamAttnProcessor2
    AttentionProcessorRegistry.register(
        model_class=QwenDoubleStreamAttnProcessor2_0,
        metadata=AttentionProcessorMetadata(
            skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
        ),
    )

    # HunyuanImageAttnProcessor
    AttentionProcessorRegistry.register(
        model_class=HunyuanImageAttnProcessor,
        metadata=AttentionProcessorMetadata(
            skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
        ),
    )

    # ZSingleStreamAttnProcessor
    AttentionProcessorRegistry.register(
        model_class=ZSingleStreamAttnProcessor,
        metadata=AttentionProcessorMetadata(
            skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
        ),
    )


def _register_transformer_blocks_metadata():
    from ..models.attention import BasicTransformerBlock
    from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
    from ..models.transformers.transformer_bria import BriaTransformerBlock
    from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
    from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
    from ..models.transformers.transformer_hunyuan_video import (
        HunyuanVideoSingleTransformerBlock,
        HunyuanVideoTokenReplaceSingleTransformerBlock,
        HunyuanVideoTokenReplaceTransformerBlock,
        HunyuanVideoTransformerBlock,
    )
    from ..models.transformers.transformer_hunyuanimage import (
        HunyuanImageSingleTransformerBlock,
        HunyuanImageTransformerBlock,
    )
    from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
    from ..models.transformers.transformer_mochi import MochiTransformerBlock
    from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
    from ..models.transformers.transformer_wan import WanTransformerBlock
    from ..models.transformers.transformer_z_image import ZImageTransformerBlock

    # BasicTransformerBlock
    TransformerBlockRegistry.register(
        model_class=BasicTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=None,
        ),
    )
    TransformerBlockRegistry.register(
        model_class=BriaTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=None,
        ),
    )

    # CogVideoX
    TransformerBlockRegistry.register(
        model_class=CogVideoXBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )

    # CogView4
    TransformerBlockRegistry.register(
        model_class=CogView4TransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )

    # Flux
    TransformerBlockRegistry.register(
        model_class=FluxTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=1,
            return_encoder_hidden_states_index=0,
        ),
    )
    TransformerBlockRegistry.register(
        model_class=FluxSingleTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=1,
            return_encoder_hidden_states_index=0,
        ),
    )

    # HunyuanVideo
    TransformerBlockRegistry.register(
        model_class=HunyuanVideoTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )
    TransformerBlockRegistry.register(
        model_class=HunyuanVideoSingleTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )
    TransformerBlockRegistry.register(
        model_class=HunyuanVideoTokenReplaceTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )
    TransformerBlockRegistry.register(
        model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )

    # LTXVideo
    TransformerBlockRegistry.register(
        model_class=LTXVideoTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=None,
        ),
    )

    # Mochi
    TransformerBlockRegistry.register(
        model_class=MochiTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )

    # Wan
    TransformerBlockRegistry.register(
        model_class=WanTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=None,
        ),
    )

    # QwenImage
    TransformerBlockRegistry.register(
        model_class=QwenImageTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=1,
            return_encoder_hidden_states_index=0,
        ),
    )

    # HunyuanImage2.1
    TransformerBlockRegistry.register(
        model_class=HunyuanImageTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )
    TransformerBlockRegistry.register(
        model_class=HunyuanImageSingleTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=1,
        ),
    )

    # ZImage
    TransformerBlockRegistry.register(
        model_class=ZImageTransformerBlock,
        metadata=TransformerBlockMetadata(
            return_hidden_states_index=0,
            return_encoder_hidden_states_index=None,
        ),
    )


# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
    hidden_states = kwargs.get("hidden_states", None)
    if hidden_states is None and len(args) > 0:
        hidden_states = args[0]
    return hidden_states


def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
    hidden_states = kwargs.get("hidden_states", None)
    encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
    if hidden_states is None and len(args) > 0:
        hidden_states = args[0]
    if encoder_hidden_states is None and len(args) > 1:
        encoder_hidden_states = args[1]
    return hidden_states, encoder_hidden_states


_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on
