from typing import List, Optional

import torch.nn as nn


def parse_norm_layer(norm_layer_str: str) -> nn.Module:
    """Parses a string indicating a norm layer into a torch layer.

    Args:
        norm_layer_str: String representation of the layer.

    Returns:
        The layer class.
    """
    if norm_layer_str == "none":
        return nn.Identity
    elif norm_layer_str == "layer_norm":
        return nn.LayerNorm
    elif norm_layer_str == "batch_norm":
        return nn.BatchNorm1d
    else:
        raise ValueError(f"Norm layer class not implemented: {norm_layer_str}")


def build_fc_network(input_size: int,
                     layer_sizes: List[int],
                     norm_layer: Optional[str] = None,
                     use_final_layer_activation: bool = False) -> nn.Module:
    """Builds a fully-connected network.
    
    Args:
        input_size: The size of the input tensor.
        layer_sizes: The sizes of the layers in the network.
        norm_layer: An optional string indicating the type of norm layer to use.
        use_final_layer_activation: Whether to apply an activation to the output.

    Returns:
        A torch module containing a fully-connected network.
    """
    layers = []
    size_in = input_size
    for i, layer_size in enumerate(layer_sizes):
        if norm_layer is not None:
            layers.append(norm_layer(size_in))

        layers.append(nn.Linear(size_in, layer_size))
        if i == len(layer_sizes) - 1:
            if use_final_layer_activation:
                layers.append(nn.ReLU())
        else:
            layers.append(nn.ReLU())
        size_in = layer_size
    return nn.Sequential(*layers)
