import os
import tempfile
import wandb
import json
import numbers
from collections.abc import Mapping, Sequence

def _jsonify(x):
    if x is None or isinstance(x, (bool, str, int, float)):
        return x
    try:
        import numpy as _np
        if isinstance(x, _np.generic):
            return x.item()
    except Exception:
        pass
    
    if isinstance(x, (tuple, list, set)):
        return [_jsonify(v) for v in x]
    
    if isinstance(x, Mapping):
        return {str(k): _jsonify(v) for k, v in x.items()}
    
    try:
        from ml_collections.config_dict import ConfigDict
        if isinstance(x, ConfigDict):
            return _jsonify(x.to_dict())
    except Exception:
        pass
    
    if hasattr(x, "__dict__"):
        try:
            return _jsonify(vars(x))
        except Exception:
            return str(x)
        

def flags_to_dict(FLAGS):
    """Try several strategies to extract ABSEIL flags safely."""
    # 1) Best case: FlagValues.flag_values_dict() (newer absl)
    for attr in ("flag_values_dict",):
        if hasattr(FLAGS, attr):
            try:
                return _jsonify(getattr(FLAGS, attr)())
            except Exception:
                pass

    # 2) get_key_flags(): returns a list of Flag objects for the current module
    try:
        d = {}
        for f in FLAGS.get_key_flags():  # Flag objects
            try:
                d[f.name] = getattr(FLAGS, f.name)
            except Exception:
                # some flags (esp. config files) can be special; best-effort
                d[f.name] = str(f)
        if d:
            return _jsonify(d)
    except Exception:
        pass

    # 3) Last resort: iterate attribute names and pick values that look like flags
    #    (very defensive, avoids private absl internals)
    guessed = {}
    for name in dir(FLAGS):
        if name.startswith("_"):
            continue
        try:
            val = getattr(FLAGS, name)
            # Heuristic: exclude callables / modules; keep JSONable things
            if not callable(val) and not hasattr(val, "__module__"):
                guessed[name] = val
        except Exception:
            continue
    return _jsonify(guessed)
    

def serialize_config_for_wandb(flags, extra: dict):
    # Base: all ABSEIL flags
    cfg = flags_to_dict(flags)

    # Add the ml_collections config if present
    algo_cfg = {}
    try:
        algo_cfg = flags.config.to_dict()
    except Exception:
        try:
            algo_cfg = dict(flags.config)
        except Exception:
            pass
    cfg["algo_config"] = _jsonify(algo_cfg)

    # Add any extra kwargs you pass to the algorithm later (optional; see §3)
    if extra:
        cfg["algo_kwargs"] = _jsonify(extra)

    # Normalize a few fields that are often non-JSON by default
    # (Your code already coerces these to tuples; we just ensure lists)
    for name in ("hidden_dims", "g_hidden_dims", "d_hidden_dims"):
        if name in cfg:
            cfg[name] = list(cfg[name]) if cfg[name] is not None else None

    return cfg


def init_wandb_or_disable(flags, fname, *, model_name: str, extra_cfg: dict):
    """Initialize W&B with robust config + smart defaults.

    - Automatically disables if online but no API key.
    - Uses temp dir to avoid clutter when key exists.
    - Sets `group` to model_name if user didn't pass one.
    """
    assert flags.wandb_run_name is not None, "Please provide a wandb_run_name for easier tracking"
    run_name = flags.wandb_run_name 
    tags = list(flags.wandb_tags) if isinstance(flags.wandb_tags, list) else []

    # Model name is stored in config and (unless overridden) used as the group
    if "model_name" not in tags:
        tags.append(f"model:{model_name}")

    # W&B output dir: put runs in a temp sandbox if API key is present (optional)
    wandb_output_dir = tempfile.mkdtemp(prefix="wandb_") if 'WANDB_API_KEY' in os.environ else None

    # Mode: respect user preference but auto-disable if no key & wants 'online'
    mode = flags.wandb_mode
    if 'WANDB_API_KEY' not in os.environ and mode == 'online':
        mode = 'disabled'

    # Project default: env_name, or 'Debug' if debug, unless user supplied one
    project = (flags.env_name if not flags.debug else 'Debug') if flags.wandb_project is None else flags.wandb_project

    group = flags.wandb_group or model_name

    cfg = serialize_config_for_wandb(flags, extra=extra_cfg)
    cfg.update({
        "video":{
            "interval": int(flags.video_interval),
            "eval_save_root": os.path.join(flags.save_dir, 'video', 'eval_videos'),
            "save_video": bool(flags.save_video)
        },
        "model_name": model_name,
    })

    run = wandb.init(
        project=project,
        entity=flags.wandb_entity or None,
        group=group,
        name=run_name,
        mode=mode,
        tags=tags,
        config=cfg,
        dir=wandb_output_dir,
        job_type="train",
    )

    try:
        wandb.run.log_code(root=".", include_fn=lambda path: path.endswith((".py", ".sh")))
    except Exception:
        pass

    return run

