import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models



class Mul(nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight
    def forward(self, x): return x * self.weight

class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

class Residual(nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    def forward(self, x): return x + self.module(x)

def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
    return nn.Sequential(
            nn.Conv2d(channels_in, channels_out,
                         kernel_size=kernel_size, stride=stride, padding=padding,
                         groups=groups, bias=False),
            nn.BatchNorm2d(channels_out),
            nn.ReLU()
    )

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, last_block_layer=False, dropout_rate=0.0):
        super(BasicBlock, self).__init__()
        self.last_block_layer = last_block_layer
        self.spatial_dropout = nn.Dropout2d(dropout_rate)  # Spatial dropout for convolutional layers

        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.spatial_dropout(out)  # Apply spatial dropout after first convolution
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        if self.last_block_layer:
            return out
        else:
            return F.relu(out)


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, last_block_layer=False, dropout_rate=0.0):
        super(Bottleneck, self).__init__()
        self.last_block_layer = last_block_layer
        self.dropout = nn.Dropout2d(dropout_rate)  # Spatial Dropout

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)  # Apply spatial dropout after first convolution
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.dropout(out)  # Apply spatial dropout after second convolution
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        if self.last_block_layer:
            return out
        else:
            return F.relu(out)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, implement_exp_tied_dropout=False, num_classes=10,
                 p_fixed=0.2, p_mem=0.1, num_batches=100, drop_mode="train",
                 input_channels=3, fac=1, dropout_rate=0.0, network=None, implement_pre_act=False):
        super(ResNet, self).__init__()

        self.implement_exp_tied_dropout = implement_exp_tied_dropout
        self.num_classes = num_classes
        self.in_planes = 64

        # Predefine a max-pooling layer (used for 224x224 inputs)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        if network == 'TinyImagenet':
            self.layer0 = nn.Sequential(
                nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            )
        else:
            self.layer0 = nn.Sequential(
                nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True)
            )

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
        self.adaptive_pool = nn.AdaptiveMaxPool2d((1, 1))
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        if self.implement_exp_tied_dropout:
            self.exp_tied_dropout = ExampleTiedDropout(p_fixed=p_fixed, p_mem=p_mem,
                                                       num_batches=num_batches, drop_mode=drop_mode)

        self.masks = {}
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout(dropout_rate)
        self.spatial_dropout = nn.Dropout2d(dropout_rate)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i, s in enumerate(strides):
            # Only record pre-activation on the last block if needed
            last_block_layer = (i == len(strides) - 1)
            layers.append(block(self.in_planes, planes, s, last_block_layer=last_block_layer))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def apply_mask(self, x, mask):
        # Normalize activations based on percentage of active neurons
        p_on = torch.sum(mask) / mask.numel()
        p_on = p_on if p_on > 0 else 1
        x = x / p_on
        if x.dim() == 2:
            return x * mask
        elif x.dim() == 4:
            return x * mask.view(1, -1, 1, 1)
        else:
            raise ValueError(f"Unsupported tensor dimensions: {x.shape}")

    def record_intermediate(self, x, intermediates, por_neuron):
        pooled = self.adaptive_pool(x).view(x.size(0), -1)
        neuron_reduced_index = int(pooled.shape[1] * por_neuron)
        intermediates.append(pooled[:, :neuron_reduced_index])
        return intermediates

    def forward(self, x, idx=None, por_neuron=1.0):
        intermediates = []

        # Initial convolution block
        out = self.layer0(x)
        # If input size is 224, apply maxpool as defined in __init__
        if x.size(2) == 224:
            out = self.maxpool(out)

        # Layer 1
        pre_act = self.layer1(out)
        intermediates = self.record_intermediate(pre_act, intermediates, por_neuron)
        out = F.relu(pre_act, inplace=True)
        out = self.spatial_dropout(out)

        # Layer 2
        pre_act = self.layer2(out)
        intermediates = self.record_intermediate(pre_act, intermediates, por_neuron)
        out = F.relu(pre_act, inplace=True)
        out = self.spatial_dropout(out)
        if self.implement_exp_tied_dropout and idx is not None:
            out = self.exp_tied_dropout(out, idx)

        # Layer 3
        pre_act = self.layer3(out)
        intermediates = self.record_intermediate(pre_act, intermediates, por_neuron)
        out = F.relu(pre_act, inplace=True)
        out = self.spatial_dropout(out)
        if self.implement_exp_tied_dropout and idx is not None:
            out = self.exp_tied_dropout(out, idx)

        # Layer 4
        pre_act = self.layer4(out)
        intermediates = self.record_intermediate(pre_act, intermediates, por_neuron)
        out = F.relu(pre_act, inplace=True)

        # Global pooling and classification
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.linear(out)

        return out, torch.cat(intermediates, dim=1)
