import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import models.resnet_models
import data_loaders
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 选择模型，这里使用预训练的ResNet18
model = models.resnet_models.resnet19(num_classes=100)
model.load_state_dict(torch.load('base.pth'))
model=model.to(device)
model.eval()  # 设置为评估模式

activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

# 注册每个层的hook
model.conv1_s.register_forward_hook(get_activation('conv1_s'))
model.fc1_s.register_forward_hook(get_activation('fc1_s'))

# 数据加载
train_dataset, val_dataset = data_loaders.build_cifar()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True,
                                               num_workers=16, pin_memory=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256,
                                              shuffle=False, num_workers=16, pin_memory=True)

# 传递数据通过模型

# 从测试集中获取一个batch的数据
images, _ = next(iter(test_loader))
images=images.to(device)
# 使用模型
output = model(images)
print(output)
# 绘制激活分布
for layer, act in activations.items():
    plt.figure()
    print(act)
    # 找到所有0.6到1.4之间的值
    in_range = (act > 0) & (act < 2)

    # 计算在范围内的元素数
    in_range_count = in_range.sum().item()

    # 计算总元素数
    total_count = act.numel()

    # 计算在范围内的元素占比
    in_range_ratio = in_range_count / total_count
    out_of_range_ratio = 1 - in_range_ratio

    # 数据准备
    labels = 'In Range (0-2)', 'Out of Range'
    sizes = [in_range_ratio, out_of_range_ratio]
    colors = ['lightblue', 'lightcoral']
    explode = (0.1, 0)  # 只突出显示第一块（在范围内的）
    plt.pie(sizes, explode=explode, labels=labels, colors=colors,
            autopct='%1.1f%%', shadow=True, startangle=140)
    plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    plt.title('Proportion of Elements within the Range 0.6 to 1.4')
    plt.savefig(f'{layer}bing.png')
    plt.show()