import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.ao.quantization as tq
try:
    from torch.ao.nn.quantized import FloatFunctional
except ImportError:
    from torch.nn.quantized import FloatFunctional


class QuantizableFCN8s(nn.Module):
    def __init__(self, n_classes=2, n_channels=3, layer_num=64):
        super(QuantizableFCN8s, self).__init__()

        # ===== Encoder =====
        self.stage1 = nn.Sequential(
            nn.Conv2d(n_channels, layer_num, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num, layer_num, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.stage2 = nn.Sequential(
            nn.Conv2d(layer_num, layer_num * 2, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num * 2, layer_num * 2, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.stage3 = nn.Sequential(
            nn.Conv2d(layer_num * 2, layer_num * 4, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num * 4, layer_num * 4, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num * 4, layer_num * 4, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.stage4 = nn.Sequential(
            nn.Conv2d(layer_num * 4, layer_num * 8, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 8),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num * 8, layer_num * 8, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 8),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num * 8, layer_num * 8, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 8),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.stage5 = nn.Sequential(
            nn.Conv2d(layer_num * 8, layer_num * 8, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 8),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num * 8, layer_num * 8, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 8),
            nn.ReLU(inplace=True),
            nn.Conv2d(layer_num * 8, layer_num * 8, 3, padding=1, bias=False),
            nn.BatchNorm2d(layer_num * 8),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )

        # ===== Classifier + Skips =====
        self.score_pool3 = nn.Conv2d(layer_num * 4, n_classes, kernel_size=1)
        self.score_pool4 = nn.Conv2d(layer_num * 8, n_classes, kernel_size=1)
        self.score       = nn.Conv2d(layer_num * 8, n_classes, kernel_size=1)

        self.upscore2_1 = nn.ConvTranspose2d(n_classes, n_classes, 4, stride=2, bias=False)
        self.upscore2_2 = nn.ConvTranspose2d(n_classes, n_classes, 4, stride=2, bias=False)
        self.upscore8   = nn.ConvTranspose2d(n_classes, n_classes, 16, stride=8, bias=False)

        # ===== Quantization Stubs =====
        self.quant = tq.QuantStub()
        self.dequant = tq.DeQuantStub()
        self.add = FloatFunctional()

    def forward(self, x):
        x = self.quant(x)

        s1 = self.stage1(x)
        s2 = self.stage2(s1)
        s3 = self.stage3(s2)
        s4 = self.stage4(s3)
        s5 = self.stage5(s4)

        s5_score = self.score(s5)
        s5_up    = self.upscore2_1(s5_score)

        s4_score = self.score_pool4(s4)
        fuse4    = self.add.add(s5_up, s4_score)  # 替代 s5_up + s4_score

        fuse4_up = self.upscore2_2(fuse4)

        s3_score = self.score_pool3(s3)
        fuse3    = self.add.add(fuse4_up, s3_score)  # 替代 fuse4_up + s3_score

        out = self.upscore8(fuse3)
        out = self.dequant(out)
        return out

    def fuse_model(self):
        """Conv+BN+ReLU 融合"""
        def fuse_seq(seq: nn.Sequential):
            i = 0
            while i + 2 < len(seq):
                if isinstance(seq[i], nn.Conv2d) and isinstance(seq[i+1], nn.BatchNorm2d) and isinstance(seq[i+2], nn.ReLU):
                    tq.fuse_modules(seq, [str(i), str(i+1), str(i+2)], inplace=True)
                    i += 3
                else:
                    i += 1
        for m in [self.stage1, self.stage2, self.stage3, self.stage4, self.stage5]:
            fuse_seq(m)
