# quantizer functions
import torch.nn as nn
from .quantizer_single import IntegerQuantizer
from .quantizer_rowwise import RowwiseQuantizer

def define_single(q_group, 
        # quant params
        qtype, ste, momentum, signed, symmetric, percentile=-1):
    return lambda: IntegerQuantizer(
        num_bits=qtype,
        use_ste=ste,
        use_momentum=momentum,
        # important quantization params
        signed=signed,
        symmetric=symmetric,
        percentile=percentile,
    )

def define_columnwise(q_group,  
        # quant params
        qtype, ste, momentum, signed, symmetric, percentile=-1):
    return lambda: RowwiseQuantizer(
        num_bits=qtype,
        use_ste=ste,
        use_momentum=momentum,
        # important quantization params
        signed=signed,
        symmetric=symmetric,
        columnwise=True,
        percentile=percentile,
    )

def define_rowwise(q_group, 
        # quant params
        qtype, ste, momentum, signed, symmetric, percentile=-1):
    return lambda: RowwiseQuantizer(
        num_bits=qtype,
        use_ste=ste,
        use_momentum=momentum,
        # important quantization params
        signed=signed,
        symmetric=symmetric,
        columnwise=False,
        percentile=percentile,
    )

def define_table_default(table, q_group, 
        # quant params
        qtype, ste, momentum, signed, symmetric, 
        # table params below
        dependency, device, norm_name=None, percentile=-1):
    
    return lambda: table(
        norm_name=norm_name,
        num_bits=qtype,
        use_ste=ste,
        use_momentum=momentum,
        # important quantization params
        signed=signed,
        symmetric=symmetric,
        dependency=dependency,
        device=device,
        percentile=percentile,
    )

## clippers with gradient update supports
def define_clip(
        # clip params
        single, bothside, operate):
    return lambda: ClipAct(
        single=single,
        bothside=bothside,
        operate=operate,
    )

def define_clip_with_singleq(
        # clip params
        single, bothside, operate, 
        # quant params
        q_group, qtype, ste, momentum, signed, symmetric):
    
    return lambda: ClipActQuantization(
        single=single,
        bothside=bothside,
        operate=operate,
        quantizer=IntegerQuantizer,
        # quantizer_kwargs
        num_bits=qtype,
        use_ste=ste,
        use_momentum=momentum,
        signed=signed,
        symmetric=symmetric,
    )

## placeholder for modules that do not quantize outputs
class IdentityModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, input, *args, **kwargs):
        return input
    
    def freeze_quantization_parameters(self):
        pass

    def unfreeze_quantization_parameters(self):
        pass
    
def no_quantization():
    return lambda: IdentityModule()