import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


mnist_config = {
    "input_size": 784,  # img_size = (28,28) ---> 28*28=784 in total
    "hidden_size_1": 13,  # number of nodes at hidden layer | ORIGINAL:500
    "hidden_size_2": 11,  # number of nodes at hidden layer | ORIGINAL:500
    "num_classes": 10,  # number of output classes discrete range [0,9]
}

mnist_patching_conf = {
    "layer_offset": {
        "fc1": 0,
        "fc2": mnist_config["hidden_size_1"],
        "fc3": mnist_config["hidden_size_1"] + mnist_config["hidden_size_2"]
},
    "total_neurons": mnist_config["hidden_size_1"] + mnist_config["hidden_size_2"] + mnist_config["num_classes"]  # 34
}

def get_mnist_neuron_global_order(bottom_up: bool = False):
    offsets = mnist_patching_conf['layer_offset']
    sizes = { 'fc1': mnist_config['hidden_size_1'],
              'fc2': mnist_config['hidden_size_2'],
              'fc3': mnist_config['num_classes']}
    names = ['fc1', 'fc2', 'fc3']
    if bottom_up:
        names.reverse()
    order = []
    for name in names:
        start = offsets[name]
        order.extend(range(start, start + sizes[name]))

    assert len(order) == mnist_patching_conf["total_neurons"], "constructed order length mismatch"
    return order




cifar10_small_config = {
    "conv_layers": [{'name': 'conv1', 'conv_channels': 8}, {'name': 'layer1.0.conv1', 'conv_channels': 16}, {'name':'layer1.0.conv2','conv_channels':16},
                    {'name': 'layer1.1.conv1', 'conv_channels': 16}, {'name': 'layer1.1.conv2', 'conv_channels': 16}],
    "total_channels": 72
}


gtsrb_config = {
    "conv_layers": [{'name': 'conv1', 'conv_channels': 16}, {'name': 'conv2', 'conv_channels': 32}],
    "total_channels": 16 + 32
}

# big_taxinet_config = {
#     "in_ch": 1,
#     "conv1_out": 16,
#     "conv2_out": 32,
#     "fc1_in": 46656,  # flattened features into fc1
#     "fc1_out": 128,
#     "fc2_out": 64,
#     "fc3_out": 1,
#     "conv_layers": [
#         {"name": "conv1", "conv_channels": 16},
#         {"name": "conv2", "conv_channels": 32},
#     ],
#     "total_channels": 16 + 32
# }

taxinet_config = {
    "conv_layers": [
        {"name": "conv1", "conv_channels": 4},
        {"name": "conv2", "conv_channels": 4},
    ],
    "total_channels": 8
}




def neuron_idx_to_neuron_name():
    """
    Build a global neuron‐index → (layer, local_idx) map
    using the fixed MNIST config.
    """
    hidden1 = mnist_config["hidden_size_1"]
    hidden2 = mnist_config["hidden_size_2"]
    num_classes = mnist_config["num_classes"]
    mapping = {}
    idx = 0
    for layer_name, size in [("fc1", hidden1), ("fc2", hidden2), ("fc3", num_classes)]:
        for neuron_idx in range(size):
            mapping[idx] = (layer_name, neuron_idx)
            idx += 1
    return mapping

class Net(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size_1)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.fc3 = nn.Linear(hidden_size_2, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        return out

class BigTaxiNetCNN(nn.Module):
    """
    Two conv layers (3x3, stride=1, pad=1) + ReLU, then 3 linear layers with ReLU between.
    Shapes are inferred from ONNX weights.
    """
    def __init__(self,
                 in_ch: int = 1,
                 conv1_out: int = 16,
                 conv2_out: int = 32,
                 fc1_out: int = 128,
                 fc2_out: int = 64,
                 fc3_out: int = 1,
                 fc1_in: int = 46656):

        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, conv1_out, kernel_size=3, stride=1, padding=1, bias=True)
        self.act = nn.ReLU()
        self.conv2 = nn.Conv2d(conv1_out, conv2_out, kernel_size=3, stride=1, padding=1, bias=True)

        # Flatten -> MLP
        self.fc1 = nn.Linear(fc1_in, fc1_out, bias=True)
        self.fc2 = nn.Linear(fc1_out, fc2_out, bias=True)
        self.fc3 = nn.Linear(fc2_out, fc3_out, bias=True)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)

        return x

class TaxiNetCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1, bias=True)
        self.act = nn.ReLU()
        self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1, bias=True)

        self.fc1 = nn.Linear(5832, 20, bias=True)  # 4 * 27 * 54 = 5832
        self.fc2 = nn.Linear(20, 10, bias=True)
        self.fc3 = nn.Linear(10, 1, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.flatten(x, 1)      # (B, 5832)
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)              # (B, 1)
        return x


