"""
General networks for pytorch.

Algorithm-specific networks should go else-where.
"""
import torch
from rlkit.torch.networks.basic import (
    Clamp,
    ConcatTuple,
    Detach,
    Flatten,
    FlattenEach,
    Split,
    Reshape,
)
from rlkit.torch.networks.normalization import LayerNorm
from rlkit.torch.networks.mlp import (
    Mlp,
    ConcatMlp,
    ConcatMultiHeadedMlp,
    ParallelMlp
)


class LinearTransform(torch.nn.Module):
    def __init__(self, m, b):
        super().__init__()
        self.m = m
        self.b = b

    def __call__(self, t):
        return self.m * t + self.b


__all__ = [
    "Clamp",
    "ConcatTuple",
    "Detach",
    "Flatten",
    "FlattenEach",
    "Split",
    "Reshape",
    "LayerNorm",
    "Mlp",
    "ConcatMlp",
    "ConcatMultiHeadedMlp",
    "ParallelMlp"
]
