import inspect
import os
import random
from typing import Iterable

import numpy as np
import torch


def set_seed(seed=123, verbose=True):

    if verbose:
        caller = inspect.currentframe().f_back.f_code.co_name
        print(f"{caller}: setting {seed=}")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)


def get_default_device():
    return "cuda" if torch.cuda.is_available() else "cpu"


def print_model_parameter_count(model):

    trainable_params_info = []
    non_trainable_params_info = []

    for name, param in model.named_parameters():
        numel = param.numel()
        if param.requires_grad:
            trainable_params_info.append((name, numel))
        else:
            non_trainable_params_info.append((name, numel))

    print("\n")

    print("=== Trainable parameters ===")
    for name, cnt in trainable_params_info:
        print(f"{name:30s}  |  {cnt:,d}")
    print(f"\n→ Total trainable parameters: {sum(cnt for _, cnt in trainable_params_info):,d}\n\n")

    print("=== Non-trainable parameters ===")
    for name, cnt in non_trainable_params_info:
        print(f"{name:30s}  |  {cnt:,d}")
    print(f"\n→ Total non-trainable parameters: {sum(cnt for _, cnt in non_trainable_params_info):,d}\n\n")

    print("=== Total: all parameters ===")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n→ All parameters: {total_params:,d}")


def get_arg_names(fn):
    # includes positional, keyword-only; excludes *args/**kwargs labels
    kinds = lambda x: (x.POSITIONAL_ONLY, x.POSITIONAL_OR_KEYWORD, x.KEYWORD_ONLY)
    return [p.name for p in inspect.signature(fn).parameters.values() if p.kind in kinds(p)]


def _named_arrays(*names, include_globals=True):
    # allow named_arrays(["A","B"]) or named_arrays(("A","B"))
    if len(names) == 1 and isinstance(names[0], Iterable) and not isinstance(names[0], (str, bytes)):
        names = tuple(names[0])

    # Start at the immediate caller, then walk up until we leave this module.
    frame = inspect.currentframe().f_back
    try:
        this_mod = __name__
        while frame and frame.f_globals.get('__name__') == this_mod:
            frame = frame.f_back

        scope = {}
        if frame:
            scope = dict(frame.f_locals)
            if include_globals:
                scope = {**frame.f_globals, **scope}

        # keep only string names present in scope
        return {k: scope[k] for k in names if isinstance(k, str) and k in scope}
    finally:
        del frame  # avoid reference cycles

def _print_shapes(**arrays):
    print("\nshapes:\n")
    for name, M in arrays.items():
        print(f"{name:<12}\t{getattr(M, 'shape', None)}")
    print("\n" + "-"*25)

def print_array_shapes(*names, include_globals=True):
    _print_shapes(**_named_arrays(*names, include_globals=include_globals))