import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet101, resnet152
from tqdm import tqdm

# Step 1: Load and preprocess the CIFAR-100 dataset
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=CIFAR100_TRAIN_MEAN, std=CIFAR100_TRAIN_STD)
])

train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4)

test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)

# Step 2: Load a pre-trained ResNet model
model = resnet101(pretrained=True)
model.eval()

# Step 3: Modify the model to extract features from the last fc layer
# model = torch.nn.Sequential(*list(model.children())[:-1], torch.nn.Flatten(), model.fc)
model = torch.nn.Sequential(*list(model.children())[:-1], torch.nn.Flatten())

# Step 4: Iterate over the dataset and store features, filenames, and labels
def extract(train_dataset, train_loader, model):
    features = []
    image_idices = []
    labels = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    with torch.no_grad():
        for idx, (inputs, targets) in tqdm(enumerate(train_loader)):
            inputs = inputs.to(device)

        # Extract the features from the inputs
            outputs = model(inputs)
            outputs = outputs.cpu()

        # Save the features, filenames, and labels
            features.append(outputs)
            image_idices.append(torch.range(idx * train_loader.batch_size, (idx+1) * train_loader.batch_size))
            labels.append(targets)

    features = torch.cat(features, 0)
    image_idices = torch.cat(image_idices, 0)
    labels = torch.cat(labels, 0)

    return features,image_idices,labels

train_features, train_image_idices, train_labels = extract(train_dataset, train_loader, model)
test_features, test_image_idices, test_labels = extract(test_dataset, test_loader, model)

print("Feature extraction completed.")

# use torch to save the extracted features, filenames, and labels to disk

save_dict = {'train_features': train_features, 'train_image_idices': train_image_idices, 'train_labels': train_labels, 'test_features': test_features, 'test_image_idices': test_image_idices, 'test_labels': test_labels}
torch.save(save_dict, "cifar100-resnet101-features.pt")

# Step 5: Load the extracted features, filenames, and labels from disk
save_dict = torch.load("cifar100-resnet101-features.pt")
train_features = save_dict['train_features']
train_image_idices = save_dict['train_image_idices']
train_labels = save_dict['train_labels']
test_features = save_dict['test_features']
test_image_idices = save_dict['test_image_idices']
test_labels = save_dict['test_labels']
