from __future__ import annotations

from typing import Any, Dict, Optional, Union
from pathlib import Path
from collections.abc import Iterable, Sequence
import argparse
import logging
import math
import os

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from diffusers import DiffusionPipeline
from diffusers.models import FluxTransformer2DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import (
    USE_PEFT_BACKEND,
    is_torch_version,
    scale_lora_layers,
    unscale_lora_layers,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


# ---------------------------------------------------------------------------
# LoRA utilities
# ---------------------------------------------------------------------------
class LoRALinear(nn.Module):
    """Wrap an nn.Linear layer with a learnable low-rank update."""

    def __init__(
        self,
        base: nn.Linear,
        rank: int = 64,
        alpha: float = 128.0,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.base = base
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank if rank > 0 else 0.0
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        self.base.weight.requires_grad_(False)
        if self.base.bias is not None:
            self.base.bias.requires_grad_(False)

        if rank > 0:
            device = base.weight.device
            self.lora_a = nn.Parameter(torch.zeros(rank, base.in_features, device=device))
            self.lora_b = nn.Parameter(torch.zeros(base.out_features, rank, device=device))
            nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
            nn.init.zeros_(self.lora_b)
        else:
            self.register_parameter("lora_a", None)
            self.register_parameter("lora_b", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.base(x)
        if self.rank == 0:
            return out

        shape = x.shape
        x_2d = self.dropout(x.reshape(-1, shape[-1]))
        lora = x_2d @ self.lora_a.t()
        lora = lora @ self.lora_b.t()
        lora = lora.view(*shape[:-1], self.base.out_features)
        return out + lora * self.scaling

    @property
    def weight(self) -> torch.Tensor:
        return self.base.weight

    @property
    def bias(self) -> torch.Tensor | None:
        return self.base.bias


def _locate_attr(module: nn.Module, path: str) -> tuple[nn.Module, str]:
    parts = path.split(".")
    parent = module
    for name in parts[:-1]:
        parent = getattr(parent, name)
    return parent, parts[-1]


def inject_lora(
    module: nn.Module,
    targets: Sequence[str],
    rank: int = 64,
    alpha: float = 128.0,
    dropout: float = 0.0,
) -> list[LoRALinear]:
    loras: list[LoRALinear] = []
    for path in targets:
        parent, name = _locate_attr(module, path)
        base = getattr(parent, name)
        if not isinstance(base, nn.Linear):
            raise TypeError(f"{path} is not an nn.Linear (got {type(base).__name__})")
        lora = LoRALinear(base, rank=rank, alpha=alpha, dropout=dropout)
        setattr(parent, name, lora)
        loras.append(lora)
    return loras


def freeze_except_lora(root: nn.Module) -> None:
    for p in root.parameters():
        p.requires_grad = False
    for m in root.modules():
        if isinstance(m, LoRALinear):
            for p in m.parameters():
                if p is not None:
                    p.requires_grad = True


def iter_lora_parameters(root: nn.Module) -> Iterable[nn.Parameter]:
    for m in root.modules():
        if isinstance(m, LoRALinear):
            for p in m.parameters():
                if p is not None and p.requires_grad:
                    yield p


def find_linear_paths(module: nn.Module, prefix: str = "") -> list[str]:
    paths: list[str] = []
    for name, child in module.named_children():
        child_prefix = f"{prefix}.{name}" if prefix else name
        if isinstance(child, nn.Linear):
            paths.append(child_prefix)
        else:
            paths.extend(find_linear_paths(child, child_prefix))
    return paths


# ---------------------------------------------------------------------------
# Custom forward
# ---------------------------------------------------------------------------
def learnibridge_forward(
    self,
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor = None,
    pooled_projections: torch.Tensor = None,
    timestep: torch.LongTensor = None,
    img_ids: torch.Tensor = None,
    txt_ids: torch.Tensor = None,
    guidance: torch.Tensor = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    return_dict: bool = True,
    controlnet_blocks_repeat: bool = False,  # kept for signature compatibility
    lora_dir: str = "adapters",
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
    """
    FluxTransformer2DModel forward with optional step-wise block replacement.
    """

    if joint_attention_kwargs is not None:
        joint_attention_kwargs = joint_attention_kwargs.copy()
        lora_scale = joint_attention_kwargs.pop("scale", 1.0)
    else:
        lora_scale = 1.0

    if USE_PEFT_BACKEND:
        scale_lora_layers(self, lora_scale)
    else:
        if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
            logger.warning("LoRA scale in joint_attention_kwargs is ignored without PEFT backend.")

    hidden_states = self.x_embedder(hidden_states)
    timestep = timestep.to(hidden_states.dtype) * 1000

    if guidance is not None:
        guidance = guidance.to(hidden_states.dtype) * 1000
    else:
        guidance = None

    temb = (
        self.time_text_embed(timestep, pooled_projections)
        if guidance is None
        else self.time_text_embed(timestep, guidance, pooled_projections)
    )

    encoder_hidden_states = self.context_embedder(encoder_hidden_states)

    if txt_ids.ndim == 3:
        txt_ids = txt_ids[0]
    if img_ids.ndim == 3:
        img_ids = img_ids[0]

    ids = torch.cat((txt_ids, img_ids), dim=0)
    image_rotary_emb = self.pos_embed(ids)

    if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
        ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
        ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
        joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})

    # Read replace steps from the instance to avoid external globals in open-source code.
    replace_steps = getattr(self, "replace_steps", set())

    if self.cnt in replace_steps:
        ckpt_path = Path(lora_dir) / f"blocks37_step{self.cnt}.pt"

        if not ckpt_path.exists():
            prev_res = getattr(self, "previous_residual", None)
            if prev_res is not None:
                hidden_states = hidden_states + prev_res
            else:
                logger.warning("previous_residual is None; skipping residual add.")
        else:
            logger.info(f"Loading LoRA checkpoint: {ckpt_path}")
            checkpoint = torch.load(ckpt_path, map_location=self.device)

            block_idx = int(checkpoint.get("block_idx", 37))
            target_block = self.single_transformer_blocks[block_idx]

            # Inject LoRA modules if not already present.
            has_lora = any(isinstance(m, LoRALinear) for m in target_block.modules())
            if not has_lora:
                available_paths = find_linear_paths(target_block)
                preferred = {"linear1", "linear2", "lin"}
                target_paths = [p for p in available_paths if p.split(".")[-1] in preferred]
                if not target_paths:
                    target_paths = available_paths
                logger.info(f"Injecting LoRA into block {block_idx}: {target_paths}")
                inject_lora(target_block, target_paths, rank=32, alpha=64.0, dropout=0.0)

            # Load weights (non-strict to tolerate partial keys).
            target_block.load_state_dict(checkpoint["state_dict"], strict=False)
            target_block.eval()
            target_block = target_block.to(torch.float16)

            # Run only the selected block using cached input.
            with torch.no_grad():
                hidden_states = target_block(
                    hidden_states=self.pre,
                    temb=temb,
                    image_rotary_emb=image_rotary_emb,
                    joint_attention_kwargs=joint_attention_kwargs,
                )

            hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
    else:
        # Normal full forward pass
        ori_hidden_states = hidden_states.clone()

        for block in self.transformer_blocks:
            if torch.is_grad_enabled() and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    encoder_hidden_states,
                    temb,
                    image_rotary_emb,
                    **ckpt_kwargs,
                )
            else:
                encoder_hidden_states, hidden_states = block(
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    temb=temb,
                    image_rotary_emb=image_rotary_emb,
                    joint_attention_kwargs=joint_attention_kwargs,
                )

        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        for index_block, block in enumerate(self.single_transformer_blocks):
            if torch.is_grad_enabled() and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    temb,
                    image_rotary_emb,
                    **ckpt_kwargs,
                )
            else:
                hidden_states = block(
                    hidden_states=hidden_states,
                    temb=temb,
                    image_rotary_emb=image_rotary_emb,
                    joint_attention_kwargs=joint_attention_kwargs,
                )

            if index_block == 36:
                self.pre = hidden_states.clone()

        hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
        self.previous_residual = hidden_states - ori_hidden_states

    hidden_states = self.norm_out(hidden_states, temb)
    output = self.proj_out(hidden_states)

    if USE_PEFT_BACKEND:
        unscale_lora_layers(self, lora_scale)

    self.cnt += 1
    if self.cnt == self.num_steps:
        self.cnt = 0

    if not return_dict:
        return (output,)

    return Transformer2DModelOutput(sample=output)

