import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def calculate_mu_sig(x:torch.Tensor, eps=1e-6):
    if x.dim() == 2:
        mu = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)
        sig = (var + eps).sqrt()
        return mu, sig
    elif x.dim() == 3:
        mu = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, keepdim=True)
        sig = (var + eps).sqrt()
        return mu, sig
    else:
        raise ValueError(f"input dim is {x.dim()}, which is invalid")

def featureDestylization(feat:torch.Tensor,
                   cur_mu:torch.Tensor,
                   cur_sig:torch.Tensor,):
    feat = (feat - cur_mu) / cur_sig
    return feat

def style_transfer(feat:torch.Tensor,
                   cur_mu:torch.Tensor,
                   cur_sig:torch.Tensor,
                   trans_mu:torch.Tensor,
                   trans_sig:torch.Tensor,):
    feat = featureDestylization(feat, cur_mu, cur_sig)
    feat = feat * trans_sig + trans_mu
    return feat

def freeze_model(m):
    m.requires_grad_(False)

def hot_model(m):
    m.requires_grad_(True)
