import torch
import torch.nn as nn
import numpy as np
import random

class Where2comm(nn.Module):
    def __init__(self, args):
        super(Where2comm, self).__init__()

        self.smooth = args['smooth']
        self.compression_ratio = args['compression_ratio']
        if 'gaussian_smooth' in args:
            kernel_size = args['gaussian_smooth']['k_size']
            c_sigma = args['gaussian_smooth']['c_sigma']
            self.gaussian_filter = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2)
            self.init_gaussian_filter(kernel_size, c_sigma)
            self.gaussian_filter.requires_grad = False

    def init_gaussian_filter(self, k_size=5, sigma=1):
        def _gen_gaussian_kernel(k_size=5, sigma=1):
            center = k_size // 2
            x, y = np.mgrid[0 - center: k_size - center, 0 - center: k_size - center]
            g = 1 / (2 * np.pi * sigma) * np.exp(-(np.square(x) + np.square(y)) / (2 * np.square(sigma)))
            return g

        gaussian_kernel = _gen_gaussian_kernel(k_size, sigma)
        self.gaussian_filter.weight.data = torch.Tensor(gaussian_kernel).to(
            self.gaussian_filter.weight.device).unsqueeze(0).unsqueeze(0)
        self.gaussian_filter.bias.data.zero_()

    def forward(self, batch_confidence_maps):

        B, L, H, W = batch_confidence_maps.shape

        communication_masks = []
        communication_rates = []
        batch_communication_maps = []
        for b in range(B):

            ori_communication_maps = batch_confidence_maps[b].sigmoid().max(dim=0)[0].unsqueeze(0)

            if self.smooth:
                communication_maps = self.gaussian_filter(ori_communication_maps)
            else:
                communication_maps = ori_communication_maps

            K = int(H * W * random.uniform(self.compression_ratio[0], self.compression_ratio[1]))
            communication_maps = communication_maps.reshape(1, H * W)
            _, indices = torch.topk(communication_maps, k=K, sorted=False)
            communication_mask = torch.zeros_like(communication_maps).to(communication_maps.device)
            ones_fill = torch.ones(1, K, dtype=communication_maps.dtype, device=communication_maps.device)
            communication_mask = torch.scatter(communication_mask, -1, indices, ones_fill).reshape(1, H, W)

            communication_rate = communication_mask[0].sum() / (H * W)

            # communication_mask = warp_affine_simple(communication_mask,
            #                                 t_matrix[0, :, :, :],
            #                                 (H, W))

            communication_masks.append(communication_mask)
            communication_rates.append(communication_rate)
            batch_communication_maps.append(ori_communication_maps * communication_mask)
        communication_rates = sum(communication_rates) / B
        communication_masks = torch.concat(communication_masks, dim=0)
        return batch_communication_maps, communication_masks, communication_rates