# -*- coding: utf-8 -*-
"""inkcoder.py

InkCoder = a deterministic, training-friendly proxy encoder for handwriting evidence.

Key ideas:
- Build two evidence proxies from the image: intensity-inkness and edge-inkness.
- Enforce local consistency (spatial density gating) to suppress isolated noise.
- Fuse evidence into a per-pixel scalar D in [0,1].
- Produce time-gates gate[t] = sigmoid(alpha[t]*(D - theta[t])) for Temporal Entry.

"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def _rgb_to_gray(x: torch.Tensor) -> torch.Tensor:
    """Convert [B,3,H,W] or [B,1,H,W] in [0,1] to [B,1,H,W]."""
    if x.dim() != 4:
        raise ValueError(f"Expected 4D tensor [B,C,H,W], got {tuple(x.shape)}")
    if x.size(1) == 1:
        return x
    if x.size(1) != 3:
        raise ValueError(f"Expected C=1 or 3, got C={x.size(1)}")
    r, g, b = x[:, 0:1], x[:, 1:2], x[:, 2:3]
    # ITU-R BT.601 luma
    return 0.299 * r + 0.587 * g + 0.114 * b


def _same_pad_2d(x: torch.Tensor, k: int, mode: str = "reflect") -> torch.Tensor:
    """Same padding for odd kernel sizes."""
    if k <= 1:
        return x
    if k % 2 == 0:
        raise ValueError("Kernel size must be odd for 'same' padding")
    p = k // 2
    return F.pad(x, (p, p, p, p), mode=mode)


def _box_blur(x: torch.Tensor, k: int) -> torch.Tensor:
    """Box blur via avg_pool2d with reflect padding."""
    if k <= 1:
        return x
    x = _same_pad_2d(x, k, mode="reflect")
    return F.avg_pool2d(x, kernel_size=k, stride=1)


def _sobel_mag(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """Sobel magnitude for [B,1,H,W] in float32/float16."""
    if x.size(1) != 1:
        raise ValueError("_sobel_mag expects [B,1,H,W]")

    # Kernels as buffers on the same device/dtype
    kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=x.device, dtype=x.dtype)
    ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=x.device, dtype=x.dtype)
    kx = kx.view(1, 1, 3, 3)
    ky = ky.view(1, 1, 3, 3)

    x_pad = F.pad(x, (1, 1, 1, 1), mode="reflect")
    gx = F.conv2d(x_pad, kx)
    gy = F.conv2d(x_pad, ky)
    mag = torch.sqrt(gx * gx + gy * gy + eps)
    return mag


def _quantile_per_batch(x: torch.Tensor, q: float, eps: float = 1e-6) -> torch.Tensor:
    """Compute per-batch quantile for [B,1,H,W] returning [B,1,1,1].

    Uses torch.quantile (available in modern PyTorch). Works in fp32 for stability.
    """
    if x.dim() != 4 or x.size(1) != 1:
        raise ValueError("_quantile_per_batch expects [B,1,H,W]")
    x32 = x.float()
    B = x32.size(0)
    flat = x32.view(B, -1)
    qv = torch.quantile(flat, q, dim=1).view(B, 1, 1, 1)
    return qv.to(dtype=x.dtype) + eps


@dataclass
class InkCoderAux:
    # Common proxies
    g0: Optional[torch.Tensor] = None
    g_int: Optional[torch.Tensor] = None
    I_int: Optional[torch.Tensor] = None
    g_edge: Optional[torch.Tensor] = None
    edge_mag: Optional[torch.Tensor] = None
    I_edge: Optional[torch.Tensor] = None
    dens_edge: Optional[torch.Tensor] = None
    keep_edge: Optional[torch.Tensor] = None

    # Fused evidence
    D: Optional[torch.Tensor] = None
    dens_D: Optional[torch.Tensor] = None
    keep_D: Optional[torch.Tensor] = None

    # Temporal
    theta_t: Optional[torch.Tensor] = None  # [T]
    alpha_t: Optional[torch.Tensor] = None  # [T]
    w_edge_t: Optional[torch.Tensor] = None  # [T]
    gate_t: Optional[torch.Tensor] = None  # [T,B,1,H,W]
    delta_gate_t: Optional[torch.Tensor] = None  # [T,B,1,H,W]


class InkCoder(nn.Module):
    """Image-to-evidence proxies + local consistency.

    This is used by TemporalCoderInk but can also be used standalone.
    """

    def __init__(
        self,
        int_blur_ks: int = 5,
        edge_blur_ks: int = 3,
        q_low: float = 0.02,
        q_high: float = 0.98,
        q_edge: float = 0.95,
        # Edge consistency (density gate)
        edge_consistency: bool = True,
        edge_cons_ks: int = 5,
        edge_cons_kappa: float = 10.0,
        edge_cons_tau: float = 0.10,
        # Intensity-guided edge gating (new, internal; no API change)
        edge_int_gate: bool = True,
        edge_int_tau: float = 0.12,
        edge_int_kappa: float = 12.0,
        # D speckle suppression
        D_speckle_suppress: bool = True,
        D_cons_ks: int = 7,
        D_cons_kappa: float = 12.0,
        D_cons_tau: float = 0.12,
        D_cons_power: float = 1.2,
        # Multiscale edge (optional)
        use_multiscale_edge: bool = True,
        edge_ms_down: int = 2,
        eps: float = 1e-6,
    ):
        super().__init__()

        self.int_blur_ks = int_blur_ks
        self.edge_blur_ks = edge_blur_ks
        self.q_low = q_low
        self.q_high = q_high
        self.q_edge = q_edge

        self.edge_consistency = edge_consistency
        self.edge_cons_ks = edge_cons_ks
        self.edge_cons_kappa = edge_cons_kappa
        self.edge_cons_tau = edge_cons_tau

        self.edge_int_gate = edge_int_gate
        self.edge_int_tau = edge_int_tau
        self.edge_int_kappa = edge_int_kappa

        self.D_speckle_suppress = D_speckle_suppress
        self.D_cons_ks = D_cons_ks
        self.D_cons_kappa = D_cons_kappa
        self.D_cons_tau = D_cons_tau
        self.D_cons_power = D_cons_power

        self.use_multiscale_edge = use_multiscale_edge
        self.edge_ms_down = edge_ms_down

        self.eps = eps

    # ---- building blocks (kept as methods so demos can call them) ----

    def _to_gray(self, x: torch.Tensor) -> torch.Tensor:
        return _rgb_to_gray(x)

    def _box_blur(self, x: torch.Tensor, which: str = "int") -> torch.Tensor:
        if which == "int":
            return _box_blur(x, self.int_blur_ks)
        if which == "edge":
            return _box_blur(x, self.edge_blur_ks)
        raise ValueError(f"Unknown blur kind: {which}")

    def _edge_mag(self, g: torch.Tensor) -> torch.Tensor:
        return _sobel_mag(g, eps=self.eps)

    # ---- proxies ----

    @torch.no_grad()
    def extract_proxies(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Convenience: return proxy maps (no grad) for visualization."""
        _, aux = self.forward(x, return_aux=True)
        return {k: v for k, v in aux.items() if isinstance(v, torch.Tensor)}

    def forward(self, x: torch.Tensor, return_aux: bool = False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Return fused evidence D in [0,1]."""
        aux: Dict[str, torch.Tensor] = {}

        # Grayscale (float) in [0,1]
        g0 = self._to_gray(x)
        g0 = g0.clamp(0.0, 1.0)

        # Slight blur for robust quantiles
        g_int = self._box_blur(g0, which="int")
        g_edge = self._box_blur(g0, which="edge")

        # Intensity inkness: high where ink is dark (low gray)
        q_lo = _quantile_per_batch(g_int, self.q_low, eps=self.eps)
        q_hi = _quantile_per_batch(g_int, self.q_high, eps=self.eps)
        # normalize gray to [0,1] then invert
        g_norm = (g_int - q_lo) / (q_hi - q_lo + self.eps)
        g_norm = g_norm.clamp(0.0, 1.0)
        I_int = (1.0 - g_norm).clamp(0.0, 1.0)

        # Edge inkness
        edge_mag = self._edge_mag(g_edge)

        # Optional multiscale edge (captures faint strokes)
        if self.use_multiscale_edge and self.edge_ms_down and self.edge_ms_down > 1:
            ds = self.edge_ms_down
            g_small = F.avg_pool2d(g_edge, kernel_size=ds, stride=ds)
            mag_small = self._edge_mag(g_small)
            mag_small = F.interpolate(mag_small, size=edge_mag.shape[-2:], mode="bilinear", align_corners=False)
            edge_mag = torch.max(edge_mag, mag_small)

        q_e = _quantile_per_batch(edge_mag, self.q_edge, eps=self.eps)
        I_edge = (edge_mag / (q_e + self.eps)).clamp(0.0, 1.0)

        # Intensity-guided edge gate: suppress edges far from ink-like regions
        if self.edge_int_gate:
            keep_int = torch.sigmoid(self.edge_int_kappa * (I_int - self.edge_int_tau))
            I_edge = I_edge * keep_int

        # Edge consistency via local density
        if self.edge_consistency and self.edge_cons_ks > 1:
            dens_edge = _box_blur(I_edge, self.edge_cons_ks)
            keep_edge = torch.sigmoid(self.edge_cons_kappa * (dens_edge - self.edge_cons_tau))
            I_edge = I_edge * keep_edge
        else:
            dens_edge = torch.zeros_like(I_edge)
            keep_edge = torch.ones_like(I_edge)

        # Fused evidence (raw, before speckle suppression)
        # Here we keep scale in [0,1] by construction.
        D = (I_int + I_edge).clamp(0.0, 1.0)

        # Speckle suppression on D via density gate
        if self.D_speckle_suppress and self.D_cons_ks > 1:
            dens_D = _box_blur(D, self.D_cons_ks)
            keep_D = torch.sigmoid(self.D_cons_kappa * (dens_D - self.D_cons_tau))
            keep_D = keep_D.pow(self.D_cons_power)
            D = D * keep_D
        else:
            dens_D = torch.zeros_like(D)
            keep_D = torch.ones_like(D)

        if return_aux:
            aux.update(
                {
                    "g0": g0,
                    "g_int": g_int,
                    "I_int": I_int,
                    "g_edge": g_edge,
                    "edge_mag": edge_mag,
                    "I_edge": I_edge,
                    "dens_edge": dens_edge,
                    "keep_edge": keep_edge,
                    "D": D,
                    "dens_D": dens_D,
                    "keep_D": keep_D,
                }
            )

        return D, aux


class TemporalCoderInk(nn.Module):
    """Temporal Entry module that produces T gate maps from a single image.

    Interface note:
    - forward(x) returns gate_t: [T,B,1,H,W]
    - forward(x, return_aux=True) returns (gate_t, aux_dict)

    Class name must stay the same for compatibility.
    """

    def __init__(
        self,
        T: int = 4,
        int_blur_ks: int = 5,
        edge_blur_ks: int = 3,
        q_low: float = 0.02,
        q_high: float = 0.98,
        q_edge: float = 0.95,
        # gate sharpness schedule
        base_alpha: float = 6.0,
        alpha_decay: float = 0.30,
        # evidence fusion schedule
        use_time_varying_fusion: bool = True,
        fuse_bias: float = -0.2,
        fuse_slope: float = 3.0,
        # IMPORTANT: time-varying fusion requires the intensity/edge proxies.
        # If True, we will internally compute proxies even when return_aux=False,
        # so the feature is truly enabled during training/inference.
        force_aux_for_fusion: bool = True,
        # theta schedule
        theta_min: float = 0.0,
        theta_max: float = 1.0,
        theta_gamma: float = 1.0,
        # speckle suppression on D
        D_speckle_suppress: bool = True,
        D_cons_ks: int = 7,
        D_cons_kappa: float = 12.0,
        D_cons_tau: float = 0.12,
        D_cons_power: float = 1.2,
        # edge consistency
        edge_consistency: bool = True,
        edge_cons_ks: int = 5,
        edge_cons_kappa: float = 10.0,
        edge_cons_tau: float = 0.10,
        # intensity-guided edge gate
        edge_int_gate: bool = True,
        edge_int_tau: float = 0.12,
        edge_int_kappa: float = 12.0,

        # multiscale edge
        use_multiscale_edge: bool = True,
        edge_ms_down: int = 2,
        eps: float = 1e-6,
    ):
        super().__init__()

        self.T = int(T)
        self.base_alpha = float(base_alpha)
        self.alpha_decay = float(alpha_decay)
        self.use_time_varying_fusion = bool(use_time_varying_fusion)
        self.fuse_bias = float(fuse_bias)
        self.fuse_slope = float(fuse_slope)
        self.force_aux_for_fusion = bool(force_aux_for_fusion)

        self.theta_min = float(theta_min)
        self.theta_max = float(theta_max)
        self.theta_gamma = float(theta_gamma)

        # Learnable scalars to adapt alpha schedule during training
        self.alpha_scale = nn.Parameter(torch.tensor(1.0))
        self.alpha_bias = nn.Parameter(torch.tensor(0.0))

        # Underlying proxy coder
        self.ink = InkCoder(
            int_blur_ks=int_blur_ks,
            edge_blur_ks=edge_blur_ks,
            q_low=q_low,
            q_high=q_high,
            q_edge=q_edge,
            edge_consistency=edge_consistency,
            edge_cons_ks=edge_cons_ks,
            edge_cons_kappa=edge_cons_kappa,
            edge_cons_tau=edge_cons_tau,
            edge_int_gate=edge_int_gate,
            edge_int_tau=edge_int_tau,
            edge_int_kappa=edge_int_kappa,
            D_speckle_suppress=D_speckle_suppress,
            D_cons_ks=D_cons_ks,
            D_cons_kappa=D_cons_kappa,
            D_cons_tau=D_cons_tau,
            D_cons_power=D_cons_power,
            use_multiscale_edge=use_multiscale_edge,
            edge_ms_down=edge_ms_down,
            eps=eps,
        )

    # ---- expose a few helpers for old demos (keeps compatibility) ----

    def _to_gray(self, x: torch.Tensor) -> torch.Tensor:
        return self.ink._to_gray(x)

    def _box_blur(self, x: torch.Tensor, which: str = "int") -> torch.Tensor:
        return self.ink._box_blur(x, which=which)

    def _edge_mag(self, g: torch.Tensor) -> torch.Tensor:
        return self.ink._edge_mag(g)

    # ---- schedules ----

    def _theta_schedule(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
        if self.T <= 1:
            t = torch.zeros(1, device=device, dtype=dtype)
        else:
            t = torch.linspace(0, 1, self.T, device=device, dtype=dtype)
        # Convex schedule: starts low (admit easy evidence) then rises.
        t = t.pow(self.theta_gamma)
        theta = self.theta_min + (self.theta_max - self.theta_min) * t
        return theta

    def _alpha_schedule(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
        if self.T <= 1:
            a = torch.tensor([self.base_alpha], device=device, dtype=dtype)
        else:
            # alpha decays over time to avoid overly hard gates late.
            t = torch.linspace(0, 1, self.T, device=device, dtype=dtype)
            a = self.base_alpha * (1.0 - self.alpha_decay * t)
        a = a * self.alpha_scale.to(device=device, dtype=dtype) + self.alpha_bias.to(device=device, dtype=dtype)
        return a

    def _w_edge_schedule(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
        if not self.use_time_varying_fusion:
            return torch.full((self.T,), 0.5, device=device, dtype=dtype)
        if self.T <= 1:
            t = torch.zeros(1, device=device, dtype=dtype)
        else:
            t = torch.linspace(0, 1, self.T, device=device, dtype=dtype)
        # More edge weight later (when blur effects accumulate).
        w = torch.sigmoid(self.fuse_slope * (t + self.fuse_bias))
        return w

    def forward(self, x: torch.Tensor, return_aux: bool = False):
        """Compute temporal gates.

        Args:
            x: [B,C,H,W] in [0,1]
            return_aux: if True, returns (gate_t, aux)

        Returns:
            gate_t: [T,B,1,H,W]
            aux (optional): dict with proxy maps and schedules
        """
        # Build proxies. If time-varying fusion is enabled, we *must* obtain
        # I_int/I_edge (and related maps) even when caller doesn't request aux.
        need_aux = bool(return_aux) or (self.use_time_varying_fusion and self.force_aux_for_fusion)
        D, aux0 = self.ink(x, return_aux=need_aux)

        device, dtype = D.device, D.dtype
        theta_t = self._theta_schedule(device, dtype)
        alpha_t = self._alpha_schedule(device, dtype)
        w_edge_t = self._w_edge_schedule(device, dtype)

        # Optional time-varying fusion: re-fuse intensity/edge at each t.
        # NOTE: if force_aux_for_fusion=True, this runs during training/inference
        # even when return_aux=False.
        if self.use_time_varying_fusion and ("I_int" in aux0) and ("I_edge" in aux0):
            I_int = aux0["I_int"]
            I_edge = aux0["I_edge"]
            D_t = []
            for t in range(self.T):
                w = w_edge_t[t]
                Dt = ((1.0 - w) * I_int + w * I_edge).clamp(0.0, 1.0)
                # Apply the same D speckle suppression as InkCoder (cheap and stable)
                if self.ink.D_speckle_suppress and self.ink.D_cons_ks > 1:
                    dens_D = _box_blur(Dt, self.ink.D_cons_ks)
                    keep_D = torch.sigmoid(self.ink.D_cons_kappa * (dens_D - self.ink.D_cons_tau))
                    keep_D = keep_D.pow(self.ink.D_cons_power)
                    Dt = Dt * keep_D
                D_t.append(Dt)
            D_stack = torch.stack(D_t, dim=0)  # [T,B,1,H,W]
        else:
            D_stack = D.unsqueeze(0).expand(self.T, *D.shape).contiguous()

        # Gates
        gate_list = []
        for t in range(self.T):
            a = alpha_t[t]
            th = theta_t[t]
            gate = torch.sigmoid(a * (D_stack[t] - th))
            gate_list.append(gate)
        gate_t = torch.stack(gate_list, dim=0)

        if not return_aux:
            return gate_t

        # delta gate (temporal novelty)
        delta = torch.zeros_like(gate_t)
        delta[0] = gate_t[0]
        if self.T > 1:
            delta[1:] = gate_t[1:] - gate_t[:-1]

        aux: Dict[str, torch.Tensor] = {}
        # Only expose proxies if caller requested aux.
        aux.update(aux0)
        aux.update(
            {
                "theta_t": theta_t.detach().float(),
                "alpha_t": alpha_t.detach().float(),
                "w_edge_t": w_edge_t.detach().float(),
                "D_t": D_stack.detach(),
                "gate_t": gate_t.detach(),
                "delta_gate_t": delta.detach(),
            }
        )
        return gate_t, aux
