import torch
from .quantization.linear_quantized import LinearQuantized

class SequentialTableInput(torch.nn.Module):
    def __init__(self, nn):
        super(SequentialTableInput, self).__init__()
        self.nn = torch.nn.ModuleList(nn) # nn needs to be list of modules
    
    def reset_parameters(self):
        for child in self.nn:
            if hasattr(child, "reset_parameters"):
                child.reset_parameters()

    def reset_quantizers(self, layer_quantizers): 
        for child in self.nn:
            if hasattr(child, "reset_quantizers"):
                child.reset_quantizers(layer_quantizers)

    def freeze_quantization_parameters(self):
        for child in self.nn:
            if hasattr(child, "freeze_quantization_parameters"):    
                child.freeze_quantization_parameters()        
    
    def unfreeze_quantization_parameters(self):
        for child in self.nn:
            if hasattr(child, "unfreeze_quantization_parameters"):    
                child.unfreeze_quantization_parameters()  

    def forward(self, inputs, q_group=None):
        for layer in self.nn:
            inputs = inputs.clone()
            if isinstance(layer, LinearQuantized):
                inputs = layer(inputs, q_group=q_group)
            else:
                inputs = layer(inputs)
        return inputs

