#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
架构字符串到6元组转换器
用于将NAS-Bench-201架构字符串转换为预测器需要的6元组格式
"""

import re
from typing import Tuple, Optional
import logging

logger = logging.getLogger(__name__)

class ArchitectureConverter:
    """NAS-Bench-201架构字符串到6元组转换器"""
    
    def __init__(self):
        """初始化转换器"""
        # NAS-Bench-201操作映射
        self.operation_mapping = {
            'none': 0,
            'skip_connect': 1, 
            'nor_conv_1x1': 2,
            'nor_conv_3x3': 3,
            'avg_pool_3x3': 4
        }
        
        # 反向映射
        self.index_to_operation = {v: k for k, v in self.operation_mapping.items()}
    
    def arch_str_to_tuple(self, arch_str: str) -> Optional[Tuple[int, int, int, int, int, int]]:
        """
        将架构字符串转换为6元组格式
        
        Args:
            arch_str: NAS-Bench-201架构字符串，格式如：
                     |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
        
        Returns:
            6元组，每个元素0-4，表示6条边上的操作
            None如果转换失败
        """
        try:
            if not arch_str or not isinstance(arch_str, str):
                return None
            
            # 解析架构字符串
            # 架构字符串格式：|op0~0|+|op1~0|op2~1|+|op3~0|op4~1|op5~2|
            # 分割为节点
            node_strs = arch_str.split('+')
            
            if len(node_strs) != 3:  # NAS-Bench-201有3个中间节点
                logger.warning(f"架构字符串格式错误，期望3个节点，得到{len(node_strs)}: {arch_str}")
                return None
            
            # 提取6条边的操作
            operations = []
            
            # 第1个节点：1条边 (0->1)
            node1_ops = self._extract_node_operations(node_strs[0])
            if len(node1_ops) != 1:
                logger.warning(f"节点1应该有1条边，得到{len(node1_ops)}: {node_strs[0]}")
                return None
            operations.extend(node1_ops)
            
            # 第2个节点：2条边 (0->2, 1->2)  
            node2_ops = self._extract_node_operations(node_strs[1])
            if len(node2_ops) != 2:
                logger.warning(f"节点2应该有2条边，得到{len(node2_ops)}: {node_strs[1]}")
                return None
            operations.extend(node2_ops)
            
            # 第3个节点：3条边 (0->3, 1->3, 2->3)
            node3_ops = self._extract_node_operations(node_strs[2])
            if len(node3_ops) != 3:
                logger.warning(f"节点3应该有3条边，得到{len(node3_ops)}: {node_strs[2]}")
                return None
            operations.extend(node3_ops)
            
            # 验证总共6条边
            if len(operations) != 6:
                logger.warning(f"期望6条边，得到{len(operations)}条边")
                return None
            
            # 转换为数字索引
            tuple_ops = []
            for op in operations:
                if op not in self.operation_mapping:
                    logger.warning(f"未知操作: {op}")
                    return None
                tuple_ops.append(self.operation_mapping[op])
            
            return tuple(tuple_ops)
            
        except Exception as e:
            logger.error(f"转换架构字符串失败: {arch_str}, 错误: {e}")
            return None
    
    def _extract_node_operations(self, node_str: str) -> list:
        """从节点字符串中提取操作列表"""
        # 移除开头和结尾的|符号，然后按|分割
        node_str = node_str.strip('|')
        if not node_str:
            return []
        
        # 分割操作
        operation_strs = node_str.split('|')
        operations = []
        
        for op_str in operation_strs:
            if '~' in op_str:
                op_name, connection = op_str.split('~')
                operations.append(op_name)
        
        return operations
    
    def tuple_to_arch_str(self, arch_tuple: Tuple[int, int, int, int, int, int]) -> str:
        """
        将6元组转换回架构字符串（用于验证）
        
        Args:
            arch_tuple: 6元组，每个元素0-4
        
        Returns:
            NAS-Bench-201架构字符串
        """
        try:
            if len(arch_tuple) != 6:
                raise ValueError(f"期望6元组，得到{len(arch_tuple)}元组")
            
            # 转换为操作名
            ops = []
            for idx in arch_tuple:
                if idx not in self.index_to_operation:
                    raise ValueError(f"无效的操作索引: {idx}")
                ops.append(self.index_to_operation[idx])
            
            # 构建架构字符串
            # 格式：|op0~0|+|op1~0|op2~1|+|op3~0|op4~1|op5~2|
            arch_str = f"|{ops[0]}~0|+|{ops[1]}~0|{ops[2]}~1|+|{ops[3]}~0|{ops[4]}~1|{ops[5]}~2|"
            
            return arch_str
            
        except Exception as e:
            logger.error(f"转换6元组失败: {arch_tuple}, 错误: {e}")
            return ""
    
    def validate_conversion(self, arch_str: str) -> bool:
        """验证转换是否正确（往返测试）"""
        try:
            # 字符串 -> 元组 -> 字符串
            arch_tuple = self.arch_str_to_tuple(arch_str)
            if arch_tuple is None:
                return False
            
            converted_str = self.tuple_to_arch_str(arch_tuple)
            
            # 比较（忽略空格差异）
            original_normalized = arch_str.replace(' ', '')
            converted_normalized = converted_str.replace(' ', '')
            
            return original_normalized == converted_normalized
            
        except Exception as e:
            logger.error(f"验证转换失败: {arch_str}, 错误: {e}")
            return False

# 创建全局转换器实例
arch_converter = ArchitectureConverter()

def arch_str_to_tuple(arch_str: str) -> Optional[Tuple[int, int, int, int, int, int]]:
    """全局函数：架构字符串转6元组"""
    return arch_converter.arch_str_to_tuple(arch_str)

def tuple_to_arch_str(arch_tuple: Tuple[int, int, int, int, int, int]) -> str:
    """全局函数：6元组转架构字符串"""
    return arch_converter.tuple_to_arch_str(arch_tuple)

def validate_arch_conversion(arch_str: str) -> bool:
    """全局函数：验证转换正确性"""
    return arch_converter.validate_conversion(arch_str)


if __name__ == "__main__":
    # 测试转换器
    converter = ArchitectureConverter()
    
    # 测试示例
    test_archs = [
        "|avg_pool_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|avg_pool_3x3~1|nor_conv_3x3~2|",
        "|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|",
        "|skip_connect~0|+|nor_conv_3x3~0|skip_connect~1|+|nor_conv_1x1~0|avg_pool_3x3~1|none~2|"
    ]
    
    print("架构转换器测试")
    print("=" * 50)
    
    for i, arch_str in enumerate(test_archs, 1):
        print(f"测试 {i}: {arch_str}")
        
        # 转换为元组
        arch_tuple = converter.arch_str_to_tuple(arch_str)
        print(f"  6元组: {arch_tuple}")
        
        # 转换回字符串
        if arch_tuple:
            converted_str = converter.tuple_to_arch_str(arch_tuple)
            print(f"  转换回: {converted_str}")
            
            # 验证
            is_valid = converter.validate_conversion(arch_str)
            print(f"  验证: {'✅' if is_valid else '❌'}")
        
        print()
