from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, Optional

import cv2
import numpy as np


@dataclass
class CompositeDomain:
    mask: np.ndarray
    masked_edges: np.ndarray
    boundary: np.ndarray
    stack: np.ndarray


def _to_gray(image: np.ndarray) -> np.ndarray:
    if image.ndim == 2:
        return image
    if image.shape[-1] == 1:
        return image[..., 0]
    return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)


def _normalize_mask(mask: np.ndarray) -> np.ndarray:
    if mask.dtype != np.float32:
        mask = mask.astype(np.float32)
    if mask.max() > 1.0:
        mask = mask / 255.0
    return (mask > 0.5).astype(np.float32)


class VesselCompositeDomainBuilder:
    def __init__(
        self,
        edge_low: int = 50,
        edge_high: int = 150,
        boundary_kernel: int = 3,
    ) -> None:
        self.edge_low = edge_low
        self.edge_high = edge_high
        self.boundary_kernel = boundary_kernel

    def compute_edges(self, source: np.ndarray) -> np.ndarray:
        gray = _to_gray(source)
        gray_u8 = (gray * 255.0).astype(np.uint8) if gray.max() <= 1.0 else gray.astype(np.uint8)
        edges = cv2.Canny(gray_u8, self.edge_low, self.edge_high)
        return (edges > 0).astype(np.float32)

    def compute_boundary(self, mask: np.ndarray) -> np.ndarray:
        kernel = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE, (self.boundary_kernel, self.boundary_kernel)
        )
        mask_u8 = (mask * 255.0).astype(np.uint8)
        boundary = cv2.morphologyEx(mask_u8, cv2.MORPH_GRADIENT, kernel)
        return (boundary > 0).astype(np.float32)

    def build(
        self,
        edited_mask: np.ndarray,
        edge_source: np.ndarray,
    ) -> CompositeDomain:
        mask = _normalize_mask(edited_mask)
        edges = self.compute_edges(edge_source)
        masked_edges = edges * mask
        boundary = self.compute_boundary(mask)
        stack = np.stack([mask, masked_edges, boundary], axis=-1).astype(np.float32)
        return CompositeDomain(
            mask=mask,
            masked_edges=masked_edges,
            boundary=boundary,
            stack=stack,
        )


class GPGProjector:
    def __init__(self, blend_radius: int = 7) -> None:
        self.blend_radius = blend_radius

    def apply(
        self,
        x_prev: np.ndarray,
        x_hat: np.ndarray,
        mask: np.ndarray,
        boundary: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        mask = _normalize_mask(mask)
        if boundary is None:
            alpha = mask
        else:
            kernel = cv2.getStructuringElement(
                cv2.MORPH_ELLIPSE, (self.blend_radius, self.blend_radius)
            )
            boundary_u8 = (boundary * 255.0).astype(np.uint8)
            soft = cv2.dilate(boundary_u8, kernel, iterations=1).astype(np.float32) / 255.0
            alpha = np.clip(mask + soft, 0.0, 1.0)

        while alpha.ndim < x_prev.ndim:
            alpha = alpha[..., None]
        return x_prev * (1.0 - alpha) + x_hat * alpha


BridgeStepFn = Callable[[np.ndarray, CompositeDomain, Dict[str, np.ndarray]], np.ndarray]


def default_bridge_step(
    x_k: np.ndarray,
    composite: CompositeDomain,
    context: Dict[str, np.ndarray],
) -> np.ndarray:
    _ = composite
    _ = context
    return x_k.copy()


class VesselEditPipeline:
    def __init__(
        self,
        bridge_step_fn: Optional[BridgeStepFn] = None,
        composite_builder: Optional[VesselCompositeDomainBuilder] = None,
        projector: Optional[GPGProjector] = None,
    ) -> None:
        self.bridge_step_fn = bridge_step_fn or default_bridge_step
        self.composite_builder = composite_builder or VesselCompositeDomainBuilder()
        self.projector = projector or GPGProjector()

    def run(
        self,
        x0: np.ndarray,
        m0: np.ndarray,
        edited_mask: np.ndarray,
        edge_source: str = "seg",
    ) -> Dict[str, np.ndarray]:
        if edge_source not in {"seg", "image"}:
            raise ValueError("edge_source must be 'seg' or 'image'")

        edge_input = m0 if edge_source == "seg" else x0
        composite = self.composite_builder.build(edited_mask=edited_mask, edge_source=edge_input)

        context = {
            "x0": x0,
            "m0": m0,
            "edited_mask": edited_mask,
        }
        x_k = x0.copy()
        x_hat = self.bridge_step_fn(x_k, composite, context)
        x_k1 = self.projector.apply(x_k, x_hat, composite.mask, composite.boundary)

        return {
            "output": x_k1,
            "composite_stack": composite.stack,
            "mask": composite.mask,
            "masked_edges": composite.masked_edges,
            "boundary": composite.boundary,
            "x_hat": x_hat,
        }

