import torch.nn as nn
import torch
import numpy as np


class fixed_watermark_decoder_d3(nn.Module):
    def __init__(self, input_channel=3, num_bits=3):
        super(fixed_watermark_decoder_d3, self).__init__()

        self.convs = nn.Sequential(
            nn.AvgPool2d(4),
            nn.Conv2d(input_channel, 256, kernel_size=(7, 7), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(256),

            nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2)),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.fc = nn.Linear(512, num_bits)

    def forward(self, input):
        output = self.convs(input)
        output = output.view(-1)
        output = self.fc(output)
        return output
    


class fixed_watermark_decoder_d4(nn.Module):
    def __init__(self, input_channel=3, num_bits=3):
        super(fixed_watermark_decoder_d4, self).__init__()

        self.convs = nn.Sequential(
            nn.AvgPool2d(4),
            nn.Conv2d(input_channel, 256, kernel_size=(7, 7), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(256),

            nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2)),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.fc = nn.Linear(512, num_bits)

    def forward(self, input):
        output = self.convs(input)
        output = output.view(-1)
        output = self.fc(output)
        return output
    

class fixed_watermark_decoder_d5(nn.Module):
    def __init__(self, input_channel=3, num_bits=3):
        super(fixed_watermark_decoder_d5, self).__init__()

        self.convs = nn.Sequential(
            nn.AvgPool2d(4),
            nn.Conv2d(input_channel, 256, kernel_size=(7, 7), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(256),

            nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2)),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.fc = nn.Linear(512, num_bits)

    def forward(self, input):
        output = self.convs(input)
        output = output.view(-1)
        output = self.fc(output)
        return output
    

class fixed_watermark_decoder_d6(nn.Module):
    def __init__(self, input_channel=3, num_bits=3):
        super(fixed_watermark_decoder_d6, self).__init__()

        self.convs = nn.Sequential(
            nn.AvgPool2d(4),
            nn.Conv2d(input_channel, 256, kernel_size=(7, 7), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(256),

            nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(512),

            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2)),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(512, num_bits)

    def forward(self, input):
        output = self.convs(input)
        output = output.view(-1)
        output = self.fc(output)
        return output


'''
The decoder is used in sec. 5.5 for comparsion against DADW
'''
class fixed_watermark_decoder_DADW(nn.Module):
    def __init__(self, input_channel=3, num_bits=3):
        super(fixed_watermark_decoder_DADW, self).__init__()

        self.convs = nn.Sequential(
            nn.AvgPool2d(4),
            nn.Conv2d(input_channel, 128, kernel_size=(7, 7), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(128),

            nn.Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2)),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.InstanceNorm2d(256),

            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2)),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.fc = nn.Linear(256, num_bits)

    def forward(self, input):
        output = self.convs(input)
        output = output.view(-1)
        output = self.fc(output)
        return output