
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from colorcubenet import CustomEfficientNet  # Import the CustomEfficientNet model
from colorcube import ColorCubeTransform  # Import the ColorCubeTransform for preprocessing
from torchvision import transforms


# Set the device to GPU or CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the model
model_path = 'Model_Path'  # Path to your trained model
model = CustomEfficientNet(num_classes=2)  # ColorCubeNet
model.load_state_dict(torch.load(model_path, map_location=device))  # Load the model weights
model = model.to(device)
model.eval()  # Set model to evaluation mode

# Preprocessing pipeline for ColorCube
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    ColorCubeTransform(),
    transforms.Normalize(mean=[0.485] * 9, std=[0.229] * 9)  # Normalize for 9 channels
])

# Path to a sample image
image_path = "xxxxxxxxxxxxxxx"  # Replace with the path to your image
img = Image.open(image_path)

# Preprocess the image and convert to tensor
img_tensor = preprocess(img).unsqueeze(0).to(device)  # Add batch dimension and move to device

# Forward pass through the model
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)
confidence_score, _ = torch.max(probabilities, dim=1)
score = output[0, _.item()]
S
# Perform backward pass to generate Grad-CAM
model.zero_grad()
score.backward(retain_graph=True)

# Gradients Stored in  self.gradients and  self.activations must be accessed.

# def get_activations_hook(module, input, output):
#     global activations
#     activations = output

# # Select the final residual layer for Grad-CAM
# target_layer = model.residual_bn 

# # Register hooks
# target_layer.register_forward_hook(get_activations_hook)
# target_layer.register_backward_hook(lambda module, grad_in, grad_out: save_gradients(grad_out[0]))




# You should have hooks in the model to extract them
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])  # Pool the gradients
for i in range(pooled_gradients.size(0)):
    activations[:, i, :, :] *= pooled_gradients[i]

# Generate heatmap from activations
heatmap = torch.mean(activations, dim=1).squeeze().detach().cpu().numpy()
heatmap = np.maximum(heatmap, 0)  # ReLU: set negative values to 0
heatmap /= np.max(heatmap)  # Normalize heatmap

# Resize heatmap to match input image size
heatmap = cv2.resize(heatmap, (img.width, img.height))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

# Superimpose the heatmap on the image
superimposed_img = cv2.addWeighted(heatmap, 0.4, np.array(img), 0.6, 0)

# Calculate SNR from the Grad-CAM heatmap
mean_signal = np.mean(heatmap)
std_noise = np.std(heatmap)
snr = mean_signal / std_noise if std_noise != 0 else float('inf')

# Print SNR and confidence score
print(f"Confidence Score: {confidence_score.item()}")
print(f"SNR: {snr:.4f}")

# Display the superimposed image
plt.figure(figsize=(8, 8))
plt.imshow(superimposed_img)
plt.axis('off')
plt.title(f'SNR: {snr:.4f}')
plt.show()
