import torch
from torch import nn
import torch.nn.functional as F
from models.heatmaps import gen_heatmaps
from models.sync_batchnorm import SynchronizedBatchNorm2d
from torch.nn.utils import spectral_norm


class SPADE(nn.Module):
    def __init__(self, input_channel, n_embeddings, n_keypoints):
        super().__init__()
        self.norm = SynchronizedBatchNorm2d(input_channel, affine=False)
        self.conv = nn.Conv2d((n_keypoints+1)*n_embeddings, 128, kernel_size=3, padding=1)
        self.conv_gamma = nn.Conv2d(128, input_channel, kernel_size=3, padding=1)
        self.conv_beta = nn.Conv2d(128, input_channel, kernel_size=3, padding=1)

    def forward(self, x, heatmaps):
        normalized_x = self.norm(x)
        heatmaps_features = F.leaky_relu(self.conv(heatmaps), 0.2)
        heatmaps_gamma = self.conv_gamma(heatmaps_features)
        heatmaps_beta = self.conv_beta(heatmaps_features)
        return (1+heatmaps_gamma) * normalized_x + heatmaps_beta


class SPADEResBlk(nn.Module):
    def __init__(self, in_channel, out_channel, n_embeddings, n_keypoints):
        super().__init__()
        mid_channel = min(in_channel, out_channel)
        self.learn_shortcut = in_channel != out_channel
        self.spade1 = SPADE(in_channel, n_embeddings, n_keypoints)
        self.conv1 = nn.Conv2d(in_channel, mid_channel, kernel_size=3, padding=1)
        self.spade2 = SPADE(mid_channel, n_embeddings, n_keypoints)
        self.conv2 = nn.Conv2d(mid_channel, out_channel, kernel_size=3, padding=1)

    def forward(self, x, heatmaps):
        x = self.conv1(F.leaky_relu(self.spade1(x, heatmaps), 0.2))
        x = self.conv2(F.leaky_relu(self.spade2(x, heatmaps), 0.2))

        return x


class LinearSpectralNormLeakyReLU(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LinearSpectralNormLeakyReLU, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.LeakyReLU(negative_slope=0.2),
        )

    def forward(self, x):
        return self.model(x)