class DupTaxiNet(nn.Module):
    def __init__(self):
        super(DupTaxiNet, self).__init__()
        self.net1 = TaxiNetCNN()
        self.net2 = TaxiNetCNN()

    def forward(self, x):
        out1 = self.net1(x)
        out2 = self.net2(x)
        # TODO consider substracting
        out = torch.cat((out1, out2), dim=1)
        return out
        # return out1 - out2




class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.downsample(x)
        out = F.relu(out)
        return out


class CifarResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(CifarResNet, self).__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.fc = nn.Linear(64, num_classes)  # Changed from `linear` to `fc`

    def _make_layer(self, block, planes, num_blocks, stride):
        layers = []
        layers.append(block(self.in_planes, planes, stride))
        self.in_planes = planes * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.fc(out)  # Changed from `linear` to `fc`
        return out

class DupCifarResNet(nn.Module):
    def __init__(self, cifar10_model):
        super(DupCifarResNet, self).__init__()
        self.net1 = cifar10_model()
        self.net2 = cifar10_model()

    def forward(self, x):
        out1 = self.net1(x)
        out2 = self.net2(x)
        out = torch.cat((out1, out2), dim=1)
        return out


class DupGtsrbNet(nn.Module):
    def __init__(self):
        super(DupGtsrbNet, self).__init__()
        self.net1 = GTSRBCNN()
        self.net2 = GTSRBCNN()

    def forward(self, x):
        out1 = self.net1(x)
        out2 = self.net2(x)
        out = torch.cat((out1, out2), dim=1)
        return out


class TripledAdvCifarNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes):
        super(TripledAdvCifarNet, self).__init__()
        self.net1 = CifarResNet(block, num_blocks, num_classes)
        self.net2 = CifarResNet(block, num_blocks, num_classes)
        self.net3 = CifarResNet(block, num_blocks, num_classes)

    def forward(self, x):
        out1 = self.net1(x)
        out2 = self.net2(x)
        out3 = self.net3(x)
        out = torch.cat((out1, out2, out3), dim=1)
        return out


class DupNet(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, num_classes):
        super(DupNet, self).__init__()
        self.net1 = Net(input_size, hidden_size_1, hidden_size_2, num_classes)
        self.net2 = Net(input_size, hidden_size_1, hidden_size_2, num_classes)

    def forward(self, x):
        out1 = self.net1(x)
        out2 = self.net2(x)
        out = torch.cat((out1, out2), dim=1)
        return out


class TripledAdvMnistNet(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, num_classes):
        super(TripledAdvMnistNet, self).__init__()
        self.net1 = Net(input_size, hidden_size_1, hidden_size_2, num_classes)
        self.net2 = Net(input_size, hidden_size_1, hidden_size_2, num_classes)
        self.net3 = Net(input_size, hidden_size_1, hidden_size_2, num_classes)

    def forward(self, x):
        out1 = self.net1(x)
        out2 = self.net2(x)
        out3 = self.net3(x)
        out = torch.cat((out1, out2, out3), dim=1)
        return out


class TripledPatchNetMnist(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, num_classes, Z_mask):
        super(TripledPatchNetMnist, self).__init__()
        self.patch_net = Net(input_size, hidden_size_1, hidden_size_2, num_classes)  # patching network.
        self.pruned_net = Net(input_size, hidden_size_1, hidden_size_2, num_classes)  # pruned network.
        self.full_net = Net(input_size, hidden_size_1, hidden_size_2, num_classes)  # full network.

        mask_tensor = torch.tensor(Z_mask).view(1, -1)
        self.Z1, self.Z2, self.Z3 = torch.split(mask_tensor, [hidden_size_1, hidden_size_2, num_classes], dim=1)
        self.input_size = input_size

    def forward(self, inp):
        X, Y = torch.split(inp, [self.input_size, self.input_size], dim=1)
        Z1 = self.Z1.to(inp.device)
        Z2 = self.Z2.to(inp.device)
        Z3 = self.Z3.to(inp.device)

        patch1 = F.relu(self.patch_net.fc1(X))
        pruned1 = F.relu(self.pruned_net.fc1(Y))
        comb1 = pruned1 * (1 - Z1) + patch1 * Z1

        patch2 = F.relu(self.patch_net.fc2(patch1))
        pruned2 = F.relu(self.pruned_net.fc2(comb1))
        comb2 = pruned2 * (1 - Z2) + patch2 * Z2

        patch3 = self.patch_net.fc3(patch2)
        pruned3 = self.pruned_net.fc3(comb2)
        comb3 = pruned3 * (1 - Z3) + patch3 * Z3

        full_out = self.full_net(Y)

        out = torch.cat((full_out, comb3), dim=1)
        return out

class DupPatchNetMnist(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, num_classes, Z_mask):
        super(DupPatchNetMnist, self).__init__()
        self.patch_net = Net(input_size, hidden_size_1, hidden_size_2, num_classes)  # patching network.
        self.pruned_net = Net(input_size, hidden_size_1, hidden_size_2, num_classes)  # pruned network.

        mask_tensor = torch.tensor(Z_mask).view(1, -1)
        self.Z1, self.Z2, self.Z3 = torch.split(mask_tensor, [hidden_size_1, hidden_size_2, num_classes], dim=1)
        self.input_size = input_size

    def forward(self, inp):
        X, Y = torch.split(inp, [self.input_size, self.input_size], dim=1)
        Z1 = self.Z1.to(inp.device)
        Z2 = self.Z2.to(inp.device)
        Z3 = self.Z3.to(inp.device)

        patch1 = F.relu(self.patch_net.fc1(X))
        pruned1 = F.relu(self.pruned_net.fc1(Y))
        comb1 = pruned1 * (1 - Z1) + patch1 * Z1

        patch2 = F.relu(self.patch_net.fc2(patch1))
        pruned2 = F.relu(self.pruned_net.fc2(comb1))
        comb2 = pruned2 * (1 - Z2) + patch2 * Z2

        patch3 = self.patch_net.fc3(patch2)
        pruned3 = self.pruned_net.fc3(comb2)
        comb3 = pruned3 * (1 - Z3) + patch3 * Z3

        return comb3


class DupPatchGtsrbNet(nn.Module):
    def __init__(self, Z_mask):
        super().__init__()
        self.patch_net = GTSRBCNN()
        self.pruned_net = GTSRBCNN()

        # masks
        self.register_buffer("mask_conv1", torch.tensor(Z_mask["conv1"], dtype=torch.float32).view(1, -1, 1, 1))
        self.register_buffer("mask_conv2", torch.tensor(Z_mask["conv2"], dtype=torch.float32).view(1, -1, 1, 1))

    def forward(self, inp):
        X_patch, Y_pruned = inp[:, :3], inp[:, 3:]

        # --- conv1 ---
        patch_conv1 = self.patch_net.act(self.patch_net.conv1(X_patch))
        pruned_conv1 = self.pruned_net.act(self.pruned_net.conv1(Y_pruned))
        mixed_conv1 = patch_conv1 * self.mask_conv1 + pruned_conv1 * (1 - self.mask_conv1)
        pooled1 = self.pruned_net.pool1(mixed_conv1)
        patch_pooled1 = self.patch_net.pool1(patch_conv1)

        # --- conv2 ---
        patch_conv2 = self.patch_net.act(self.patch_net.conv2(patch_pooled1))
        pruned_conv2 = self.pruned_net.act(self.pruned_net.conv2(pooled1))
        mixed_conv2 = patch_conv2 * self.mask_conv2 + pruned_conv2 * (1 - self.mask_conv2)
        pooled2 = self.pruned_net.pool2(mixed_conv2)

        # --- fully connected ---
        flattened = pooled2.view(pooled2.size(0), -1)
        fc1_out = self.pruned_net.act(self.pruned_net.fc1(flattened))
        logits = self.pruned_net.fc2(fc1_out)
        return logits


class TripledPatchResNetConvHeads(nn.Module):
    def __init__(self, Z_mask: dict):
        super().__init__()
        self.patch_net = cifar10_big()
        self.pruned_net = cifar10_big()
        self.full_net = cifar10_big()
        self.Z_mask = Z_mask

    def forward(self, inp):
        X = inp[:, :3, :, :]
        Y = inp[:, 3:, :, :]

        # --- conv1 head
        p = F.relu(self.patch_net.bn1(self.patch_net.conv1(X)))
        q = F.relu(self.pruned_net.bn1(self.pruned_net.conv1(Y)))
        m = self.Z_mask['conv1']
        mask = p.new_full((1, p.size(1), 1, 1), m)
        out  = p * mask + q * (1 - mask)

        # --- now each BasicBlock in layer1, layer2, layer3 -----------
        for layer_name in ('layer1', 'layer2', 'layer3'):
            patch_seq  = getattr(self.patch_net,  layer_name)
            pruned_seq = getattr(self.pruned_net, layer_name)
            for i, (pb, qb) in enumerate(zip(patch_seq, pruned_seq)):
                prefix = f"{layer_name}.{i}"

                # conv1
                p1 = F.relu(pb.bn1(pb.conv1(out)))
                q1 = F.relu(qb.bn1(qb.conv1(out)))
                m1 = self.Z_mask[f"{prefix}.conv1"]
                mask1 = p1.new_full((1, p1.size(1), 1, 1), m1)
                x1 = p1 * mask1 + q1 * (1 - mask1)

                # downsample: **no mixing**, always use pruned_net’s skip
                if len(pb.downsample) == 2:
                    res = qb.downsample(out)
                else:
                    res = out

                # conv2
                p2 = pb.bn2(pb.conv2(x1))
                q2 = qb.bn2(qb.conv2(x1))
                m2 = self.Z_mask[f"{prefix}.conv2"]
                mask2 = p2.new_full((1, p2.size(1), 1, 1), m2)
                x2 = p2 * mask2 + q2 * (1 - mask2)

                out = F.relu(x2 + res)

        # --- final heads -----------------------------------------
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        comb_out = self.pruned_net.fc(out)
        full_out = self.full_net(Y)
        return torch.cat([full_out, comb_out], dim=1)

