import sys  
sys.path.append("/workspace")

import pandas as pd
import torch  
from torch.utils.data import random_split  
from torch_geometric.loader import DataLoader  
from dataloaders.common import generate_full_path, seed_worker  

def load_metadata_and_embeddings(load_path, cod_basepath="/cif"):  
    """Load the embedding metadata."""  
    metadata_df = pd.read_csv(load_path)  
    metadata_df = metadata_df.dropna(subset=['title'])
    metadata_df['cif_path'] = metadata_df['file'].astype(str).apply(lambda x: generate_full_path(x, base_path=cod_basepath))
    return metadata_df  
  

def prepare_data_loaders(batch_size, dataset):  
    """Prepare data loaders for train, validation and test."""  
    # Split dataset into train, validation and test  
    dataset_size = len(dataset)  
    train_size = int(0.8 * dataset_size)  
    val_size = int(0.1 * dataset_size)  
    test_size = dataset_size - train_size - val_size  
  
    seed = 42  
    generator = torch.Generator().manual_seed(seed)  
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=generator)  
  
    # Create data loaders  
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True,
                              pin_memory=True,persistent_workers=True,
                              worker_init_fn=seed_worker)  
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=False,
                            pin_memory=True,persistent_workers=True,
                            worker_init_fn=seed_worker)  
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=False,
                             pin_memory=True,persistent_workers=True,
                             worker_init_fn=seed_worker)  
  
    return train_loader, val_loader, test_loader  