# data.py
import os
import pickle
import random
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_mean_pool

def load_tcga_graphs(graphs_path: str):
    """Load TCGA graph data (pickle format)"""
    with open(graphs_path, 'rb') as f:
        graphs = pickle.load(f)
    print(f"Loaded {len(graphs)} graphs from {graphs_path}.")
    return graphs

def split_graphs(graphs, train_ratio: float = 0.8):
    """Shuffle graph data and split into train/test sets"""
    random.shuffle(graphs)
    num_total = len(graphs)
    num_train = int(train_ratio * num_total)
    train_graphs = graphs[:num_train]
    test_graphs  = graphs[num_train:]
    print(f"Train graphs: {len(train_graphs)}, Test graphs: {len(test_graphs)}")
    return train_graphs, test_graphs

def get_dataloaders(train_graphs, test_graphs, batch_size: int):
    """Create DataLoaders"""
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

def apply_virtual_knockdown(graph: Data, knockdown_gene_idx: int) -> Data:
    """
    Perform virtual knockdown by setting features of specified gene to 0 and removing connected edges
    """
    x = graph.x.clone()
    x[knockdown_gene_idx] = 0
    mask = (graph.edge_index[0] != knockdown_gene_idx) & (graph.edge_index[1] != knockdown_gene_idx)
    edge_index = graph.edge_index[:, mask]
    edge_attr = graph.edge_attr[mask] if graph.edge_attr is not None else None
    if edge_attr is not None and edge_attr.dim() == 1:
        edge_attr = edge_attr.unsqueeze(-1)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

def to_device(data: Data, device: torch.device) -> Data:
    """Transfer tensors in Data object to specified device"""
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.to(device)
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
    if hasattr(data, 'batch'):
        data.batch = data.batch.to(device)
    return data

def get_graph_embedding(node_embeddings, batch):
    """Get graph embedding (mean pooling) from node embeddings"""
    return global_mean_pool(node_embeddings, batch)
