import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# class CIFAR10Decoder(nn.Module):
#     def __init__(self, hidden_size):
#         super(CIFAR10Decoder, self).__init__()
#         self.D4 = DecoderBlock(128, 64, 64, nn.BatchNorm2d)
#         self.D3 = DecoderBlock(64, 64, 64, nn.BatchNorm2d)
#         self.D2 = DecoderBlock(64, 64, 64, nn.BatchNorm2d)
#         self.D1 = DecoderBlock(64, 64, 64, nn.BatchNorm2d)
#         self.D0 = DecoderBlock(64, 64, 3, nn.BatchNorm2d, islast=True)
#         self.hidden_size = hidden_size
#         self._init_weight()
#
#     def forward(self, h):
#         # -1, hidden_size, 2, 2
#         D4_in = h.view(-1, self.hidden_size, 2, 2)
#
#         D4_out = self.D4(D4_in)
#         D3_in = F.interpolate(D4_out, size=(4,4), mode='bilinear', align_corners=True)
#
#         D3_out = self.D3(D3_in)
#         D2_in = F.interpolate(D3_out, size=(8,8), mode='bilinear', align_corners=True)
#
#         D2_out = self.D2(D2_in)
#         D1_in = F.interpolate(D2_out, size=(16,16), mode='bilinear', align_corners=True)
#
#         D1_out = self.D1(D1_in)
#         D0_in = F.interpolate(D1_out, size=(32,32), mode='bilinear', align_corners=True)
#
#         D_out = self.D0(D0_in)
#
#         return D_out
#
#     def _init_weight(self):
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 torch.nn.init.kaiming_normal_(m.weight)
#             elif isinstance(m, nn.BatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()
#
# class DecoderBlock(nn.Module):
#     def __init__(self, inChannel, midChannel, outChannel, BatchNorm, islast = False):
#         super(DecoderBlock, self).__init__()
#         self.islast = islast
#         self.conv1 = nn.Conv2d(inChannel, midChannel, kernel_size=3, stride=1, padding=1, bias=False)
#         self.conv2 = nn.Conv2d(midChannel, midChannel, kernel_size=3, stride=1, padding=1, bias=False)
#         self.conv3 = nn.Conv2d(midChannel, outChannel, kernel_size=3, stride=1, padding=1, bias=False)
#
#         self.bn1 = BatchNorm(midChannel)
#         self.bn2 = BatchNorm(midChannel)
#         self.bn3 = BatchNorm(outChannel)
#
#         self.relu = nn.ReLU(inplace=True)
#
#         self.dropout1 = nn.Dropout(0.25)
#         self.dropout2 = nn.Dropout(0.25)
#         self.dropout3 = nn.Dropout(0.25)
#
#     def forward(self, x):
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)
#         out = self.dropout1(out)
#
#         out = self.conv2(out)
#         out = self.bn2(out)
#         out = self.relu(out)
#         out = self.dropout2(out)
#
#         if self.islast:
#             out = self.conv3(out)
#         else:
#             out = self.conv3(out)
#             out = self.bn3(out)
#             out = self.relu(out)
#             out = self.dropout3(out)
#
#         return out

class CIFAR10Decoder(nn.Module):
    def __init__(self, hidden_size):
        super(CIFAR10Decoder, self).__init__()
        self.hidden_size = hidden_size

        self.decoder = nn.Sequential(
            # -1, hidden_size, 4, 4
            nn.ConvTranspose2d(hidden_size, hidden_size, 4, 2, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(inplace=True),
            # -1, hidden_size, 7, 7
            nn.ConvTranspose2d(hidden_size, hidden_size, 4, 2, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(inplace=True),
            # -1, hidden_size, 14, 14
            nn.ConvTranspose2d(hidden_size, hidden_size, 4, 2, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(inplace=True),
            # -1, 1, 28, 28
            nn.ConvTranspose2d(hidden_size, 3, 4, 2, 1),
        )

    def forward(self, h):
        # -1, hidden_size, 2, 2
        h = h.view(-1, self.hidden_size, 2, 2)
        # -1, 1, 28, 28
        x = self.decoder(h)
        return x

#
# class CIFAR10Decoder_backup(nn.Module):
#     def __init__(self, hidden_size):
#         super(CIFAR10Decoder_backup, self).__init__()
#         self.hidden_size = hidden_size
#
#         self.decoder = nn.Sequential(
#             # -1, hidden_size, 4, 4
#             nn.ConvTranspose2d(hidden_size, hidden_size, 4, 2, 1),
#             nn.BatchNorm2d(hidden_size),
#             nn.ReLU(inplace=True),
#             # -1, hidden_size, 7, 7
#             nn.ConvTranspose2d(hidden_size, hidden_size, 4, 2, 1),
#             nn.BatchNorm2d(hidden_size),
#             nn.ReLU(inplace=True),
#             # -1, hidden_size, 14, 14
#             nn.ConvTranspose2d(hidden_size, hidden_size, 4, 2, 1),
#             nn.BatchNorm2d(hidden_size),
#             nn.ReLU(inplace=True),
#             # -1, 1, 28, 28
#             nn.ConvTranspose2d(hidden_size, 3, 4, 2, 1),
#             nn.Sigmoid()
#         )
#
#     def forward(self, h):
#         # -1, hidden_size, 2, 2
#         h = h.view(-1, self.hidden_size, 2, 2)
#         # -1, 1, 28, 28
#         x = self.decoder(h)
#         return x


# The below code is adapted from github.com/juho-lee/set_transformer/blob/master/modules.py
class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.stack(Q.split(dim_split, 1), 0)
        K_ = torch.stack(K.split(dim_split, 1), 0)
        V_ = torch.stack(V.split(dim_split, 1), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        # O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(1, 0), 2)[0]
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        #O = O + F.elu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)