import torch
import math
from .batch_ops import batch_matmul, batch_bias_add
from .masked_module import MaskedModule


class Linear(MaskedModule):
    __constants__ = ['in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = torch.nn.Parameter(torch.Tensor(out_features)) if bias else None
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        front_shape = list(input.shape[:-1])
        input = input.flatten(end_dim=-2)

        res = batch_matmul(input, self.weight.transpose(-1,-2))
        if self.bias is not None:
            res = batch_bias_add(res, self.bias)

        return res.view(*front_shape, res.shape[-1])

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )