import random
from typing import List

import torch
from torch import nn

from mmdet3d.models.builder import FUSERS

__all__ = ["AddFuser"]


@FUSERS.register_module()
class AddFuser(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dropout: float = 0) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dropout = dropout

        self.transforms = nn.ModuleList()
        for k in range(len(in_channels)):
            self.transforms.append(
                nn.Sequential(
                    nn.Conv2d(in_channels[k], out_channels, 3, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(True),
                )
            )

    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
        features = []
        for transform, input in zip(self.transforms, inputs):
            features.append(transform(input))

        weights = [1] * len(inputs)
        if self.training and random.random() < self.dropout:
            index = random.randint(0, len(inputs) - 1)
            weights[index] = 0

        return sum(w * f for w, f in zip(weights, features)) / sum(weights)