class BigDupPatchTaxiNet(nn.Module):
    def __init__(self, Z_mask):
        super().__init__()
        self.patch_net = TaxiNetCNN()
        self.pruned_net = TaxiNetCNN()
        # masks for conv layers (channel-wise)
        self.register_buffer("mask_conv1", torch.tensor(Z_mask["conv1"], dtype=torch.float32).view(1, -1, 1, 1))
        self.register_buffer("mask_conv2", torch.tensor(Z_mask["conv2"], dtype=torch.float32).view(1, -1, 1, 1))

    def forward(self, inp):
        # split along channels into [patch | pruned]
        C_half = inp.size(1) // 2
        X_patch = inp[:, :C_half]
        Y_pruned = inp[:, C_half:]

        # conv1
        patch_conv1 = self.patch_net.act(self.patch_net.conv1(X_patch))
        pruned_conv1 = self.pruned_net.act(self.pruned_net.conv1(Y_pruned))
        mixed_conv1 = patch_conv1 * self.mask_conv1 + pruned_conv1 * (1 - self.mask_conv1)

        # conv2
        patch_conv2 = self.patch_net.act(self.patch_net.conv2(patch_conv1))
        pruned_conv2 = self.pruned_net.act(self.pruned_net.conv2(mixed_conv1))
        mixed_conv2 = patch_conv2 * self.mask_conv2 + pruned_conv2 * (1 - self.mask_conv2)

        # FC head from pruned_net
        z = torch.flatten(mixed_conv2, 1)
        z = self.pruned_net.act(self.pruned_net.fc1(z))
        z = self.pruned_net.act(self.pruned_net.fc2(z))
        out = self.pruned_net.fc3(z)
        return out

class DupPatchTaxiNet(nn.Module):
    def __init__(self, Z_mask):
        super().__init__()
        self.patch_net  = TaxiNetCNN()
        self.pruned_net = TaxiNetCNN()

        self.register_buffer("mask_conv1", torch.tensor(Z_mask["conv1"], dtype=torch.float32).view(1, -1, 1, 1))
        self.register_buffer("mask_conv2", torch.tensor(Z_mask["conv2"], dtype=torch.float32).view(1, -1, 1, 1))

    def forward(self, inp):

        # split along channels into [patch | pruned]
        C_half = inp.size(1) // 2
        X_patch = inp[:, :C_half, :, :]
        Y_pruned = inp[:, C_half:, :, :]

        # conv1
        patch_conv1  = self.patch_net.conv1(X_patch)
        pruned_conv1 = self.pruned_net.conv1(Y_pruned)
        mixed_conv1  = patch_conv1 * self.mask_conv1 + pruned_conv1 * (1 - self.mask_conv1)

        # conv2
        patch_conv2  = self.patch_net.conv2(patch_conv1)
        pruned_conv2 = self.pruned_net.conv2(mixed_conv1)
        mixed_conv2  = patch_conv2 * self.mask_conv2 + pruned_conv2 * (1 - self.mask_conv2)

        # FC head from pruned_net
        z = torch.flatten(mixed_conv2, 1)
        z = self.pruned_net.act(self.pruned_net.fc1(z))
        z = self.pruned_net.act(self.pruned_net.fc2(z))
        out = self.pruned_net.fc3(z)
        return out

def dup_cifar10_net(model_type, **kwargs):
    if model_type == "cifar10-big":
        cifar10_net = cifar10_big
    elif model_type == "cifar10-small":
        cifar10_net = cifar10_small
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    return DupCifarResNet(cifar10_net, **kwargs)


def dup_gtsrb_net():
    return DupGtsrbNet()

