# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for OT-Bridge. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

from __future__ import annotations

import argparse
from pathlib import Path
from typing import Tuple

import cv2
import numpy as np

from ot_bridge.pipeline import VesselEditPipeline


def _read_image(path: Path, grayscale: bool = False) -> np.ndarray:
    flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
    img = cv2.imread(str(path), flag)
    if img is None:
        raise FileNotFoundError(f"Cannot read image: {path}")
    if not grayscale:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def _to_float01(img: np.ndarray) -> np.ndarray:
    if img.dtype != np.float32:
        img = img.astype(np.float32)
    if img.max() > 1.0:
        img = img / 255.0
    return img


def _save_image(path: Path, img: np.ndarray) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    if img.ndim == 3 and img.shape[-1] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img_u8 = np.clip(img * 255.0, 0, 255).astype(np.uint8)
    cv2.imwrite(str(path), img_u8)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="OT-Bridge vessel edit pipeline demo")
    p.add_argument("--input-cag", type=Path, required=True, help="Input CAG image x0")
    p.add_argument("--input-seg", type=Path, required=True, help="Input segmentation mask m0")
    p.add_argument("--edited-mask", type=Path, required=True, help="Edited mask m")
    p.add_argument("--out-dir", type=Path, required=True, help="Output directory")
    p.add_argument("--edge-source", choices=["seg", "image"], default="seg")
    p.add_argument("--save-intermediates", action="store_true", help="Save composite domain images")
    return p.parse_args()


def main() -> None:
    args = parse_args()

    x0 = _to_float01(_read_image(args.input_cag, grayscale=False))
    m0 = _to_float01(_read_image(args.input_seg, grayscale=True))
    m = _to_float01(_read_image(args.edited_mask, grayscale=True))

    pipeline = VesselEditPipeline()
    outputs = pipeline.run(x0=x0, m0=m0, edited_mask=m, edge_source=args.edge_source)

    out_dir = args.out_dir
    _save_image(out_dir / "output.png", outputs["output"])

    if args.save_intermediates:
        _save_image(out_dir / "mask.png", outputs["mask"])
        _save_image(out_dir / "masked_edges.png", outputs["masked_edges"])
        _save_image(out_dir / "boundary.png", outputs["boundary"])
        np.save(out_dir / "composite_stack.npy", outputs["composite_stack"])


if __name__ == "__main__":
    main()

