import logging

from yacs.config import CfgNode as CN

_C = CN()
_C.GLOBALS = CN()
_C.GLOBALS.CONSOLE_VERBOSE_LEVEL = logging.INFO  # 控制台日志级别
_C.GLOBALS.DEVICE = "cuda"
_C.GLOBALS.EXPR_NAME = "default_name"  # 实验名称
_C.GLOBALS.FILE_VERBOSE_LEVEL = logging.DEBUG  # 归档日志级别
_C.GLOBALS.SEED = 1024  # 随机数种子

_C.DATASET = CN()
_C.DATASET.ADD_STRAINS = False # 是否添加污点噪声
_C.DATASET.BUILD_FROM_EXIST_DATASET = False  # 是否从已分割的碎片数据集中建立
_C.DATASET.CONTOUR_IMAGE_SIZE = 540  # 作为边缘输入的图片宽高
_C.DATASET.CONTOUR_MAX_LEN = 1050  # 边缘的最大点数
_C.DATASET.CUTTINGLINE_EDGE_DISTANCE = 100  # 切割线和边缘的最小允许距离
_C.DATASET.FRAGMENT_AREA_THRES = 0.15  # 最大碎片面积占比阈值
_C.DATASET.FRAGMENT_DISCARD_THRES = 50  # 最小碎片尺寸阈值
_C.DATASET.FRAGMENT_PATH = "./dataset/fragment_debug"  # 碎片数据集路径
_C.DATASET.MAX_SCRIPTS = 40  # 一张图片最多分成了多少碎片
_C.DATASET.MAX_EDGES = 100 # 一张图片最多分有多少个碎片对
_C.DATASET.RAW_PATH = "../data/raw20"  # 原始数据集路径
_C.DATASET.TEXTURE_IMAGE_SIZE = 540  # 数据集原始碎片的最大宽高
_C.DATASET.TRAIN_VALID_PERCENT = (0.7, 0.1)  # 数据集分割比

_C.DATASET.CLASSIFY = CN()
_C.DATASET.CLASSIFY.FACTOR = 0.5 # 分类数据集正例占比

_C.NET = CN()
_C.NET.BLOCKS = 14  # 特征提取层的堆叠数(ViT: 6)
_C.NET.FEATURE_EXTRACT_DIM = 64  # 特征提取出的向量长度 
_C.NET.PATCH_SIZE = 7  # pairingnet matching 训练的输入 patch 大小

_C.NET.DECODER = CN() # 交叉注意力解码器设置
_C.NET.DECODER.BLOCKS = 2  # 交叉注意力解码器层数
_C.NET.DECODER.NUM_HEADS = 8 # 注意力头的数量

_C.NET.RESGCN = CN() # ResGCN 特征提取设置
_C.NET.RESGCN.ENCODER_TYPE = "l" # 输入边缘编码器类型 (l, io, ilo)
_C.NET.RESGCN.GCN_BLOCK_TYPE = "res+"  # graph backbone block type {plain, res, dense}
_C.NET.RESGCN.GCN_FILTERS = 64  # number of channels of deep features
_C.NET.RESGCN.GCN_GAT_HEADS = 1  # GATConv head numbers

_C.NET.VIT = CN() # ViT 编码器特征提取设置
_C.NET.VIT.NUM_HEADS = 8 # 注意力头的数量

_C.TEST = CN()
_C.TEST.BATCH_SIZE = 1
_C.TEST.RES_SAVE_PATH = "./checkpoints/pairingnet_matching_1000"
_C.TEST.STAT_DICT_PATH = "./checkpoints/pairingnet_matching_1000/matching/checkpoint_best.tar"
_C.TEST.TEST_DATA_PATH = "./dataset/1000_all/build_dataset_1000_pairing_test_set.pkl"
_C.TEST.TYPE = "pairing_all" # 测试哪个网络（pairing / matching / pairing_all）

_C.TEST.MATCHING = CN()  # 局部特征匹配（Matching）测试设置
_C.TEST.MATCHING.CONV_THRES = 0.006  # pairingnet matching conv 阈值
_C.TEST.MATCHING.DECODER = "CrossAttn" # Matching 网络的解码器模块（CrossAttn）
_C.TEST.MATCHING.FEATURE_EXTRACT = "ResGCN" # Matching 网络的特征提取模块（ResGCN / ViT）
_C.TEST.MATCHING.FEATURE_FUSE = "Attention" # Matching 网络的特征融合模块（Attention / SelfGateV1 / SelfGateV2）

