import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torchvision import models

class QNetBimanual(nn.Module):
    def __init__(self, inchannels, im_w):
        super(QNetBimanual, self).__init__()
        self.state_trunk = nn.Sequential(nn.Conv2d(12, 32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 1),
                                         nn.ReLU(True))
        self.head        = nn.Sequential(nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,32, 3, 1),
                                                                 nn.ReLU(True),
                                                                 nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,2, 3, 1))
        self.im_w = im_w

    def forward(self, obs1, goal1, obs2, goal2, inedgemasks=None):
        if inedgemasks is not None:
            x = torch.cat([obs, inedgemasks, goal], dim=1)
        else:
            x = torch.cat([obs1, goal1, obs2, goal2], dim=1)
        x  = self.state_trunk(x)
        out = self.head(x)
        out = nn.Upsample(size=(self.im_w,self.im_w), mode="bilinear").forward(out)
        return out


class QNetBimanualSingleScale(nn.Module):
    def __init__(self, inchannels, im_w):
        super(QNetBimanualSingleScale, self).__init__()
        self.state_trunk = nn.Sequential(nn.Conv2d(6, 32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 2),
                                         nn.ReLU(True),
                                         nn.Conv2d(32,32, 5, 1),
                                         nn.ReLU(True))
        self.head        = nn.Sequential(nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,32, 3, 1),
                                                                 nn.ReLU(True),
                                                                 nn.UpsamplingBilinear2d(scale_factor=2),
                                                                 nn.Conv2d(32,2, 3, 1))
        self.im_w = im_w

    def forward(self, obs, goal, inedgemasks=None):
        if inedgemasks is not None:
            x = torch.cat([obs, inedgemasks, goal], dim=1)
        else:
            x = torch.cat([obs, goal], dim=1)
        x  = self.state_trunk(x)
        out = self.head(x)
        out = nn.Upsample(size=(self.im_w,self.im_w), mode="bilinear").forward(out)
        return out