from .attention import Fuse
from .decoder_utils import *
from .encoder_utils import *


class U_SNN(nn.Module):
    def __init__(self, in_channel=3, out_channel=1, args=None):
        super().__init__()
        self.T = args.T
        k = 1
        self.down_channels = [32 * k, 64 * k, 64 * k, 128 * k, 256 * k, 512 * k]
        step_mode = args.step_mode
        backend = args.backend

        self.nz, self.numel = {}, {}
        self.all_nnz = 0
        self.all_nnumel = 0

        self.plain_conv = SpikingPlainBlock(in_channel, 32 * k, 7, 3, step_mode, backend, args)
        self.encoder_layer1 = nn.Sequential(
            SpikingDownBlock(32 * k, 64 * k, 3, 1, 1, step_mode, backend, False, args),
            SpikingDownBlock(64 * k, 64 * k, 3, 1, 1, step_mode, backend, True, args),
        )
        self.encoder_layer2 = nn.Sequential(
            SpikingDownBlock(64 * k, 128 * k, 3, 1, 2, step_mode, backend, False, args),
            SpikingDownBlock(128 * k, 128 * k, 3, 1, 1, step_mode, backend, True, args),
        )
        self.encoder_layer3 = nn.Sequential(
            SpikingDownBlock(128 * k, 256 * k, 3, 1, 2, step_mode, backend, False, args),
            SpikingDownBlock(256 * k, 256 * k, 3, 1, 1, step_mode, backend, False, args),
            SpikingDownBlock(256 * k, 256 * k, 3, 1, 1, step_mode, backend, True, args),
        )
        self.encoder_layer4 = nn.Sequential(
            SpikingDownBlock(256 * k, 512 * k, 3, 1, 2, step_mode, backend, False, args),
            SpikingDownBlock(512 * k, 512 * k, 3, 1, 1, step_mode, backend, False, args),
            SpikingDownBlock(512 * k, 512 * k, 3, 1, 1, step_mode, backend, True, args),
        )

        self.encoder_layer5 = nn.Sequential(
            SpikingDownBlock(512 * k, 512 * k, 3, 1, 2, step_mode, backend, False, args),
            SpikingDownBlock(512 * k, 512 * k, 3, 1, 1, step_mode, backend, False, args),
            SpikingDownBlock(512 * k, 512 * k, 3, 1, 1, step_mode, backend, True, args),
        )

        self.bottleneck = nn.Sequential(
            OrigResBlock(512 * k, 512 * k, step_mode=step_mode, backend=backend, act=True),
            OrigResBlock(512 * k, 512 * k, step_mode=step_mode, backend=backend, act=False),
        )

        self.decoder_layer5 = SpikingUpBlock(512 * k, 512 * k, 3, 1, 2, step_mode, backend, args)
        self.decoder_layer4 = SpikingUpBlock(512 * k, 256 * k, 3, 1, 2, step_mode, backend, args)
        self.decoder_layer3 = SpikingUpBlock(256 * k, 128 * k, 3, 1, 2, step_mode, backend, args)
        self.decoder_layer2 = SpikingUpBlock(128 * k, 64 * k, 3, 1, 2, step_mode, backend, args)
        self.decoder_layer1 = SpikingUpBlock(64 * k, 32 * k, 3, 1, 1, step_mode, backend, args)

        self.fusion_layer2 = Fuse()
        self.fusion_layer3 = Fuse()
        self.fusion_layer4 = Fuse()
        self.fusion_layer5 = Fuse()

        self.predict_depth5 = SpikingPredUpBlock(512 * k, out_channel, 7, 3, 8, step_mode=step_mode, backend=backend,
                                                 args=args)
        self.predict_depth4 = SpikingPredUpBlock(256 * k, out_channel, 7, 3, 4, step_mode=step_mode, backend=backend,
                                args=args)

        self.predict_depth3 = SpikingPredUpBlock(128 * k, out_channel, 7, 3, 2, step_mode=step_mode, backend=backend,
                                args=args)

        self.predict_depth2 = SpikingPredUpBlock(64 * k, out_channel, 7, 3, 1, step_mode=step_mode, backend=backend,
                                args=args)

        self.predict_depth1 = SpikingPredUpBlock(32 * k, out_channel, 7, 3, 1, step_mode=step_mode, backend=backend,
                                args=args)

    def forward(self, x):
        # 时序维度扩展
        x = x.repeat(self.T, 1, 1, 1, 1)  # [T, B, C, H, W]

        # 编码
        encode_down, encode_x = self.plain_conv(x)
        encode_down1, encode_x1 = self.encoder_layer1(encode_down)
        encode_down2, encode_x2 = self.encoder_layer2(encode_down1)
        encode_down3, encode_x3 = self.encoder_layer3(encode_down2)
        encode_down4, encode_x4 = self.encoder_layer4(encode_down3)
        encode_down5, encode_x5 = self.encoder_layer5(encode_down4)

        bottleneck_x = self.bottleneck(encode_down5)

        # 解码 + 融合
        decode_up5 = self.decoder_layer5(bottleneck_x)
        fusion5 = self.fusion_layer5(encode_x4, decode_up5)
        decode_up4 = self.decoder_layer4(fusion5)

        fusion4 = self.fusion_layer4(encode_x3, decode_up4)
        decode_up3 = self.decoder_layer3(fusion4)

        fusion3 = self.fusion_layer3(encode_x2, decode_up3)
        decode_up2 = self.decoder_layer2(fusion3)

        fusion2 = self.fusion_layer2(encode_x1, decode_up2)
        decode_up1 = self.decoder_layer1(fusion2)

        # 输出预测
        pred_edge5 = self.predict_depth5(fusion5)

        pred_edge4 = self.predict_depth4(fusion4)

        pred_edge3 = self.predict_depth3(fusion3)

        pred_edge2 = self.predict_depth2(fusion2)

        pred_edge1 = self.predict_depth1(encode_x + decode_up1)

        return pred_edge1, pred_edge2, pred_edge3, pred_edge4, pred_edge5

    # firing rate
    def process_nz(self, nz_numel):
        nz, numel = nz_numel
        total_nnz, total_nnumel = 0, 0

        for module, nnz in nz.items():
            nnumel = numel[module]
            if nnumel != 0:
                total_nnz += nnz
                total_nnumel += nnumel
        if total_nnumel != 0:
            self.all_nnz += total_nnz
            self.all_nnumel += total_nnumel

    def add_hooks(self, instance):
        def get_nz(name):
            def hook(model, input, output):
                self.nz[name] += torch.count_nonzero(output)
                self.numel[name] += output.numel()

            return hook

        self.hooks = {}

        for name, module in self.named_modules():
            if isinstance(module, instance):
                self.nz[name], self.numel[name] = 0, 0
                self.hooks[name] = module.register_forward_hook(get_nz(name))

    def reset_nz_numel(self):
        for name, module in self.named_modules():
            self.nz[name], self.numel[name] = 0, 0

    def get_nz_numel(self):
        return self.nz, self.numel
