"""SAGE Orchestrator
Declarative multi-model orchestrator for Saliency Attribution for Goal-grounded Evaluation (SAGE).
"""
__author__ = 'XYZ'


import argparse
import ast
import os
import sys
import yaml

from datetime import datetime
from importlib import import_module

from pathlib import Path
from types import SimpleNamespace


from .core._log_ import logger
log = logger(__file__)

from .utils.torchapi import loadmodel, unloadmodel
from .dataset import get_splits


def resolve_weights_path(path):
  """Resolve weights path with __DATAHUB_ROOT__ if not absolute."""
  if not path:
    return None
  if os.path.isabs(path):
    return path
  root = os.getenv("__DATAHUB_ROOT__")
  if not root:
    log.error("Environment variable __DATAHUB_ROOT__ not set")
    sys.exit(1)
  return os.path.join(root, path)


def load_yaml_config(config_path):
  """Load YAML experiment config."""
  try:
    with open(config_path, "r") as f:
      return yaml.safe_load(f)
  except Exception as e:
    log.error(f"Failed to load YAML config: {e}")
    sys.exit(1)


def load_backend_module(backend):
  """Dynamically import a backend module under current package."""
  try:
    package_name = __package__
    mod = import_module(f".{backend}", package=package_name)
    log.info(f"Loaded backend module: {backend}")
    return mod
  except ImportError as e:
    log.error(f"Error loading backend {backend}: {str(e)}")
    sys.exit(1)


def build_model_lookup(models_cfg):
  """Build dictionary of model_id -> model spec from YAML.

  Supports both list-of-dicts and dict-of-dicts formats.
  """
  lookup = {}
  if isinstance(models_cfg, dict):
    # Already mapping: {id: {spec}}
    for model_id, spec in models_cfg.items():
      lookup[model_id] = spec
  elif isinstance(models_cfg, list):
    # Old format: [{id: ..., arch: ...}, ...]
    for m in models_cfg:
      model_id = m.get("id")
      if not model_id:
        log.error(f"Model entry missing 'id': {m}")
        sys.exit(1)
      lookup[model_id] = m
  else:
    log.error(f"Unsupported models format: {type(models_cfg)}")
    sys.exit(1)
  return lookup


def load_model_from_spec(model_spec, args, base_logs, model_id=None):
  """Create args_obj and load model from spec dict."""
  args_obj = SimpleNamespace(**vars(args))
  args_obj.dataset = model_spec.get("dataset", args.dataset)
  args_obj.net = model_spec.get("arch", args.net)
  args_obj.weights_path = resolve_weights_path(model_spec.get("weights_path") or args.weights_path)
  args_obj.num_class = model_spec.get("num_classes", args.num_class)
  args_obj.input_size = tuple(model_spec.get("input_size", args.input_size))

  # Ensure to_path is tagged with model id
  if model_id is None:
    model_id = model_spec.get("id", "unknown_model")
  args_obj.to_path = os.path.join(base_logs, model_id)
  Path(args_obj.to_path).mkdir(parents=True, exist_ok=True)

  device = "cuda" if args_obj.gpu and __import__("torch").cuda.is_available() else "cpu"
  __dataset_root__ = os.getenv("__DATASET_ROOT__")
  splits = get_splits(args_obj.dataset, args_obj.datasetcfg, __dataset_root__)

  model = loadmodel(args_obj)
  return args_obj, model.to(device), device, splits



def execute_function(step, context, stage_name=None, model_lookup=None, args=None, base_logs=None):
  """Execute a single function step defined in flowplan."""
  fn = step.get("fn")
  params = step.get("params", {})

  # Handle model injection if specified
  model_id = params.pop("model", None)
  if model_id:
    if model_id not in model_lookup:
      log.error(f"Model id {model_id} not found in config")
      sys.exit(1)
    if context.get("model") is not None:
      unloadmodel(context["model"])
    model_spec = model_lookup[model_id]
    args_obj, model, device, splits = load_model_from_spec(model_spec, args, base_logs, model_id=model_id)

    # Ensure args_obj is stored in context (this is used by backends)
    context["args"] = args_obj
    context["model"] = model
    context["device"] = device
    context["splits"] = splits
    context["model_id"] = model_id
    context["to_path"] = args_obj.to_path   # existing behaviour

    # --- important: canonical locations expected by saliency backend ---
    # Use the model_id-based to_path (unique per model_spec) and place saliency under it:

    model_base = Path(args_obj.to_path)           # <base_logs>/<model_id>
    saliency_dir = model_base / "saliency"         # <base_logs>/<model_id>/saliency
    # set the keys the saliency backend expects
    context["out_root"] = str(saliency_dir)       # used by saliency._write_indices_and_manifest etc.
    context["saliency_dir"] = str(saliency_dir)   # friendly alias
    context["groups_base_dir"] = str(model_base)  # used by _collect_groups_from_context as base for groups
    # ensure the directory exists early
    Path(saliency_dir).mkdir(parents=True, exist_ok=True)

    # --- initialize index & stats containers if not present ---
    # generate_saliency expects these; initialize now so no KeyError occurs
    context.setdefault("_indices_maps", [])
    context.setdefault("_indices_tokens", [])
    context.setdefault("_stats", {"by_group": {}, "total_images": 0})

    log.info(f"Injected model into context: model_id={model_id}, to_path={args_obj.to_path}, saliency_dir={saliency_dir}")

  if "." in fn:
    module_name, func_name = fn.split(".")
  else:
    module_name, func_name = stage_name, fn

  log.debug(f"About to load backend module '{module_name}' for function '{func_name}'")
  mod = load_backend_module(module_name)

  if not hasattr(mod, func_name):
    log.error(f"Function `{func_name}` not found in module `{module_name}`")
    sys.exit(1)

  func = getattr(mod, func_name)
  log.info(f"Executing: {module_name}.{func_name} with params={params}")

  try:
    context = func(context, **params)
  except Exception as e:
    log.error(f"Execution failed in {module_name}.{func_name}: {e}")
    sys.exit(1)

  return context


