""" Distributed training/validation utils

Author:  
Adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/distributed.py
"""

import re
import torch
from torch import distributed as dist
from timm.utils.model import unwrap_model


def reduce_tensor(tensor, n, reduce="sum"):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    if reduce == "mean":
        rt /= n
    return rt


def distribute_stats(model, stats, world_size, reduce=False):
    # ensure every node has the same running bn stats
    for expert_mean_name, expert_mean_buf in unwrap_model(model).named_buffers(recurse=True):
        if any([re.match(f"{stat}.\d+.*", expert_mean_name) for stat in stats]):
            if reduce:
                # average mean across whole group
                torch.distributed.all_reduce(expert_mean_buf, op=dist.ReduceOp.SUM)
                expert_mean_buf /= float(world_size)
            else:
                # broadcast mean from rank 0 to whole group
                torch.distributed.broadcast(expert_mean_buf, 0)
