"""Annotation pipeline (node-based flows) using YOLO API + PlainDataSet."""
__author__ = "XYZ"

import pdb
import os
import sys
import json

import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime

from . import rflow
from .dataset import load_plain_dataset, get_splits
from .core._log_ import logger
from .core.fwo import write_json
from .utils.yoloapi import (
  load_model,
  build_basepaths,
  predict,
  visualize,
  extract_segmented_rgb,
  crop_bbox,
  _select_person_mask,
  _write_annotations,
)

log = logger(__file__)
this = sys.modules[__name__]


def _largest_bbox(boxes):
  if boxes is None or len(boxes) == 0:
    return None
  areas = ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])).astype(float)
  return boxes[int(np.argmax(areas))]

def _mirror_dirs(replica_roots, rel_to_dataset):
  rel_dir = os.path.dirname(rel_to_dataset)
  out = {k: os.path.join(root_dir, rel_dir) for k, root_dir in replica_roots.items()}
  [os.makedirs(p, exist_ok=True) for p in out.values()]
  return out

def _compute_metrics(result, h, w):
  metrics = {}
  bbox_area = result.get("bbox_area")
  seg_ratio = result.get("seg_ratio")
  full_area = float(h * w) if h and w else 0.0

  if bbox_area and full_area > 0:
    metrics["bbox_full_ratio"] = bbox_area / full_area
  if seg_ratio is not None:
    metrics["seg_full_ratio"] = seg_ratio
  if bbox_area and full_area > 0 and seg_ratio is not None:
    metrics["seg_bbox_ratio"] = seg_ratio / (bbox_area / full_area + 1e-8)

  metrics["jaccard_bbox"] = result.get("jaccard_bbox", None)
  metrics["jaccard_seg"] = result.get("jaccard_seg", None)
  return metrics

def generate_split_stats(split, split_stats, split_base, dataset):
  imgs_list, labels, rows = [], [], []
  for res in split_stats:
    if res.get("missed"):
      continue
    img_path = res["src"]
    stem = os.path.splitext(os.path.basename(img_path))[0]
    label_idx = res.get("label_idx", 0)
    label = dataset.index_to_label.get(label_idx, "unknown")
    imgs_list.append([label, stem])
    labels.append([label, label_idx])
    h, w = res.get("shape", (0, 0))
    rows.append({"classname": label, "img": stem, **_compute_metrics(res, h, w)})

  pd.DataFrame(imgs_list, columns=["classname", "img"]).to_csv(
    os.path.join(split_base, "imgs_list.csv"), index=False)
  pd.DataFrame(labels, columns=["classname", "id"]).drop_duplicates().to_csv(
    os.path.join(split_base, "labels.csv"), index=False)
  pd.DataFrame(rows).to_csv(os.path.join(split_base, f"{split}.csv"), index=False)

  unique_labels = sorted(set([r[1] for r in labels]))
  class_map = {idx: dataset.index_to_label.get(idx, "unknown") for idx in unique_labels}

  return {
    "num_imgs": len(imgs_list),
    "num_labels": len(unique_labels),
    "split": split,
    "classes": class_map,
    "imgs_list": imgs_list,
    "labels": labels,
  }


def generate_modelinfo(model_info, **kwargs):
  info = {
    "author": kwargs.get('author'),
    "version": kwargs.get('version'),
    "timestamp": kwargs.get('timestamp'),
    "weights_path": kwargs.get('weights_path'),
    "device": kwargs.get('device'),
    "imgsz": kwargs.get('imgsz'),
    "conf_thres": kwargs.get('conf_thres'),
    "iou_thres": kwargs.get('iou_thres'),
    "only_class": (not kwargs.get('all_classes')),
    "names": model_info.get('names', {}),
    "nc": model_info.get('nc', 0),
    "model_type": model_info.get('type', None),
  }
  return info


def _run_predict(ctx, img_array):
  classes_filter = [int(ctx["prefer_class"])] if (ctx["only_class"] and ctx["prefer_class"] is not None) else None
  preds = predict(
    ctx["model"], img_array, device=ctx["device"], imgsz=ctx["imgsz"],
    conf=ctx["conf_thres"], iou=ctx["iou_thres"], classes_filter=classes_filter
  )
  if not preds:
    return None

  p = preds[0]
  masks = p.get("masks")

  fg_mask = None
  if masks is not None and len(masks) > 0:
    areas = [int(m.sum()) for m in masks]
    if areas and max(areas) > 0:
      fg_mask = masks[int(np.argmax(areas))]

  return {
    "preds": p,
    "boxes": p.get("boxes"),
    "classes": p.get("classes"),
    "masks": masks,
    "tight_bbox": _largest_bbox(p.get("boxes")),
    "fg_mask": fg_mask,
  }


