from typing import Any, Dict, Optional, Union
import os
import torch
import numpy as np

from diffusers import DiffusionPipeline
from diffusers.models import FluxTransformer2DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import (
    USE_PEFT_BACKEND,
    is_torch_version,
    logging,
    scale_lora_layers,
    unscale_lora_layers,
)

logger = logging.get_logger(__name__)


def save_feature(feature, save_path: str, feature_name: str, time_step: int) -> None:
    """
    Save intermediate features to disk.
    Supports tensor, tuple/list, and dict structures.
    """
    os.makedirs(save_path, exist_ok=True)
    file_name = f"{feature_name}_step_{time_step}.pt"
    save_file = os.path.join(save_path, file_name)

    if isinstance(feature, torch.Tensor):
        torch.save(feature.detach().cpu(), save_file)

    elif isinstance(feature, (tuple, list)):
        feature = tuple(
            f.detach().cpu() if isinstance(f, torch.Tensor) else f
            for f in feature
        )
        torch.save(feature, save_file)

    elif isinstance(feature, dict):
        feature = {
            k: (v.detach().cpu() if isinstance(v, torch.Tensor) else v)
            for k, v in feature.items()
        }
        torch.save(feature, save_file)

    else:
        logger.warning(f"Unexpected feature type {type(feature)}, saving raw object.")
        torch.save(feature, save_file)


def forward(
    self,
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor = None,
    pooled_projections: torch.Tensor = None,
    timestep: torch.LongTensor = None,
    img_ids: torch.Tensor = None,
    txt_ids: torch.Tensor = None,
    guidance: torch.Tensor = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    controlnet_block_samples=None,
    controlnet_single_block_samples=None,
    return_dict: bool = True,
    controlnet_blocks_repeat: bool = False,
    save_path: str = "features",
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
    """
    Custom forward method with optional intermediate feature dumping.
    """

    if joint_attention_kwargs is not None:
        joint_attention_kwargs = joint_attention_kwargs.copy()
        lora_scale = joint_attention_kwargs.pop("scale", 1.0)
    else:
        lora_scale = 1.0

    if USE_PEFT_BACKEND:
        scale_lora_layers(self, lora_scale)
    else:
        if joint_attention_kwargs is not None and "scale" in joint_attention_kwargs:
            logger.warning("LoRA scale is ignored without PEFT backend.")

    hidden_states = self.x_embedder(hidden_states)

    timestep = timestep.to(hidden_states.dtype) * 1000
    guidance = guidance.to(hidden_states.dtype) * 1000 if guidance is not None else None

    temb = (
        self.time_text_embed(timestep, pooled_projections)
        if guidance is None
        else self.time_text_embed(timestep, guidance, pooled_projections)
    )

    encoder_hidden_states = self.context_embedder(encoder_hidden_states)

    if txt_ids.ndim == 3:
        txt_ids = txt_ids[0]
    if img_ids.ndim == 3:
        img_ids = img_ids[0]

    ids = torch.cat((txt_ids, img_ids), dim=0)
    image_rotary_emb = self.pos_embed(ids)

    if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
        ip_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
        joint_attention_kwargs["ip_hidden_states"] = self.encoder_hid_proj(ip_embeds)

    save_path = getattr(self, "save_path", save_path)
    save_feature(temb, save_path, "temb", self.cnt)
    save_feature(image_rotary_emb, save_path, "image_rotary_emb", self.cnt)

    if joint_attention_kwargs is not None:
        torch.save(
            joint_attention_kwargs,
            os.path.join(save_path, f"joint_attention_kwargs_step_{self.cnt}.pt"),
        )

    for index_block, block in enumerate(self.transformer_blocks):
        if torch.is_grad_enabled() and self.gradient_checkpointing:

            def custom_forward(*inputs):
                return block(*inputs)

            ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
                custom_forward,
                hidden_states,
                encoder_hidden_states,
                temb,
                image_rotary_emb,
                **ckpt_kwargs,
            )
        else:
            encoder_hidden_states, hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                joint_attention_kwargs=joint_attention_kwargs,
            )

        if controlnet_block_samples is not None:
            interval = int(np.ceil(len(self.transformer_blocks) / len(controlnet_block_samples)))
            idx = index_block % len(controlnet_block_samples) if controlnet_blocks_repeat else index_block // interval
            hidden_states = hidden_states + controlnet_block_samples[idx]

    hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

    for index_block, block in enumerate(self.single_transformer_blocks):
        if torch.is_grad_enabled() and self.gradient_checkpointing:

            def custom_forward(*inputs):
                return block(*inputs)

            ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
            hidden_states = torch.utils.checkpoint.checkpoint(
                custom_forward,
                hidden_states,
                temb,
                image_rotary_emb,
                **ckpt_kwargs,
            )
        else:
            hidden_states = block(
                hidden_states=hidden_states,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                joint_attention_kwargs=joint_attention_kwargs,
            )

        save_feature(hidden_states, save_path, f"singleblock_{index_block}_output", self.cnt)

        if controlnet_single_block_samples is not None:
            interval = int(np.ceil(len(self.single_transformer_blocks) / len(controlnet_single_block_samples)))
            hidden_states[:, encoder_hidden_states.shape[1]:] += controlnet_single_block_samples[index_block // interval]

    hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:]
    hidden_states = self.norm_out(hidden_states, temb)
    output = self.proj_out(hidden_states)

    if USE_PEFT_BACKEND:
        unscale_lora_layers(self, lora_scale)

    self.cnt = (self.cnt + 1) % self.num_steps

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)

FluxTransformer2DModel.forward = forward
# =========================
# Inference entry
# =========================

num_inference_steps = 50
seed = 42

pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.float16,
)
pipeline.transformer.__class__.num_steps = num_inference_steps
pipeline.to("cuda")

prompts_file = "train_prompts.txt"
if not os.path.exists(prompts_file):
    raise FileNotFoundError("prompts.txt not found.")

with open(prompts_file, "r", encoding="utf-8") as f:
    prompts = [line.strip() for line in f if line.strip()]

for i, prompt in enumerate(prompts):
    save_dir = os.path.join("features", str(i))
    os.makedirs(save_dir, exist_ok=True)

    pipeline.transformer.__class__.cnt = 0
    pipeline.transformer.save_path = save_dir

    image = pipeline(
        prompt,
        num_inference_steps=num_inference_steps,
        generator=torch.Generator("cpu").manual_seed(seed),
    ).images[0]

    image.save(f"output_{i}.png")