from typing import Dict, Any, List
import dgl

from .base_serializer import BaseGraphSerializer, SerializationResult, GlobalIDMapping


class ImageSerpentineSerializer(BaseGraphSerializer):
    """
    图像蛇形扫描序列化器：
    - 偶数行从左到右，奇数行从右到左（行号从 0 开始）
    - 仅节点像素 token，禁用边 token
    - 需 image_shape=(H,W,C) 且 N==H*W
    """

    def __init__(self):
        super().__init__()
        self.name = "image_serpentine"
        self.include_edge_tokens = False

    def _initialize_serializer(self, dataset_loader, graph_data_list: List[Dict[str, Any]] = None) -> None:
        return

    def _serialize_single_graph(self, graph_data: Dict[str, Any], **kwargs) -> SerializationResult:
        dgl_graph = self._validate_graph_data(graph_data)
        shape = graph_data.get('image_shape', None)
        if shape is None:
            raise ValueError("缺少 image_shape (H,W,C)")
        H, W, C = shape
        if dgl_graph.num_nodes() != H * W:
            raise ValueError(f"节点数与 H*W 不一致: N={dgl_graph.num_nodes()}, H*W={H*W}")

        node_path: List[int] = []
        for r in range(H):
            if (r % 2) == 0:
                for c in range(W):
                    node_path.append(r * W + c)
            else:
                for c in range(W - 1, -1, -1):
                    node_path.append(r * W + c)

        token_ids, element_ids = self._convert_path_to_tokens(node_path, graph_data)
        id_map = GlobalIDMapping(dgl_graph)
        return SerializationResult([token_ids], [element_ids], id_map)



