# data.py
import os
import pickle
import random
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

def load_tcga_graphs(graphs_path: str):
    with open(graphs_path, 'rb') as f:
        graphs = pickle.load(f)
    print(f"Loaded {len(graphs)} graphs from {graphs_path}.")
    return graphs

def split_graphs(graphs, train_ratio: float = 0.8):
    random.shuffle(graphs)
    n = len(graphs)
    n_tr = int(train_ratio * n)
    return graphs[:n_tr], graphs[n_tr:]

def get_dataloaders(train_graphs, test_graphs, batch_size: int):
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_graphs,  batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

def ensure_edge_attr(graph: Data) -> Data:
    E = graph.edge_index.size(1)
    if getattr(graph, "edge_attr", None) is None:
        graph.edge_attr = torch.ones((E, 1), dtype=torch.float32, device=graph.edge_index.device)
    elif graph.edge_attr.dim() == 1:
        graph.edge_attr = graph.edge_attr.unsqueeze(-1)
    return graph

def to_device(data: Data, device: torch.device) -> Data:
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if getattr(data, 'edge_attr', None) is not None:
        data.edge_attr = data.edge_attr.to(device)
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
    if hasattr(data, 'batch') and data.batch is not None:
        data.batch = data.batch.to(device)
    return data