class Generator(nn.Module):
    def __init__(self, hyper_paras):
        super(Generator, self).__init__()
        self.z_dim = hyper_paras['z_dim']
        self.n_keypoints = hyper_paras['n_keypoints']
        self.n_embedding = hyper_paras['n_embedding']
        self.tau = hyper_paras['tau']
        self.noise_shapes = [(self.z_dim,), (self.z_dim,), (self.z_dim,)]

        self.keypoints_embedding = nn.Embedding(self.n_keypoints, self.n_embedding)

        self.gen_keypoints_embedding_noise = nn.Sequential(
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            nn.Linear(self.z_dim, self.n_embedding),
        )

        self.gen_keypoints_layer = nn.Sequential(
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            nn.Linear(self.z_dim, self.n_keypoints * 2),
            )

        self.gen_background_embedding = nn.Sequential(
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            LinearSpectralNormLeakyReLU(self.z_dim, self.z_dim),
            nn.Linear(self.z_dim, self.n_embedding),
        )

        self.start = nn.Parameter(torch.randn(1, 512, 4, 4), requires_grad=True)

        self.pre_image_sizes = [4, 8, 16, 32]

        self.pre_spade_blocks = nn.ModuleList([
            SPADEResBlk(512, 512, self.n_embedding, self.n_keypoints),  # 4
            SPADEResBlk(512, 512, self.n_embedding, self.n_keypoints),  # 8
            SPADEResBlk(512, 512, self.n_embedding, self.n_keypoints),  # 16
            SPADEResBlk(512, 512, self.n_embedding, self.n_keypoints),  # 32
        ])

        self.image_sizes = [64, 128, 256, 512, 1024]

        self.spade_blocks = nn.ModuleList([
            SPADEResBlk(512, 256, self.n_embedding, self.n_keypoints),  # 64
            SPADEResBlk(256, 128, self.n_embedding, self.n_keypoints),  # 128
            SPADEResBlk(128, 64, self.n_embedding, self.n_keypoints),  # 256
            SPADEResBlk(64, 32, self.n_embedding, self.n_keypoints),  # 512
            SPADEResBlk(32, 16, self.n_embedding, self.n_keypoints),  # 1024
        ])

        self.to_rgbs = nn.ModuleList([nn.Conv2d(256, 3, kernel_size=1),  # 64
                                      nn.Conv2d(128, 3, kernel_size=1),  # 128
                                      nn.Conv2d(64, 3, kernel_size=1),  # 256
                                      nn.Conv2d(32, 3, kernel_size=1),  # 512
                                      nn.Conv2d(16, 3, kernel_size=1),  # 1024
                                      ])

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, a=0.2)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, input_dict, stage, alpha, requires_penalty=False):
        attributes = self.gen_atrributes(input_dict)
        rgb_ori = self.use_atrributes(attributes, stage, alpha)['img']

        penalty_on_keypoints = torch.tensor(0.0, device=rgb_ori.device)

        if requires_penalty:

            bg_embed_noise, kp_embed_noise = input_dict['input_noise1'], input_dict['input_noise2']
            keypoint_heatmap = gen_heatmaps(attributes['keypoints'], heatmap_size=self.image_sizes[stage], tau=self.tau)
            heatmap = keypoint_heatmap.max(dim=1, keepdim=True)[0]

            background_embeddings = self.gen_background_embedding(torch.randn_like(bg_embed_noise))
            attributes2 = {'keypoints': attributes['keypoints'], 'kp_emb': attributes['kp_emb'], 'bg_emb': background_embeddings}
            rgb = self.use_atrributes(attributes2, stage, alpha)['img']
            penalty_on_keypoints = F.mse_loss(rgb_ori * heatmap, rgb * heatmap) * 100

        return {'img': rgb_ori, 'penalty_on_keypoints': penalty_on_keypoints}

    def gen_keypoints(self, input_dict):
        z = input_dict['input_noise0']
        return torch.tanh(self.gen_keypoints_layer(z).reshape(-1, self.n_keypoints, 2) / 20)

    def gen_keypoints_feat(self, kp_embed_noise):
        keypoints_fixed_embeddings = self.keypoints_embedding(
            torch.arange(self.n_keypoints, device=kp_embed_noise.device).unsqueeze(0).repeat(kp_embed_noise.shape[0], 1)
        )
        keypoints_embeddings = self.gen_keypoints_embedding_noise(kp_embed_noise)
        keypoints_embeddings = keypoints_fixed_embeddings * keypoints_embeddings.unsqueeze(1)
        return keypoints_embeddings

    def gen_atrributes(self, input_dict):
        keypoints = self.gen_keypoints(input_dict)
        bg_embed_noise, kp_embed_noise = input_dict['input_noise1'], input_dict['input_noise2']
        keypoints_embeddings = self.gen_keypoints_feat(kp_embed_noise)
        background_embeddings = self.gen_background_embedding(bg_embed_noise)
        return {'keypoints': keypoints, 'kp_emb': keypoints_embeddings, 'bg_emb': background_embeddings}

    def use_atrributes(self, input_dict, stage, alpha):
        keypoints = input_dict['keypoints']
        keypoints_embeddings = input_dict['kp_emb']
        background_embeddings = input_dict['bg_emb']

        embeddings = torch.cat((keypoints_embeddings, background_embeddings.unsqueeze(1)), dim=1)

        x = self.start.expand(keypoints.shape[0], -1, -1, -1)

        for i in range(len(self.pre_image_sizes)):
            image_size = self.pre_image_sizes[i]
            keypoint_heatmap = gen_heatmaps(keypoints, heatmap_size=image_size, tau=self.tau)
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x = self.pre_spade_blocks[i](x, heatmaps)
            x = F.interpolate(x, size=image_size*2, mode='bilinear', align_corners=False)

        for i in range(len(self.image_sizes)):
            image_size = self.image_sizes[i]
            keypoint_heatmap = gen_heatmaps(keypoints, heatmap_size=image_size, tau=self.tau)
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x_new = self.spade_blocks[i](x, heatmaps)

            if stage == 0:
                return {'img': self.to_rgbs[0](x_new)}

            elif alpha == 0 and i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                return {'img': img_old}

            elif alpha == 1 and i == stage:
                return {'img': self.to_rgbs[i](x_new)}

            elif i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                img_new = self.to_rgbs[stage](x_new)
                return {'img': alpha * img_new + (1-alpha) * img_old}

            else:
                x = F.interpolate(x_new, size=self.image_sizes[i + 1], mode='bilinear', align_corners=False)

    def use_atrributes_two_object(self, input_dict, dist, stage, alpha):
        keypoints = input_dict['keypoints']
        keypoints_embeddings = input_dict['kp_emb']
        background_embeddings = input_dict['bg_emb']

        embeddings = torch.cat((keypoints_embeddings, background_embeddings.unsqueeze(1)), dim=1)

        x = self.start.expand(keypoints.shape[0], -1, -1, -1)

        for i in range(len(self.pre_image_sizes)):
            image_size = self.pre_image_sizes[i]
            keypoint_heatmap = gen_heatmaps(
                torch.cat([keypoints[:, :, 0:1], keypoints[:, :, 1:2]-dist], dim=2),
                                            heatmap_size=image_size, tau=self.tau)
            keypoint_heatmap2 = gen_heatmaps(
                torch.cat([keypoints[:, :, 0:1], keypoints[:, :, 1:2] + dist], dim=2),
                heatmap_size=image_size, tau=self.tau)
            keypoint_heatmap = torch.cat(
                [keypoint_heatmap.unsqueeze(2), keypoint_heatmap2.unsqueeze(2)], dim=2
            ).max(dim=2)[0]
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x = self.pre_spade_blocks[i](x, heatmaps)
            x = F.interpolate(x, size=image_size*2, mode='bilinear', align_corners=False)

        for i in range(len(self.image_sizes)):
            image_size = self.image_sizes[i]
            keypoint_heatmap = gen_heatmaps(
                torch.cat([keypoints[:, :, 0:1], keypoints[:, :, 1:2] - dist], dim=2),
                heatmap_size=image_size, tau=self.tau)
            keypoint_heatmap2 = gen_heatmaps(
                torch.cat([keypoints[:, :, 0:1], keypoints[:, :, 1:2] + dist], dim=2),
                heatmap_size=image_size, tau=self.tau)
            keypoint_heatmap = torch.cat(
                [keypoint_heatmap.unsqueeze(2), keypoint_heatmap2.unsqueeze(2)], dim=2
            ).max(dim=2)[0]
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x_new = self.spade_blocks[i](x, heatmaps)

            if stage == 0:
                return {'img': self.to_rgbs[0](x_new)}

            elif alpha == 0 and i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                return {'img': img_old}

            elif alpha == 1 and i == stage:
                return {'img': self.to_rgbs[i](x_new)}

            elif i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                img_new = self.to_rgbs[stage](x_new)
                return {'img': alpha * img_new + (1-alpha) * img_old}

            else:
                x = F.interpolate(x_new, size=self.image_sizes[i + 1], mode='bilinear', align_corners=False)

    def use_atrributes_multi_parts(self, input_dict, input_dict2, kp_idices, kp_n_pos, stage, alpha):
        keypoints = input_dict['keypoints']
        keypoints_embeddings = input_dict['kp_emb']
        background_embeddings = input_dict['bg_emb']

        embeddings = torch.cat((keypoints_embeddings, background_embeddings.unsqueeze(1)), dim=1)

        x = self.start.expand(keypoints.shape[0], -1, -1, -1)

        for i in range(len(self.pre_image_sizes)):
            image_size = self.pre_image_sizes[i]
            keypoint_heatmap = gen_heatmaps(keypoints, heatmap_size=image_size, tau=self.tau)
            for i_kp in range(len(kp_idices)):
                kp_idx = kp_idices[i_kp]
                keypoint_heatmap_idx = gen_heatmaps(kp_n_pos[i_kp], heatmap_size=image_size, tau=self.tau)
                keypoint_heatmap_idx = \
                torch.cat([keypoint_heatmap_idx, keypoint_heatmap[:, kp_idx:kp_idx + 1, :, :]], dim=1).max(dim=1, keepdim=True)[0]
                keypoint_heatmap[:, kp_idx:kp_idx + 1, :, :] = keypoint_heatmap_idx
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x = self.pre_spade_blocks[i](x, heatmaps)
            x = F.interpolate(x, size=image_size*2, mode='bilinear', align_corners=False)

        for i in range(len(self.image_sizes)):
            image_size = self.image_sizes[i]
            keypoint_heatmap = gen_heatmaps(keypoints, heatmap_size=image_size, tau=self.tau)
            for i_kp in range(len(kp_idices)):
                kp_idx = kp_idices[i_kp]
                keypoint_heatmap_idx = gen_heatmaps(kp_n_pos[i_kp], heatmap_size=image_size, tau=self.tau)
                keypoint_heatmap_idx = \
                    torch.cat([keypoint_heatmap_idx, keypoint_heatmap[:, kp_idx:kp_idx + 1, :, :]], dim=1).max(dim=1, keepdim=True)[0]
                keypoint_heatmap[:, kp_idx:kp_idx + 1, :, :] = keypoint_heatmap_idx
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x_new = self.spade_blocks[i](x, heatmaps)

            if stage == 0:
                return {'img': self.to_rgbs[0](x_new)}

            elif alpha == 0 and i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                return {'img': img_old}

            elif alpha == 1 and i == stage:
                return {'img': self.to_rgbs[i](x_new)}

            elif i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                img_new = self.to_rgbs[stage](x_new)
                return {'img': alpha * img_new + (1-alpha) * img_old}

            else:
                x = F.interpolate(x_new, size=self.image_sizes[i + 1], mode='bilinear', align_corners=False)

    def use_atrributes_removing_parts(self, input_dict, kp_indices, stage, alpha):
        keypoints = input_dict['keypoints']
        keypoints_embeddings = input_dict['kp_emb']
        background_embeddings = input_dict['bg_emb']

        embeddings = torch.cat((keypoints_embeddings, background_embeddings.unsqueeze(1)), dim=1)

        x = self.start.expand(keypoints.shape[0], -1, -1, -1)

        for i in range(len(self.pre_image_sizes)):
            image_size = self.pre_image_sizes[i]
            keypoint_heatmap = gen_heatmaps(keypoints, heatmap_size=image_size, tau=self.tau)
            for kp_idx in kp_indices:
                keypoint_heatmap[:, kp_idx:kp_idx+1, :, :] = torch.zeros_like(keypoint_heatmap[:, kp_idx:kp_idx+1, :, :])
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x = self.pre_spade_blocks[i](x, heatmaps)
            x = F.interpolate(x, size=image_size*2, mode='bilinear', align_corners=False)

        for i in range(len(self.image_sizes)):
            image_size = self.image_sizes[i]
            keypoint_heatmap = gen_heatmaps(keypoints, heatmap_size=image_size, tau=self.tau)
            for kp_idx in kp_indices:
                keypoint_heatmap[:, kp_idx:kp_idx + 1, :, :] = torch.zeros_like(keypoint_heatmap[:, kp_idx:kp_idx + 1, :, :])
            inv_heatmaps = 1 - keypoint_heatmap.max(dim=1, keepdim=True)[0]
            keypoint_heatmap = torch.cat((keypoint_heatmap, inv_heatmaps), dim=1)
            heatmaps = keypoint_heatmap.unsqueeze(2) * embeddings.unsqueeze(-1).unsqueeze(-1)
            heatmaps = heatmaps.reshape(keypoint_heatmap.shape[0], keypoint_heatmap.shape[1] * embeddings.shape[2], x.shape[-1], x.shape[-1])
            x_new = self.spade_blocks[i](x, heatmaps)

            if stage == 0:
                return {'img': self.to_rgbs[0](x_new)}

            elif alpha == 0 and i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                return {'img': img_old}

            elif alpha == 1 and i == stage:
                return {'img': self.to_rgbs[i](x_new)}

            elif i == stage:
                img_old = self.to_rgbs[stage - 1](x)
                img_new = self.to_rgbs[stage](x_new)
                return {'img': alpha * img_new + (1-alpha) * img_old}

            else:
                x = F.interpolate(x_new, size=self.image_sizes[i + 1], mode='bilinear', align_corners=False)


if __name__ == '__main__':
    model = Generator({'z_dim': 256, 'n_keypoints': 10, 'n_embedding': 32, 'tau': 0.01})
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))

    for name, layer in model.named_children():
        print(name)
        if name in ['gen_keypoints_embedding_noise', 'gen_keypoints_layer', 'gen_background_embedding']:
            for parameter in layer.parameters():
                parameter.requires_grad = False