_C.TEST.PAIRING_ALL = CN()  # 拼接（Pairing All）测试参数
_C.TEST.PAIRING_ALL.ASSEMBLER_TYPE = "HLM" # 全局拼接器类型（None / CO / Kruskal / HLM）
_C.TEST.PAIRING_ALL.CLUSTER = True # 是否使用聚类
_C.TEST.PAIRING_ALL.DESISION_THRES = 1. # 决策阈值
_C.TEST.PAIRING_ALL.K = 5.0  # 边选择阈值
_C.TEST.PAIRING_ALL.K_DENOISE = 20.0 # 渐进匹配的候选边选择阈值
_C.TEST.PAIRING_ALL.MATCHING_DECODER = "CrossAttn" # Pairing All 局部特征匹配网络的解码器模块（CrossAttn）
_C.TEST.PAIRING_ALL.MATCHING_FEATURE_EXTRACT = "ResGCN" # Pairing All 局部特征匹配网络的特征提取模块（ResGCN / ViT）
_C.TEST.PAIRING_ALL.MATCHING_FEATURE_FUSE = "Attention" # Pairing All 局部特征匹配网络的特征融合模块（Attention / SelfGateV1 / SelfGateV2）
_C.TEST.PAIRING_ALL.MATCHING_STAT_DICT_PATH = "./checkpoints/pairingnet_matching_1000/matching/checkpoint_best.tar"
_C.TEST.PAIRING_ALL.PAIRING_FEATURE_EXTRACT = "ResGCN" # Pairing All 全局特征匹配网络的特征提取模块（ResGCN / ViT）
_C.TEST.PAIRING_ALL.PAIRING_FEATURE_FUSE = "Attention" # Pairing All 全局特征匹配的特征融合模块（Attention / SelfGateV1 / SelfGateV2）
_C.TEST.PAIRING_ALL.PAIRING_STAT_DICT_PATH = "./checkpoints/pairingnet_matching_1000/matching/checkpoint_best.tar"
_C.TEST.PAIRING_ALL.PICS_NUM = 3 # 一次测试几张图
_C.TEST.PAIRING_ALL.SCORE_EVAL_TYPE = "CNN" # 分数计算网络类型 CNN / MaxDiag

_C.TRAIN = CN()
_C.TRAIN.BATCH_SIZE = 10 
_C.TRAIN.CHECKPOINT_PATH = "./results"
_C.TRAIN.EPOCH = 128
_C.TRAIN.LOAD_FROM_CHECKPOINTS = False # 是否断点续训练
_C.TRAIN.LR = 1e-3
# 训练时使用的训练集路径
_C.TRAIN.TRAIN_DATA_PATH = "./dataset/200_pairing/build_dataset_200_pairing_train_set.pkl"
# 训练时使用的训练集路径
_C.TRAIN.VALID_DATA_PATH = "./dataset/200_pairing/build_dataset_200_pairing_valid_set.pkl"
_C.TRAIN.SAVE_INTERVAL = 100 # 保存 checkpoints 的频率
# 训练哪个网络 (pairing / matching)
_C.TRAIN.TYPE = "pairing"
_C.TRAIN.WEIGHT_DECAY = 5e-4  # llmco4mr pairing 训练时的 l2 正则化

_C.TRAIN.CLASSIFY = CN() # 局部特征匹配后分类训练设置
_C.TRAIN.CLASSIFY.CONV_THRES = 0.006  # pairingnet matching conv 阈值
_C.TRAIN.CLASSIFY.DECODER = "CrossAttn" # 分类网络的解码器模块（CrossAttn）
_C.TRAIN.CLASSIFY.FEATURE_EXTRACT = "ResGCN" # 分类网络的特征提取模块（ResGCN / ViT）
_C.TRAIN.CLASSIFY.FEATURE_FUSE = "Attention" # 分类网络的特征融合模块（Attention / SelfGateV1 / SelfGateV2）
_C.TRAIN.CLASSIFY.STAT_DICT = "" # 分类网络冻结的参数（之前已训练好）位置

_C.TRAIN.MATCHING = CN() # 局部特征匹配（Matching）训练设置
_C.TRAIN.MATCHING.DECODER = "CrossAttn" # Matching 网络的解码器模块（CrossAttn）
_C.TRAIN.MATCHING.FEATURE_EXTRACT = "ResGCN" # Matching 网络的特征提取模块（ResGCN / ViT）
_C.TRAIN.MATCHING.FEATURE_FUSE = "Attention" # Matching 网络的特征融合模块（Attention / SelfGateV1 / SelfGateV2）

_C.TRAIN.PAIRING = CN()  # 全局特征匹配（Pairing）训练设置
_C.TRAIN.PAIRING.FEATURE_EXTRACT = "ResGCN" # PAIRING 网络的特征提取模块（ResGCN / ViT）
_C.TRAIN.PAIRING.FEATURE_FUSE = "Attention" # PAIRING 网络的特征融合模块（Attention / SelfGateV1 / SelfGateV2）
_C.TRAIN.PAIRING.INFONCE_TEMPERATURE = 0.12

cfg = _C


def get_cfg_defaults():
    return _C.clone()


if __name__ == "__main__":
    print(cfg)
