#%% Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import math

#%% Prune + Apply Function
def prune_segnet_model(model, prune_ratio=0.3):
    def get_l1_norms(conv: nn.Conv2d):
        return conv.weight.data.abs().view(conv.out_channels, -1).sum(dim=1)

    def prune_conv_layer(conv, keep_out_idx, keep_in_idx=None):
        old_weight = conv.weight.data
        old_bias = conv.bias.data if conv.bias is not None else None

        keep_out_idx = torch.tensor(keep_out_idx, dtype=torch.long)
        new_out = len(keep_out_idx)
        if keep_in_idx is not None:
            keep_in_idx = torch.tensor(keep_in_idx, dtype=torch.long)
        else:
            keep_in_idx = torch.arange(conv.in_channels)
        new_in = len(keep_in_idx)

        new_conv = nn.Conv2d(new_in, new_out, conv.kernel_size,
                             stride=conv.stride, padding=conv.padding,
                             dilation=conv.dilation, bias=conv.bias is not None)

        new_conv.weight.data = old_weight[keep_out_idx][:, keep_in_idx].clone()
        if old_bias is not None:
            new_conv.bias.data = old_bias[keep_out_idx].clone()

        return new_conv, keep_out_idx

    def prune_bn_layer(bn, keep_idx):
        keep_idx = torch.tensor(keep_idx, dtype=torch.long)
        new_bn = nn.BatchNorm2d(len(keep_idx), eps=bn.eps, momentum=bn.momentum)
        new_bn.weight.data = bn.weight.data[keep_idx].clone()
        new_bn.bias.data = bn.bias.data[keep_idx].clone()
        new_bn.running_mean = bn.running_mean[keep_idx].clone()
        new_bn.running_var = bn.running_var[keep_idx].clone()
        return new_bn

    ALL_LAYERS = [
        ('ConvEn11', 'BNEn11'), ('ConvEn12', 'BNEn12'),
        ('ConvEn21', 'BNEn21'), ('ConvEn22', 'BNEn22'),
        ('ConvEn31', 'BNEn31'), ('ConvEn32', 'BNEn32'), ('ConvEn33', 'BNEn33'),
        ('ConvEn41', 'BNEn41'), ('ConvEn42', 'BNEn42'), ('ConvEn43', 'BNEn43'),
        ('ConvEn51', 'BNEn51'), ('ConvEn52', 'BNEn52'), ('ConvEn53', 'BNEn53'),

        ('ConvDe53', 'BNDe53'), ('ConvDe52', 'BNDe52'), ('ConvDe51', 'BNDe51'),
        ('ConvDe43', 'BNDe43'), ('ConvDe42', 'BNDe42'), ('ConvDe41', 'BNDe41'),
        ('ConvDe33', 'BNDe33'), ('ConvDe32', 'BNDe32'), ('ConvDe31', 'BNDe31'),
        ('ConvDe22', 'BNDe22'), ('ConvDe21', 'BNDe21'),
        ('ConvDe12', 'BNDe12'), ('ConvDe11', 'BNDe11')
    ]

    model = deepcopy(model)
    prev_keep = list(range(model.ConvEn11.in_channels))

    for conv_name, bn_name in ALL_LAYERS:
        conv = getattr(model, conv_name)
        bn = getattr(model, bn_name)

        norms = get_l1_norms(conv)

        if conv_name == 'ConvDe11':  # 保留所有类别输出通道
            keep_out_idx = list(range(conv.out_channels))
        else:
            n_keep_out = max(1, math.floor(len(norms) * (1 - prune_ratio)))
            keep_out_idx = torch.topk(norms, k=n_keep_out, largest=True).indices.tolist()

        conv_new, keep_out_idx = prune_conv_layer(conv, keep_out_idx, prev_keep)
        bn_new = prune_bn_layer(bn, keep_out_idx)

        setattr(model, conv_name, conv_new)
        setattr(model, bn_name, bn_new)

        prev_keep = keep_out_idx

    return model

#%% Example Usage
if __name__ == '__main__':
    from models.SegNet import SegNet
    model = SegNet(n_classes=2, n_channels=4, layer_num=64, LAM=False)
    pruned_model = prune_segnet_model(model, prune_ratio=0.3)
    dummy_input = torch.randn(1, model.in_chn, 256, 256)
    with torch.no_grad():
        out = pruned_model(dummy_input)
    print("Pruned model forward pass OK. Output shape:", out.shape)
