import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from MiniResNetPlus import MiniResNetPlus  


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = MiniResNetPlus(num_classes=100)
model.load_state_dict(torch.load("checkpoints_relu_prob_sigmoid/model_epoch20.pth", map_location=device))
model.eval()


reconstructed_model = MiniResNetPlus(num_classes=100)
reconstructed_model.load_state_dict(torch.load("reconstructed_modelplus_layers_new_20.pth", map_location=device))

reconstructed_model.eval()


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)





def prediction_agreement(model1, model2, dataloader, device):
    model1.to(device).eval()
    model2.to(device).eval()
    total = 0
    same_prediction = 0
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            preds1 = outputs1.argmax(dim=1)
            preds2 = outputs2.argmax(dim=1)
            same_prediction += (preds1 == preds2).sum().item()
            total += inputs.size(0)
    return 100 * same_prediction / total


def average_cosine_similarity(model1, model2, dataloader, device):
    model1.to(device).eval()
    model2.to(device).eval()
    cosine_sim_total = 0
    total = 0
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            out1 = F.softmax(model1(inputs), dim=1)
            out2 = F.softmax(model2(inputs), dim=1)
            sim = F.cosine_similarity(out1, out2, dim=1)
            cosine_sim_total += sim.sum().item()
            total += inputs.size(0)
    return cosine_sim_total / total


agreement = prediction_agreement(model, reconstructed_model, testloader, device)
cos_sim = average_cosine_similarity(model, reconstructed_model, testloader, device)



print(f"Prediction consistency rate: {agreement:.2f}%")
print(f"Average cosine similarity: {cos_sim:.4f}")
