import torch.nn as nn
import torch

from lbmqt.nn.layers import QLinear, QConv1d, QConv2d, QConv3d, \
    QConvTranspose1d, QConvTranspose2d, QConvTranspose3d
from lbmqt.functional import is_quantifiable
from lbmqt.qscheme import QScheme
from lbmqt.conf import config


def get_param_by_name(module, param_name: str):
    param_name_list = param_name.split('.')
    n = len(param_name_list)
    submodule = module
    for i in range(n):
        if i >= n - 1:
            param = getattr(submodule, param_name_list[i], None)
            if param is None:
                raise ValueError('module {} does not have param {}'.format(
                    module.__class__.__name__, param_name
                ))
            else:
                return param
        submodule = getattr(submodule, param_name_list[i], None)
        if submodule is None:
            raise ValueError('module {} does not have submodule {}'.format(
                module.__class__.__name__, '.'.join(param_name_list[:i+1])
            ))


def set_param_by_name(module, param_name: str, param):
    param_name_list = param_name.split('.')
    n = len(param_name_list)
    submodule = module
    for i in range(n):
        if i >= n - 1:
            delattr(submodule, param_name_list[i])
            setattr(submodule, param_name_list[i], param)
            return
        submodule = getattr(submodule, param_name_list[i], None)
        if submodule is None:
            raise ValueError('module {} does not have submodule {}'.format(
                module.__class__.__name__, '.'.join(param_name_list[:i+1])
            ))


