# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig


@dataclass
class FluxVAEArchConfig(VAEArchConfig):
    spatial_compression_ratio: int = 1

    base_dim: int = 96
    decoder_base_dim: int | None = None
    z_dim: int = 16
    dim_mult: tuple[int, ...] = (1, 2, 4, 4)
    num_res_blocks: int = 2
    attn_scales: tuple[float, ...] = ()
    temperal_downsample: tuple[bool, ...] = (False, True, True)
    dropout: float = 0.0

    is_residual: bool = False
    in_channels: int = 3
    out_channels: int = 3
    patch_size: int | None = None
    scale_factor_temporal: int = 4
    scale_factor_spatial: int = 8
    clip_output: bool = True


@dataclass
class Flux2VAEArchConfig(FluxVAEArchConfig):
    pass


@dataclass
class FluxVAEConfig(VAEConfig):
    arch_config: FluxVAEArchConfig = field(default_factory=FluxVAEArchConfig)

    use_feature_cache: bool = True

    use_tiling: bool = False
    use_temporal_tiling: bool = False
    use_parallel_tiling: bool = False

    def __post_init__(self):
        self.blend_num_frames = (
            self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
        ) * 2

    def post_init(self):
        # Calculate vae_scale_factor: prefer block_out_channels, fallback to dim_mult or scale_factor_spatial
        if (
            hasattr(self.arch_config, "block_out_channels")
            and self.arch_config.block_out_channels
        ):
            self.arch_config.vae_scale_factor = 2 ** (
                len(self.arch_config.block_out_channels) - 1
            )
        elif self.arch_config.dim_mult:
            self.arch_config.vae_scale_factor = 2 ** (
                len(self.arch_config.dim_mult) - 1
            )
        else:
            self.arch_config.vae_scale_factor = self.arch_config.scale_factor_spatial

        self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor


@dataclass
class Flux2VAEConfig(FluxVAEConfig):
    arch_config: Flux2VAEArchConfig = field(default_factory=Flux2VAEArchConfig)
