import torch
from pprint import pprint as print


def tensor_contains_nan(tensor):
  return torch.any(torch.isnan(tensor))


def tensor_contains_large(tensor, threshold=1e8):
  return torch.any(torch.abs(tensor) > threshold)


def tensor_break_on_large(tensor, threshold=1e3):
  if tensor_contains_large(tensor, threshold):
    import pdb
    pdb.set_trace()


def net_print_net_params(net):
  # Print all the parameters
  for name, param in net.named_parameters():
    print(name)
    print(param)
    print("----------")


def net_print_nan_grad(net):
  for name, param in net.named_parameters():
    if param.grad is not None and torch.any(torch.isnan(param.grad)):
      print(f"NaN gradient found in parameter: {name}")


def net_break_on_nan(net):
  # Check if model parameters contain NaN
  for name, param in net.named_parameters():
    if tensor_contains_nan(param.data):
      print(name)
      print(param)
      print("Model parameters contain NaN values. Breaking training loop.")
      net_print_nan_grad(net)
      import pdb
      pdb.set_trace()


def check_leaf_dims(d, ndim=2):
  """
    Recursively checks if all leaf nodes in a nested dictionary have ndim=2.
    Assumes that leaves are either empty dicts or PyTorch tensors.
    """
  for key, value in d.items():
    if isinstance(value, dict):
      print(key)
      if not value:
        # It's an empty dict, which is considered a leaf. Skip it.
        continue
      else:
        # It's a nested dict, recursively check its items
        check_leaf_dims(value)
    elif isinstance(value, torch.Tensor):
      print(key + ':', value.ndim)
      if value.ndim != ndim:
        # The tensor does not have ndim=2
        print(key, 'has dim ', value.ndim)
    else:
      # Encountered a non-dict, non-tensor object - handle as needed
      print(f"Unexpected type {type(value)} encountered in the dictionary.")

  # If all checks passed, return True
  return True

def if_shape_equals(var_list):
  shape = var_list[0].shape
  for v in var_list:
    if v is not None:
      assert v.shape == shape
