from __future__ import annotations

import logging
from dataclasses import dataclass
import math
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageOps
from diffusers import StableDiffusion3Pipeline
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import BaseOutput
from torchvision import transforms
from torchvision.transforms import InterpolationMode


DEFAULT_SEED = 42
DEFAULT_IMAGE_SIZE = 1024
DEFAULT_COMPUTE_DTYPE = torch.float16

LOGGER = logging.getLogger(__name__)

# --------------------------------------------------------------------------- #
# NaviSD3 editor (self-contained NaviEdit implementation)                   #
# --------------------------------------------------------------------------- #


class NaviSD3(nn.Module):
    """
    NaviEdit core algorithm for SD3/MM-DiT backbones with adaptive scale navigation (ASN)
    and optional auto-masking. Debug / logging paths are stripped for standalone use.
    """

    def __init__(self, unet, scheduler, model_cfg, **kwargs):
        super().__init__()
        self.dit = unet
        self.scheduler = scheduler
        self.device = unet.device

        self.debug_print = bool(model_cfg.get("debug_print", False))

        self.num_inference_steps = int(model_cfg.get("n_steps", 28))
        self.n_avg = int(model_cfg.get("noise_samples", 1))

        self.src_guidance_scale = float(model_cfg.get("src_guidance_scale", 2.0))
        self.tar_guidance_scale = float(model_cfg.get("tar_guidance_scale", 6.0))

        self.t_edit = int(model_cfg.get("t_edit", 28))
        self.t_ref = int(model_cfg.get("t_ref", 22))

        self.edit_dt = float(model_cfg.get("edit_dt", 1.0))
        self.use_equiv_gain = bool(model_cfg.get("use_equiv_gain", True))

        # Mask config
        self.mask_mode = str(model_cfg.get("mask_mode", "none"))
        self.mask_pow = float(model_cfg.get("mask_pow", 1.0))
        self.clamp_strength = float(model_cfg.get("clamp_strength", 0.0))
        self.mask_ema = float(model_cfg.get("mask_ema", 0.85))
        self.mask_quantile = float(model_cfg.get("mask_quantile", 0.90))
        self.mask_min_area = float(model_cfg.get("mask_min_area", 0.02))
        self.mask_max_area = float(model_cfg.get("mask_max_area", 0.65))
        self.mask_grow = bool(model_cfg.get("mask_grow", True))
        self.mask_grow_r_th = float(model_cfg.get("mask_grow_r_th", 0.35))
        self.mask_grow_patience = int(model_cfg.get("mask_grow_patience", 2))
        self.mask_grow_quantile = float(model_cfg.get("mask_grow_quantile", 0.97))
        self.mask_dilate_max_k = int(model_cfg.get("mask_dilate_max_k", 7))
        self.mask_blur_k = int(model_cfg.get("mask_blur_k", 5))

        # CFL safety
        self.use_cfl = bool(model_cfg.get("use_cfl", True))
        self.cfl_tau = float(model_cfg.get("cfl_tau", 8.0))

        # Auto-mask backend
        self._repr_hooks: list = []
        self._repr_enabled = False
        self._repr_block_ids = []
        self._repr_accum = None
        self._repr_count = 0
        self._mask_hw = None
        self._mask_batch_slices = None

        if self.mask_mode == "auto_repr":
            self._try_install_repr_hooks()

        # ASN constants
        self._asn_ema_beta = 0.90
        self._asn_tail_frac = 0.20
        self._asn_eps = 1e-8
        self._asn_p_clamp_eps = 1e-6

    # -------------------------
    # helpers
    # -------------------------
    def _unpack_embed(
        self, embed: Union[torch.Tensor, Tuple[torch.Tensor, ...]]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        if isinstance(embed, (list, tuple)):
            if len(embed) == 2:
                prompt_embeds, pooled_prompt_embeds = embed
                return prompt_embeds, pooled_prompt_embeds, None, None
            if len(embed) == 4:
                prompt_embeds, pooled_prompt_embeds, neg_prompt_embeds, neg_pooled_prompt_embeds = embed
                return prompt_embeds, pooled_prompt_embeds, neg_prompt_embeds, neg_pooled_prompt_embeds
        return embed, None, None, None

    def _retrieve_timesteps(self, device):
        if hasattr(self.scheduler, "set_timesteps"):
            self.scheduler.set_timesteps(self.num_inference_steps, device=device)
            return self.scheduler.timesteps
        return getattr(self.scheduler, "timesteps", [])

    def _normalized_timestep(self, t, timestep_scale: float, device):
        if torch.is_tensor(t):
            t_tensor = t.to(device=device, dtype=torch.float32)
        else:
            t_tensor = torch.tensor(t, device=device, dtype=torch.float32)
        if timestep_scale <= 1.0:
            return t_tensor
        return t_tensor / float(timestep_scale)

    def _mix_with_noise(self, x_src, noise, t_scalar: torch.Tensor):
        return (1.0 - t_scalar) * x_src + t_scalar * noise

    def _calc_v_sd3(
        self,
        latent_model_input: torch.Tensor,
        prompt_embeds: torch.Tensor,
        pooled_prompt_embeds: Optional[torch.Tensor],
        t,
        do_cfg: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        model_dtype = getattr(self.dit, "dtype", latent_model_input.dtype)
        model_input = latent_model_input.to(dtype=model_dtype)
        if hasattr(self.scheduler, "scale_model_input"):
            model_input = self.scheduler.scale_model_input(model_input, t)

        if not torch.is_tensor(t):
            timestep = torch.tensor([t], device=model_input.device, dtype=torch.float32)
        else:
            timestep = t.to(device=model_input.device, dtype=torch.float32)
        if timestep.dim() == 0:
            timestep = timestep.reshape(1)
        timestep = timestep.expand(model_input.shape[0])

        model_output = self.dit(
            hidden_states=model_input,
            timestep=timestep,
            encoder_hidden_states=prompt_embeds,
            pooled_projections=pooled_prompt_embeds,
            joint_attention_kwargs=None,
            return_dict=False,
        )[0]

        if do_cfg:
            v_uncond_src, v_text_src, v_uncond_tar, v_text_tar = model_output.chunk(4, dim=0)
            v_src = v_uncond_src + self.src_guidance_scale * (v_text_src - v_uncond_src)
            v_tar = v_uncond_tar + self.tar_guidance_scale * (v_text_tar - v_uncond_tar)
            return v_src, v_tar

        v_src, v_tar = model_output.chunk(2, dim=0)
        return v_src, v_tar

    def _prepare_prompt_embeds(
        self,
        src_embed: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
        tar_embed: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
        reference_tensor: torch.Tensor,
        device,
    ) -> Tuple[bool, torch.Tensor, Optional[torch.Tensor]]:
        src_prompt_embeds, src_pooled_prompt_embeds, src_neg_prompt_embeds, src_neg_pooled_prompt_embeds = (
            self._unpack_embed(src_embed)
        )
        tar_prompt_embeds, tar_pooled_prompt_embeds, tar_neg_prompt_embeds, tar_neg_pooled_prompt_embeds = (
            self._unpack_embed(tar_embed)
        )

        model_dtype = getattr(self.dit, "dtype", reference_tensor.dtype)
        src_prompt_embeds = src_prompt_embeds.to(device=device, dtype=model_dtype)
        tar_prompt_embeds = tar_prompt_embeds.to(device=device, dtype=model_dtype)
        if src_pooled_prompt_embeds is not None:
            src_pooled_prompt_embeds = src_pooled_prompt_embeds.to(device=device, dtype=model_dtype)
        if tar_pooled_prompt_embeds is not None:
            tar_pooled_prompt_embeds = tar_pooled_prompt_embeds.to(device=device, dtype=model_dtype)

        do_cfg = (self.src_guidance_scale > 1.0) or (self.tar_guidance_scale > 1.0)
        if do_cfg:
            if src_neg_prompt_embeds is None:
                src_neg_prompt_embeds = torch.zeros_like(src_prompt_embeds)
            else:
                src_neg_prompt_embeds = src_neg_prompt_embeds.to(device=device, dtype=model_dtype)
            if tar_neg_prompt_embeds is None:
                tar_neg_prompt_embeds = torch.zeros_like(tar_prompt_embeds)
            else:
                tar_neg_prompt_embeds = tar_neg_prompt_embeds.to(device=device, dtype=model_dtype)

            if src_pooled_prompt_embeds is not None:
                if src_neg_pooled_prompt_embeds is None:
                    src_neg_pooled_prompt_embeds = torch.zeros_like(src_pooled_prompt_embeds)
                else:
                    src_neg_pooled_prompt_embeds = src_neg_pooled_prompt_embeds.to(device=device, dtype=model_dtype)
            if tar_pooled_prompt_embeds is not None:
                if tar_neg_pooled_prompt_embeds is None:
                    tar_neg_pooled_prompt_embeds = torch.zeros_like(tar_pooled_prompt_embeds)
                else:
                    tar_neg_pooled_prompt_embeds = tar_neg_pooled_prompt_embeds.to(device=device, dtype=model_dtype)

            src_tar_prompt_embeds = torch.cat(
                [src_neg_prompt_embeds, src_prompt_embeds, tar_neg_prompt_embeds, tar_prompt_embeds],
                dim=0,
            )
            if src_pooled_prompt_embeds is not None and tar_pooled_prompt_embeds is not None:
                src_tar_pooled_prompt_embeds = torch.cat(
                    [src_neg_pooled_prompt_embeds, src_pooled_prompt_embeds, tar_neg_pooled_prompt_embeds, tar_pooled_prompt_embeds],
                    dim=0,
                )
            else:
                src_tar_pooled_prompt_embeds = None
        else:
            src_tar_prompt_embeds = torch.cat([src_prompt_embeds, tar_prompt_embeds], dim=0)
            if src_pooled_prompt_embeds is not None and tar_pooled_prompt_embeds is not None:
                src_tar_pooled_prompt_embeds = torch.cat([src_pooled_prompt_embeds, tar_pooled_prompt_embeds], dim=0)
            else:
                src_tar_pooled_prompt_embeds = None

        return do_cfg, src_tar_prompt_embeds, src_tar_pooled_prompt_embeds

    def _prepare_timesteps(self, device) -> Tuple[torch.Tensor, float]:
        timesteps = self._retrieve_timesteps(device=device)
        if torch.is_tensor(timesteps):
            timesteps_list = timesteps.to(device=device, dtype=torch.float32)
        else:
            timesteps_list = torch.tensor(list(timesteps), device=device, dtype=torch.float32)

        if timesteps_list.numel() == 0:
            raise ValueError("Scheduler timesteps are empty.")

        max_timestep = float(timesteps_list.max().item())
        scheduler_config = getattr(self.scheduler, "config", None)
        timestep_scale = float(getattr(scheduler_config, "num_train_timesteps", max_timestep))
        if max_timestep <= 1.0:
            timestep_scale = 1.0
        return timesteps_list, timestep_scale

    def _build_monotone_path(self, timesteps_list: torch.Tensor, timestep_scale: float, device):
        t_list = timesteps_list.to(device=device, dtype=torch.float32)
        u_list = self._normalized_timestep(t_list, timestep_scale, device=device)  # [N]
        order = torch.argsort(u_list, descending=True)
        t_path = t_list[order]
        u_path = u_list[order]
        return t_path, u_path

    def _interp_from_p(
        self,
        t_path: torch.Tensor,
        u_path: torch.Tensor,
        p: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        N = int(t_path.numel())
        if N <= 1:
            return t_path[0], u_path[0]

        p = p.clamp(0.0, 1.0)
        idx_f = p * float(N - 1)
        idx0 = torch.floor(idx_f).to(torch.long)
        idx1 = (idx0 + 1).clamp(max=N - 1)
        frac = (idx_f - idx0.to(torch.float32))

        t0 = t_path[idx0]
        t1 = t_path[idx1]
        u0 = u_path[idx0]
        u1 = u_path[idx1]

        t_val = t0 * (1.0 - frac) + t1 * frac
        u_val = u0 * (1.0 - frac) + u1 * frac
        return t_val, u_val

    # -------------------------
    # Mask utilities
    # -------------------------
    @staticmethod
    def _odd(k: int) -> int:
        k = int(k)
        if k <= 1:
            return 1
        return k if (k % 2 == 1) else (k + 1)

    def _blur2d(self, x: torch.Tensor, k: int) -> torch.Tensor:
        k = self._odd(k)
        if k <= 1:
            return x
        return F.avg_pool2d(x, kernel_size=k, stride=1, padding=k // 2)

    def _dilate2d(self, x: torch.Tensor, k: int) -> torch.Tensor:
        k = self._odd(k)
        if k <= 1:
            return x
        return F.max_pool2d(x, kernel_size=k, stride=1, padding=k // 2)

    def _normalize_map(self, m: torch.Tensor) -> torch.Tensor:
        m = m.float()
        m_min = m.amin(dim=(2, 3), keepdim=True)
        m_max = m.amax(dim=(2, 3), keepdim=True)
        return (m - m_min) / (m_max - m_min + 1e-6)

    def _quantile_binarize(self, m: torch.Tensor, q: float) -> torch.Tensor:
        B = m.shape[0]
        flat = m.view(B, -1)
        thr = torch.quantile(flat, q=q, dim=1, keepdim=True)
        thr = thr.view(B, 1, 1, 1)
        return (m >= thr).float()

    def _enforce_area_bounds(self, M: torch.Tensor) -> torch.Tensor:
        B, _, H, W = M.shape
        area = M.mean(dim=(2, 3), keepdim=True)
        min_a = self.mask_min_area
        max_a = self.mask_max_area

        flat = M.view(B, -1)
        q = self.mask_quantile

        q_hi = (q + 0.08)
        q_lo = (q - 0.12)

        q_use = torch.full((B, 1), q, device=M.device, dtype=torch.float32)
        q_use = torch.where(area.view(B, 1) > max_a, torch.full_like(q_use, q_hi), q_use)
        q_use = torch.where(area.view(B, 1) < min_a, torch.full_like(q_use, q_lo), q_use)
        q_use = q_use.clamp(0.5, 0.995)

        thr = torch.quantile(flat, q=q_use.squeeze(1), dim=1).view(B, 1, 1, 1)
        return (M >= thr).float()

    # -------------------------
    # Auto-mask backend (repr hooks)
    # -------------------------
    def _try_install_repr_hooks(self):
        if self._repr_enabled:
            return

        blocks = None
        if hasattr(self.dit, "transformer_blocks"):
            blocks = getattr(self.dit, "transformer_blocks")
        elif hasattr(self.dit, "blocks"):
            blocks = getattr(self.dit, "blocks")

        if blocks is None:
            self._repr_enabled = False
            return

        n_blocks = len(blocks)
        if n_blocks <= 0:
            self._repr_enabled = False
            return

        lo = max(0, n_blocks // 3)
        hi = max(lo + 1, (2 * n_blocks) // 3)
        self._repr_block_ids = list(range(lo, hi))

        def _hook_fn(module, inputs, output):
            if output is None:
                return
            if not isinstance(output, torch.Tensor):
                return
            if output.dim() < 4:
                return
            if self._repr_accum is None or self._mask_hw is None or self._mask_batch_slices is None:
                return
            B = self._mask_batch_slices[0].stop - self._mask_batch_slices[0].start
            feat = output
            feat = feat.float().mean(dim=1, keepdim=True)
            feat = F.adaptive_avg_pool2d(feat, self._mask_hw)
            self._repr_accum += feat[:B]
            self._repr_accum += feat[self._mask_batch_slices[0]]
            self._repr_accum += feat[self._mask_batch_slices[1]]
            self._repr_count += 3

        for idx in self._repr_block_ids:
            handle = blocks[idx].register_forward_hook(_hook_fn, with_kwargs=False)
            self._repr_hooks.append(handle)

        self._repr_enabled = True

    def _reset_repr_accum(self, H: int, W: int, B: int, do_cfg: bool):
        self._mask_hw = (H, W)
        self._repr_accum = torch.zeros((B, 1, H, W), device=self.device, dtype=torch.float32)
        self._repr_count = 0
        if do_cfg:
            self._mask_batch_slices = (slice(B, 2 * B), slice(3 * B, 4 * B))
        else:
            self._mask_batch_slices = (slice(0, B), slice(B, 2 * B))

    def _finalize_repr_mask(self, prev_M: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if (self._repr_accum is None) or (self._repr_count <= 0):
            return None
        heat = self._repr_accum / float(self._repr_count)
        heat = self._normalize_map(heat)
        heat = self._blur2d(heat, self.mask_blur_k)
        M = self._quantile_binarize(heat, self.mask_quantile)
        M = self._enforce_area_bounds(M)
        if prev_M is not None:
            M = self.mask_ema * prev_M + (1.0 - self.mask_ema) * M
            M = M.clamp(0.0, 1.0)
        return M

    # -------------------------
    # Fallback auto mask: ΔV-saliency
    # -------------------------
    def _dv_mask(self, dv: torch.Tensor, prev_M: Optional[torch.Tensor]) -> torch.Tensor:
        sal = dv.float().pow(2).mean(dim=1, keepdim=True)
        sal = self._normalize_map(sal)
        sal = self._blur2d(sal, self.mask_blur_k)
        M = self._quantile_binarize(sal, self.mask_quantile)
        M = self._enforce_area_bounds(M)
        if prev_M is not None:
            M = self.mask_ema * prev_M + (1.0 - self.mask_ema) * M
            M = M.clamp(0.0, 1.0)
        return M

    # -------------------------
    # Active-set growth
    # -------------------------
    def _maybe_grow_mask(
        self,
        M: torch.Tensor,
        dv: torch.Tensor,
        grow_counter: int,
    ) -> Tuple[torch.Tensor, int]:
        if not self.mask_grow:
            return M, 0

        outside = (1.0 - M) * dv
        num = outside.float().pow(2).mean(dim=1).sum(dim=(1, 2)).sqrt()
        den = dv.float().pow(2).mean(dim=1).sum(dim=(1, 2)).sqrt().clamp_min(1e-8)
        r = (num / den).mean().item()

        if r > self.mask_grow_r_th:
            grow_counter += 1
        else:
            grow_counter = 0

        if grow_counter < self.mask_grow_patience:
            return M, grow_counter

        pressure_map = outside.float().pow(2).mean(dim=1, keepdim=True)
        pressure_map = self._normalize_map(pressure_map)

        B = pressure_map.shape[0]
        flat = pressure_map.view(B, -1)
        thr = torch.quantile(flat, q=self.mask_grow_quantile, dim=1).view(B, 1, 1, 1)
        add = (pressure_map >= thr).float()

        k = self._odd(max(1, self.mask_dilate_max_k // 2))
        add = self._dilate2d(add, k)

        M_new = torch.maximum(M, add)
        M_new = self._enforce_area_bounds(M_new)
        return M_new, 0

    # -------------------------
    # ASN policy: pick next p
    # -------------------------
    def _asn_pick_p_next(
        self,
        p: float,
        i: int,
        t_edit: int,
        dv_eff: torch.Tensor,
        prev_dv_eff: Optional[torch.Tensor],
        M: Optional[torch.Tensor],
        mag_ema: Optional[float],
    ) -> Tuple[float, float, float, float]:
        eps = self._asn_eps

        steps_left = max(1, t_edit - i)
        rem = max(0.0, 1.0 - p)
        dp_min = rem / float(steps_left)

        mag = (dv_eff.float().pow(2).mean(dim=1).mean(dim=(1, 2)).sqrt().mean()).item()

        if mag_ema is None:
            mag_ema_new = mag
        else:
            mag_ema_new = self._asn_ema_beta * mag_ema + (1.0 - self._asn_ema_beta) * mag

        mag_norm = mag / (mag_ema_new + eps)

        osc = 0.0
        if prev_dv_eff is not None:
            a = dv_eff.float().flatten(1)
            b = prev_dv_eff.float().flatten(1)
            a_n = a / (a.norm(dim=1, keepdim=True) + eps)
            b_n = b / (b.norm(dim=1, keepdim=True) + eps)
            cos = (a_n * b_n).sum(dim=1).mean().clamp(-1.0, 1.0).item()
            osc = float(1.0 - cos)

        pressure = 0.0
        if M is not None:
            outside = (1.0 - M) * dv_eff
            num = outside.float().pow(2).mean(dim=1).sum(dim=(1, 2)).sqrt().mean().item()
            den = dv_eff.float().pow(2).mean(dim=1).sum(dim=(1, 2)).sqrt().mean().item()
            pressure = float(num / (den + eps))

        r = float(mag_norm * (1.0 + osc) * (1.0 + pressure))
        speed = float(torch.sigmoid(torch.tensor(-math.log(r + eps))).item())

        tail_steps = max(2, int(round(self._asn_tail_frac * t_edit)))
        if i >= (t_edit - tail_steps):
            speed = 0.0

        mult = 0.5 + speed
        dp = dp_min * mult

        if i < (t_edit - 1):
            p_next = min(1.0 - self._asn_p_clamp_eps, p + dp)
        else:
            p_next = 1.0

        p_next = max(p, p_next)
        p_next = min(1.0, p_next)
        return p_next, mag_ema_new, r, osc

    # -------------------------
    # forward (quiet version)
    # -------------------------
    @torch.no_grad()
    def forward_noprint(
        self,
        x_src: torch.Tensor,
        src_embed,
        edit_embed,
        noise: Optional[torch.Tensor],
        mask: Optional[torch.Tensor] = None,
    ):
        device = x_src.device
        _ = noise  # unused; kept for interface compatibility

        do_cfg, src_tar_prompt_embeds, src_tar_pooled_prompt_embeds = self._prepare_prompt_embeds(
            src_embed, edit_embed, x_src, device
        )
        timesteps_list, timestep_scale = self._prepare_timesteps(device=device)

        total_steps = int(timesteps_list.numel())
        if total_steps <= 0:
            raise ValueError("Scheduler timesteps are empty.")

        n_avg = max(1, int(self.n_avg))
        t_edit = max(1, int(self.t_edit))

        t_path, u_path = self._build_monotone_path(timesteps_list, timestep_scale, device=device)
        N = int(t_path.numel())
        if N <= 0:
            raise ValueError("Empty timestep path.")

        if N <= 1:
            p = 0.0
        else:
            t_ref = self.t_ref
            start_idx = max(0, N - t_ref)
            p = float(start_idx) / float(N - 1)

        ds = float(self.edit_dt)

        zt_edit = x_src.clone()

        B, _, H, W = x_src.shape
        M = None
        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            if mask.shape[0] != B:
                raise ValueError("mask batch size mismatch.")
            if mask.shape[-2:] != (H, W):
                mask = F.interpolate(mask.float(), size=(H, W), mode="nearest")
            M = mask.float().clamp(0.0, 1.0)

        if self.mask_mode == "auto_repr" and not self._repr_enabled:
            self._try_install_repr_hooks()

        grow_counter = 0
        u_hat_list = []

        mag_ema = None
        prev_dv_eff = None

        for i in range(t_edit):
            p_t = torch.tensor(p, device=device, dtype=torch.float32)
            u_i, u_norm_i = self._interp_from_p(t_path, u_path, p_t)

            step_noise = torch.randn_like(x_src)

            zt_src = self._mix_with_noise(x_src, step_noise, u_norm_i)
            zt_tar = zt_edit + zt_src - x_src

            if do_cfg:
                latent_model_input = torch.cat([zt_src, zt_src, zt_tar, zt_tar], dim=0)
            else:
                latent_model_input = torch.cat([zt_src, zt_tar], dim=0)

            if self.mask_mode == "auto_repr" and self._repr_enabled:
                self._reset_repr_accum(H=H, W=W, B=B, do_cfg=do_cfg)

            v_src, v_tar = self._calc_v_sd3(
                latent_model_input,
                src_tar_prompt_embeds,
                src_tar_pooled_prompt_embeds,
                u_i,
                do_cfg,
            )
            dv = (v_tar - v_src)

            if n_avg > 1:
                dv = dv.clone()

            if self.mask_mode == "none":
                pass
            elif self.mask_mode == "user":
                pass
            else:
                prev_M = M
                M_new = None

                if self.mask_mode == "auto_repr" and self._repr_enabled:
                    M_new = self._finalize_repr_mask(prev_M)

                if M_new is None:
                    M_new = self._dv_mask(dv, prev_M)

                un = float(u_norm_i.mean().item())
                k_dil = 1 + int(round((self.mask_dilate_max_k - 1) * un))
                k_dil = self._odd(max(1, min(k_dil, self.mask_dilate_max_k)))
                M_new = self._dilate2d(M_new, k_dil).clamp(0.0, 1.0)
                M = M_new

            if M is not None:
                M, grow_counter = self._maybe_grow_mask(M, dv, grow_counter)

            dv_eff = dv
            if M is not None:
                dv_eff = (M.clamp(0.0, 1.0) ** self.mask_pow) * dv

            p_next, mag_ema, _r, _osc = self._asn_pick_p_next(
                p=p,
                i=i,
                t_edit=t_edit,
                dv_eff=dv_eff,
                prev_dv_eff=prev_dv_eff,
                M=M,
                mag_ema=mag_ema,
            )

            p_next_t = torch.tensor(p_next, device=device, dtype=torch.float32)
            u_next, u_norm_next = self._interp_from_p(t_path, u_path, p_next_t)

            dt_u = (u_norm_next - u_norm_i)

            if self.use_equiv_gain:
                gain = dt_u / ds
            else:
                gain = torch.ones_like(dt_u)

            step_scale = 1.0
            if self.use_cfl:
                upd_mag = (dv_eff.float().pow(2).mean(dim=1).sum(dim=(1, 2)).sqrt().mean()).item()
                dt_mag = float(dt_u.abs().mean().item())
                if (upd_mag > 1e-8) and (dt_mag > 0.0):
                    step_scale = min(1.0, self.cfl_tau / (dt_mag * upd_mag + 1e-6))

            zt_edit = zt_edit.to(torch.float32)
            zt_edit = zt_edit + (ds * gain * step_scale) * dv_eff.to(torch.float32)
            zt_edit = zt_edit.to(dv_eff.dtype)

            if (M is not None) and (self.clamp_strength > 0.0):
                alpha = float(self.clamp_strength)
                zt_edit = zt_edit + alpha * (1.0 - M) * (x_src - zt_edit)

            u_hat_list.append(dv_eff)

            prev_dv_eff = dv_eff.detach()
            p = float(p_next)

        x0_pred = zt_edit
        energy_scale = [0.0] * x0_pred.shape[0]
        return x0_pred, u_hat_list, energy_scale


# --------------------------------------------------------------------------- #
# Pipeline output container                                                   #
# --------------------------------------------------------------------------- #


@dataclass
class NaviEditPipelineOutput(BaseOutput):
    images: List[Image.Image] | torch.Tensor
    latents: torch.Tensor


class _CenterSquareCropTransform:
    """Center-crop the shorter image dimension before resizing."""

    def __call__(self, image: Image.Image) -> Image.Image:
        width, height = image.size
        if width == height:
            return image
        target = min(width, height)
        try:
            resample = Image.Resampling.LANCZOS  # type: ignore[attr-defined]
        except AttributeError:  # pragma: no cover
            resample = Image.LANCZOS
        return ImageOps.fit(
            image,
            (target, target),
            method=resample,
            centering=(0.5, 0.5),
        )

    def __repr__(self) -> str:  # pragma: no cover - debugging helper
        return f"{self.__class__.__name__}()"


# --------------------------------------------------------------------------- #
# NaviEdit Pipeline                                                           #
# --------------------------------------------------------------------------- #


class NaviEditPipeline(DiffusionPipeline):
    """Standalone NaviEdit pipeline for SD3/SD3.5 editing."""

    def __init__(
        self,
        sd_pipe: StableDiffusion3Pipeline,
        *,
        default_edit_config: Optional[Dict[str, Any]] = None,
        image_size: int = DEFAULT_IMAGE_SIZE,
        device: Optional[str | torch.device] = None,
        compute_dtype: torch.dtype = DEFAULT_COMPUTE_DTYPE,
        text_dtype: Optional[torch.dtype] = None,
        use_center_crop: bool = True,
        negative_prompt: str | Sequence[str] = "",
    ) -> None:
        super().__init__()

        self._device = torch.device(
            device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
        )
        self._compute_dtype = self._pick_compute_dtype(compute_dtype)
        self._text_dtype = text_dtype or self._compute_dtype
        self._use_center_crop = bool(use_center_crop)
        self.image_size = int(image_size)
        self.default_edit_config = default_edit_config or {}
        self._negative_prompt = negative_prompt
        self._sd_pipe = sd_pipe

        # Keep module references consistent with the original SD3 pipeline.
        self.register_modules(
            transformer=sd_pipe.transformer,
            scheduler=sd_pipe.scheduler,
            vae=sd_pipe.vae,
            text_encoder=sd_pipe.text_encoder,
            text_encoder_2=getattr(sd_pipe, "text_encoder_2", None),
            text_encoder_3=getattr(sd_pipe, "text_encoder_3", None),
            tokenizer=sd_pipe.tokenizer,
            tokenizer_2=getattr(sd_pipe, "tokenizer_2", None),
            tokenizer_3=getattr(sd_pipe, "tokenizer_3", None),
        )

        self._set_module_precisions()
        self.transformer.eval()
        self.vae.eval()
        for enc in (self.text_encoder, self.text_encoder_2, self.text_encoder_3):
            if enc is not None:
                enc.eval()

        self._vae_transform = self._build_vae_transform()
        self._navi_model: Optional[NaviSD3] = None
        self._last_model_cfg: Optional[Dict[str, Any]] = None

    # ------------------------------------------------------------------ #
    # Construction helpers
    # ------------------------------------------------------------------ #
    @classmethod
    def from_pretrained_sd3(
        cls,
        model_path: str,
        *,
        default_edit_config: Optional[Dict[str, Any]] = None,
        device: Optional[str | torch.device] = None,
        torch_dtype: torch.dtype = DEFAULT_COMPUTE_DTYPE,
        text_dtype: Optional[torch.dtype] = None,
        image_size: int = DEFAULT_IMAGE_SIZE,
        use_center_crop: bool = True,
        negative_prompt: str | Sequence[str] = "",
        local_files_only: bool = True,
    ) -> "NaviEditPipeline":
        """Instantiate the pipeline directly from an SD3/SD3.5 directory."""
        model_path = str(Path(model_path).expanduser().resolve())
        pipe = StableDiffusion3Pipeline.from_pretrained(
            model_path,
            torch_dtype=torch_dtype,
            local_files_only=local_files_only,
        )
        pipe.set_progress_bar_config(disable=True)
        return cls(
            sd_pipe=pipe,
            default_edit_config=default_edit_config,
            image_size=image_size,
            device=device,
            compute_dtype=torch_dtype,
            text_dtype=text_dtype,
            use_center_crop=use_center_crop,
            negative_prompt=negative_prompt,
        )

    # ------------------------------------------------------------------ #
    # Public API
    # ------------------------------------------------------------------ #
    @torch.no_grad()
    def __call__(
        self,
        image: Image.Image | torch.Tensor,
        *,
        source_prompt: str,
        target_prompt: str,
        edit_config: Optional[Dict[str, Any]] = None,
        seed: Optional[int] = None,
        mask: Optional[torch.Tensor] = None,
        output_type: str = "pil",
    ) -> NaviEditPipelineOutput:
        """Run NaviEdit once on a single image."""

        cfg = self._prepare_edit_config(edit_config)
        self._ensure_navi_model(cfg)

        seed_value = int(seed) if seed is not None else DEFAULT_SEED
        torch.manual_seed(seed_value)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed_value)

        pixel_values = self._prepare_image_tensor(image)
        latents = self._encode_image_to_latent(pixel_values)

        src_embed = self.encode_prompt([source_prompt])
        tgt_embed = self.encode_prompt([target_prompt])

        x0_pred, _, _ = self._navi_model.forward_noprint(  # type: ignore[arg-type]
            x_src=latents,
            src_embed=src_embed,
            edit_embed=tgt_embed,
            noise=None,
            mask=mask,
        )

        decoded = self._decode_latent_to_image(x0_pred)
        images = self._tensor_to_pil(decoded) if output_type == "pil" else decoded

        return NaviEditPipelineOutput(
            images=images,
            latents=x0_pred.detach().cpu(),
        )

    def encode_prompt(self, prompts: Sequence[str]) -> Any:
        """Encode text prompts using the underlying SD3 text encoders."""
        return self._encode_prompts(prompts)

    # ------------------------------------------------------------------ #
    # Internal helpers
    # ------------------------------------------------------------------ #
    def _pick_compute_dtype(self, dtype: torch.dtype) -> torch.dtype:
        if self._device.type == "cpu" and dtype in {torch.float16, torch.bfloat16}:
            LOGGER.warning("Requested %s on CPU; falling back to float32 for stability.", dtype)
            return torch.float32
        return dtype

    def _set_module_precisions(self) -> None:
        model_dtype = self._compute_dtype
        for module in (self.transformer, self.vae):
            if module is not None:
                module.to(device=self._device, dtype=model_dtype)

        text_dtype = self._text_dtype or model_dtype
        for enc in (self.text_encoder, self.text_encoder_2, self.text_encoder_3):
            if enc is not None:
                enc.to(device=self._device, dtype=text_dtype)

    def _prepare_edit_config(self, edit_config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
        cfg = dict(self.default_edit_config)
        if edit_config:
            cfg.update(edit_config)
        return cfg

    def _ensure_navi_model(self, cfg: Dict[str, Any]) -> None:
        if self._navi_model is None or self._last_model_cfg != cfg:
            self._navi_model = NaviSD3(
                unet=self.transformer,
                scheduler=self.scheduler,
                model_cfg=cfg,
            )
            self._navi_model.eval()
            self._last_model_cfg = dict(cfg)

    def _prepare_image_tensor(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
        if isinstance(image, Image.Image):
            vae_tensor = self._vae_transform(image.convert("RGB"))
        elif torch.is_tensor(image):
            tensor = image.float()
            if tensor.ndim == 3:
                tensor = tensor.unsqueeze(0)
            if tensor.max() > 1.0:
                tensor = tensor / 255.0
            tensor = tensor * 2.0 - 1.0
            vae_tensor = tensor
        else:
            raise TypeError("image must be a PIL.Image or a torch.Tensor.")

        if vae_tensor.ndim == 3:
            vae_tensor = vae_tensor.unsqueeze(0)

        if self._use_center_crop and vae_tensor.ndim == 4:
            _, _, height, width = vae_tensor.shape
            if height != width:
                side = min(height, width)
                top = (height - side) // 2
                left = (width - side) // 2
                vae_tensor = vae_tensor[:, :, top : top + side, left : left + side]

        return vae_tensor.to(device=self._device, dtype=self._compute_dtype)

    def _encode_image_to_latent(self, pixel_values: torch.Tensor) -> torch.Tensor:
        scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
        dtype = self.vae.dtype
        device = self._device
        autocast_device = "cuda" if device.type == "cuda" else device.type
        use_autocast = device.type in {"cuda", "mps"} and dtype in {torch.float16, torch.bfloat16}
        pixel_values = pixel_values.to(device=device, dtype=dtype)
        with torch.autocast(device_type=autocast_device, dtype=dtype, enabled=use_autocast):
            latent_dist = self.vae.encode(pixel_values).latent_dist
            latents = latent_dist.mode()
        latents = latents * scaling_factor
        return latents.to(device=self._device, dtype=self._compute_dtype)

    def _decode_latent_to_image(self, latents: torch.Tensor) -> torch.Tensor:
        scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
        dtype = self.vae.dtype
        device = self._device
        autocast_device = "cuda" if device.type == "cuda" else device.type
        use_autocast = device.type in {"cuda", "mps"} and dtype in {torch.float16, torch.bfloat16}

        latents = latents.to(device=device, dtype=dtype)
        with torch.autocast(device_type=autocast_device, dtype=dtype, enabled=use_autocast):
            decoded = self.vae.decode(latents / scaling_factor).sample
        decoded = decoded.clamp(-1.0, 1.0)
        decoded = (decoded + 1.0) / 2.0
        return decoded.to(dtype=self._compute_dtype)

    def _encode_prompts(self, prompts: Sequence[str]) -> Any:
        negative = self._negative_prompt
        if isinstance(negative, str):
            negative = [negative] * len(prompts)
        encode_kwargs = dict(
            prompt=list(prompts),
            prompt_2=None,
            prompt_3=None,
            negative_prompt=negative,
            do_classifier_free_guidance=True,
            device=self._device,
            num_images_per_prompt=1,
        )
        prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = (
            self._sd_pipe.encode_prompt(**encode_kwargs)
        )
        return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds

    def _tensor_to_pil(self, tensor: torch.Tensor) -> List[Image.Image]:
        tensor = tensor.detach().cpu().clamp(0.0, 1.0)
        to_pil = transforms.ToPILImage()
        return [to_pil(sample) for sample in tensor]

    def _build_vae_transform(self) -> transforms.Compose:
        """Create image->latent preprocessing transform."""
        ops: List[Any] = []
        if self._use_center_crop:
            ops.append(_CenterSquareCropTransform())
            resize_interp = InterpolationMode.LANCZOS
        else:
            resize_interp = InterpolationMode.BILINEAR
        ops.append(
            transforms.Resize(
                (self.image_size, self.image_size),
                interpolation=resize_interp,
            )
        )
        ops.extend(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
        return transforms.Compose(ops)
