
from torch_geometric.loader import DataLoader, NeighborLoader
from data.sampling import collect_subgraphs, ego_graphs_sampler

import numpy as np


def subsampling(data, config, sampler='rw'):
    train_idx = data.train_mask.nonzero().squeeze()
    val_idx = data.val_mask.nonzero().squeeze()
    test_idx = data.test_mask.nonzero().squeeze()
    if sampler == 'rw':
        train_graphs = collect_subgraphs(train_idx, data, walk_steps=config.walk_steps, restart_ratio=config.restart)
        val_graphs = collect_subgraphs(val_idx, data, walk_steps=config.walk_steps, restart_ratio=config.restart)
        test_graphs = collect_subgraphs(test_idx, data, walk_steps=config.walk_steps, restart_ratio=config.restart)
    elif sampler == 'khop':
        train_graphs = ego_graphs_sampler(train_idx, data, hop=config.k)
        val_graphs = ego_graphs_sampler(val_idx, data, hop=config.k)
        test_graphs = ego_graphs_sampler(test_idx, data, hop=config.k)
    kwargs = {'batch_size': config.batch_size, 'num_workers': 6, 'persistent_workers': True}
    train_loader = DataLoader(train_graphs, shuffle=True, **kwargs)
    val_loader = DataLoader(val_graphs, **kwargs)
    test_loader = DataLoader(test_graphs, **kwargs)
    
    return train_loader, val_loader, test_loader

def prepare_dataloader(config, data, split_idx):
        for key in ['train_masks', 'val_masks', 'test_masks']:
            if hasattr(data, key):
                delattr(data, key)
        for key in list(data.keys()):
            val = data[key]
            if key == "n_asin":
                delattr(data, key)
                continue
            if isinstance(val, list):
                try:
                    data[key] = torch.tensor(val)
                except Exception:
                    print(f"⚠️ Removing incompatible list-type field '{key}'")
                    delattr(data, key)
        if config.dataset == "ogbn-products":
            delattr(data, "edge_index")
            delattr(data, "n_id")

        # print(data)
        # print(data.x.dtype)
        # print(data.x.shape)
        # print(data.edge_index.shape)
        # print(data.y.shape)
        # print("=== Data Object Attribute Inspection ===\n")

        # 遍历 data 的所有属性
#         for key, value in data.items():
#             print(f"Attribute: {key}")

#             # 打印 Python 类型
#             print(f"  Type: {type(value)}")

#             # 如果是 Tensor (PyTorch)
#             if hasattr(value, 'dtype'):
#                 print(f"  Tensor dtype: {value.dtype}")
#             else:
#                 print(f"  No dtype (not a tensor-like object)")

#             # 如果有 shape 属性
#             if hasattr(value, 'shape'):
#                 print(f"  Shape: {value.shape}")
#             else:
#                 print(f"  No shape attribute")

#             # 额外：如果是 NumPy 数组，检查是否为 object 类型
#             if isinstance(value, (np.ndarray)) and value.dtype == 'O':
#                 print(f"  ⚠️ WARNING: NumPy array with dtype=object (contains mixed/non-numeric types)")

#             # 额外：如果是 Tensor，检查是否在正确设备上
#             if hasattr(value, 'device'):
#                 print(f"  Device: {value.device}")

#             print("-" * 50)
            
        # print(data["n_id"])
        

    # return train_loader and subgraph_loader
        num_neighbors = [15, 10, 5, 5]
        assert config.layer_num <= 4
        if hasattr(data, "edge_index") and hasattr(data.edge_index, "is_contiguous"):
            if not data.edge_index.is_contiguous():
                data.edge_index = data.edge_index.contiguous()
        num_workers = 0
        train_loader = NeighborLoader(
            data,
            input_nodes=split_idx["train"],
            num_neighbors=num_neighbors[: config.layer_num],
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
        )
        
        subgraph_loader = NeighborLoader(
            data,
            input_nodes=None,
            num_neighbors=[-1],
            batch_size=config.eval_batch_size,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
        )

        # print("\n=== DataLoader Information ===")
        # print(f"train_loader type: {type(train_loader)}")
        # print(f"Number of batches in train_loader: {len(train_loader)}")
        # print(f"subgraph_loader type: {type(subgraph_loader)}")
        # print(f"Number of batches in subgraph_loader: {len(subgraph_loader)}")

        # # 查看第一个批次的数据（关键位置2）
        # try:
        #     print("\n=== First Batch from train_loader ===")
        #     first_batch = next(iter(train_loader))
        #     print(f"Batch type: {type(first_batch)}")
        #     print(
        #         f"Batch.x type: {type(first_batch.x)}, shape: {first_batch.x.shape if hasattr(first_batch.x, 'shape') else 'N/A'}")
        #     print(
        #         f"Batch.edge_index type: {type(first_batch.edge_index)}, shape: {first_batch.edge_index.shape if hasattr(first_batch.edge_index, 'shape') else 'N/A'}")
        #     print(
        #         f"Batch.y type: {type(first_batch.y)}, shape: {first_batch.y.shape if hasattr(first_batch.y, 'shape') else 'N/A'}")
        # except Exception as e:
        #     print(f"Error when getting first batch: {e}")
        
        return train_loader, subgraph_loader 