import torch
from typing import Any, List, Tuple, Optional, Union, Dict
from hyvideo.modules.attenion import get_cu_seqlens

def hy_forward(
    self,
    x: torch.Tensor,
    t: torch.Tensor,  # Should be in range(0, 1000).
    text_states: torch.Tensor = None,
    text_mask: torch.Tensor = None,  # Now we don't use it.
    text_states_2: Optional[torch.Tensor] = None,  # Text embedding for modulation.
    freqs_cos: Optional[torch.Tensor] = None,
    freqs_sin: Optional[torch.Tensor] = None,
    guidance: torch.Tensor = None,  # Guidance for modulation, should be cfg_scale x 1000.
    current_step: int = 0,
    return_dict: bool = True,
    **kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
    out = {}
    img = x
    txt = text_states
    _, _, ot, oh, ow = x.shape
    tt, th, tw = (
        ot // self.patch_size[0],
        oh // self.patch_size[1],
        ow // self.patch_size[2],
    )

    # Prepare modulation vectors.
    vec = self.time_in(t)

    # text modulation
    vec = vec + self.vector_in(text_states_2)

    # guidance modulation
    if self.guidance_embed:
        if guidance is None:
            raise ValueError(
                "Didn't get guidance strength for guidance distilled model."
            )

        # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
        vec = vec + self.guidance_in(guidance)

    # Embed image and text.
    img = self.img_in(img)
    if self.text_projection == "linear":
        txt = self.txt_in(txt)
    elif self.text_projection == "single_refiner":
        txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
    else:
        raise NotImplementedError(
            f"Unsupported text_projection: {self.text_projection}"
        )

    txt_seq_len = txt.shape[1]
    img_seq_len = img.shape[1]

    # Compute cu_squlens and max_seqlen for flash attention
    cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
    cu_seqlens_kv = cu_seqlens_q
    max_seqlen_q = img_seq_len + txt_seq_len
    max_seqlen_kv = max_seqlen_q

    freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None

    # --------------------- Pass through DiT blocks ------------------------
    self.current['step'] = current_step

    self.cal_type()

    self.current['stream'] = 'double_stream'

    for i, block in enumerate(self.double_blocks):
        self.current['layer'] = i
        double_block_args = [
            img,
            txt,
            vec,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
            freqs_cis,
            self.cache_dic, ##
            self.current, ##
        ]

        img, txt = block(*double_block_args)

    # Merge txt and img to pass through single stream blocks.
    x = torch.cat((img, txt), 1)

    self.current['stream'] = 'single_stream'

    if len(self.single_blocks) > 0:
        for i, block in enumerate(self.single_blocks):
            self.current['layer'] = i
            single_block_args = [
                x,
                vec,
                txt_seq_len,
                cu_seqlens_q,
                cu_seqlens_kv,
                max_seqlen_q,
                max_seqlen_kv,
                (freqs_cos, freqs_sin),
                self.cache_dic, ##
                self.current, ##
            ]

            x = block(*single_block_args)

    img = x[:, :img_seq_len, ...]

    # ---------------------------- Final layer ------------------------------
    img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)

    img = self.unpatchify(img, tt, th, tw)
    if return_dict:
        out["x"] = img
        return out
    return img