import torch
import torch.nn as nn 
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vgg16
import numpy as np

# 1. Data preprocessing and DataLoader
transform = transforms.Compose([
    transforms.Resize(224),  # VGG16 pretrained model expects 224x224 input
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),  # CIFAR-10 mean
                         (0.2023, 0.1994, 0.2010))  # CIFAR-10 std
])

train_set = torchvision.datasets.CIFAR10(
    root='/public/torchvision_datasets',
    train=True,
    download=True,
    transform=transform
)

test_set = torchvision.datasets.CIFAR10(
    root='/public/torchvision_datasets',
    train=False,
    download=True,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=False, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)

# 2. Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3. Load pretrained VGG16 and build the feature extraction model
#    - VGG16 consists of features, avgpool, and classifier
#    - Here we keep only the features part, use global average pooling, then flatten
vgg = vgg16(pretrained=True)  # Load pretrained model
vgg.eval()                    # Set to inference mode

# Custom feature extractor, output shape will be (N, 512)
feature_extractor = nn.Sequential(
    vgg.features,                     # Convolution + pooling layers
    nn.AdaptiveAvgPool2d((1, 1)),     # Global average pooling -> (N, 512, 1, 1)
    nn.Flatten()                      # Flatten -> (N, 512)
).to(device)

# No training required, so no need to freeze parameters or define optimizer/loss

# 4. Define a function to perform inference and collect features and labels
def extract_features(data_loader):
    all_feats = []
    all_labels = []
    feature_extractor.eval()  # Ensure eval mode

    total_batches = len(data_loader)  # Total number of batches
    with torch.no_grad():  # Disable gradient computation
        for batch_idx, (inputs, labels) in enumerate(data_loader):
            inputs = inputs.to(device)
            feats = feature_extractor(inputs)  # (batch_size, 512)
            all_feats.append(feats.cpu().numpy())
            all_labels.append(labels.numpy())

            # Print progress
            print(f"\rProcessing batch {batch_idx + 1}/{total_batches}", end="")

    # Concatenate into a single array
    all_feats = np.concatenate(all_feats, axis=0)   # [num_samples, 512]
    all_labels = np.concatenate(all_labels, axis=0) # [num_samples]
    print()  # Newline after progress
    return all_feats, all_labels


# 5. Extract features for training and test sets
train_feats, train_labels = extract_features(train_loader)
test_feats, test_labels = extract_features(test_loader)

# 6. Save to npz files
np.savez("cifar_train.npz", images=train_feats, labels=train_labels)
np.savez("cifar_test.npz",  images=test_feats,  labels=test_labels)

print("Feature extraction complete!")
print(f"Train feature shape: {train_feats.shape}, Test feature shape: {test_feats.shape}")