class QModule(nn.Module):
    def __init__(self, model):
        super().__init__()
        # self.model = model.to(config.device)
        self.model = model

        self.use_autograd_names = {name: False for name, _ in self.model.named_parameters()} # names of parameters indicating deletion
        self.q_names = [] # names of quantifiable parameters

        # init various parameter names list/dict
        self._configure_param_names()

        simulate = True if config.debug_quantization_difference else config.simulate
        self.scheme = QScheme(
            name='param',
            param_names=self.q_names,
            bits=config.compression_bits_p,
            group_size=config.group_size,
            enable=config.enable_quantize_p,
            num_mode=config.numerical_mode_p,
            stochastic=config.stochastic,
            simulate=simulate,
        )
        
        # to gpu, quantization and reset
        self._configure_quantization_and_reset()

        # rebuild model
        self._configure_layers()

    def _configure_param_names(self):
        self.param2name = {param: name for name, param in self.model.named_parameters()}
        self._pick_param_names(self.model)
        del self.param2name

        print(f'Module configuring parameter names ended')

    def _pick_param_names(self, module):
        for name, child in module.named_children():
            if isinstance(child, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d,
                                  nn.Linear)):
                for _, param in child.named_parameters():
                    if is_quantifiable(param.data, config.quantifiable_lower_bound):
                        self.q_names.append(self.param2name[param])
            elif isinstance(child, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)):
                for _, param in child.named_parameters():
                    param_name = self.param2name[param]
                    # note that the parameter not to be deleted is not quantifiable
                    self.use_autograd_names[param_name] = True
            else:
                self._pick_param_names(child)

    def _configure_quantization_and_reset(self):
        self.param_names = [name for name, _ in self.model.named_parameters()]
        for param_name in self.param_names:
            p = get_param_by_name(self.model, param_name)
            if not self.use_autograd_names[param_name]:
                # move to cuda here since these tensors are not of type `torch.nn.Parameter`
                p.data = p.to(config.device) # move to cuda at start
                # quantize(optional)
                if self.scheme.is_quantifiable(param_name):
                    new_p = nn.Parameter(self.scheme.quantize_data(p.data, param_name, dilate=config.params_initial_dilation), requires_grad=False)
                    setattr(new_p, 'trainable', p.requires_grad)
                    setattr(new_p, 'pname', param_name)
                    del p
                    set_param_by_name(self.model, param_name, new_p)
                else:
                    setattr(p, 'trainable', p.requires_grad)
                    setattr(p, 'pname', param_name)
                    p.requires_grad = False
            else: # for BatchNorm1d, BatchNorm2d, BatchNorm3d
                # only set 'pname' attribute
                # move to cuda afterwards by calling `nn.Module.to()` since there are some buffers in model
                # note: will the `pname` attribute be cleared? not
                setattr(p, 'trainable', p.requires_grad)
                setattr(p, 'pname', param_name)
            # print(f'...{param_name} initial quantizatione ended')
        print(f'Module configuring initial quantization ended')

    def _configure_layers(self):
        QModule._convert_layers(self.model)
        self._set_module_names()
        print(f'Module configuring layers ended')

    @staticmethod
    def _convert_layers(module):

        def backward_hook_for_unquantized_layers(m, grad_inputs, grad_outputs):
            # backward hook for BatchNorm1d, BatchNorm2d, BatchNorm3d to take gradients into optimizer
            if config.debug_layers_backward:
                print(f'in backward hook for unquantized layer {m.name}')

            assert m.optimizer is not None

            assert m.weight.grad is not None
            m.optimizer.accumulate_gradient(m.weight.pname, m.weight.grad)
            m.weight.grad = None
            
            # print(f'in backward hook, param_id: {m.optimizer.param_id}')
            # print(f'in backward hook, param2name: {m.optimizer.param2name}')
            # print(f'in backward hook, m.weight.id: {id(m.weight)}')
            # print(f'in backward hook, m.bias.id: {id(m.bias)}')
            # print(f'in backward hook, m.weight.pname: {m.weight.pname}')
            # print(f'in backward hook, m.bias.pname: {m.bias.pname}')
            
            if m.bias is not None:
                m.optimizer.accumulate_gradient(m.bias.pname, m.bias.grad)
                m.bias.grad = None

            # grad_w = m.optimizer.get_param_state(m.weight, 'grad_accumulator')
            # print(f'in backward hook, m.optimizer.state[w].grad_accumulator: {grad_w}')
            # print(f'in backward hook, m.w.grad: {m.weight.grad}')
            # grad_b = m.optimizer.get_param_state(m.bias, 'grad_accumulator')
            # print(f'in backward hook, m.optimizer.state[b].grad_accumulator: {grad_b}')
            # print(f'in backward hook, m.b.grad: {m.bias.grad}')

        for name, child in module.named_children():
            # Do not convert layers that are already quantized
            if isinstance(child, (QConv1d, QConv2d, QConv3d, QConvTranspose1d, QConvTranspose2d, QConvTranspose3d,
                                  QLinear)):
                continue
            
            if isinstance(child, nn.Conv1d):
                setattr(module, name, QConv1d(child.in_channels, child.out_channels,
                    child.kernel_size, child.stride, child.padding, child.dilation,
                    child.groups, child.bias is not None, child.padding_mode, child))
            elif isinstance(child, nn.Conv2d):
                setattr(module, name, QConv2d(child.in_channels, child.out_channels,
                    child.kernel_size, child.stride, child.padding, child.dilation,
                    child.groups, child.bias is not None, child.padding_mode, child))
            elif isinstance(child, nn.Conv3d):
                setattr(module, name, QConv3d(child.in_channels, child.out_channels,
                    child.kernel_size, child.stride, child.padding, child.dilation,
                    child.groups, child.bias is not None, child.padding_mode, child))
            elif isinstance(child, nn.ConvTranspose1d):
                setattr(module, name, QConvTranspose1d(child.in_channels, child.out_channels,
                    child.kernel_size, child.stride, child.padding, child.output_padding,
                    child.groups, child.bias, child.dilation, child.padding_mode, child))
            elif isinstance(child, nn.ConvTranspose2d):
                setattr(module, name, QConvTranspose2d(child.in_channels, child.out_channels,
                    child.kernel_size, child.stride, child.padding, child.output_padding,
                    child.groups, child.bias, child.dilation, child.padding_mode, child))
            elif isinstance(child, nn.ConvTranspose3d):
                setattr(module, name, QConvTranspose3d(child.in_channels, child.out_channels,
                    child.kernel_size, child.stride, child.padding, child.output_padding,
                    child.groups, child.bias, child.dilation, child.padding_mode, child))
            elif isinstance(child, nn.Linear):
                setattr(module, name, QLinear(child.in_features, child.out_features,
                    child.bias is not None, child))
            elif isinstance(child, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                child.to(config.device) # move to cuda at start
                child.register_full_backward_hook(backward_hook_for_unquantized_layers)
            else:
                QModule._convert_layers(child)

    def _set_module_names(self):
        for name, child in self.model.named_modules():
            setattr(child, 'name', name)
    
    @staticmethod
    def _set_common_attribute(module, attr_name, attr):
        for name, child in module.named_children():
            if isinstance(child, (QConv1d, QConv2d, QConv3d, QConvTranspose1d, QConvTranspose2d, QConvTranspose3d,
                                  QLinear, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                setattr(child, attr_name, attr)
            else:
                QModule._set_common_attribute(child, attr_name, attr)

    def set_optimizer(self, optimizer):
        QModule._set_common_attribute(self.model, 'scheme', self.scheme)
        QModule._set_common_attribute(self.model, 'optimizer', optimizer)
    
    def get_named_model_parameters(self):
        parameters = []
        for param_name in self.param_names:
            p = get_param_by_name(self.model, param_name)
            parameters.append((param_name, p))
        return parameters

    # not used
    def state_dict(self):
        ret = {
            'model': self.model.state_dict(),
            'scheme': self.scheme
        }
        return ret

    def load_state_dict(self, state_dict):
        model_dict = state_dict['model']
        self.model.load_state_dict(model_dict)
        self.scheme = state_dict['scheme']
        QModule._set_common_attribute(self.model, 'scheme', self.scheme)
