import torch
import torch.nn as nn
import numpy as np

from .primitives import PRIMITIVES


def running_stat(current_stat: nn.Parameter, batch_stat: torch.Tensor, momentum: float=0.1):

    assert momentum < 1. and momentum > 0.
    updated_stat = momentum * current_stat + (1. - momentum) * batch_stat

    return updated_stat


def from_config(config, dim, op_factory=None, **op_factory_args):
    
    parent_expert_id = config.get("parent_expert_id", None)
    op = PRIMITIVES[config["primitive"]](parent_expert_id, dim, op_factory, **op_factory_args)
    op.associate_tasks(config["associated_tasks"])

    return op