def node__activate_first_roots(ctx):
  ctx["active_roots"] = ctx["first_roots"]
  return ctx

def node__activate_second_roots(ctx):
  if ctx.get("second_roots") is None:
    raise ValueError("second_roots not initialized for this flow")
  ctx["active_roots"] = ctx["second_roots"]
  return ctx


def node__save_annotations(ctx):
  if ctx.get("missed"):
    return ctx
  h, w = ctx["shape"]
  basepaths = _mirror_dirs(ctx["active_roots"], ctx["rel_to_dataset"])
  _write_annotations(ctx["img_path"], basepaths["annotation"], ctx["boxes"], ctx["classes"], ctx["masks"], w, h)
  return ctx

def node__save_bbox(ctx):
  if ctx.get("missed") or ctx.get("tight_bbox") is None:
    return ctx
  img = ctx["img_bgr"]
  tb = ctx["tight_bbox"]
  out_dir = _mirror_dirs(ctx["active_roots"], ctx["rel_to_dataset"])["bbox"]
  stem = os.path.splitext(os.path.basename(ctx["img_path"]))[0]
  cv2.imwrite(os.path.join(out_dir, f"{stem}.jpg"), crop_bbox(img, tb))
  return ctx

def node__save_seg(ctx):
  if ctx.get("missed"):
    return ctx
  out_dirs = _mirror_dirs(ctx["active_roots"], ctx["rel_to_dataset"])
  stem = os.path.splitext(os.path.basename(ctx["img_path"]))[0]

  if ctx.get("save_mask_png") and ctx.get("fg_mask") is not None:
    cv2.imwrite(os.path.join(out_dirs["mask"], f"{stem}.png"), (ctx["fg_mask"] * 255).astype(np.uint8))

  seg_rgb = extract_segmented_rgb(ctx["img_bgr"], ctx["masks"])
  cv2.imwrite(os.path.join(out_dirs["seg"], f"{stem}.jpg"), seg_rgb)
  return ctx

def node__save_viz(ctx):
  if ctx.get("missed"):
    return ctx
  out_dir = _mirror_dirs(ctx["active_roots"], ctx["rel_to_dataset"])["viz"]
  stem = os.path.splitext(os.path.basename(ctx["img_path"]))[0]
  dbg = visualize(ctx["img_bgr"], boxes_xyxy=ctx.get("tight_bbox"), masks=ctx.get("masks"))
  cv2.imwrite(os.path.join(out_dir, f"{stem}.jpg"), dbg)
  return ctx

def node__compute_metrics(ctx):
  if ctx.get("missed"):
    return ctx

  res = ctx.setdefault("result", {})
  boxes = ctx.get("boxes")
  masks = ctx.get("masks")
  h, w = ctx["shape"]

  if boxes is not None and len(boxes) > 0:
    bbox_areas = ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])).astype(float)
    res["bbox_area"] = float(np.max(bbox_areas))

  if masks is not None and len(masks) > 0:
    fg_area = sum(int(m.sum()) for m in masks)
    res["seg_ratio"] = fg_area / float(h * w)
    res["num_masks"] = len(masks)

  res["num_boxes"] = int(boxes.shape[0]) if boxes is not None else 0
  res["missed"] = False
  return ctx


def node__load_image(ctx):
  img = cv2.imread(ctx["img_path"])
  if img is None:
    ctx.update({"missed": True})
    return ctx
  h, w = img.shape[:2]
  ctx.update({"img_bgr": img, "shape": (h, w), "orig_shape": (h, w), "missed": False})
  return ctx


def node__restore_original_shape(ctx):
  if "orig_shape" in ctx:
    ctx["shape"] = ctx["orig_shape"]
  return ctx

def node__predict(ctx):
  if ctx.get("missed"):
    return ctx
  update = _run_predict(ctx, ctx["img_bgr"])
  if not update:
    ctx["missed"] = True
    return ctx
  ctx.update(update)
  return ctx

