import torch
import torch.nn as nn
from torch.nn import functional as F
from opencood.models.comm_modules.mutual_communication_popego import CommunicationPopego


class How2comm(nn.Module):
    def __init__(self, args):
        super(How2comm, self).__init__()
        self.channel = args['dim']

        self.commu_module = CommunicationPopego(
            args, in_planes=self.channel)

    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return split_x

    def forward(self, feats, record_len):
        feat_list = self.regroup(feats, record_len)
        sparse_feat_list, commu_loss, commu_rate, sparse_mask = self.commu_module(
            feat_list)
        sparse_feats = torch.cat(sparse_feat_list, dim=0)

        return sparse_feats, commu_loss, commu_rate
