import torch
import math
from torch import nn
import torchvision
import torch.nn.functional as F
from torch.nn.utils import spectral_norm


class DetectorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True):
        super().__init__()
        self.conv_res = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(2 if downsample else 1))

        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
        )

        self.downsample = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(out_channels),
        ) if downsample else nn.Sequential()

    def forward(self, x):
        res = self.conv_res(x)
        x = self.net(x)
        x = self.downsample(x)
        x = F.relu(x + res, inplace=True)
        return x


class DetectorTransposedBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.net(x)
        return x


class Detector(nn.Module):
    def __init__(self, hyper_paras):
        super().__init__()
        self.n_keypoints = hyper_paras['n_keypoints']

        self.conv = nn.Sequential(
            DetectorBlock(3, 64, downsample=True),  # 64
            DetectorBlock(64, 64, downsample=False),  # 64
            DetectorBlock(64, 128, downsample=True),  # 32
            DetectorBlock(128, 128, downsample=False),  # 32
            DetectorBlock(128, 256, downsample=True),  # 16
            DetectorBlock(256, 256, downsample=False),  # 16
            DetectorBlock(256, 512, downsample=True),  # 8
            DetectorBlock(512, 512, downsample=False),  # 8
            DetectorTransposedBlock(512, 256),   # 16
            DetectorTransposedBlock(256, 128),   # 32
            DetectorTransposedBlock(128, 64),  # 64
            DetectorTransposedBlock(64, self.n_keypoints),  # 128
            )

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, input_dict):
        x = input_dict['img']
        x = self.conv(x)
        return {'heatmap': x}


if __name__ == '__main__':
    det = Detector({'n_keypoints': 10})
    dummy_input = {'img': torch.randn(32, 3, 128, 128)}
    print(det(dummy_input)['heatmap'].shape)
