from scipy.io import savemat
from sentence_transformers import SentenceTransformer
from torchtext.datasets import IMDB
import torch
import os

# Suppress torchtext deprecation warnings.
import torchtext
torchtext.disable_torchtext_deprecation_warning()

def create_sbert_embedded_imdb_dataset(
    model_name="all-mpnet-base-v2",
    train_cache_path="imdb_embeddings_train.pt",
    test_cache_path="imdb_embeddings_test.pt"
):
    """
    Loads the IMDB dataset and uses SBERT to embed each review.
    If cached embeddings exist, they are loaded from disk.
    
    Returns:
        X_train, y_train, X_test, y_test as PyTorch tensors.
    """
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings from disk...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing SBERT embeddings...")
    sbert = SentenceTransformer(model_name)
    sbert.eval()

    # Load IMDB dataset splits
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    
    # Map labels: some versions use 1 and 2, so we map them to 0 and 1.
    label_map = {1: 0, 2: 1}

    X_train_list, y_train_list = [], []
    for (label, text) in train_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_train_list.append(emb)
        y_train_list.append(label_int)

    X_test_list, y_test_list = [], []
    for (label, text) in test_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_test_list.append(emb)
        y_test_list.append(label_int)

    # Convert lists to PyTorch tensors
    X_train = torch.tensor(X_train_list, dtype=torch.float32)
    y_train = torch.tensor(y_train_list, dtype=torch.long)
    X_test  = torch.tensor(X_test_list, dtype=torch.float32)
    y_test  = torch.tensor(y_test_list, dtype=torch.long)

    # Save embeddings for future runs
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("SBERT embeddings computed and saved to disk.")

    return X_train, y_train, X_test, y_test

if __name__ == '__main__':
    # Define cache file paths (you can change these if needed)
    train_cache = "imdb_embeddings_trainBig.pt"
    test_cache  = "imdb_embeddings_testBig.pt"
    
    # Extract the IMDB data with SBERT embeddings
    X_train, y_train, X_test, y_test = create_sbert_embedded_imdb_dataset(
        train_cache_path=train_cache,
        test_cache_path=test_cache
    )
    
    # Convert PyTorch tensors to numpy arrays for saving
    x_train = X_train.cpu().numpy()
    y_train = y_train.cpu().numpy()
    x_test  = X_test.cpu().numpy()
    y_test  = y_test.cpu().numpy()

    # Save the data to a MATLAB .mat file
    savemat('BayesianNN_IMDB_data.mat', {
        'x_train': x_train,
        'y_train': y_train,
        'x_test': x_test,
        'y_test': y_test
    })
