# data.py
import os
import pickle
import random
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj

# --- Data loading & preprocessing ---
def load_tcga_graphs(pkl_path: str):
    """Load TCGA graph data from pickle and precompute dense_adj to store in Data"""
    with open(pkl_path, 'rb') as f:
        graphs = pickle.load(f)
    print(f"Loaded {len(graphs)} graphs from {pkl_path}.")

    # Compute and store dense_adj for each graph
    for data in graphs:
        # Device will be moved later by to_device
        batch = torch.zeros(data.x.size(0), dtype=torch.long)
        dense = to_dense_adj(data.edge_index, batch=batch)[0]  # [N, N]
        deg = dense.sum(dim=1)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        D_inv_sqrt = torch.diag(deg_inv_sqrt)
        data.dense_adj = D_inv_sqrt @ dense @ D_inv_sqrt  # Normalized adjacency
        # Add last dimension if edge_attr is 1D
        if hasattr(data, 'edge_attr') and data.edge_attr is not None and data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
    return graphs


def split_graphs(graphs, train_ratio: float = 0.8):
    random.shuffle(graphs)
    num_total = len(graphs)
    num_train = int(train_ratio * num_total)
    train_graphs = graphs[:num_train]
    test_graphs  = graphs[num_train:]
    print(f"Train graphs: {len(train_graphs)}, Test graphs: {len(test_graphs)}")
    return train_graphs, test_graphs


def get_dataloaders(train_graphs, test_graphs, batch_size: int):
    from torch_geometric.loader import DataLoader
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_graphs,  batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


def to_device(data: Data, device: torch.device) -> Data:
    """Transfer all tensors in the Data object to the specified device"""
    # Applies to all tensor attributes if Data/Batch from PyG
    data = data.to(device)
    # Add last dimension if edge_attr is 1D ([E]→[E,1])
    if hasattr(data, 'edge_attr') and data.edge_attr is not None and data.edge_attr.dim() == 1:
        data.edge_attr = data.edge_attr.unsqueeze(-1)
    return data