def dup_taxinet_net():
    return DupTaxiNet()


def tripled_adv_mnist_model(input_size, hidden_size_1, hidden_size_2, num_classes):
    return TripledAdvMnistNet(input_size, hidden_size_1, hidden_size_2, num_classes)


def tripled_adv_cifar_model():
    return TripledAdvCifarNet(BasicBlock, [3, 3, 3], num_classes=10)

def cifar10_resnet20(**kwargs):
    return CifarResNet(BasicBlock, [3, 3, 3], num_classes=10, **kwargs)

def cifar10_big(**kwargs):
    return cifar10_resnet20(**kwargs)

def tripled_patch_mnist_model(Z_mask):
    return TripledPatchNetMnist(mnist_config['input_size'], mnist_config['hidden_size_1'], mnist_config['hidden_size_2'], mnist_config['num_classes'], Z_mask)

def dup_patch_mnist_model(Z_mask):
    return DupPatchNetMnist(mnist_config['input_size'], mnist_config['hidden_size_1'], mnist_config['hidden_size_2'], mnist_config['num_classes'], Z_mask)

def dup_patch_gtsrb_model(Z_mask):
    return DupPatchGtsrbNet(Z_mask)

def dup_patch_taxinet_model(Z_mask):
    return DupPatchTaxiNet(Z_mask)

######################### SMALL CIFAR-10 MODEL

class SmallBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bn=True, kernel=3):
        super(SmallBasicBlock, self).__init__()
        self.bn = bn
        if kernel == 3:
            self.conv1 = nn.Conv2d(
                in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=(not self.bn))
            if self.bn:
                self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                                   stride=1, padding=1, bias=(not self.bn))
        elif kernel == 2:
            self.conv1 = nn.Conv2d(
                in_planes, planes, kernel_size=2, stride=stride, padding=1, bias=(not self.bn))
            if self.bn:
                self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=2,
                                   stride=1, padding=0, bias=(not self.bn))
        elif kernel == 1:
            self.conv1 = nn.Conv2d(
                in_planes, planes, kernel_size=1, stride=stride, padding=0, bias=(not self.bn))
            if self.bn:
                self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=1,
                                   stride=1, padding=0, bias=(not self.bn))
        else:
            exit("kernel not supported!")

        if self.bn:
            self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            if self.bn:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion*planes,
                              kernel_size=1, stride=stride, bias=(not self.bn)),
                    nn.BatchNorm2d(self.expansion*planes)
                )
            else:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion*planes,
                              kernel_size=1, stride=stride, bias=(not self.bn)),
                )

    def forward(self, x):
        if self.bn:
            out = F.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
        else:
            out = F.relu(self.conv1(x))
            out = self.conv2(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet5(nn.Module):
    def __init__(self, block, num_blocks=2, num_classes=10, in_planes=64, bn=True, last_layer="avg"):
        super(ResNet5, self).__init__()
        self.in_planes = in_planes
        self.bn = bn
        self.last_layer = last_layer
        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3,
                               stride=2, padding=1, bias=not self.bn)
        if self.bn: self.bn1 = nn.BatchNorm2d(in_planes)
        self.layer1 = self._make_layer(block, in_planes*2, num_blocks, stride=2, bn=bn, kernel=3)
        if self.last_layer == "avg":
            self.avg2d = nn.AvgPool2d(4)
            self.linear = nn.Linear(in_planes * 8 * block.expansion, num_classes)
        elif self.last_layer == "dense":
            self.linear1 = nn.Linear(in_planes * 8 * block.expansion * 16, 100)
            self.linear2 = nn.Linear(100, num_classes)
        else:
            exit("last_layer type not supported!")

    def _make_layer(self, block, planes, num_blocks, stride, bn, kernel):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn, kernel))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        if self.bn:
            out = F.relu(self.bn1(self.conv1(x)))
        else:
            out = F.relu(self.conv1(x))
        out = self.layer1(out)
        if self.last_layer == "avg":
            out = self.avg2d(out)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
        elif self.last_layer == "dense":
            out = torch.flatten(out, 1)
            out = F.relu(self.linear1(out))
            out = self.linear2(out)
        return out


