import torch
import torch.nn as nn
from copy import deepcopy

# Adjust this import to your actual path.

#%% ---------- Utilities ----------

def _ensure_tensor_idx(idx, device=None):
    if isinstance(idx, torch.Tensor):
        return idx.to(device) if device else idx
    t = torch.as_tensor(idx, dtype=torch.long)
    return t.to(device) if device else t

def clone_conv2d_subset(conv: nn.Conv2d, in_keep_idx, out_keep_idx):
    """Clone a Conv2d with selected in/out channels."""
    in_keep_idx  = _ensure_tensor_idx(in_keep_idx)
    out_keep_idx = _ensure_tensor_idx(out_keep_idx)

    new_conv = nn.Conv2d(
        in_channels      = in_keep_idx.numel(),
        out_channels     = out_keep_idx.numel(),
        kernel_size      = conv.kernel_size,
        stride           = conv.stride,
        padding          = conv.padding,
        dilation         = conv.dilation,
        groups           = conv.groups,
        bias             = (conv.bias is not None),
        padding_mode     = conv.padding_mode
    )

    with torch.no_grad():
        new_conv.weight.copy_(conv.weight.data.index_select(0, out_keep_idx)[:, in_keep_idx, :, :].clone())
        if conv.bias is not None:
            new_conv.bias.copy_(conv.bias.data.index_select(0, out_keep_idx).clone())

    return new_conv

def clone_bn_subset(bn: nn.BatchNorm2d, out_keep_idx):
    """Clone a BN with selected features."""
    out_keep_idx = _ensure_tensor_idx(out_keep_idx)

    new_bn = nn.BatchNorm2d(out_keep_idx.numel(), eps=bn.eps, momentum=bn.momentum, affine=True, track_running_stats=True)
    with torch.no_grad():
        new_bn.weight.copy_(bn.weight.data.index_select(0, out_keep_idx).clone())
        new_bn.bias.copy_(bn.bias.data.index_select(0, out_keep_idx).clone())
        new_bn.running_mean.copy_(bn.running_mean.data.index_select(0, out_keep_idx).clone())
        new_bn.running_var.copy_(bn.running_var.data.index_select(0, out_keep_idx).clone())
    return new_bn

def select_out_channels_by_l1(conv: nn.Conv2d, keep_ratio: float):
    """Return indices of output channels to keep by L1-norm ranking."""
    W = conv.weight.data
    l1 = W.abs().view(W.size(0), -1).sum(dim=1)
    k  = max(1, int(round(W.size(0) * keep_ratio)))
    topk = torch.topk(l1, k=k, largest=True, sorted=True).indices
    return topk.sort()[0]

def build_1x1_with_in_subset(conv1x1: nn.Conv2d, in_keep_idx):
    """Rebuild a 1x1 conv to match new in_channels (out_channels unchanged)."""
    in_keep_idx = _ensure_tensor_idx(in_keep_idx)
    new_conv = nn.Conv2d(
        in_channels  = in_keep_idx.numel(),
        out_channels = conv1x1.out_channels,
        kernel_size  = conv1x1.kernel_size,
        stride       = conv1x1.stride,
        padding      = conv1x1.padding,
        dilation     = conv1x1.dilation,
        groups       = conv1x1.groups,
        bias         = (conv1x1.bias is not None),
        padding_mode = conv1x1.padding_mode
    )
    with torch.no_grad():
        new_conv.weight.copy_(conv1x1.weight.data[:, in_keep_idx, :, :].clone())
        if conv1x1.bias is not None:
            new_conv.bias.copy_(conv1x1.bias.data.clone())
    return new_conv

#%% ---------- Stage pruning ----------

def prune_stage_sequential(stage: nn.Sequential, prune_ratio: float, in_keep_idx=None):
    """
    Prune a stage: pattern is (Conv -> BN -> ReLU) * n + MaxPool2d.
    - prune_ratio: fraction in (0,1), the proportion to prune per Conv
    - in_keep_idx: indices of input channels kept from previous stage (or None for first stage)
    """
    keep_ratio = 1.0 - prune_ratio
    modules = list(stage.children())
    new_layers = []
    i = 0

    cur_in_idx = None if in_keep_idx is None else _ensure_tensor_idx(in_keep_idx)
    last_out_idx = None

    while i < len(modules):
        m = modules[i]
        if isinstance(m, nn.Conv2d):
            conv = modules[i]
            bn   = modules[i+1]
            relu = modules[i+2]

            if cur_in_idx is None:
                cur_in_idx = torch.arange(conv.in_channels, dtype=torch.long)

            out_keep_idx = select_out_channels_by_l1(conv, keep_ratio=keep_ratio)
            new_conv = clone_conv2d_subset(conv, in_keep_idx=cur_in_idx, out_keep_idx=out_keep_idx)
            new_bn   = clone_bn_subset(bn, out_keep_idx=out_keep_idx)

            new_layers += [new_conv, new_bn, relu]

            cur_in_idx = torch.arange(out_keep_idx.numel(), dtype=torch.long)
            last_out_idx = cur_in_idx
            i += 3
        else:
            new_layers.append(m)
            i += 1

    new_stage = nn.Sequential(*new_layers)
    return new_stage, last_out_idx

#%% ---------- End-to-end FCN8s pruning ----------

def prune_fcn8s_model(model: nn.Module, prune_ratio: float = 0.3):
    """
    Fully aligned structured pruning for FCN8s backbone + skip heads.
    """
    assert hasattr(model, "LAM") and (model.LAM is False), "This pruner assumes LAM=False."

    m = deepcopy(model)

    # Stage 1
    m.stage1, ch_s1_out = prune_stage_sequential(m.stage1, prune_ratio, in_keep_idx=None)
    # Stage 2
    m.stage2, ch_s2_out = prune_stage_sequential(m.stage2, prune_ratio, in_keep_idx=ch_s1_out)
    # Stage 3
    m.stage3, ch_s3_out = prune_stage_sequential(m.stage3, prune_ratio, in_keep_idx=ch_s2_out)
    # Stage 4
    m.stage4, ch_s4_out = prune_stage_sequential(m.stage4, prune_ratio, in_keep_idx=ch_s3_out)
    # Stage 5
    m.stage5, ch_s5_out = prune_stage_sequential(m.stage5, prune_ratio, in_keep_idx=ch_s4_out)

    # Rebuild skip connections
    m.score_pool3 = build_1x1_with_in_subset(m.score_pool3, in_keep_idx=torch.arange(ch_s3_out.numel()))
    m.score_pool4 = build_1x1_with_in_subset(m.score_pool4, in_keep_idx=torch.arange(ch_s4_out.numel()))
    m.score       = build_1x1_with_in_subset(m.score,       in_keep_idx=torch.arange(ch_s5_out.numel()))

    return m