"""XAI hooks: source or data handlers, iterators and IO handlers."""
__author__ = 'XYZ'

import pdb
import os
import sys

from argparse import Namespace
from pathlib import Path

from ..core._log_ import logger
log = logger(__file__)

this = sys.modules[__name__]


def gethook(name, registry_name):
  """Retrieve a callable hook from a named registry."""
  registry = getattr(this, registry_name, None)
  if not isinstance(registry, dict):
    raise KeyError(f"Registry '{registry_name}' not found in hooks.")
  if name not in registry:
    raise KeyError(f"Hook '{name}' not found in registry '{registry_name}'. Available: {list(registry.keys())}")
  fn = registry[name]
  if not callable(fn):
    raise TypeError(f"Hook '{name}' in registry '{registry_name}' is not callable.")
  return fn


## -----------------------------
## Registries
## -----------------------------
SOURCE_HANDLERS = {}

def register_source_handler(key):
  def decorator(func):
    SOURCE_HANDLERS[key] = func
    return func
  return decorator


## -----------------------------
## Source Handlers
## -----------------------------
@register_source_handler('filesystem')
def handle_filesystem_source(args):
  from ..core.fro import get_input_files

  return iter(get_input_files(
    args.from_path,
    sample_limit=args.sample_limit,
    shuffle=args.shuffle,
  ))


@register_source_handler('dataset')
def dataset_source_and_input(config):
  """Unified dataset hook for both input and source registration (paths iterator)."""
  from ..dataset import get_splits, load_dataset

  ## normalize
  args_ns = Namespace(**config) if isinstance(config, dict) else config

  __dataset_root__ = os.getenv('__DATASET_ROOT__')
  splits = get_splits(args_ns.dataset, args_ns.datasetcfg, __dataset_root__)

  # Build dataset/loader to compute mean/std etc. (load_dataset needs args)
  dataset_split = getattr(args_ns, "dataset_split",
                          args_ns.split[0] if hasattr(args_ns, "split") else "test")
  _, dataset, _, _ = load_dataset(args_ns, splits, flag=dataset_split)

  ## Always return PATHS for both input_source and source_handler
  ## This makes `for image_path in source:` work the same as the old code.
  ## dataset.imgs is a list of file paths; wrap in iter() to be explicit.
  return iter(dataset.imgs)