def node__predict_crop(ctx):
  if ctx.get("crop_unavailable", False):
    return ctx
  update = _run_predict(ctx, ctx["crop_array"])
  if not update:
    return ctx
  ctx.update(update)
  ctx["img_bgr"] = ctx["crop_array"]
  ctx["shape"] = ctx["crop_array"].shape[:2]
  return ctx


def node__crop_to_bbox(ctx):
  if ctx.get("missed") or ctx.get("tight_bbox") is None:
    ctx["crop_unavailable"] = True
    return ctx

  img = ctx["img_bgr"]
  h, w = img.shape[:2]
  x1, y1, x2, y2 = [int(round(v)) for v in ctx["tight_bbox"].tolist()]

  # clamp coordinates
  x1, y1 = max(0, x1), max(0, y1)
  x2, y2 = min(w, x2), min(h, y2)

  # invalid bbox → skip
  if x2 <= x1 or y2 <= y1:
    ctx["crop_unavailable"] = True
    return ctx

  crop = img[y1:y2, x1:x2].copy()
  if crop.size == 0:
    ctx["crop_unavailable"] = True
    return ctx

  ctx.update({
    "crop_unavailable": False,
    "crop_array": crop
  })
  return ctx


FLOW_PLANS = {
  "simple": [
    {"fn": "activate_first_roots"},
    {"fn": "load_image"},
    {"fn": "predict"},
    {"fn": "save_annotations"},
    {"fn": "save_bbox"},
    {"fn": "save_seg"},
    {"fn": "save_viz"},
    {"fn": "restore_original_shape"},   ## safe but harmless
    {"fn": "compute_metrics"},
  ],
  "bboxseg": [
    {"fn": "activate_first_roots"},
    {"fn": "load_image"},
    {"fn": "predict"},
    {"fn": "save_annotations"},
    {"fn": "save_bbox"},
    {"fn": "save_viz"},
    {"fn": "crop_to_bbox"},
    {"fn": "activate_second_roots"},
    {"fn": "predict_crop"},
    {"fn": "save_seg"},
    {"fn": "save_viz"},
    {"fn": "restore_original_shape"},   ## <- ensures metrics use original HxW
    {"fn": "compute_metrics"},
  ],
}


