import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import clip
from tqdm import tqdm  # For progress bar
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def main():
    # Load CLIP model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device)

    # Custom preprocessing to handle single-channel images (grayscale)
    custom_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB
        transforms.Resize((224, 224)),  # Resize for CLIP
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),  # CLIP normalization
    ])

    # MNIST Dataset
    mnist_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=custom_transform)

    # FMNIST Dataset
    fmnist_dataset = datasets.FashionMNIST(root="./data", train=True, download=True, transform=custom_transform)

    # Create DataLoaders for batch processing
    batch_size = 128  # You can adjust the batch size
    mnist_loader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    fmnist_loader = DataLoader(fmnist_dataset, batch_size=batch_size, shuffle=True, num_workers=0)


    # CIFAR-10 dataset loading and preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),  # CLIP model expects 224x224 input size
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),  # CLIP-specific normalization
    ])

    # Load CIFAR-10 dataset
    cifar10_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    cifar10_loader = DataLoader(cifar10_data, batch_size=batch_size, shuffle=True, num_workers=0)

    # Function to extract features with batch processing and progress bar
    def extract_features_with_progress_bar(data_loader):
        features, labels = [], []
        # Set up tqdm for the progress bar
        for batch_idx, (images, y) in enumerate(tqdm(data_loader, desc="Extracting features")):
            # Move images to device
            images = images.to(device)

            # Extract features
            with torch.no_grad():
                batch_features = model.encode_image(images)

            # Normalize features
            batch_features /= batch_features.norm(dim=-1, keepdim=True)

            # Append batch features to the list
            features.append(batch_features.cpu().numpy())
            labels.append(y)

        # Concatenate all batch features into one array
        return np.concatenate(features, axis=0), np.concatenate(labels)
    
    # # Extract features for MNIST and FMNIST
    # print("Extracting MNIST features...")
    # mnist_features = extract_features_with_progress_bar(mnist_loader)
    #
    # print("Extracting FMNIST features...")
    # fmnist_features = extract_features_with_progress_bar(fmnist_loader)

    # Save the features as NumPy arrays
    # np.save("mnist_features.npy", mnist_features)
    # np.save("fmnist_features.npy", fmnist_features)

    features, labels = extract_features_with_progress_bar(cifar10_loader)
    torch.save((features, labels), f'./features/cifar10_features_clip.tar')

    print("Feature extraction complete.")

    # # Optional: Compare the first MNIST image with the first FMNIST image using cosine similarity
    # similarity = cosine_similarity(mnist_features[0].reshape(1, -1), fmnist_features[0].reshape(1, -1))
    # print(f"Similarity between the first MNIST and FMNIST image: {similarity[0][0]}")

if __name__ == "__main__":
    main()