class ResNet9(nn.Module):
    def __init__(self, block, num_blocks=2, num_classes=10, in_planes=64, bn=True, last_layer="avg"):
        super(ResNet9, self).__init__()
        self.in_planes = in_planes
        self.bn = bn
        self.last_layer = last_layer
        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3,
                               stride=2, padding=1, bias=not self.bn)
        if self.bn: self.bn1 = nn.BatchNorm2d(in_planes)
        self.layer1 = self._make_layer(block, in_planes*2, num_blocks, stride=2, bn=bn, kernel=3)
        self.layer2 = self._make_layer(block, in_planes*2, num_blocks, stride=2, bn=bn, kernel=3)
        if self.last_layer == "avg":
            self.avg2d = nn.AvgPool2d(4)
            self.linear = nn.Linear(in_planes * 2 * block.expansion, num_classes)
        elif self.last_layer == "dense":
            self.linear1 = nn.Linear(in_planes * 2 * block.expansion * 16, 100)
            self.linear2 = nn.Linear(100, num_classes)
        else:
            exit("last_layer type not supported!")

    def _make_layer(self, block, planes, num_blocks, stride, bn, kernel):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn, kernel))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        if self.bn:
            out = F.relu(self.bn1(self.conv1(x)))
        else:
            out = F.relu(self.conv1(x))
        out = self.layer1(out)
        out = self.layer2(out)
        if self.last_layer == "avg":
            out = self.avg2d(out)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
        elif self.last_layer == "dense":
            out = torch.flatten(out, 1)
            out = F.relu(self.linear1(out))
            out = self.linear2(out)
        return out


class GTSRBCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Table 5 architecture:
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool1 = nn.AvgPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool2 = nn.AvgPool2d(2)
        self.fc1  = nn.Linear(2048, 128)
        self.fc2  = nn.Linear(128, 43)
        self.act  = nn.ReLU()

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.pool1(x)
        x = self.act(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)     # should be 4608 features
        x = self.act(self.fc1(x))
        return self.fc2(x)            # logits; softmax handled by loss



class DupPatchCifar10SmallNet(nn.Module):
    def __init__(self, Z_mask):
        super().__init__()
        self.patch_net  = resnet2b()
        self.pruned_net = resnet2b()

        # masks
        self.register_buffer("mask_conv1",    torch.tensor(Z_mask["conv1"],        dtype=torch.float32).view(1,-1,1,1))
        self.register_buffer("mask_l1_0_c1",  torch.tensor(Z_mask["layer1.0.conv1"],dtype=torch.float32).view(1,-1,1,1))
        self.register_buffer("mask_l1_0_c2",  torch.tensor(Z_mask["layer1.0.conv2"],dtype=torch.float32).view(1,-1,1,1))
        self.register_buffer("mask_l1_1_c1",  torch.tensor(Z_mask["layer1.1.conv1"],dtype=torch.float32).view(1,-1,1,1))
        self.register_buffer("mask_l1_1_c2",  torch.tensor(Z_mask["layer1.1.conv2"],dtype=torch.float32).view(1,-1,1,1))

        # pruned shortcuts (using your naming)
        blk0 = self.pruned_net.layer1[0]
        self.shortcut0 = blk0.shortcut if len(blk0.shortcut)>0 else nn.Identity()
        blk1 = self.pruned_net.layer1[1]
        self.shortcut1 = blk1.shortcut if len(blk1.shortcut)>0 else nn.Identity()
        # patch shortcuts
        blk0_p = self.patch_net.layer1[0]
        self.shortcut0_patch = blk0_p.shortcut if len(blk0_p.shortcut)>0 else nn.Identity()

    def forward(self, inp):
        X_patch, Y_pruned = inp[:, :3], inp[:, 3:]

        # initial conv1
        patch_conv1  = F.relu(self.patch_net.conv1(X_patch))
        pruned_conv1 = F.relu(self.pruned_net.conv1(Y_pruned))
        mixed_conv1  = patch_conv1 * self.mask_conv1 + pruned_conv1 * (1 - self.mask_conv1)

        # ---- block0 ----
        # conv1
        patch0_c1  = F.relu(self.patch_net.layer1[0].conv1(patch_conv1))
        pruned0_c1 = F.relu(self.pruned_net.layer1[0].conv1(mixed_conv1))
        mixed0_c1  = patch0_c1 * self.mask_l1_0_c1 + pruned0_c1 * (1 - self.mask_l1_0_c1)

        # conv2
        patch0_c2  = self.patch_net.layer1[0].conv2(patch0_c1)
        pruned0_c2 = self.pruned_net.layer1[0].conv2(mixed0_c1)
        mixed0_c2  = patch0_c2 * self.mask_l1_0_c2 + pruned0_c2 * (1 - self.mask_l1_0_c2)

        out_patch_block0 = F.relu(patch0_c2 + self.shortcut0_patch(patch_conv1))
        out_pruned_block0 = F.relu(mixed0_c2 + self.shortcut0(mixed_conv1))

        # ---- block1 ----
        # conv1
        patch1_c1  = F.relu(self.patch_net.layer1[1].conv1(out_patch_block0))
        pruned1_c1 = F.relu(self.pruned_net.layer1[1].conv1(out_pruned_block0))
        mixed1_c1  = patch1_c1 * self.mask_l1_1_c1 + pruned1_c1 * (1 - self.mask_l1_1_c1)

        # conv2
        patch1_c2  = self.patch_net.layer1[1].conv2(patch1_c1)
        pruned1_c2 = self.pruned_net.layer1[1].conv2(mixed1_c1)
        mixed1_c2  = patch1_c2 * self.mask_l1_1_c2 + pruned1_c2 * (1 - self.mask_l1_1_c2)

        pruned_out_block1 = F.relu(mixed1_c2 + self.shortcut1(out_pruned_block0))

        # final head
        feat   = torch.flatten(pruned_out_block1, 1)
        hidden = F.relu(self.pruned_net.linear1(feat))
        logits = self.pruned_net.linear2(hidden)
        return logits


