import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class PixelWeightedFusionSoftmax(nn.Module):
    '''
    This is responding to the following equation mention in the 226 line of the paper.
    $
    W_j = \frac{\phi_{\text{conv}}\left([B_k; Z_{j \rightarrow k}^\prime]\right)}{\sum_{j=1}^N \phi_{\text{conv}}\left([B_k; Z_{j \rightarrow k}^\prime]\right)}
    \in \mathbb{R}^\mathbf{X \times Y \times 1}
    $
    '''
    def __init__(self,channel):
        super(PixelWeightedFusionSoftmax, self).__init__()
        self.conv1_1 = nn.Conv2d(channel, 128, kernel_size=1, stride=1, padding=0)
        self.bn1_1 = nn.BatchNorm2d(128)
        self.conv1_2 = nn.Conv2d(128, 32, kernel_size=1, stride=1, padding=0)
        self.bn1_2 = nn.BatchNorm2d(32)
        self.conv1_3 = nn.Conv2d(32, 8, kernel_size=1, stride=1, padding=0)
        self.bn1_3 = nn.BatchNorm2d(8)
        self.conv1_4 = nn.Conv2d(8, 1, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        x = x.view(-1, x.size(-3), x.size(-2), x.size(-1))
        x_1 = F.relu(self.bn1_1(self.conv1_1(x)))
        x_1 = F.relu(self.bn1_2(self.conv1_2(x_1)))
        x_1 = F.relu(self.bn1_3(self.conv1_3(x_1)))
        x_1 = F.relu(self.conv1_4(x_1))
        return x_1

class CompressNet(nn.Module):
    '''
        This is responding to the feature compression mentioned in the 204 line of the paper.
        '''
    def __init__(self, channel):
            super(CompressNet, self).__init__()
            self.conv1_1 = nn.Conv2d(channel, channel // 8, kernel_size=1, stride=1, padding=0)
            self.bn1_1 = nn.BatchNorm2d(channel // 8)
            self.conv1_2 = nn.Conv2d(channel // 8, 1, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
            x_1 = F.relu(self.bn1_1(self.conv1_1(x)))
            x_1 = F.relu(self.conv1_2(x_1))
            return x_1

class stack_channel(nn.Conv2d):
    '''
    The implementation of the sliding window to calculate the discrepancy between each position's features and the central features.
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(stack_channel, self).__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                            padding=padding, bias=bias)
        square_dis = np.zeros((out_channels, kernel_size, kernel_size))
        for i in range(out_channels):
            square_dis[i, i // 7, i % 7] = 1
        self.square_dis = nn.Parameter(torch.Tensor(square_dis), requires_grad=False)

    def forward(self, x):
        kernel = self.square_dis.detach()
        stack = F.conv2d(x, kernel, stride=1, padding=3, groups=1)
        return stack
class Gaussian_based_interpolation(nn.Conv2d):
    '''
    The implementation of a learnable Gaussian-based interpolation to infill undefined values in the sparse transmitted features, which is mentioned in the 221 line.
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, interplate='none'):
        super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.interpolate = interplate
        self.r = 7
        self.padding_interpolate = 3
        self.Lambda = nn.Parameter(torch.tensor(3.0))
        square_dis = np.zeros((self.r, self.r))
        center_point = (square_dis.shape[0] // 2, square_dis.shape[1] // 2)
        for i in range(square_dis.shape[0]):
            for j in range(square_dis.shape[1]):
                square_dis[i][j] = (i - center_point[0]) ** 2 + (j - center_point[1]) ** 2
        square_dis[center_point[0]][center_point[1]] = 100000.0
        self.square_dis = nn.Parameter(torch.Tensor(square_dis), requires_grad=False)
    def forward(self, x, mask):
        y = super().forward(x)
        self.out_h, self.out_w = y.size(-2), y.size(-1)
        kernel = (-(self.Lambda ** 2) * self.square_dis.detach()).exp()
        kernel = kernel / (kernel.sum() + 10 ** (-5))
        kernel = kernel.expand((self.out_channels, 1, kernel.size(0), kernel.size(1)))
        interpolated = F.conv2d(y * mask, kernel, stride=1, padding=self.padding_interpolate, groups=self.out_channels)
        out = y * mask + interpolated * (1 - mask)

        return out
class Complementary_Calculation(nn.Module):
    '''
    The implementation of the complementary calculation for selective collaborative interactions, which is mentiond in the line 216 to 218.
    '''
    def __init__(self):
        super(Complementary_Calculation, self).__init__()
        self.masker = Gaussian_based_interpolation(64, 64, kernel_size=3, padding=1)
    def interplot_f(self, feature, masker):
        masker_t = torch.zeros_like(feature)
        masker_t[:, masker[0], masker[1]] = 1
        masker_f = masker_t[None, :, :, :].float()
        inter = self.masker(feature.unsqueeze(0), masker_f)
        return torch.squeeze(inter)
    def forward(self, ego_conf, nb_conf, delta=0.25):
        w = ego_conf.shape[-2]
        h = ego_conf.shape[-1]
        ego_request = 1 - ego_conf
        att_map = ego_request * nb_conf
        top_delta = torch.sort(att_map.reshape(-1), descending=True)
        self_holder = top_delta[int(w * h * delta)]
        masker = torch.where(att_map >= self_holder)
        return masker

class SISW(nn.Module):
    '''
    The implementation of Sparse Interaction via Sliding Windows Module.
    '''
    def __init__(self, layer=3):
        super(SISW, self).__init__()
        self.agent_num = 4
        self.layer = layer
        self.PixelWeightedFusion = PixelWeightedFusionSoftmax(64 * 2)
        self.compress = CompressNet(64)
        self.stack = stack_channel(1, 9, kernel_size=3, padding=1)
        self.attcoll = Complementary_Calculation()
        self.masker = Gaussian_based_interpolation(64, 64, kernel_size=3, padding=1)
    def generate_information_volume(self, input_agent):
        w = input_agent.shape[-2]
        h = input_agent.shape[-1]
        batch_nb = input_agent.reshape(-1, 1, 1, 1)
        stack = self.stack(input_agent).permute(2, 3, 1, 0).contiguous().reshape(-1, 9, 1, 1)
        p = F.sigmoid((stack - batch_nb)).mean(dim=1).reshape(w, h)
        return p

    def interplot_f(self, feature, masker):
        masker_t = torch.zeros_like(feature)
        masker_t[:, masker[0], masker[1]] = 1
        masker_f = masker_t[None, :, :, :].float()
        inter = self.masker(feature.unsqueeze(0), masker_f)
        return torch.squeeze(inter)
    def generate_information_volume_noise(self, input_agent):
        w = input_agent.shape[-2]
        h = input_agent.shape[-1]
        batch_nb = input_agent.reshape(-1, 1, 1, 1)
        stack = self.stack(input_agent).permute(2, 3, 1, 0).contiguous().reshape(-1, 9, 1, 1)
        p = F.sigmoid((stack - batch_nb)).mean(dim=1).reshape(w, h)
        noise_level = 1e-8
        noise = torch.randn(p.size()) * noise_level
        p_noise = p + noise
        return p_noise
    def forward(self, bevs, istrain=True):
        B_T, N, C, H, W = bevs.shape
        local_com_mat = bevs
        T = 3
        batch_size = B_T // T
        local_com_mat = local_com_mat.view(batch_size, T, N, C, H, W)
        local_com_mat_update = torch.ones_like(local_com_mat)
        bandwidth = []
        for b in range(batch_size):
            num_agent = 4
            for t in range(T):
                for i in range(1):
                    tg_agent = local_com_mat[b, t, i]  # 256x32x32
                    neighbor_feat_list = list()
                    neighbor_feat_list.append(tg_agent)
                    for j in range(num_agent):
                        if j != i:
                            nb_agent = torch.unsqueeze(local_com_mat[b, t, j], 0)
                            if nb_agent.min() + nb_agent.max() == 0:
                                neighbor_feat_list.append(nb_agent[0])
                            else:
                                tg_agent_com = self.compress(torch.unsqueeze(tg_agent, 0))
                                warp_feat_trans_com = self.compress(nb_agent)
                                if istrain:
                                    tg_entropy = self.generate_information_volume_noise(tg_agent_com)
                                    nb_entropy = self.generate_information_volume_noise(warp_feat_trans_com)
                                else:
                                    tg_entropy = self.generate_information_volume(tg_agent_com)
                                    nb_entropy = self.generate_information_volume(warp_feat_trans_com)
                                selection = self.attcoll(tg_entropy,nb_entropy)
                                bandwidth.append(len(selection[0])/40000)
                                warp_feat_interplot = self.interplot_f(nb_agent.squeeze(0), selection)
                                neighbor_feat_list.append(warp_feat_interplot.squeeze(0))
                    tmp_agent_weight_list = list()
                    sum_weight = 0
                    for k in range(num_agent):
                        try:
                            cat_feat = torch.cat([tg_agent, neighbor_feat_list[k]], dim=0)
                        except:
                            print(tg_agent.shape)
                            print(neighbor_feat_list[k].shape)
                        AgentWeight = torch.squeeze(self.PixelWeightedFusion(cat_feat))
                        tmp_agent_weight_list.append(torch.exp(AgentWeight))
                        sum_weight = sum_weight + torch.exp(AgentWeight)
                    agent_weight_list = list()
                    for k in range(num_agent):
                        AgentWeight = torch.div(tmp_agent_weight_list[k], sum_weight)
                        AgentWeight.expand([64, -1, -1])
                        agent_weight_list.append(AgentWeight)
                    agent_wise_weight_feat = 0
                    for k in range(num_agent):
                        agent_wise_weight_feat = agent_wise_weight_feat + agent_weight_list[k] * neighbor_feat_list[k]
                    local_com_mat_update[b, t, i] = agent_wise_weight_feat

        # weighted feature maps is passed to decoder
        feat_fuse_mat = local_com_mat_update[:, :, 0, :, :, :].view(-1, C, H, W)
        return (feat_fuse_mat, np.mean(bandwidth).item())
