# data.py
import os
import pickle
import random
import torch
from torch_geometric.data import Data, DataLoader

def load_tcga_graphs(pkl_path: str):
    """Load TCGA graph data from pickle file"""
    with open(pkl_path, 'rb') as f:
        graphs = pickle.load(f)
    print(f"Loaded {len(graphs)} graphs from {pkl_path}.")
    return graphs

def split_graphs(graphs, train_ratio: float = 0.8):
    """Split graph data into train/test sets with shuffling"""
    num_total = len(graphs)
    indices = list(range(num_total))
    random.shuffle(indices)
    train_size = int(num_total * train_ratio)
    train_indices = indices[:train_size]
    test_indices = indices[train_size:]
    train_graphs = [graphs[i] for i in train_indices]
    test_graphs = [graphs[i] for i in test_indices]
    print(f"Train graphs: {len(train_graphs)}, Test graphs: {len(test_graphs)}")
    return train_graphs, test_graphs

def get_dataloaders(train_graphs, test_graphs, batch_size: int):
    """Create PyG DataLoaders"""
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

def to_device(data: Data, device: torch.device) -> Data:
    """Transfer tensors in Data object to specified device"""
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.to(device)
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
    return data
