import torch
import torch.nn as nn


class WinterNet(nn.Module):
    def __init__(self, d=16, fd=64, num_inputs=768, activation=nn.functional.relu):
        assert num_inputs == 768, "Support for other input sizes not implemented in this codebase."
        super().__init__()
        self.d = d
        self.activation = activation
        self.c1 = nn.Conv2d(12, 12 * d, 15, padding=7, bias=False)
        self.b1 = nn.parameter.Parameter(data=torch.zeros((12 * d, 8, 8)))
        self.out = nn.Conv2d(2 * 12 * d, 3, 8, padding=0)
        self.f_dim = num_inputs
        self.f1 = nn.Linear(self.f_dim, fd)
        self.fout = nn.Linear(2 * fd, 3, bias=False)

    def forward(self, x_in, activate=True):
        x = x_in[:, :].view(-1, 12, 8, 8)
        x_mirror = torch.zeros_like(x)
        for i in range(8):
            x_mirror[:, :, i, :] = x[:, :, 7-i, :]
        x_mirror = torch.roll(x_mirror, 6, dims=1)

        mask = torch.repeat_interleave(x, self.d, dim=1)
        x = self.c1(x) + self.b1
        x = x * mask

        xm = x_mirror.clone()
        mask_mirrored = torch.repeat_interleave(xm, self.d, dim=1)
        xm = self.c1(xm) + self.b1
        xm = xm * mask_mirrored

        x = self.activation(torch.cat([x, xm], dim=1))
        x = self.out(x)
        x = torch.squeeze(x, 3)
        x = torch.squeeze(x, 2)

        fx = x_in[:, :self.f_dim]
        fx = self.activation(self.f1(fx))

        fxm = x_mirror.view(-1, 12*8*8)
        fxm = self.activation(self.f1(fxm))

        x = x + self.fout(torch.cat([fx, fxm], dim=1))
        return x
