import torch

def style_inject(feat, styles, style_index=None, eps=1e-5):
    if styles is None:
        raise NotImplemented

    style_var, style_mu = styles
    style_var = style_var.detach()
    style_mu = style_mu.detach()

    var = torch.var(feat.flatten(2), dim=2, unbiased=False, keepdim=True)
    mu = torch.mean(feat.flatten(2), dim=2, keepdim=True)

    var = var.unsqueeze(2)
    mu = mu.unsqueeze(2)

    # TODO: Which one is better (different indicies or same index)?
    if style_index is None:
        if feat.shape[0] > style_var.shape[0]:
            random_index1 = torch.randperm(feat.shape[0]) % style_var.shape[0]
            random_index2 = torch.randperm(feat.shape[0]) % style_mu.shape[0]
        else:
            random_index1 = torch.randperm(style_var.shape[0])[:feat.shape[0]]
            random_index2 = torch.randperm(style_mu.shape[0])[:feat.shape[0]]
        style_var = style_var[random_index1, :]
        style_mu = style_mu[random_index2, :]

    else:   
        style_var = style_var[style_index, :]
        style_mu = style_mu[style_index, :]

    stylized_feat = ((feat - mu) / (var + eps).sqrt()) * (style_var + eps).sqrt() + style_mu

    return stylized_feat

def normalize_style(feat, eps=1e-5):
    var = torch.var(feat.flatten(2), dim=2, unbiased=False, keepdim=True)
    mu = torch.mean(feat.flatten(2), dim=2, keepdim=True)

    var = var.unsqueeze(2)
    mu = mu.unsqueeze(2)

    normalized_feat = (feat - mu) / (var + eps).sqrt()
    return normalized_feat

def get_styles(feat):
    var = torch.var(feat.flatten(2), dim=2, unbiased=False, keepdim=True)
    mu = torch.mean(feat.flatten(2), dim=2, keepdim=True)
    var = var.unsqueeze(2)
    mu = mu.unsqueeze(2)

    return var, mu

def get_style_list(styles):
    style_list = []
    for idx in range(0, len(styles), 2):
        style_list.append((styles[idx], styles[idx + 1]))
    return style_list

