import torch
import torch.nn as nn
import math
from copy import deepcopy
from models.unet import DoubleConv, Down, Up, OutConv  # 需要你本地的 UNet 组件

def prune_unet_model(model, prune_ratio=0.3):
    """
    结构感知 U-Net 剪枝
    prune_ratio: 剪掉比例 (0.3 = 剪掉30%，保留70%)
    """

    def get_l1_norms(conv: nn.Conv2d):
        return conv.weight.data.abs().view(conv.out_channels, -1).sum(dim=1)

    def prune_double_conv(double_conv, in_ch, prune_ratio):
        """返回新的 DoubleConv 和 输出通道数"""
        conv1, bn1, relu1, conv2, bn2, relu2 = double_conv

        # conv1 剪枝
        norms1 = get_l1_norms(conv1)
        n_keep1 = max(1, math.floor(len(norms1) * (1 - prune_ratio)))
        keep1 = torch.topk(norms1, k=n_keep1, largest=True).indices.tolist()

        conv1_new = nn.Conv2d(in_ch, n_keep1, kernel_size=3, padding=1, bias=(conv1.bias is not None))
        conv1_new.weight.data = conv1.weight.data[keep1][:, :in_ch].clone()
        if conv1.bias is not None:
            conv1_new.bias.data = conv1.bias.data[keep1].clone()
        bn1_new = nn.BatchNorm2d(n_keep1)

        # conv2 剪枝
        norms2 = get_l1_norms(conv2)
        n_keep2 = max(1, math.floor(len(norms2) * (1 - prune_ratio)))
        keep2 = torch.topk(norms2, k=n_keep2, largest=True).indices.tolist()

        conv2_new = nn.Conv2d(n_keep1, n_keep2, kernel_size=3, padding=1, bias=(conv2.bias is not None))
        conv2_new.weight.data = conv2.weight.data[keep2][:, keep1].clone()
        if conv2.bias is not None:
            conv2_new.bias.data = conv2.bias.data[keep2].clone()
        bn2_new = nn.BatchNorm2d(n_keep2)

        new_double_conv = nn.Sequential(
            conv1_new, bn1_new, relu1,
            conv2_new, bn2_new, relu2
        )
        return new_double_conv, n_keep2

    # === 开始剪枝 ===
    model = deepcopy(model)

    # Encoder
    in_ch = model.inc.double_conv[0].in_channels
    model.inc.double_conv, c1 = prune_double_conv(model.inc.double_conv, in_ch, prune_ratio)
    model.down1.maxpool_conv[1].double_conv, c2 = prune_double_conv(model.down1.maxpool_conv[1].double_conv, c1, prune_ratio)
    model.down2.maxpool_conv[1].double_conv, c3 = prune_double_conv(model.down2.maxpool_conv[1].double_conv, c2, prune_ratio)
    model.down3.maxpool_conv[1].double_conv, c4 = prune_double_conv(model.down3.maxpool_conv[1].double_conv, c3, prune_ratio)
    model.down4.maxpool_conv[1].double_conv, c5 = prune_double_conv(model.down4.maxpool_conv[1].double_conv, c4, prune_ratio)

    # Decoder (需要重新构造 Up)
    # up1: in_channels = c5 + c4
    model.up1 = Up(c5 + c4, c4)
    # up2: in_channels = out(up1) + c3
    model.up2 = Up(c4 + c3, c3)
    # up3: in_channels = out(up2) + c2
    model.up3 = Up(c3 + c2, c2)
    # up4: in_channels = out(up3) + c1
    model.up4 = Up(c2 + c1, c1)

    # outc: 保持原始类别数
    model.outc = OutConv(c1, model.outc.conv.out_channels)

    return model
