import torch
import torch.nn as nn

class QNet(nn.Module):
    def __init__(self, in_channels, im_w):
        super(QNet, self).__init__()
        self.state_trunk = nn.Sequential(nn.Conv2d(in_channels, 32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 1),
                                         nn.ReLU(True))
        self.head        = nn.Sequential(nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,32, 3, 1),
                                                                 nn.ReLU(True),
                                                                 nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,1, 3, 1))
        self.im_w = im_w

    def forward(self, x):
        x  = self.state_trunk(x)
        out = self.head(x)
        out = nn.Upsample(size=(self.im_w,self.im_w), mode="bilinear").forward(out)
        return out