"""
Annotation utilities: format conversion, mask/bbox ops, Jaccard-controlled dilation, visualization.
"""
__author__ = 'XYZ'

import pdb
import json
import os
from typing import Any, Dict, List, Tuple, Optional

import cv2
import numpy as np


# ----------------------------
# Formats: YOLO TXT  <->  VIA JSON
# ----------------------------

def save_yolo_txt(txt_path: str, boxes_xyxy: np.ndarray, cls: np.ndarray, img_w: int, img_h: int) -> None:
  """
  Save YOLO txt in (class cx cy w h) normalized format.
  """
  os.makedirs(os.path.dirname(txt_path), exist_ok=True)
  lines = []
  for b, c in zip(boxes_xyxy, cls):
    x1, y1, x2, y2 = b.tolist()
    w = x2 - x1
    h = y2 - y1
    cx = x1 + w / 2.0
    cy = y1 + h / 2.0
    nx = cx / img_w
    ny = cy / img_h
    nw = w / img_w
    nh = h / img_h
    lines.append(f"{int(c)} {nx:.6f} {ny:.6f} {nw:.6f} {nh:.6f}")
  with open(txt_path, 'w') as f:
    f.write("\n".join(lines))


def yolo_txt_to_xyxy(label_line: str, img_w: int, img_h: int) -> Tuple[int, float, float, float, float]:
  """
  Convert one YOLO line -> (cls, x1, y1, x2, y2)
  """
  parts = label_line.strip().split()
  c = int(parts[0])
  cx, cy, w, h = map(float, parts[1:])
  cx *= img_w
  cy *= img_h
  w *= img_w
  h *= img_h
  x1 = cx - w / 2.0
  y1 = cy - h / 2.0
  x2 = cx + w / 2.0
  y2 = cy + h / 2.0
  return c, x1, y1, x2, y2


def save_via_json(json_path: str,
                  img_filename: str,
                  img_size: int,
                  regions: List[Dict[str, Any]]) -> None:
  """
  VIA JSON single-image schema, regions contain 'shape_attributes' and 'region_attributes'.
  """
  os.makedirs(os.path.dirname(json_path), exist_ok=True)
  key = f"{img_filename}{str(img_size)}"
  data = {
    key: {
      "filename": img_filename,
      "size": img_size,
      "regions": regions,
      "file_attributes": {}
    }
  }
  with open(json_path, 'w') as f:
    json.dump(data, f)


def bbox_xyxy_to_via_region(x1: float, y1: float, x2: float, y2: float, label: str = "person") -> Dict[str, Any]:
  w = float(x2 - x1)
  h = float(y2 - y1)
  return {
    "shape_attributes": {
      "name": "rect",
      "x": float(x1),
      "y": float(y1),
      "width": w,
      "height": h
    },
    "region_attributes": {
      "label": label
    }
  }


def mask_to_via_region(mask: np.ndarray, label: str = "person") -> Dict[str, Any]:
  """
  Convert binary mask -> VIA polygon region via contour.
  """
  contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  if not contours:
    return {}
  ## pick largest contour
  cnt = max(contours, key=cv2.contourArea)
  pts = cnt.reshape(-1, 2)
  all_points_x = [int(p[0]) for p in pts]
  all_points_y = [int(p[1]) for p in pts]
  return {
    "shape_attributes": {
      "name": "polygon",
      "all_points_x": all_points_x,
      "all_points_y": all_points_y
    },
    "region_attributes": {
      "label": label
    }
  }


# ----------------------------
# Mask / BBox utilities
# ----------------------------

def mask_to_bbox(mask: np.ndarray) -> Tuple[int, int, int, int]:
  ys, xs = np.where(mask > 0)
  if len(xs) == 0 or len(ys) == 0:
    return 0, 0, 0, 0
  x1, y1 = int(xs.min()), int(ys.min())
  x2, y2 = int(xs.max()), int(ys.max())
  return x1, y1, x2, y2


def expand_bbox(x1: int, y1: int, x2: int, y2: int, img_w: int, img_h: int, factor: float) -> Tuple[int, int, int, int]:
  """
  Expand bbox by factor (0.2 -> +20% each side, clipped to image).
  """
  w = x2 - x1
  h = y2 - y1
  dx = int(w * factor / 2.0)
  dy = int(h * factor / 2.0)
  nx1 = max(0, x1 - dx)
  ny1 = max(0, y1 - dy)
  nx2 = min(img_w - 1, x2 + dx)
  ny2 = min(img_h - 1, y2 + dy)
  return nx1, ny1, nx2, ny2


def jaccard_target_dilate_cv(fg_mask: np.ndarray, target_jaccard: float, max_iter: int = 64) -> np.ndarray:
  F = (fg_mask > 0).astype(np.uint8)
  if F.sum() == 0:
    return F
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
  R = F.copy()
  X = float(target_jaccard)
  for _ in range(max_iter):
    inter = np.logical_and(F, R).sum()
    union = np.logical_or(F, R).sum()
    j = inter / max(1, union)
    if j >= X:
      break
    R = cv2.dilate(R, kernel, iterations=1)
  return R


def jaccard_target_dilate(fg_mask: np.ndarray, target_jaccard: float) -> np.ndarray:
  """
  Dilate foreground mask to achieve desired Jaccard with dilated region.
  J(F, R) = |F ∩ R| / |F ∪ R| = X.
  We grow a ring around F until J ~ target.
  """
  import scipy.ndimage as ndi
  F = (fg_mask > 0).astype(np.uint8)
  if F.sum() == 0:
    return F
  dist = ndi.distance_transform_edt(1 - F)
  ## Binary search on radius
  low, high = 0.0, float(dist.max())
  X = float(target_jaccard)
  for _ in range(16):
    mid = (low + high) / 2.0
    R = (dist <= mid).astype(np.uint8)
    inter = np.logical_and(F, R).sum()
    union = np.logical_or(F, R).sum()
    j = inter / max(1, union)
    if j < X:
      low = mid
    else:
      high = mid
  R = (dist <= high).astype(np.uint8)
  return R


