"""
Ultralytics-backed API (detect/segment) aligned with torchapi naming.
"""
__author__ = 'XYZ'

import os
import time
from typing import Any, Dict, List, Optional

import cv2
import numpy as np

try:
  import torch
except ImportError:
  raise ImportError("torch is required")

try:
  from ultralytics import YOLO
except ImportError:
  raise ImportError("ultralytics is required: pip install ultralytics")


from .annonutils import (
  save_yolo_txt,
  save_via_json,
  bbox_xyxy_to_via_region,
  mask_to_via_region,
  masked_rgb,
  jaccard_target_dilate,
  visualize,
  extract_segmented_rgb,
  crop_bbox,
)


def _ensure_dir(p: str) -> None:
  os.makedirs(p, exist_ok=True)


def select_device(device: str = "cpu"):
  if device.lower() == "cuda":
    if torch.cuda.is_available():
      return "cuda"
    print("[warn] CUDA requested but not available, falling back to CPU")
  return "cpu"


def load_model(weights_path: str = "yolov8s.pt",
               device: str = "cpu",
               conf_thres: Optional[float] = None,
               iou_thres: Optional[float] = None):
  """
  Load Ultralytics YOLO model with explicit args.

  Returns:
    - model
    - (optionally) (device, model_info) if return_info=True
  """
  start = time.time()
  device = select_device(device)
  print(f"[load_model] loading: {weights_path} on {device}")

  model = YOLO(weights_path)
  model.overrides = model.overrides or {}
  if conf_thres is not None:
    model.overrides['conf'] = float(conf_thres)
  if iou_thres is not None:
    model.overrides['iou'] = float(iou_thres)

  print(f"[load_model] ready in {time.time() - start:.2f}s")
  model_info = {
    "names": getattr(model, "names", {}),
    "nc": getattr(model, "nc", len(getattr(model, "names", {}))),
    "type": getattr(model, "task", "detect"),
  }
  return model, device, model_info


def _to_numpy(x):
  return x.detach().cpu().numpy() if hasattr(x, 'detach') else np.asarray(x)


def predict(model, img: np.ndarray, device="cpu", imgsz=640,
            conf=0.25, iou=0.45, classes_filter=None):
  """
  Run YOLO prediction on a numpy image array.
  Returns a list of dicts with boxes, classes, masks.
  """
  if img is None:
    return []

  results = model.predict(
    img, device=device, imgsz=imgsz,
    conf=conf, iou=iou, classes=classes_filter, verbose=False
  )
  out = []
  for r in results:
    out.append({
      "boxes": r.boxes.xyxy.cpu().numpy() if r.boxes else np.zeros((0,4)),
      "classes": r.boxes.cls.cpu().numpy() if r.boxes else np.zeros((0,)),
      "masks": r.masks.data.cpu().numpy() if r.masks is not None else None,
    })
  return out


def unload_model(model: Any) -> None:
  del model
  if torch.cuda.is_available():
    torch.cuda.empty_cache()


def _write_annotations(img_path, out_dir, boxes, classes, masks, w, h):
  os.makedirs(out_dir, exist_ok=True)
  stem = os.path.splitext(os.path.basename(img_path))[0]
  ann_path = os.path.join(out_dir, f"{stem}.txt")

  lines = []
  if boxes is not None and len(boxes) > 0:
    for i, box in enumerate(boxes):
      cls_id = int(classes[i]) if classes is not None else 0
      x1, y1, x2, y2 = box
      cx = (x1 + x2) / 2.0 / w
      cy = (y1 + y2) / 2.0 / h
      bw = (x2 - x1) / float(w)
      bh = (y2 - y1) / float(h)
      line = f"{cls_id} {cx:.6f} {cy:.6f} {bw:.6f} {bh:.6f}"

      # only if masks exist and align with boxes
      if masks is not None and len(masks) > i:
        m = masks[i]
        # optional: flatten to polygon or save mask ID
        line += f" mask:{np.count_nonzero(m)}"

      lines.append(line)

  with open(ann_path, "w") as f:
    f.write("\n".join(lines))


def _largest_mask(masks: Optional[List[np.ndarray]]) -> Optional[np.ndarray]:
  if not masks:
    return None
  areas = [int(m.sum()) for m in masks]
  if not areas or max(areas) == 0:
    return None
  return masks[int(np.argmax(areas))]


def _mask_from_bbox(b: np.ndarray, h: int, w: int) -> np.ndarray:
  m = np.zeros((h, w), dtype=np.uint8)
  x1, y1, x2, y2 = [int(v) for v in b.tolist()]
  x1, y1 = max(0, x1), max(0, y1)
  x2, y2 = min(w - 1, x2), min(h - 1, y2)
  if x2 > x1 and y2 > y1:
    m[y1:y2, x1:x2] = 1
  return m


def _select_person_mask(
  masks: Optional[List[np.ndarray]],
  boxes: Optional[np.ndarray],
  classes: Optional[np.ndarray],
  prefer_class: Optional[int],
  h: int,
  w: int
) -> Optional[np.ndarray]:
  if masks:
    if prefer_class is not None and classes is not None and len(classes) > 0:
      idxs = np.where(classes.astype(int) == int(prefer_class))[0]
      if len(idxs) > 0:
        areas = [int(masks[i].sum()) for i in idxs]
        if areas and max(areas) > 0:
          return masks[idxs[int(np.argmax(areas))]]
    lm = _largest_mask(masks)
    if lm is not None:
      return lm
  if boxes is not None and boxes.shape[0] > 0:
    areas = ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])).astype(float)
    k = int(np.argmax(areas))
    return _mask_from_bbox(boxes[k], h, w)
  return None


def _expand_bbox_to_ratio(b: np.ndarray, r: float, h: int, w: int) -> np.ndarray:
  """
  Expand bbox so that A'=(1+r)A, clipped to image.
  """
  x1, y1, x2, y2 = [float(v) for v in b.tolist()]
  cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
  bw, bh = max(1.0, x2 - x1), max(1.0, y2 - y1)
  scale = float(np.sqrt(max(1e-6, 1.0 + r)))
  nw, nh = bw * scale, bh * scale
  nx1, ny1 = cx - nw / 2.0, cy - nh / 2.0
  nx2, ny2 = cx + nw / 2.0, cy + nh / 2.0
  nx1, ny1 = max(0.0, nx1), max(0.0, ny1)
  nx2, ny2 = min(float(w - 1), nx2), min(float(h - 1), ny2)
  return np.array([nx1, ny1, nx2, ny2], dtype=np.float32)


def _dilate_mask_to_ratio(fg: np.ndarray, r: float) -> np.ndarray:
  """
  Grow mask so that (Area(R) - Area(F)) / Area(F) ≈ r.
  jaccard_target_dilate uses X = 1/(1+r).
  """
  fg = (fg > 0).astype(np.uint8)
  area_f = int(fg.sum())
  if area_f == 0 or r <= 0:
    return fg
  X = float(1.0 / (1.0 + r))
  return jaccard_target_dilate(fg, X)


def _crop_by_bbox(img: np.ndarray, b: np.ndarray) -> np.ndarray:
  x1, y1, x2, y2 = [int(round(v)) for v in b.tolist()]
  x1, y1 = max(0, x1), max(0, y1)
  x2, y2 = min(img.shape[1] - 1, x2), min(img.shape[0] - 1, y2)
  if x2 <= x1 or y2 <= y1:
    return img
  return img[y1:y2, x1:x2]


def _materialize_variant(
    mode: str,
    img_bgr: np.ndarray,
    fg_mask: Optional[np.ndarray],
    tight_bbox: Optional[np.ndarray],
    bg_ratio: float,
    crop_to_mask: bool,
) -> np.ndarray:
  out = img_bgr
  if mode == 'B0':
    out = img_bgr
  elif mode == 'MASK':
    out = img_bgr if fg_mask is None else masked_rgb(img_bgr, fg_mask, crop_to_mask)
  elif mode == 'MASK_C':
    if fg_mask is None:
      out = img_bgr
    else:
      out_mask = _dilate_mask_to_ratio(fg_mask, bg_ratio)
      out = masked_rgb(img_bgr, out_mask, crop_to_mask)
  elif mode == 'BBOX':
    out = img_bgr if tight_bbox is None else _crop_by_bbox(img_bgr, tight_bbox)
  elif mode == 'BBOX_C':
    if tight_bbox is None:
      out = img_bgr
    else:
      h, w = img_bgr.shape[:2]
      eb = _expand_bbox_to_ratio(tight_bbox, bg_ratio, h, w)
      out = _crop_by_bbox(img_bgr, eb)
  return out


