import logging
from functools import reduce
from copy import deepcopy
import torch.nn as nn

from .vanilla_class import wrap_linear, wrap_conv
from .INSTANT_class import wrap_linear_compression_layer, wrap_pointwise_conv_compression_layer
from .lbpwht_class import wrap_linear_wht_layer, wrap_pointwise_conv_wht_layer
from .GF_class import wrap_linear_GF_layer, wrap_conv_GF_layer


##########################################################################################################################

def register_normal_linear(module, cfgs):
    if cfgs == -1:
        logging.info("No Filter Required")
        return
    # Install filter
    for name, type in cfgs["finetuned_layer"].items():
        path_seq = name.split('.')
        target = reduce(getattr, path_seq, module)
        if isinstance(type, nn.Linear):
            upd_layer = wrap_linear(target, cfgs["backward_time"], cfgs["forward_time"], cfgs["inference_time"])
        elif isinstance(type, nn.Conv2d):
            upd_layer = wrap_conv(target, cfgs["backward_time"], cfgs["forward_time"], cfgs["inference_time"])
        parent = reduce(getattr, path_seq[:-1], module)
        setattr(parent, path_seq[-1], upd_layer)

def register_INSTANT(model, cfgs):
    logging.info("Registering INSTANT filter")
    if cfgs == -1:
        logging.info("No Filter Required")
        return
    # Install filter
    for name, module in model.named_modules():
        if name in cfgs["finetuned_layer"]:
            if module.__class__.__name__ == "Linear" or module.__class__.__name__ ==  "NonDynamicallyQuantizableLinear":
                compressed_module = wrap_linear_compression_layer(module, backward_time=cfgs["backward_time"], forward_time=cfgs["forward_time"], inference_time=cfgs["inference_time"])
            elif module.__class__.__name__ == "Conv2d" and module.kernel_size == (1, 1):
                compressed_module = wrap_pointwise_conv_compression_layer(module, backward_time=cfgs["backward_time"], forward_time=cfgs["forward_time"], inference_time=cfgs["inference_time"])
            else:
                print("View compressed module in Register Filter again!!")
                print("Module name: ", module.__class__.__name__ )
            parent_module = dict(model.named_modules())[name.rsplit('.', 1)[0]] 
            setattr(parent_module, name.rsplit('.', 1)[1], compressed_module) 
    return model

def register_LBPWHT(module, cfgs):
    logging.info("Registering LBPWHT filter")
    if cfgs == -1:
        logging.info("No Filter Required")
        return
    # Install filter
    for name, type in cfgs["finetuned_layer"].items():
        path_seq = name.split('.')
        target = reduce(getattr, path_seq, module)

        for param in target.parameters(): # Turn off gradient of previous version
            param.requires_grad = False

        if isinstance(type, nn.Linear):
            upd_layer = wrap_linear_wht_layer(target, 8, True, 4, backward_time=cfgs["backward_time"], forward_time=cfgs["forward_time"])
            parent = reduce(getattr, path_seq[:-1], module)
            setattr(parent, path_seq[-1], upd_layer)

        elif isinstance(type, nn.Conv2d):
            upd_layer = wrap_pointwise_conv_wht_layer(target, 8, True, 4, backward_time=cfgs["backward_time"], forward_time=cfgs["forward_time"])
            parent = reduce(getattr, path_seq[:-1], module)
            setattr(parent, path_seq[-1], upd_layer)
        else:
            print("Do not filter layer: ", module.__class__.__name__ )

def register_GF(module, cfgs):
    logging.info("Registering Gradient Filtering!")
    if cfgs == -1:
        logging.info("No Filter Required")
        return
    # Install filter
    for name, type in cfgs["finetuned_layer"].items():
        path_seq = name.split('.')
        target = reduce(getattr, path_seq, module)

        for param in target.parameters():
            param.requires_grad = False

        if isinstance(type, nn.Linear):
            upd_layer = wrap_linear_GF_layer(target, 8, True, backward_time=cfgs["backward_time"], forward_time=cfgs["forward_time"])
        elif isinstance(type, nn.Conv2d):
            upd_layer = wrap_conv_GF_layer(target, 8, True, backward_time=cfgs["backward_time"], forward_time=cfgs["forward_time"])

        parent = reduce(getattr, path_seq[:-1], module)
        setattr(parent, path_seq[-1], upd_layer)