import torch as t
import torch.nn as nn


class NormalisePixels(nn.Module):
    def forward(self, x):
        return x / 255.0


class ChannelFirst(nn.Module):
    def forward(self, x):
        return x.permute(0, 3, 1, 2).contiguous()


class Conv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=1,
        stride=1,
        padding=0,
    ):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel, stride, padding)

    def forward(self, x):
        y = self.conv(x)
        y = t.relu(y)
        return y


class ResidualLinear(nn.Module):
    def __init__(
        self,
        dims,
    ):
        super().__init__()

        self.lin1 = nn.Linear(dims, dims)
        self.lin2 = nn.Linear(dims, dims)

    def forward(self, x):
        y = self.lin1(x)
        y = t.relu(y)
        y = self.lin2(y)
        y = t.relu(x + y)
        return y


class ResidualConv(nn.Module):
    def __init__(
        self,
        channels,
    ):
        super().__init__()

        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)

    def forward(self, x):
        y = self.conv1(x)
        y = t.relu(y)
        y = self.conv2(y)
        y = t.relu(x + y)
        return y