def main(args):
  dataset_name  = args.dataset
  datasetcfg    = args.datasetcfg
  splits_to_run = args.splits
  weights_path  = args.weights_path
  device        = args.device
  imgsz         = args.imgsz
  conf_thres    = args.conf_thres
  iou_thres     = args.iou_thres
  prefer_class  = args.prefer_class
  all_classes   = args.all_classes
  viz           = args.viz
  sample_limit  = args.sample_limit
  save_mask_png = args.save_mask_png
  to_path       = args.to_path
  flow_name     = args.flow
  timestamp     = args.timestamp

  __dataset_root__ = os.getenv("__DATASET_ROOT__")
  splits = get_splits(dataset_name, datasetcfg, __dataset_root__)

  model, device, model_info = load_model(weights_path, device)
  first_roots = build_basepaths(to_path, dataset_name)
  second_roots = build_basepaths(to_path, f"{dataset_name}-bboxseg") if flow_name == "bboxseg" else None

  info = generate_modelinfo(model_info, **{
    "author": __author__,
    "version": "0.0.1",
    "timestamp": timestamp,
    "weights_path": weights_path,
    "device": device,
    "imgsz": imgsz,
    "conf_thres": conf_thres,
    "iou_thres": iou_thres,
    "all_classes": all_classes,
  })
  write_json(os.path.join(to_path, "modelinfo.json"), info)

  summary = {
    "splits": {},
    "missed": {},
  }
  all_stats = {}
  all_missed_stats = {}

  for split in splits_to_run:
    if f"{split}loadertxt" not in splits:
      log.warning(f"Split {split} not found in dataset config.")
      continue

    split_file  = splits[f"{split}loadertxt"]
    images_root = splits["loadertxt"]
    loader, dataset = load_plain_dataset(
      root=images_root,
      split_file=split_file,
      sample_limit=sample_limit,
      batch_size=1,
      shuffle=False,
      num_workers=0,
      prefix_depth=2
    )

    split_stats = []
    missed_split_stats = []
    for batch in tqdm(loader, desc=f"{split}", unit="img"):
      img_path, label_idx, rel_from_images_root = batch

      if isinstance(img_path, (list,tuple)): img_path = img_path[0]
      if hasattr(label_idx, "item"): label_idx = label_idx.item()
      if isinstance(rel_from_images_root, (list,tuple)): rel_from_images_root = rel_from_images_root[0]

      ## rel_to_dataset = os.path.join(split, rel_from_images_root)
      ## splits are indexed per split but saved without split folder, and within respective class foldername; in sf3d
      rel_to_dataset = rel_from_images_root

      # pdb.set_trace()
      ## per-image context
      ctx = {
        "module": this,  ## <-- lets rflow look up node__* here as a fallback
        "fns": { f"node__{n}": getattr(this, f"node__{n}") for n in {
          "activate_first_roots","activate_second_roots","load_image","predict",
          "save_annotations","save_bbox","save_seg","save_viz","compute_metrics",
          "crop_to_bbox","predict_crop","restore_original_shape"
        }},
        "model": model,
        "device": device,
        "imgsz": imgsz,
        "conf_thres": conf_thres,
        "iou_thres": iou_thres,
        "prefer_class": prefer_class,
        "only_class": (not all_classes),
        "save_mask_png": save_mask_png,
        "viz": viz,
        ##
        "first_roots": first_roots,
        "second_roots": second_roots,
        "active_roots": first_roots,
        ##
        "img_path": img_path,
        "rel_to_dataset": rel_to_dataset,
        ##
        "result": {"missed": True},
      }

      ctx = rflow.run(FLOW_PLANS[flow_name], ctx)

      ## normalize result back to original src + original HxW for reporting
      res = ctx.get("result", {})
      res["src"] = img_path
      res["label_idx"] = label_idx

      ## to ensure only images with responses are generated;
      if res.get("missed"):
        ## log separately the images that are missed
        missed_split_stats.append(res)
        continue

      res["shape"] = ctx.get("orig_shape", res.get("shape"))
      split_stats.append(res)

    all_stats[split] = split_stats
    all_missed_stats[split] = missed_split_stats
    split_base = os.path.join(to_path, split)
    os.makedirs(split_base, exist_ok=True)
    split_summary = generate_split_stats(split, split_stats, split_base, dataset)
    summary["splits"][split] = split_summary
    summary["missed"] = all_missed_stats

    write_json(os.path.join(split_base, "summary.missed.json"), all_missed_stats)
    write_json(os.path.join(split_base, "summary.json"), split_summary)

  summary["dataset"] = dataset_name
  summary["timestamp"] = timestamp
  summary["flow"] = flow_name
  write_json(os.path.join(to_path, "summary.json"), summary)
  return all_stats


def parse_args():
  import argparse
  p = argparse.ArgumentParser(description="Annotation builder using node-based flows")

  ## dataset
  p.add_argument("--dataset", required=True, type=str)
  p.add_argument("--datasetcfg", required=True, type=str)
  p.add_argument("--splits", nargs="+", default=["train","val","test"])

  ## model
  p.add_argument("--weights_path", type=str, default="yolov8s.pt")
  p.add_argument("--device", type=str, default="cuda")
  p.add_argument("--imgsz", type=int, default=640)

  ## outputs
  timestamp = datetime.now().strftime("%d%m%y_%H%M%S")
  default_to = os.path.join("logs", f"annotate-{timestamp}")
  p.add_argument("--to", dest="to_path", type=str, default=default_to)

  ## flow
  p.add_argument("--flow", type=str, default="simple", choices=list(FLOW_PLANS.keys()))

  ## thresholds / filtering
  p.add_argument("--conf_thres", type=float, default=0.25)
  p.add_argument("--iou_thres", type=float, default=0.45)
  p.add_argument("--prefer-class", dest="prefer_class", type=int, default=0)
  p.add_argument("--all-classes", action="store_true", default=False)

  ## misc saves
  p.add_argument("--viz", action="store_true", default=False)
  p.add_argument("--sample_limit", type=int, default=0)
  p.add_argument("--save-mask-png", action="store_true", default=False)

  args = p.parse_args()
  args.timestamp = timestamp
  return args


def print_args(args):
  log.info("Arguments:")
  for k in vars(args):
    log.info(f"{k}: {getattr(args, k)}")


if __name__ == "__main__":
  args = parse_args()
  print_args(args)
  main(args)
