import torch
import torch.nn as nn
import torch.nn.functional as F



def get_activation_function(activation: str = 'relu'):
    if activation == 'relu':
        return torch.nn.ReLU()
    elif activation == 'leaky_relu':
        return torch.nn.LeakyReLU()
    elif activation == 'sigmoid':
        return torch.nn.Sigmoid()
    elif activation == 'identity':
        fn = lambda x: x
        return fn
    elif activation == 'tanh':
        return torch.nn.Tanh()
    else:
        raise ValueError(f"Invalid activation type {activation}")

def initialize_layer_w_zero(layer):
    # layer.weight = nn.Parameter(
    #     torch.zeros_like(layer.weight),
    #     requires_grad=True,
    # )
    # if layer.bias is not None:
    #     layer.bias = nn.Parameter(
    #         torch.zeros_like(layer.bias),
    #         requires_grad=True,
    #     )
    for n, p in layer.named_parameters():
        p = nn.Parameter(
            torch.zeros_like(p),
            requires_grad=True
        )
    return layer