def run_stage(stage_name, stage_steps, run_flags, context, model_lookup, args, base_logs):
  updated_context = context
  if not run_flags.get(stage_name, False):
    log.info(f"Skipping stage: {stage_name}")
  else:
    log.info(f"Running stage: {stage_name}")
    for step in stage_steps:
      updated_context = execute_function(step, updated_context, stage_name,
                                         model_lookup=model_lookup, args=args, base_logs=base_logs)
  return updated_context


def main(args):
  cfg = load_yaml_config(args.config)
  exp_id = cfg.get("id", f"sage-exp-{datetime.now().strftime('%d%m%y_%H%M%S')}")
  log.info(f"Starting SAGE experiment: {exp_id}")

  base_logs = args.to_path or os.path.join("logs", f"sage-{datetime.now().strftime('%d%m%y_%H%M%S')}")
  Path(base_logs).mkdir(parents=True, exist_ok=True)

  run_flags = cfg.get("run", {})
  flowplan = cfg.get("flowplan", [])
  model_lookup = build_model_lookup(cfg.get("models", []))

  context = {
    "exp_id": exp_id,
    "cfg": cfg,
    "base_dir": base_logs,
  }

  for stage in flowplan:
    stage_name, stage_steps = list(stage.items())[0]
    context = run_stage(stage_name, stage_steps, run_flags, context,
                        model_lookup=model_lookup, args=args, base_logs=base_logs)

  if context.get("model") is not None:
    unloadmodel(context["model"])

  log.info(f"SAGE experiment {exp_id} completed successfully.")


def parse_args():
  """Parse command-line arguments for SAGE orchestrator."""
  parser = argparse.ArgumentParser(description="SAGE Orchestrator")

  parser.add_argument('--from', dest='from_path', type=str, help="Input path(s), comma-separated")
  parser.add_argument('--to', dest='to_path', type=str, default=os.path.join("logs", f"sage-{datetime.now().strftime('%d%m%y_%H%M%S')}"), help="Output directory")
  parser.add_argument('--config', type=str, required=True, help="Path to experiment YAML config")

  ## Dataset-related
  parser.add_argument('--dataset', type=str, help="Dataset ID override")
  parser.add_argument('--datasetcfg', type=str, default="data/ddd-datasets.yml", help="Dataset config file")
  parser.add_argument('--split', type=str, default="test", help="Dataset split (train/val/test)")

  ## Model-related
  parser.add_argument('--weights_path', type=str, help="Model weights path override")
  parser.add_argument('--net', type=str, help="Model architecture override")
  parser.add_argument('--pretrain', action='store_true', default=False, help="Use pretrained weights")
  parser.add_argument('--num_class', type=int, help="Number of classes")

  ## Training/eval parameters
  parser.add_argument('--loss', type=str, default="CrossEntropyLoss")
  parser.add_argument('--input_size', type=str, default="(224,224)")
  parser.add_argument('--batch_size', type=int, default=64)
  parser.add_argument('--num_workers', type=int, default=4)
  parser.add_argument('--sample_limit', type=int, default=None)
  parser.add_argument('--gpu', action='store_true', default=True)
  parser.add_argument('--num_iterations', type=int, default=100)

  parser.add_argument('--cache', action='store_true', default=True, help="Enable application-level caching")

  ## which parts to emit to disk (keeps compatibility with standalone saliency CLI)
  parser.add_argument('--parts', nargs='+', default=['whole','fg','bg'],
                 help='Which parts to emit (default: whole fg bg)')

  args = parser.parse_args()

  ## normalize input_size
  try:
    args.input_size = ast.literal_eval(args.input_size)
    if not isinstance(args.input_size, tuple) or len(args.input_size) != 2:
      raise ValueError
  except (ValueError, SyntaxError):
    log.error("Error: --input_size should be a tuple of two integers, e.g., '(224,224)'")
    sys.exit(1)

  return args


def print_args(args):
  """Print parsed arguments."""
  print("Arguments:")
  for k, v in vars(args).items():
    print(f"{k}: {v}")


if __name__ == "__main__":
  args = parse_args()
  print_args(args)
  main(args)
