import torch
from torch import nn


class DownSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, bias=False):
        super(DownSampleConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

class ChannelConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False):
        super(ChannelConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

class Conv3_3(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=False):
        super(Conv3_3, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

class Aug(nn.Module):
    def __init__(self, f_inchannels,r_inchannels, out_channels):
        super(Aug, self).__init__()
        self.channelConv = ChannelConv(f_inchannels,out_channels,1,1,0,False)
        self.downSampleConv = DownSampleConv(r_inchannels,out_channels,3,2,1,False)
        self.conv3_3 = Conv3_3(out_channels,out_channels,3,1,1,False)

    def forward(self, f,r):
        f=self.channelConv(f)#1,64,56,56
        r=self.downSampleConv(r)#1,64,56,56
        return self.conv3_3(f+r)#1,64,56,56--->1,64,56,56

if __name__ == '__main__':
    f=torch.randn(1,128,56,56)
    r=torch.randn(1,64,112,112)
    aug=Aug(128,64,64)
    res=aug(f,r)
    print('aaa')
