import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.layers import SubnetConv, SubnetLinear

class BasicBlockWRNSubnet(nn.Module):
    expansion = 1

    def __init__(
        self,
        conv_layer,
        in_planes,
        out_planes,
        stride=1,
        dropRate=0.0,
        prune_reg='weight',
        task_mode='harp_prune',
    ):
        super().__init__()
        # Pre-activation: BN + ReLU
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        # First convolution
        self.conv1 = conv_layer(
            in_planes, out_planes, kernel_size=3, stride=stride, padding=1,
            bias=False, prune_reg=prune_reg, task_mode=task_mode
        )
        # BN + ReLU before second conv
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = conv_layer(
            out_planes, out_planes, kernel_size=3, stride=1, padding=1,
            bias=False, prune_reg=prune_reg, task_mode=task_mode
        )
        self.droprate = dropRate
        # Shortcut
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = None
        if not self.equalInOut:
            self.convShortcut = conv_layer(
                in_planes, out_planes, kernel_size=1, stride=stride,
                bias=False, prune_reg=prune_reg, task_mode=task_mode
            )

    def forward(self, x):
        # pre-activation of input
        preact = self.relu1(self.bn1(x))
        # residual branch
        out = self.relu2(self.bn2(self.conv1(preact)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        # shortcut connection uses pre-activated tensor when changing dimensions
        skip = x if self.equalInOut else self.convShortcut(preact)
        return skip + out

class NetworkBlockSubnet(nn.Module):
    def __init__(
        self,
        nb_layers,
        in_planes,
        out_planes,
        block,
        conv_layer,
        stride,
        dropRate=0.0,
        prune_reg='weight',
        task_mode='harp_prune',
    ):
        super().__init__()
        layers = []
        for i in range(int(nb_layers)):
            s = stride if i == 0 else 1
            ip = in_planes if i == 0 else out_planes
            layers.append(
                block(
                    conv_layer, ip, out_planes, s,
                    dropRate, prune_reg, task_mode
                )
            )
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)

class WideResNetSubnet(nn.Module):
    def __init__(
        self,
        conv_layer,
        linear_layer,
        depth=28,
        n_cls=10,
        widen_factor=4,
        dropRate=0.0,
        prune_reg='weight',
        task_mode='harp_prune',
        normalize_features=False,
        normalize_logits=False,
    ):
        super().__init__()
        assert (depth - 4) % 6 == 0, "Depth must be 6n+4"
        n = (depth - 4) // 6
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]

        # Initial convolution
        self.conv1 = conv_layer(
            3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False,
            prune_reg=prune_reg, task_mode=task_mode
        )
        # Residual blocks
        self.block1 = NetworkBlockSubnet(
            n, nChannels[0], nChannels[1], BasicBlockWRNSubnet,
            conv_layer, stride=1, dropRate=dropRate,
            prune_reg=prune_reg, task_mode=task_mode
        )
        self.block2 = NetworkBlockSubnet(
            n, nChannels[1], nChannels[2], BasicBlockWRNSubnet,
            conv_layer, stride=2, dropRate=dropRate,
            prune_reg=prune_reg, task_mode=task_mode
        )
        self.block3 = NetworkBlockSubnet(
            n, nChannels[2], nChannels[3], BasicBlockWRNSubnet,
            conv_layer, stride=2, dropRate=dropRate,
            prune_reg=prune_reg, task_mode=task_mode
        )
        # Final BN + ReLU
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.nChannels = nChannels[3]
        # Classifier
        self.fc = linear_layer(
            self.nChannels, n_cls, prune_reg=prune_reg, task_mode=task_mode
        )
        self.normalize_features = normalize_features
        self.normalize_logits = normalize_logits

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n_val = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n_val))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear) and m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        if self.normalize_features:
            out = out / out.norm(dim=-1, keepdim=True).clamp(min=1e-10)
        logits = self.fc(out)
        if self.normalize_logits:
            logits = logits - logits.mean(dim=-1, keepdim=True)
            norms = logits.norm(dim=-1, keepdim=True).clamp(min=1e-10)
            logits = logits / norms
        return logits


def wrn_28_4_subnet(
    n_cls,
    prune_reg='weight',
    task_mode='harp_prune',
    widen_factor=4,
    dropRate=0.0,
    normalize_features=False,
    normalize_logits=False
):
    """
    Subnet-prunable WideResNet-28-4 for CIFAR10.
    """
    return WideResNetSubnet(
        SubnetConv,
        SubnetLinear,
        depth=28,
        n_cls=n_cls,
        widen_factor=widen_factor,
        dropRate=dropRate,
        prune_reg=prune_reg,
        task_mode=task_mode,
        normalize_features=normalize_features,
        normalize_logits=normalize_logits,
    )
