import torch
from torch_geometric.datasets import Planetoid
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset loading
dataset = Planetoid(root='./Cora', name='Cora')
data = dataset[0].to(device)

# Load pre-saved data.x
data_x = torch.load('data_x.pt')

# Visualize the latent space
def visualize_latent_space(data, labels, method, save_path=None):
    # Convert data and labels to numpy arrays
    original_features = data.x.cpu().numpy()  # Original features
    loaded_features = data_x.cpu().numpy()     # Loaded features for comparison with original features
    
    labels = labels.cpu().numpy()
    
    # Use TSNE algorithm for dimensionality reduction
    reducer = TSNE(n_components=2, random_state=42)
    raw_embedding = reducer.fit_transform(original_features)  # TSNE embedding of original features
    
    reducer = TSNE(n_components=2, random_state=43)
    loaded_embedding = reducer.fit_transform(loaded_features)  # TSNE embedding of loaded features

    # Visualize original features and loaded features
    plt.figure(figsize=(20, 8))
    
    plt.subplot(1, 2, 1)
    plt.scatter(raw_embedding[:, 0], raw_embedding[:, 1], c=labels, cmap='viridis', s=5)
    plt.colorbar(label='Class')
    plt.title(f'Raw Data Visualization using {method.upper()}', fontsize=18)
    plt.xlabel('Component 1', fontsize=15)
    plt.ylabel('Component 2', fontsize=15)

    plt.subplot(1, 2, 2)
    # Here we assume that the loaded features have the same number of rows as the original features, i.e., the same number of samples
    scatter = plt.scatter(loaded_embedding[:, 0], loaded_embedding[:, 1], c=labels, cmap='viridis', s=5)
    plt.colorbar(scatter, label='Class')
    plt.title(f'Data Augmentation Space Visualization using {method.upper()}', fontsize=18)
    plt.xlabel('Component 1', fontsize=15)
    plt.ylabel('Component 2', fontsize=15)

    if save_path is not None:
        plt.savefig(save_path)
    else:
        plt.show()

# Call the visualization function, passing in the dataset's labels
visualize_latent_space(data, data.y, method='t-sne', save_path='latent_space_visualization.png')