def resnet2b():
    return ResNet5(SmallBasicBlock, num_blocks=2, in_planes=8, bn=False, last_layer="dense")

def resnet4b():
    return ResNet9(SmallBasicBlock, num_blocks=2, in_planes=16, bn=False, last_layer="dense")

def cifar10_small(**kwargs):
    return resnet2b()

def dup_patch_cifar10_model(Z_mask):
    return DupPatchCifar10SmallNet(Z_mask)

### loading

def tripled_patch_cifar10_model(Z_mask):
    return TripledPatchResNetConvHeads(Z_mask)

def FC_model(input_size, hidden_size_1, hidden_size_2, num_classes):
    return DupNet(input_size, hidden_size_1, hidden_size_2, num_classes)


def load_mnist_batch(spec, mnist_batch_path, patching=False, patch_eps=None, verify_gold_label=False):
    """Loads a preprocessed MNIST sample from an npy file."""
    eps = spec["epsilon"]
    assert eps is not None, "You must specify an epsilon"

    data = np.load(mnist_batch_path, allow_pickle=True).item()
    X_np = data['X']
    labels_np = data['label']
    predicted_np = data['predicted']
    runner_np = data['runner_up']
    target_np = data['winner_logit']

    X = torch.from_numpy(X_np).float()  # (B, 784)
    labels = torch.from_numpy(labels_np).view(-1, 1)  # (B,1)
    runner = torch.from_numpy(runner_np).view(-1, 1)  # (B,1)
    target = torch.from_numpy(target_np).view(-1, 1)  # (B,1)
    predicted = torch.from_numpy(predicted_np).view(-1, 1)  # (B,1)

    # verifying against full networks predictions vs verifying against gold labels
    verification_labels = labels if verify_gold_label  else predicted

    if patching:
        assert patch_eps is not None
        Y = X.clone().detach()
        Y_eps = torch.full_like(Y, eps)
        X_eps = torch.full_like(X, patch_eps)
        # Concatenate along feature‐axis
        ret_eps = torch.cat((X_eps, Y_eps), dim=1)  # (B, 2×784)
        inp = torch.cat((X, Y), dim=1)  # (B, 2×784)
    else:
        ret_eps = torch.tensor(eps).view(1, 1)  # broadcast
        inp = X  # (B, 784)

    return inp, verification_labels, torch.ones_like(inp), torch.zeros_like(inp), ret_eps, runner, target

def load_vision_batch(dataset, spec, batch_path, patching=False, patch_eps=None, verify_gold_label=False):
    if dataset not in ['cifar10', 'gtsrb']: raise ValueError("Unsupported dataset. Use 'cifar10' or 'gtsrb'.")

    eps = spec["epsilon"]
    assert eps is not None, "you must specify an epsilon"
    data = np.load(batch_path, allow_pickle=True).item()
    X_np = data['X'] # (B, 3, 32, 32)
    labels_np = data['label'] # (B,)
    predicted_np = data['predicted']  # (B,)
    runner_np = data['runner_up'] # (B,)
    target_np = data['winner_logit']

    X = torch.from_numpy(X_np).float() # (B,3,32,32)
    labels = torch.from_numpy(labels_np).view(-1, 1) # (B,1)
    runner = torch.from_numpy(runner_np).view(-1, 1) # (B,1)
    target = torch.from_numpy(target_np).view(-1, 1) # (B,1)
    predicted = torch.from_numpy(predicted_np).view(-1, 1)  # (B,1)

    # verifying against full networks predictions vs verifying against gold labels
    verification_labels = labels if verify_gold_label else predicted

    if dataset == 'cifar10':
        base_min = torch.tensor([-2.42907, -2.41825, -2.22139], dtype=torch.float32).view(1, 3, 1, 1)
        base_max = torch.tensor([2.51409, 2.59679, 2.75373], dtype=torch.float32).view(1, 3, 1, 1)
    else:  # gtsrb
        # these are dummy values, should be replaced with real ones for GTSRB
        base_min = torch.zeros(1, 3, 1, 1, dtype=torch.float32)
        base_max = torch.ones(1, 3, 1, 1, dtype=torch.float32)

    if patching:
        assert patch_eps is not None
        Y = X.clone()
        Y_eps = torch.full_like(Y, eps)
        X_eps = torch.full_like(X, patch_eps)
        ret_eps = torch.cat((X_eps, Y_eps), dim=1)  # (B,6,32,32)
        inp = torch.cat((X, Y), dim=1)  # (B,6,32,32)
        base_min = base_min.repeat(1, 2, 1, 1)  # (1, 6, 1, 1) - 2 copies of 3 channels
        base_max = base_max.repeat(1, 2, 1, 1)  # (1, 6, 1, 1) - 2 copies of 3 channels
    else:
        ret_eps = torch.tensor(eps).view(1, 1, 1, 1)  # broadcast
        inp = X  # (B,3,32,32)

    data_min = base_min.expand_as(inp)
    data_max = base_max.expand_as(inp)

    return inp, verification_labels, data_max, data_min, ret_eps, runner, target

