import torch
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE


class FeatureHook:
    def __init__(self):
        self.features = None

    def hook(self, module, input, output):
        self.features = output

    def get_features(self):
        return self.features


# Load pre-trained model
model = models.resnet18(pretrained=True)

# Create an instance of FeatureHook
hook = FeatureHook()

# Register the hook to the desired layer
target_layer = model.layer4
hook_handle = target_layer.register_forward_hook(hook.hook)

# Load CIFAR10 dataset
dataset = CIFAR10(root='./data', train=False, download=True, transform=ToTensor())
dataloader = DataLoader(dataset, batch_size=100, shuffle=False)

# Extract features from the model
model.eval()
with torch.no_grad():
    for images, labels in dataloader:
        outputs = model(images)

# Detach the hook
hook_handle.remove()

# Get the features
features = hook.get_features()

# Apply global average pooling to the features
features = torch.mean(features, dim=(2, 3))

# Convert features to numpy array
features = features.view(features.size(0), -1).cpu().numpy()

# Apply t-SNE to reduce dimensionality
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
embedded_features = tsne.fit_transform(features)

# Plot the t-SNE visualization
plt.scatter(embedded_features[:, 0], embedded_features[:, 1], c=labels, cmap='tab10')
plt.colorbar()
plt.show()