import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import init_weights, quant, round


# space expanded and channel expanded
class ExpandedConv(nn.Module):
    def __init__(self, ch_in, ch_out, ks=3) -> None:
        super(ExpandedConv, self).__init__()

        self.ks = ks
        self.ch_in = ch_in
        self.ch_out = ch_out
        groups = ch_in

        ch = 64 * groups
        padding = (ks - 1) // 2
        self.net = nn.Sequential(
            nn.Conv2d(ch_in, ch, 1, padding=padding, groups=groups, padding_mode="reflect"),
            nn.ReLU(),
            nn.Conv2d(ch, ch, 1, groups=groups),
            nn.ReLU(),
            nn.Conv2d(ch, ch_in * ch_out * ks * ks, 1, groups=groups),
        )

        init_weights(self)

    def forward(self, x):
        N, C, H, W = x.shape
        ks = self.ks
        ch_in = self.ch_in
        ch_out = self.ch_out

        x = self.net(x)
        x = x * 0.125

        # LUT quantization
        x = quant(x, 12)

        # reshape and inplace add
        x = x.view(N, ch_in, ch_out, ks, ks, H + ks - 1, W + ks - 1)
        x = x.sum(dim=1)
        y = torch.zeros((N, ch_out, H, W), dtype=x.dtype, device=x.device)
        for i in range(ks):
            for j in range(ks):
                y += x[:, :, i, j, i : H + i, j : W + j]

        # feature quantization
        y = quant(y, 8)

        return y


class Upscale(nn.Module):
    def __init__(self, upscale, ch_in, ch_out) -> None:
        super(Upscale, self).__init__()

        self.upscale = upscale
        groups = ch_in

        ch = 64 * groups
        self.net = nn.Sequential(
            nn.Conv2d(ch_in, ch, 1, groups=groups),
            nn.ReLU(),
            nn.Conv2d(ch, ch, 1, groups=groups),
            nn.ReLU(),
            nn.Conv2d(ch, ch_out * upscale * upscale, 1, groups=groups),
            nn.PixelShuffle(upscale),
        )

        init_weights(self)

    def forward(self, x):
        x = self.net(x)

        s = 2048
        x = x * s
        x = round(x)
        x = x / s

        x = x.sum(dim=1, keepdim=True)

        return x


# expanded convolutional network
class ECNN(nn.Module):
    def __init__(self, upscale=4, ch=8, layers=8) -> None:
        super(ECNN, self).__init__()

        self.upscale = upscale

        self.net = nn.ModuleList()
        self.net.append(ExpandedConv(1, ch, ks=3))
        self.net += [ExpandedConv(ch, ch, ks=3) for _ in range(layers - 1)]
        self.net.append(Upscale(upscale, ch, ch))

        init_weights(self)

    def forward(self, x):
        N, C, H, W = x.shape
        upscale = self.upscale
        out_H, out_W = H * upscale, W * upscale

        x = x.view(N * C, 1, H, W)
        for layer in self.net:
            x = layer(x)

        x = x.view(N, C, out_H, out_W)

        return x
