"""
support both params with q_group and no q_group, quantize the linear layer
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.nn.functional as F
from torch.nn import Identity

LINEAR_QUANTIZER_KEYS = ["inputs", "features", "weights"]

class LinearQuantized(nn.Linear):
    """A quantizable linear layer"""
    def __init__(
        self, in_features, out_features, layer_quantizers, bias=True, signed=True, 
    ):
        # create quantization modules for this layer
        self.signed = signed
        self.layer_quant_fns = layer_quantizers[0] if self.signed else layer_quantizers[1]

        super(LinearQuantized, self).__init__(in_features, out_features, bias)

    def reset_parameters(self):
        super().reset_parameters()
        self.layer_quant = nn.ModuleDict()
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key] = self.layer_quant_fns[key]()

    def reset_quantizers(self, layer_quantizers):
        self.layer_quant_fns = layer_quantizers[0] if self.signed else layer_quantizers[1]
        self.layer_quant = nn.ModuleDict()
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key] = self.layer_quant_fns[key]()
    
    def freeze_quantization_parameters(self):
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key].freeze_quantization_parameters()
    
    def unfreeze_quantization_parameters(self):
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key].unfreeze_quantization_parameters()

    def forward(self, input, q_group=None):

        input_q = self.layer_quant["inputs"](input, True, False, edge_index=None, q_group=q_group)
        w_q = self.layer_quant["weights"](self.weight)
        out = F.linear(input_q, w_q, self.bias)
        out = self.layer_quant["features"](out, True, False, edge_index=None, q_group=q_group)

        return out


class LinearNotQuantized(nn.Linear):
    """A quantizable linear layer"""
    def __init__(
        self, in_features, out_features, layer_quantizers, bias=True, signed=True, 
    ):
        # create quantization modules for this layer
        self.signed = signed
        self.layer_quant_fns = layer_quantizers[0] if self.signed else layer_quantizers[1]

        super(LinearNotQuantized, self).__init__(in_features, out_features, bias)

    def reset_parameters(self):
        super().reset_parameters()
        self.layer_quant = nn.ModuleDict()
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key] = self.layer_quant_fns[key]()


    def reset_quantizers(self, layer_quantizers):
        self.layer_quant_fns = layer_quantizers[0] if self.signed else layer_quantizers[1]
        self.layer_quant = nn.ModuleDict()
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key] = self.layer_quant_fns[key]()
    
    def freeze_quantization_parameters(self):
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key].freeze_quantization_parameters()
    
    def unfreeze_quantization_parameters(self):
        for key in LINEAR_QUANTIZER_KEYS:
            self.layer_quant[key].unfreeze_quantization_parameters()

    def forward(self, input, q_group=None):
        out = F.linear(input, self.weight, self.bias)

        return out