# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn.functional as F
from torch import distributed as dist

from ..linear import LinearBuilder, LinearImpl


class DefaultLinearImpl(LinearImpl):
    """Linear implementation api."""

    def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
        """forward."""
        out = F.linear(x, weight, bias)
        if all_reduce:
            dist.all_reduce(out)
        return out


class DefaultLinearBuilder(LinearBuilder):
    """linear implementation builder."""

    @staticmethod
    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):
        """build."""
        return DefaultLinearImpl()
