import torch.nn as nn
from quant.quant_layer import QuantModule
from quant.quant_block import BaseQuantBlock, QuantResnetBlock, QuantAttnBlock
from quant_control.block_w_recon import block_w_reconstruction
from quant_control.layer_w_recon import layer_w_reconstruction
import logging
logger = logging.getLogger(__name__)

class recon_Qmodel():
    def __init__(self, args, qnn, kwargs):
        self.args = args
        self.model = qnn
        self.kwargs = kwargs
        self.down_name = None

    def recon_w_model(self, module: nn.Module):
        for name, module in module.named_children():
            if self.down_name == None and name == 'down':
                self.down_name = 'down'
            if self.down_name == 'down' and name == '1' and isinstance(module, BaseQuantBlock) == 0:
                logger.info('reconstruction for down 1 modulelist')
                # print('reconstruction for down 1 modulelist')
                block_w_reconstruction(self.model, module.block[0], **self.kwargs)
                block_w_reconstruction(self.model, module.attn[0], **self.kwargs)
                block_w_reconstruction(self.model, module.block[1], **self.kwargs)
                block_w_reconstruction(self.model, module.attn[1], **self.kwargs)
                layer_w_reconstruction(self.model, module.downsample.conv, **self.kwargs)
                self.down_name = 'over'
            elif isinstance(module, QuantModule):
                logger.info('Reconstruction for layer {}'.format(name))
                layer_w_reconstruction(self.model, module, **self.kwargs)
            elif isinstance(module, BaseQuantBlock):
                logger.info('Reconstruction for block {}'.format(name))
                block_w_reconstruction(self.model, module, **self.kwargs)
            elif name == 'up':
                self.recon_up_w_model(module)
            else:
                self.recon_w_model(module)
                
    def recon_up_w_model(self, module: nn.Module):
        for up_name, up_module in reversed(list(module.named_children())):
            if up_name == '1':
                logger.info('reconstruction for up 1 modulelist')
                block_w_reconstruction(self.model, up_module.block[0], **self.kwargs)
                block_w_reconstruction(self.model, up_module.attn[0], **self.kwargs)
                block_w_reconstruction(self.model, up_module.block[1], **self.kwargs)
                block_w_reconstruction(self.model, up_module.attn[1], **self.kwargs)
                block_w_reconstruction(self.model, up_module.block[2], **self.kwargs)
                block_w_reconstruction(self.model, up_module.attn[2], **self.kwargs) 
                layer_w_reconstruction(self.model, up_module.upsample.conv, **self.kwargs)
            elif isinstance(up_module, QuantModule):
                logger.info('Reconstruction for layer {}'.format(up_name))
                layer_w_reconstruction(self.model, up_module, **self.kwargs)
            elif isinstance(up_module, BaseQuantBlock):
                logger.info('Reconstruction for block {}'.format(up_name))
                block_w_reconstruction(self.model, up_module, **self.kwargs)
            else:
                self.recon_w_model(up_module)

    def w_recon(self):
        self.recon_w_model(self.model)
        return self.model
