import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import models.resnet_models
import data_loaders
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 选择模型，这里使用预训练的ResNet18
model = models.resnet_models.resnet19(num_classes=100)
model.load_state_dict(torch.load('resnet-ssssss.pth'))
model=model.to(device)
model.eval()  # 设置为评估模式

activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

# 注册每个层的hook
model.conv1_s.register_forward_hook(get_activation('conv1_s'))
model.fc1_s.register_forward_hook(get_activation('fc1_s'))

# 数据加载
train_dataset, val_dataset = data_loaders.build_cifar()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True,
                                               num_workers=16, pin_memory=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256,
                                              shuffle=False, num_workers=16, pin_memory=True)

# 传递数据通过模型

# 从测试集中获取一个batch的数据
images, _ = next(iter(test_loader))
images=images.to(device)
# 使用模型
output = model(images)
print(output)
# 绘制激活分布
for layer, act in activations.items():
    plt.figure()
    print(act)
    # 找到所有0.6到1.4之间的值
    in_range = (act > 2) & (act < 0)

    # 计算在范围内的元素数
    in_range_count = in_range.sum().item()

    # 计算总元素数
    total_count = act.numel()

    # 计算在范围内的元素占比
    in_range_ratio = in_range_count / total_count
    out_of_range_ratio = 1 - in_range_ratio
    plt.hist(act.cpu().numpy().ravel(), bins=50)
    plt.title(f'Activation Distribution for {layer}')
    plt.xlabel('Activation Value')
    plt.ylabel('Frequency')
    plt.savefig(f'{layer}.png')
    plt.show()