"""
Extracts CIFAR-10 embeddings using a pre-trained ResNet-50 and saves the dataset to a MAT file.
"""

import os
import sys
import time
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import numpy as np
from scipy.io import savemat

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def create_resnet50_embedded_cifar10_dataset(train_cache_path="cifar10_train_embeddings.pt",
                                             test_cache_path="cifar10_test_embeddings.pt"):
    """
    Loads CIFAR-10 and uses a pre-trained ResNet-50 to extract the final pooled 2048-dimensional features.
    The features and labels are cached to disk for faster subsequent runs.
    """
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for CIFAR-10...")

    # Define transforms: resize CIFAR-10 images (32x32) to 224x224, then normalize as per ImageNet.
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Download CIFAR-10 datasets (download if not present)
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

    # Load pre-trained ResNet-50 and remove its final fully-connected layer.
    resnet50 = models.resnet50(pretrained=True)
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)

    # Extract features for training data.
    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)  # (batch, 2048, 1, 1)
            features = features.view(features.size(0), -1)  # Flatten to (batch, 2048)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)

    # Extract features for test data.
    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)

    # Save the embedded dataset to disk.
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")

    return X_train, y_train, X_test, y_test

if __name__ == "__main__":
    # Compute or load CIFAR-10 embeddings.
    X_train, y_train, X_test, y_test = create_resnet50_embedded_cifar10_dataset()

    # Convert tensors to NumPy arrays.
    x_train = X_train.cpu().numpy()
    y_train = y_train.cpu().numpy()
    x_test  = X_test.cpu().numpy()
    y_test  = y_test.cpu().numpy()

    # Save the data to a MAT file.
    savemat("BayesianNN_CIFAR10_data.mat", {
        "x_train": x_train,
        "y_train": y_train,
        "x_test": x_test,
        "y_test": y_test
    })
    print("CIFAR-10 data saved to BayesianNN_CIFAR10_data.mat")
