import torch
from torchvision import models

from activation_optimization_by_channel import validate_model

model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
layer = model.features[10]
channel_index = 0  # Replace with the index of the channel you want to modify

# Create a mask that sets the desired channel's output to 0
mask = torch.ones(layer.weight.shape)
mask[:, channel_index, :, :] = 0

# Modify the layer's weight using the mask
layer.weight.data *= mask

# (Optional) If the layer has a bias term, set the bias of the desired channel to 0
layer.bias.data[channel_index] = 0

model=model.to('cuda')
model.eval()
accuracy = validate_model(model, True,)

print('Accuracy: ', accuracy)
torch.save(accuracy, 'alexnet/blocked_channel_accuracy.pt')
