import torch
from layers import get_act


class ScoreNetworkReduced(torch.nn.Module):
    # takes an input image and time, returns the score function
    def __init__(self, config):
        super().__init__()
        nf = config.model.nf

        self.image_size = config.data.image_size

        self.act = get_act(config)
        chs = [nf, nf]
        self.chs = chs
        self._convs = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Conv2d(4, chs[0], kernel_size=3, padding=1),  # (batch, ch, 28, 28)
                torch.nn.Tanh(),  # (batch, 8, 28, 28)
            ),
            torch.nn.Sequential(
                # torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 14, 14)
                torch.nn.Conv2d(chs[0], chs[1], kernel_size=3, padding=1),  # (batch, ch, 14, 14)
                torch.nn.Tanh(),  # (batch, 8, 28, 28)
            ),
            # torch.nn.Sequential(
            #     torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 7, 7)
            #     torch.nn.Conv2d(chs[1], chs[2], kernel_size=3, padding=1),  # (batch, ch, 7, 7)
            #     torch.nn.ReLU(),  # (batch, 8, 28, 28)
            # ),
            # torch.nn.Sequential(
            #     torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1),  # (batch, ch, 4, 4)
            #     torch.nn.Conv2d(chs[2], chs[3], kernel_size=3, padding=1),  # (batch, ch, 4, 4)
            #     torch.nn.ReLU(),  # (batch, 8, 28, 28)
            # ),
            # torch.nn.Sequential(
            #     torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 2, 2)
            #     torch.nn.Conv2d(chs[3], chs[4], kernel_size=3, padding=1),  # (batch, ch, 2, 2)
            #     torch.nn.ReLU(),  # (batch, 8, 28, 28)
            # ),
        ])
        self._tconvs = torch.nn.ModuleList([
            torch.nn.Sequential(
                # input is the output of convs[4]
                # torch.nn.ConvTranspose2d(chs[-1], chs[-2], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, 64, 4, 4)
                torch.nn.Conv2d(chs[0], chs[1], kernel_size=3, padding=1),  # (batch, ch, 14, 14)
                # torch.nn.Tanh(),  # (batch, 8, 28, 28)
                self.act,
            ),
            # torch.nn.Sequential(
            #     # input is the output from the above sequential concated with the output from convs[3]
            #     torch.nn.ConvTranspose2d(chs[-2] * 2, chs[-3], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, 32, 7, 7)
            #     torch.nn.ReLU(),  # (batch, 8, 28, 28)
            # ),
            # torch.nn.Sequential(
            #     # input is the output from the above sequential concated with the output from convs[2]
            #     torch.nn.ConvTranspose2d(chs[2] * 2, chs[1], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, chs[2], 14, 14)
            #     torch.nn.ReLU(),  # (batch, 8, 28, 28)
            # ),
            # torch.nn.Sequential(
            #     # input is the output from the above sequential concated with the output from convs[1]
            #     torch.nn.ConvTranspose2d(chs[1] * 2, chs[0], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, chs[1], 28, 28)
            #     torch.nn.ReLU(),  # (batch, 8, 28, 28)
            # ),
            torch.nn.Sequential(
                # input is the output from the above sequential concated with the output from convs[0]
                torch.nn.Conv2d(chs[0] * 2, chs[0], kernel_size=3, padding=1),  # (batch, chs[0], 28, 28)
                # torch.nn.ReLU(),  # (batch, 8, 28, 28)
                # torch.nn.Tanh(),  # (batch, 8, 28, 28)
                self.act,
                torch.nn.Conv2d(chs[0], 3, kernel_size=3, padding=1),  # (batch, 1, 28, 28)
                # torch.nn.Conv2d(chs[0] * 2, 1, kernel_size=3, padding=1),  # (batch, 1, 28, 28)
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # x: (..., ch0 * 28 * 28), t: (..., 1)
        # x2 = torch.reshape(x, (*x.shape[:-1], 1, 28, 28))  # (..., ch0, 28, 28)
        # tt = t[..., None, None].expand(*t.shape[:-1], 1, 28, 28)  # (..., 1, 28, 28)
        # x2 = torch.reshape(x, (*x.shape[:-1], 1, 6, 6))  # (..., ch0, 28, 28)
        x2 = x
        if t.shape[0] == 1:
            tt = t[..., None, None, None].expand(x.shape[0], 1, self.image_size, self.image_size)
        else:
            tt = t[..., None, None, None].expand(t.shape[0], 1, self.image_size, self.image_size)  # (..., 1, 28, 28)
        x2t = torch.cat((x2, tt), dim=-3)
        signal = x2t
        signals = []
        for i, conv in enumerate(self._convs):
            signal = conv(signal)
            # print(signal.shape)
            if i < len(self._convs) - 1:
                signals.append(signal)

        for i, tconv in enumerate(self._tconvs):
            if i == 0:
                signal = tconv(signal)
                # print(signal.shape)
            else:
                signal = torch.cat((signal, signals[-i]), dim=-3)
                # print(signal.shape)
                signal = tconv(signal)
        # signal = torch.reshape(signal, (*signal.shape[:-3], -1))  # (..., 1 * 28 * 28)
        return signal