import torch.nn as nn

from .modules import TQLinear, TQLoRALinear
from .quant_utils import vectorwise_dequant



def make_quant(
    module, names, name='', config="./lpmm/configs/default.yml", is_cuda=True, device=None, dtype=None
):
    if isinstance(module, TQLoRALinear): # TQLinear, TQLoRALinear
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        name1 = name + '.' + attr if name != '' else attr
        if name1 in names:
            # print(attr)
            # print(tmp)
            # print(name1)
            setattr(
                module, attr, TQLoRALinear( # TQLinear
                    name=name1,
                    weight=tmp.weight,
                    bias=tmp.bias,
                    in_features=tmp.in_features, 
                    out_features=tmp.out_features, 
                    config=config,
                    is_cuda=is_cuda,
                    device=device,
                    dtype=dtype,
                    r=64, # 32
                    lora_alpha=16, 
                    lora_dropout=0.1, 
                    q_trainable=False # True
                )
                # module, attr, TQLinear( # TQLinear
                #     name=name1,
                #     weight=tmp.weight,
                #     bias=tmp.bias,
                #     in_features=tmp.in_features, 
                #     out_features=tmp.out_features, 
                #     config=config,
                #     is_cuda=is_cuda,
                #     device=device,
                #     dtype=dtype,
                # )
            )
    for name1, child in module.named_children():
        make_quant(
            child, 
            names, 
            name=name + '.' + name1 if name != '' else name1, 
            config=config,
            is_cuda=is_cuda,
            device=device,
            dtype=dtype
        )


def make_dequant(
    module, names, name='', device=None, dtype=None
):
    if isinstance(module, nn.Linear):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        name1 = name + '.' + attr if name != '' else attr
        if name1 in names:
            dequant_weight = vectorwise_dequant(tmp.q_weight, tmp.q_scales, tmp.q_biases, qmap=tmp.qmap, shape=tmp.w_shape, **tmp.q_metadata)
            setattr(
                module, attr, nn.Linear(
                    in_features=tmp.in_features, 
                    out_features=tmp.out_features, 
                    device=device,
                    dtype=dtype,
                )
            )
            tmp = getattr(module, attr)
            tmp.weight.data.copy_(dequant_weight)
            # print(dequant_weight)
            # print(tmp.weight)

    for name1, child in module.named_children():
        make_dequant(
            child, 
            names, 
            name=name + '.' + name1 if name != '' else name1, 
            device=device,
            dtype=dtype
        )


def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


def find_layers_for_TQ(module, layers=[TQLinear], name=''): # TQLinear, TQLoRALinear
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res