from dataclasses import dataclass

import torch
from torch import Tensor, nn
from typing import Dict
from flux.modules.layers import (
    DoubleStreamBlock,
    EmbedND,
    LastLayer,
    MLPEmbedder,
    SingleStreamBlock,
    timestep_embedding,
)

def flux_forward(
        self,
        img: Tensor,
        img_ids: Tensor,
        txt: Tensor,
        txt_ids: Tensor,
        timesteps: Tensor,
        y: Tensor,
        guidance: Tensor | None = None,
        *args,
        **kwargs,
    ) -> Tensor:
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")
        
        # running on sequences img
        img = self.img_in(img)
        vec = self.time_in(timestep_embedding(timesteps, 256))
        if self.params.guidance_embed:
            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
        vec = vec + self.vector_in(y)
        txt = self.txt_in(txt)

        ids = torch.cat((txt_ids, img_ids), dim=1)
        pe = self.pe_embedder(ids)

        self.cal_type()

        self.current['stream'] = 'double_stream'
        for i, block in enumerate(self.double_blocks):
            self.current['layer'] = i
            img, txt = block(img=img, txt=txt, vec=vec, pe=pe, cache_dic=self.cache_dic, current=self.current)

        img = torch.cat((txt, img), 1)

        self.current['stream'] = 'single_stream'
        for i, block in enumerate(self.single_blocks):
            self.current['layer'] = i
            img = block(img, vec=vec, pe=pe, cache_dic=self.cache_dic, current=self.current)

        img = img[:, txt.shape[1] :, ...]

        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
        return img