import torch
import torch.nn as nn
import torch.ao.quantization as tq

class SegNetDown(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, return_indices=True)
        )

    def forward(self, x):
        return self.block(x)

    def fuse_model(self):
        tq.fuse_modules(self.block, [["0","1","2"], ["3","4","5"]], inplace=True)


class SegNetUp(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.unpool = nn.MaxUnpool2d(2, stride=2)
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, indices, output_size):
        x = self.unpool(x, indices, output_size=output_size)
        return self.block(x)

    def fuse_model(self):
        tq.fuse_modules(self.block, [["0","1","2"], ["3","4","5"]], inplace=True)


class QuantizableSegNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2, layer_num=64):
        super().__init__()
        self.down1 = SegNetDown(n_channels, layer_num)
        self.down2 = SegNetDown(layer_num, layer_num*2)
        self.down3 = SegNetDown(layer_num*2, layer_num*4)
        self.down4 = SegNetDown(layer_num*4, layer_num*8)

        self.up4 = SegNetUp(layer_num*8, layer_num*4)
        self.up3 = SegNetUp(layer_num*4, layer_num*2)
        self.up2 = SegNetUp(layer_num*2, layer_num)
        self.up1 = nn.Conv2d(layer_num, n_classes, 1)

        self.quant = tq.QuantStub()
        self.dequant = tq.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)

        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)

        x = self.up4(x4, None, x3.size())  # ⚠️ 如果你之前保存了 indices，要传进去
        x = self.up3(x, None, x2.size())
        x = self.up2(x, None, x1.size())
        x = self.up1(x)

        x = self.dequant(x)
        return x

    def fuse_model(self):
        self.down1.fuse_model()
        self.down2.fuse_model()
        self.down3.fuse_model()
        self.down4.fuse_model()
        self.up4.fuse_model()
        self.up3.fuse_model()
        self.up2.fuse_model()
