import torch.nn.functional as F
import torch.nn as nn

import torch
import math


class CompactLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super(CompactLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias

        if bias:
            self.compact_weight = nn.Parameter(torch.empty((out_features, in_features + 1), **factory_kwargs))
        else:
            self.compact_weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))

        self.reset_parameters()

    def reset_parameters(self):
        if self.bias:
            weight = self.compact_weight[:, :self.in_features]
            bias = self.compact_weight[:, self.in_features]
            nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(bias, -bound, bound)
        else:
            nn.init.kaiming_uniform_(self.compact_weight, a=math.sqrt(5))

    def forward(self, inputs):
        if self.bias:
            weight = self.compact_weight[:, :self.in_features]
            bias = self.compact_weight[:, self.in_features]
            return F.linear(inputs, weight, bias)
        else:
            return F.linear(inputs, self.compact_weight)

    def extra_repr(self):
        return "in_features={}, out_features={}, bias={}".format(
            self.in_features, self.out_features, self.bias
        )