def build_basepaths(to_path: str, dataset_name: str) -> Dict[str, str]:
  outputs = {
    "annotation": os.path.join(to_path, f"{dataset_name}-annotation"),
    "bbox":       os.path.join(to_path, f"{dataset_name}-bbox"),
    "mask":       os.path.join(to_path, f"{dataset_name}-mask"),
    "seg":        os.path.join(to_path, f"{dataset_name}-seg"),
    "viz":        os.path.join(to_path, f"{dataset_name}-viz"),
  }
  [os.makedirs(p, exist_ok=True) for p in outputs.values()]
  return outputs




def process_one(
  src_path: str,
  rel_to_dataset: str,        # e.g. "imgs/train/c0/img_123.jpg" (or "Day/Cam1/c0/..")
  replica_roots: Dict[str, str],
  model,
  device: str,
  imgsz: Optional[int],
  conf_thres: float,
  iou_thres: float,
  prefer_class: Optional[int],
  only_class: bool,
  mode: str,
  bg_ratio: float,
  viz: bool,
  save_mask_png: bool,
  crop_to_mask: bool,
) -> Dict[str, Any]:

  result = {"missed": True, "bbox_area": None, "seg_ratio": None, "num_boxes": 0, "num_masks": 0}

  img_bgr = cv2.imread(src_path)
  if img_bgr is None:
    print(f"[warn] unreadable: {src_path}")
    return result

  h, w = img_bgr.shape[:2]
  classes_filter = [int(prefer_class)] if (only_class and prefer_class is not None) else None
  preds = predict(
    model, img_bgr, device=device, imgsz=imgsz,
    conf=conf_thres, iou=iou_thres, classes_filter=classes_filter
  )

  # mirror the directory structure of the source under each replica root
  rel_dir = os.path.dirname(rel_to_dataset)  # no file name
  basepaths = {k: os.path.join(root_dir, rel_dir) for k, root_dir in replica_roots.items()}
  [os.makedirs(p, exist_ok=True) for p in basepaths.values()]

  if preds:
    p = preds[0]
    boxes   = p.get('boxes',   np.zeros((0, 4), dtype=np.float32))
    classes = p.get('classes', np.zeros((0,),    dtype=np.float32))
    masks   = p.get('masks',   None)

    _write_annotations(src_path, basepaths['annotation'], boxes, classes, masks, w, h)

    tight_bbox = None
    if boxes is not None and boxes.shape[0] > 0:
      areas = ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])).astype(float)
      tight_bbox = boxes[int(np.argmax(areas))]

    fg_mask = _select_person_mask(masks, boxes, classes, prefer_class, h, w)
    _ = _materialize_variant(mode, img_bgr, fg_mask, tight_bbox, bg_ratio, crop_to_mask=crop_to_mask)

    stem = os.path.splitext(os.path.basename(src_path))[0]
    if tight_bbox is not None:
      cv2.imwrite(os.path.join(basepaths['bbox'], f"{stem}.jpg"), crop_bbox(img_bgr, tight_bbox))
    if save_mask_png and fg_mask is not None:
      cv2.imwrite(os.path.join(basepaths['mask'], f"{stem}.png"), (fg_mask * 255).astype(np.uint8))
    cv2.imwrite(os.path.join(basepaths['seg'], f"{stem}.jpg"), extract_segmented_rgb(img_bgr, masks))
    cv2.imwrite(os.path.join(basepaths['viz'], f"{stem}.jpg"), visualize(img_bgr, boxes_xyxy=tight_bbox, masks=masks))

    if boxes is not None and boxes.shape[0] > 0:
      bbox_areas = ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])).astype(float)
      result["bbox_area"] = float(max(bbox_areas))
    if masks:
      fg_area = sum(int(m.sum()) for m in masks)
      result["seg_ratio"] = fg_area / float(h * w)
      result["num_masks"] = len(masks)

    result["num_boxes"] = int(boxes.shape[0])
    result["missed"] = False

  return result