def build_replace_steps_periodic(
    max_step: int = 49,
    N: int = 8,
    head_excluded: set[int] | None = None,
    periodic_start: int = 10,
) -> set[int]:
    """
    Build replace steps by excluding:
      1) initial fixed steps
      2) periodic steps with period N starting from periodic_start

    Args:
        max_step: Maximum timestep (inclusive).
        N: Period for periodic exclusion.
        head_excluded: Fixed steps to exclude at the beginning.
        periodic_start: First step to start periodic exclusion.

    Returns:
        A set of replace steps.
    """
    if head_excluded is None:
        head_excluded = {0, 1, 2}

    periodic_excluded = {
        t for t in range(periodic_start, max_step + 1, N)
    }

    excluded = head_excluded | periodic_excluded
    return {t for t in range(max_step + 1) if t not in excluded}


# ---------------------------------------------------------------------------
# Script entry (safe defaults)
# ---------------------------------------------------------------------------
def _safe_filename(text: str, max_len: int = 80) -> str:
    """Create a filesystem-safe short filename stem."""
    keep = []
    for ch in text:
        if ch.isalnum() or ch in ("-", "_"):
            keep.append(ch)
        elif ch.isspace():
            keep.append("_")
    stem = "".join(keep).strip("_")
    if not stem:
        stem = "prompt"
    return stem[:max_len]


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="black-forest-labs/FLUX.1-dev")
    parser.add_argument("--prompt", type=str, default="A woman walking through sunflower field at golden hour")
    parser.add_argument("--steps", type=int, default=50)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--lora_dir", type=str, default="adapters")
    parser.add_argument("--output", type=str, default="outputs")
    args = parser.parse_args()

    FluxTransformer2DModel.forward = learnibridge_forward

    pipeline = DiffusionPipeline.from_pretrained(
        args.model,
        torch_dtype=torch.float16,
    )

    # You may choose either cpu-offload or a full GPU placement.
    # Keeping it as a user-controlled choice is more portable for open-source code.
    pipeline.enable_model_cpu_offload()
    pipeline.to(args.device)

    replace_steps = build_replace_steps_periodic(
        max_step=49,
        N=7,
        head_excluded={0, 1},
        periodic_start=2,
    )


    pipeline.transformer.__class__.cnt = 0
    pipeline.transformer.__class__.num_steps = args.steps
    pipeline.transformer.__class__.replace_steps = replace_steps
    pipeline.transformer.__class__.previous_residual = None

    # Make lora_dir configurable and not hard-coded to any private path.
    # The forward reads it from the function arg, so we store it here for clarity.
    pipeline.transformer.__class__.lora_dir = args.lora_dir

    image = pipeline(
        args.prompt,
        num_inference_steps=args.steps,
        generator=torch.Generator("cpu").manual_seed(args.seed),
    ).images[0]

    out_dir = Path(args.output)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_name = f"learnibridge_{_safe_filename(args.prompt)}.png"
    image.save(out_dir / out_name)
    logger.info(f"Saved: {out_dir / out_name}")


if __name__ == "__main__":
    main()