def load_taxinet_batch(spec, taxinet_batch_path, patching=False, patch_eps=None, verify_gold_label=False):
    """
    Load a TaxiNet batch saved via np.save(dict) from prepare_taxinet_batch.
    Returns (inp, verification_labels, data_max, data_min, ret_eps, runner, target)
    to match the CIFAR/GTSRB loader signature.

    Dict keys expected in the .npy:
      - "X":        (B, C, H, W) float32
      - "y_true":   (B,) float32
      - "y_pred":   (B,) float32
      - "abs_error":(B,) float32   (not used here, but handy to keep)
    """
    eps = spec["epsilon"]
    assert eps is not None, "you must specify an epsilon"

    data = np.load(taxinet_batch_path, allow_pickle=True).item()
    X_np      = data["X"]                         # (B,C,H,W)
    y_true_np = data["y_true"]                    # (B,)
    y_pred_np = data["y_pred"]                    # (B,)

    X = torch.from_numpy(X_np).float()            # (B,C,H,W)
    y_true = torch.from_numpy(y_true_np).view(-1, 1).float()  # (B,1)
    y_pred = torch.from_numpy(y_pred_np).view(-1, 1).float()  # (B,1)

    # For regression, "verification_labels" is simply the 0 entry, as there is only one entry
    verification_labels = torch.zeros(y_true.size(), dtype=torch.int)


    C = X.size(1)
    # Base box (assume inputs already normalized to ~[0,1])
    base_min = torch.zeros(1, C, 1, 1, dtype=torch.float32)
    base_max = torch.ones(1,  C, 1, 1, dtype=torch.float32)

    ### todo consider doing entire taxi net without channel dimension.
    if patching:
        assert patch_eps is not None, "patch_eps required when patching=True"
        Y = X.clone()
        # eps tensors for each half
        X_eps = torch.full_like(X, patch_eps)
        Y_eps = torch.full_like(Y, eps)

        inp = torch.cat((X, Y), dim=1)           # (B, 2C, H, W)
        ret_eps = torch.cat((X_eps, Y_eps), dim=1)
        # expand bounds for 2C channels
        base_min = base_min.repeat(1, 2, 1, 1)
        base_max = base_max.repeat(1, 2, 1, 1)
    else:
        inp = X                                   # (B,C,H,W)
        ret_eps = torch.tensor(eps, dtype=torch.float32).view(1, 1, 1, 1)

    data_min = base_min.expand_as(inp)
    data_max = base_max.expand_as(inp)

    # For API compatibility: return placeholders for runner/target like vision loader.
    runner = torch.zeros_like(verification_labels)      # no runner-up in regression
    target = y_pred                                     # use model's scalar prediction

    return inp, verification_labels, data_max, data_min, ret_eps, runner, target

def load_gtsrb_batch(spec, gtsrb_batch_path, patching=False, patch_eps=None, verify_gold_label=False):
    """
    Load a batch of GTSRB data from a .npy file.
    """
    return load_vision_batch('gtsrb', spec, gtsrb_batch_path, patching, patch_eps, verify_gold_label)


def load_cifar10_batch(spec, cifar10_batch_path, patching=False, patch_eps=None, verify_gold_label=False):
    """
    Load a batch of CIFAR-10 data from a .npy file.
    """
    return load_vision_batch('cifar10', spec, cifar10_batch_path, patching, patch_eps, verify_gold_label)
