import torch.nn as nn
from quant.quant_layer import QuantModule
from quant.quant_block import BaseQuantBlock, QuantResnetBlock, QuantAttnBlock
from quant.block_w_recon import block_w_reconstruction
from quant.layer_w_recon import layer_w_reconstruction
import logging
logger = logging.getLogger(__name__)

class recon_Qmodel():
    def __init__(self, args, qnn, kwargs):
        self.args = args
        self.model = qnn
        self.kwargs = kwargs
        self.down_name = None

    def recon_w_model(self, module: nn.Module):
        for name, module in module.named_children():
            if self.down_name == None and name == 'down':
                self.down_name = 'down'
            if self.down_name == 'down' and name == '1' and isinstance(module, BaseQuantBlock) == 0:
                logger.info('reconstruction for down 1 modulelist')
                # print('reconstruction for down 1 modulelist')
                block_w_reconstruction(self.model, module.block[0], **self.kwargs)
                block_w_reconstruction(self.model, module.attn[0], **self.kwargs)
                block_w_reconstruction(self.model, module.block[1], **self.kwargs)
                block_w_reconstruction(self.model, module.attn[1], **self.kwargs)
                layer_w_reconstruction(self.model, module.downsample.conv, **self.kwargs)
                self.down_name = 'over'
            elif isinstance(module, QuantModule):
                logger.info('Reconstruction for layer {}'.format(name))
                # print('Reconstruction for layer {}'.format(name))
                layer_w_reconstruction(self.model, module, **self.kwargs)
            elif isinstance(module, BaseQuantBlock):
                logger.info('Reconstruction for block {}'.format(name))
                # print('Reconstruction for block {}'.format(name))
                block_w_reconstruction(self.model, module, **self.kwargs)
            elif name == 'up':
                self.recon_up_w_model(module)
            else:
                self.recon_w_model(module)

    def recon_up_w_model(self, module: nn.Module):
        for up_name, up_module in reversed(list(module.named_children())):
            if up_name == '1':
                logger.info('reconstruction for up 1 modulelist')
                # print('reconstruction for up 1 modulelist')
                block_w_reconstruction(self.model, up_module.block[0], **self.kwargs)
                block_w_reconstruction(self.model, up_module.attn[0], **self.kwargs)
                block_w_reconstruction(self.model, up_module.block[1], **self.kwargs)
                block_w_reconstruction(self.model, up_module.attn[1], **self.kwargs)
                block_w_reconstruction(self.model, up_module.block[2], **self.kwargs)
                block_w_reconstruction(self.model, up_module.attn[2], **self.kwargs) 
                layer_w_reconstruction(self.model, up_module.upsample.conv, **self.kwargs)
            elif isinstance(up_module, QuantModule):
                logger.info('Reconstruction for layer {}'.format(up_name))
                # print('Reconstruction for layer {}'.format(up_name))
                layer_w_reconstruction(self.model, up_module, **self.kwargs)
            elif isinstance(up_module, BaseQuantBlock):
                logger.info('Reconstruction for block {}'.format(up_name))
                # print('Reconstruction for block {}'.format(up_name))
                block_w_reconstruction(self.model, up_module, **self.kwargs)
            else:
                self.recon_w_model(up_module)

    def w_recon(self):
        self.recon_w_model(self.model)
        return self.model

class recon_layer_Qmodel():
    def __init__(self, args, qnn, kwargs):
        self.args = args
        self.model = qnn
        self.kwargs = kwargs
        self.down_name = None

    def w_recon(self):
        self.recon_w_model(self.model)
        return self.model

    def recon_w_model(self, module: nn.Module):
        for name, module in module.named_children():
            if self.down_name == None and name == 'down':
                self.down_name = 'down'
            if self.down_name == 'down' and name == '1' and isinstance(module, BaseQuantBlock) == 0:
                logger.info('reconstruction for down 1 modulelist')
                # print('reconstruction for down 1 modulelist')
                self.recon_w_block(module.block[0])
                self.recon_w_block(module.attn[0])
                self.recon_w_block(module.block[1])
                self.recon_w_block(module.attn[1])
                layer_w_reconstruction(self.model, module.downsample.conv, **self.kwargs)
                self.down_name = 'over'
            elif isinstance(module, QuantModule):
                logger.info('Reconstruction for layer {}'.format(name))
                # print('Reconstruction for layer {}'.format(name))
                layer_w_reconstruction(self.model, module, **self.kwargs)
            elif isinstance(module, BaseQuantBlock):
                logger.info('Reconstruction for block {}'.format(name))
                # print('Reconstruction for block {}'.format(name))
                self.recon_w_block(module)
            elif name == 'up':
                self.recon_up_w_model(module)
            else:
                self.recon_w_model(module)

    def recon_up_w_model(self, module: nn.Module):
        for up_name, up_module in reversed(list(module.named_children())):
            if up_name == '1':
                logger.info('reconstruction for up 1 modulelist')
                # print('reconstruction for up 1 modulelist')
                self.recon_w_block(up_module.block[0])
                self.recon_w_block(up_module.attn[0])
                self.recon_w_block(up_module.block[1])
                self.recon_w_block(up_module.attn[1])
                self.recon_w_block(up_module.block[2])
                self.recon_w_block(up_module.attn[2]) 
                layer_w_reconstruction(self.model, up_module.upsample.conv, **self.kwargs)
            elif isinstance(up_module, QuantModule):
                logger.info('Reconstruction for layer {}'.format(up_name))
                # print('Reconstruction for layer {}'.format(up_name))
                layer_w_reconstruction(self.model, up_module, **self.kwargs)
            elif isinstance(up_module, BaseQuantBlock):
                logger.info('Reconstruction for block {}'.format(up_name))
                # print('Reconstruction for block {}'.format(up_name))
                self.recon_w_block(up_module)
            else:
                self.recon_w_model(up_module)

    def recon_w_block(self, block: nn.Module):
        if isinstance(block, QuantResnetBlock):
            self.recon_w_QuantResnetBlock_block(block)
        elif isinstance(block, QuantAttnBlock):
            self.recon_w_QuantAttnBlock_block(block)

    def recon_w_QuantResnetBlock_block(self, module: nn.Module):
        for name, module in module.named_children():
            if isinstance(module, QuantModule):
                logger.info('Reconstruction for layer {}'.format(name))
                # print('Reconstruction for layer {}'.format(name))
                layer_w_reconstruction(self.model, module, **self.kwargs)
            else:
                self.recon_w_QuantResnetBlock_block(module)

    def recon_w_QuantAttnBlock_block(self, module: nn.Module):
        logger.info('Reconstruction for Attn')
        # print('Reconstruction for Attn')
        layer_w_reconstruction(self.model, module.q, **self.kwargs)
        layer_w_reconstruction(self.model, module.k, **self.kwargs)
        layer_w_reconstruction(self.model, module.v, **self.kwargs)
        layer_w_reconstruction(self.model, module.proj_out, **self.kwargs)