import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):
    def __init__(self, data_dict):
        self.data_dict = data_dict
        self.samples = [sample for samples in data_dict.values() for sample in samples]  # 获取所有样本
        self.labels = [label for label in data_dict.keys() for _ in range(len(data_dict[label]))]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (torch.from_numpy(self.samples[idx]).float(), torch.tensor(self.labels[idx], dtype=torch.long))


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc3 = nn.Linear(768, 10)

    def forward(self, x):
        x = self.fc3(x)
        return F.softmax(x, dim=1)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyNet().to(device)
model.load_state_dict(torch.load('cifar10_best_model.pth'))
model.eval()


cifar100_test_features_file = "../cifar10_test_features.pkl"
with open(cifar100_test_features_file, "rb") as f:
    cifar100_test_features = pickle.load(f)


test_dataset = MyDataset(cifar100_test_features)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)


def test_per_class_accuracy(model, dataloader, device, num_classes):
    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            for i in range(len(labels)):
                label = labels[i].item()
                class_correct[label] += (predicted[i] == labels[i]).item()
                class_total[label] += 1

    class_accuracies = [class_correct[i] / class_total[i] if class_total[i] > 0 else 0 for i in range(num_classes)]
    average_accuracy = sum(class_accuracies) / num_classes
    return class_accuracies, average_accuracy


class_accuracies, average_accuracy = test_per_class_accuracy(model, test_dataloader, device, num_classes=10)

print(f"Average Classification Accuracy for 10 Classes: {average_accuracy:.4f}")