def masked_rgb(img_bgr: np.ndarray, mask: np.ndarray, crop: bool = False) -> np.ndarray:
  """Applies a binary mask to an RGB image, making the background black."""
  if img_bgr is None or mask is None:
      return img_bgr

  ## Ensure mask is boolean or 0/1 for logical indexing
  mask_bool = mask > 0

  ## Create an output image with black background
  output_img = np.zeros_like(img_bgr)

  ## Apply the mask to each channel
  output_img[mask_bool] = img_bgr[mask_bool]

  if crop:
      ## Get bounding box of the mask
      y_coords, x_coords = np.where(mask_bool)
      if len(y_coords) == 0 or len(x_coords) == 0:
          return output_img # Return black image if mask is empty

      y1, y2 = y_coords.min(), y_coords.max()
      x1, x2 = x_coords.min(), x_coords.max()

      ## Add a small padding for better visual
      pad = 5
      y1, x1 = max(0, y1 - pad), max(0, x1 - pad)
      y2, x2 = min(img_bgr.shape[0], y2 + pad), min(img_bgr.shape[1], x2 + pad)

      return output_img[y1:y2, x1:x2]

  return output_img


# Update the _materialize_variant to call this function correctly
def _crop_with_mask(img_bgr: np.ndarray, mask: np.ndarray) -> np.ndarray:
  """Crops an image to the tight bounding box of the mask and applies the mask."""
  if img_bgr is None or mask is None:
      return img_bgr

  ## Find the bounding box of the mask
  y_coords, x_coords = np.where(mask > 0)
  if len(y_coords) == 0 or len(x_coords) == 0:
      return np.zeros_like(img_bgr) # Return a black image if mask is empty

  y1, y2 = y_coords.min(), y_coords.max()
  x1, x2 = x_coords.min(), x_coords.max()

  ## Crop the image and mask
  cropped_img = img_bgr[y1:y2, x1:x2]
  cropped_mask = mask[y1:y2, x1:x2]

  ## Apply the cropped mask to the cropped image
  masked_cropped_img = np.zeros_like(cropped_img)
  masked_cropped_img[cropped_mask > 0] = cropped_img[cropped_mask > 0]

  return masked_cropped_img


def visualize(
  image: np.ndarray,
  boxes_xyxy: Optional[np.ndarray] = None,
  masks: Optional[List[np.ndarray]] = None,
  color=(0, 255, 0)
) -> np.ndarray:
  out = image.copy()

  ## Draw bounding boxes
  if boxes_xyxy is not None:
    boxes_xyxy = np.array(boxes_xyxy, dtype=np.float32)
    if boxes_xyxy.ndim == 1:  ## single (4,)
      boxes_xyxy = boxes_xyxy[None, :]
    for b in boxes_xyxy:
      x1, y1, x2, y2 = map(int, b.tolist())
      cv2.rectangle(out, (x1, y1), (x2, y2), color, 2)

  ## Overlay masks
  if masks is not None and len(masks) > 0:
    H, W = out.shape[:2]
    overlay = out.copy()
    for m in masks:
      if m.shape[0] != H or m.shape[1] != W:
        m = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST)

      m_bool = m.astype(bool)
      overlay[m_bool] = (
        overlay[m_bool] * 0.5 + np.array(color) * 0.5
      ).astype(np.uint8)
    out = cv2.addWeighted(out, 0.7, overlay, 0.3, 0)

  return out


def resize_masks_to_image(masks, image_shape):
  """
  Resize YOLO masks to match the original image shape.
  Args:
    masks (np.ndarray): shape (N, h, w)
    image_shape (tuple): (H, W, C) of the original image
  Returns:
    np.ndarray: resized masks of shape (N, H, W)
  """
  H, W = image_shape[:2]
  resized = []
  for m in masks:
    resized_mask = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST)
    resized.append(resized_mask)
  return np.array(resized)


def extract_segmented_rgb(image: np.ndarray, masks: List[np.ndarray]) -> np.ndarray:
  if masks is None or len(masks) == 0:
    return np.zeros_like(image)

  ## Fix: resize masks to original image shape
  masks_resized = resize_masks_to_image(masks, image.shape)

  ## Merge all masks into one
  combined_mask = np.zeros(image.shape[:2], dtype=bool)
  # pdb.set_trace()
  for m in masks_resized:
    combined_mask |= m.astype(bool)

  segmented = np.zeros_like(image)
  segmented[combined_mask] = image[combined_mask]
  return segmented


def crop_bbox(image: np.ndarray, boxes_xyxy: Optional[np.ndarray]) -> np.ndarray:
  if boxes_xyxy is None or len(boxes_xyxy) == 0:
    return np.zeros_like(image)

  boxes_xyxy = np.array(boxes_xyxy, dtype=np.int32)
  if boxes_xyxy.ndim == 1:  ## single box (4,)
    boxes_xyxy = boxes_xyxy[None, :]

  ## For now, crop largest box only
  areas = (boxes_xyxy[:, 2] - boxes_xyxy[:, 0]) * (boxes_xyxy[:, 3] - boxes_xyxy[:, 1])
  largest_idx = int(np.argmax(areas))
  x1, y1, x2, y2 = boxes_xyxy[largest_idx]
  x1, y1 = max(0, x1), max(0, y1)
  x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
  return image[y1:y2, x1:x2]
