import argparse
import os
import sys
import math
import random
import shutil
from datetime import datetime

# ==========================
# 内联 utils.py
# ==========================
import torch
import torch.distributions as D
import yaml
import numpy as np
from torch.optim import lr_scheduler
import torch.nn.init as init
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge


def compute_mcc(z1, z2):
    Ncomp = z1.size()[-1]
    from scipy.stats import spearmanr
    from scipy.optimize import linear_sum_assignment

    CorMat = (np.abs(np.corrcoef(z1.T, z2.T)))[:Ncomp, Ncomp:]
    ii = linear_sum_assignment(-1 * CorMat)
    mcc_pearson = CorMat[ii].mean()

    rho, _ = np.abs(spearmanr(z1, z2))
    CorMat_s = rho[:Ncomp, Ncomp:]
    ii_s = linear_sum_assignment(-1 * CorMat_s)
    mcc_spearman = CorMat_s[ii_s].mean()

    metric_dict = {}
    pearson_vec = CorMat[ii]
    for i in range(len(pearson_vec)):
        metric_dict['Pearson/MCC_%d' % (i+1)] = pearson_vec[i]
    metric_dict['Pearson/MCC_avg'] = mcc_pearson
    spearman_vec = CorMat_s[ii_s]
    for i in range(len(spearman_vec)):
        metric_dict['Spearman/MCC_%d' % (i+1)] = spearman_vec[i]
    metric_dict['SPearman/MCC_avg'] = mcc_spearman
    return metric_dict

def moving_average(model, model_test, beta=0.999):
    for param, param_test in zip(model.parameters(), model_test.parameters()):
        param_test.data = torch.lerp(param.data, param_test.data, beta)

def write_loss(iterations, output_dict, train_writer):
    members = [x for x in output_dict if ('loss' in x  or 'elbo' in x)]
    for m in members:
        train_writer.add_scalar('train/'+m, output_dict[m], iterations + 1)

def write_metric(iterations, metric_dict, train_writer):
    members = [x for x in metric_dict if 'loss' in x or 'elbo' in x]
    for m in members:
        train_writer.add_scalar('validation/'+m, metric_dict[m], iterations + 1)

def prepare_sub_folder(output_directory):
    image_directory = os.path.join(output_directory, 'images')
    if not os.path.exists(image_directory):
        os.makedirs(image_directory)
    checkpoint_directory = os.path.join(output_directory, 'checkpoints')
    if not os.path.exists(checkpoint_directory):
        os.makedirs(checkpoint_directory)
    log_directory = os.path.join(output_directory, 'logs')
    if not os.path.exists(log_directory):
        os.makedirs(log_directory)
    return checkpoint_directory, image_directory, log_directory

def get_model_list(dirname):
    if os.path.exists(dirname) is False:
        return None
    gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if os.path.isfile(os.path.join(dirname, f))  and ".pt" in f]
    if gen_models is None or len(gen_models) == 0:
        return None
    gen_models.sort()
    last_model_name = gen_models[-1]
    return last_model_name

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

def get_config(config):
    with open(config, 'r') as stream:
        return yaml.safe_load(stream)

def perform_baseline_causal_discovery(sublabels, save_dir, config=None):
    """基于10个sublabels进行基线因果发现（图1）"""
    try:
        # 导入因果发现库
        from causallearn.search.ConstraintBased.PC import pc
        from causallearn.utils.GraphUtils import GraphUtils
        import pandas as pd
        from sklearn.preprocessing import StandardScaler
        import numpy as np
        
        print("基于11个sublabels进行基线因果发现...")
        
        # 标准化数据
        from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
        normalization_method = getattr(config, 'normalization_method', 'standard')
        
        if normalization_method == 'none':
            sublabels_scaled = sublabels.numpy()
        elif normalization_method == 'standard':
            scaler = StandardScaler()
            sublabels_scaled = scaler.fit_transform(sublabels.numpy())
        elif normalization_method == 'minmax':
            scaler = MinMaxScaler()
            sublabels_scaled = scaler.fit_transform(sublabels.numpy())
        elif normalization_method == 'robust':
            scaler = RobustScaler()
            sublabels_scaled = scaler.fit_transform(sublabels.numpy())
        else:
            scaler = StandardScaler()
            sublabels_scaled = scaler.fit_transform(sublabels.numpy())

        print(f"data sublabels_scaled shape: {sublabels_scaled.shape}")
        
        # 11个已知变量的特征名称
        feature_names = ['Clarity', 'Directness', 'Attitude', 'Openness', 'Evidence', 
                         'Rigor', 'De-escalation', 'ReviewQuality', 'ReviewerOpenness', 
                         'ConcernSeverity', 'RatingChange']
        
        # 创建变量名 - 11个已知概念
        var_names = []
        for i in range(7):  # 7个作者概念
            var_names.append(f'{feature_names[i]}')
        for i in range(3):  # 3个审稿人概念
            var_names.append(f'{feature_names[7+i]}')
        var_names.append('RatingChange')  # 第11个变量
        
        # 使用PC算法进行因果发现
        print("运行PC算法（基线）...")
        knowledge = BackgroundKnowledge()
        knowledge.add_forbidden_by_pattern('RatingChange', '.*')
        print("禁止: RatingChange -> 所有变量")
        cg = pc(sublabels_scaled, node_names=var_names, indep_test='rcit', alpha=0.03, background_knowledge=knowledge)
        # cg = pc(sublabels_scaled, node_names=var_names, indep_test='fisherz', alpha=0.03)
        
        # 保存基线因果图
        print("保存基线因果图...")
        pyd = GraphUtils.to_pydot(cg.G)
        pyd.write_png(os.path.join(save_dir, 'baseline_causal_graph.png'))
        
        # 获取邻接矩阵
        adj_matrix = cg.G.graph
        np.save(os.path.join(save_dir, 'baseline_adjacency_matrix.npy'), adj_matrix)
        
        # 保存边信息
        edge_list = []
        for i in range(len(var_names)):
            for j in range(len(var_names)):
                if adj_matrix[i, j] != 0:
                    edge_list.append((var_names[i], var_names[j], adj_matrix[i, j]))
        
        with open(os.path.join(save_dir, 'baseline_causal_graph_edges.txt'), 'w') as f:
            f.write("Source\tTarget\tEdge_Type\n")
            for edge in edge_list:
                edge_type = "directed" if edge[2] == 1 else "undirected"
                f.write(f"{edge[0]}\t{edge[1]}\t{edge_type}\n")
        
        # 计算相关性矩阵
        corr_matrix = np.corrcoef(sublabels_scaled.T)
        np.save(os.path.join(save_dir, 'baseline_correlation_matrix.npy'), corr_matrix)
        
        # 打印统计信息
        
        return cg, adj_matrix, corr_matrix, edge_list
        
    except ImportError as e:
        print(f"警告: 无法导入causallearn库: {e}")
        print("将跳过基线因果发现")
        return None, None, None, None

def perform_causal_discovery(latents, save_dir, config=None):
    """使用PC算法进行因果发现并保存因果图"""
    try:
        # 导入因果发现库
        from causallearn.search.ConstraintBased.PC import pc
        from causallearn.utils.GraphUtils import GraphUtils
        import pandas as pd
        from sklearn.preprocessing import StandardScaler
        import numpy as np
        
        print("使用PC算法进行因果发现...")

        # 标准化数据
        from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
        normalization_method = getattr(config, 'normalization_method', 'standard')
        
        if normalization_method == 'none':
            latents_scaled = latents.numpy()
            rating_change_data = np.load("./labeled_data_labels_11features.npy")[:, -1:]
            rating_change_scaled = rating_change_data
        elif normalization_method == 'standard':
            scaler = StandardScaler()
            latents_scaled = scaler.fit_transform(latents.numpy())
            rating_change_data = np.load("./labeled_data_labels_11features.npy")[:, -1:]
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        elif normalization_method == 'minmax':
            scaler = MinMaxScaler()
            latents_scaled = scaler.fit_transform(latents.numpy())
            rating_change_data = np.load("./labeled_data_labels_11features.npy")[:, -1:]
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        elif normalization_method == 'robust':
            scaler = RobustScaler()
            latents_scaled = scaler.fit_transform(latents.numpy())
            rating_change_data = np.load("./labeled_data_labels_11features.npy")[:, -1:]
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        else:
            scaler = StandardScaler()
            latents_scaled = scaler.fit_transform(latents.numpy())
            rating_change_data = np.load("./labeled_data_labels_11features.npy")[:, -1:]
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        
        print(f"data latents_scaled shape: {latents_scaled.shape}")
        
        # 合并12个学到的概念 + 1个RatingChange
        all_data = np.concatenate([latents_scaled, rating_change_scaled], axis=1)
        print(f"combined data shape: {all_data.shape}")
        
        # 11个已知变量的特征名称
        feature_names = ['Clarity', 'Directness', 'Attitude', 'Openness', 'Evidence', 
                         'Rigor', 'De-escalation', 'ReviewQuality', 'ReviewerOpenness', 
                         'ConcernSeverity', 'RatingChange']
        
        # 创建变量名 - 总共13个变量（12个学到的 + 1个RatingChange）
        var_names = []
        for i in range(7):  # 前7个是已知作者变量
            var_names.append(f'{feature_names[i]}')
        var_names.append('Author_Unknown')  # 第8个是未知作者变量
        for i in range(3):  # 接下来3个是已知审稿人变量
            var_names.append(f'{feature_names[7+i]}')
        var_names.append('Reviewer_Unknown')  # 第12个是未知审稿人变量
        var_names.append('RatingChange')  # 第13个是RatingChange（从数据中读取）
        
        # 使用PC算法进行因果发现（使用默认参数）
        print("运行PC算法...")
        knowledge = BackgroundKnowledge()
        knowledge.add_forbidden_by_pattern('RatingChange', '.*')
        print("禁止: RatingChange -> 所有变量")
        cg = pc(all_data, node_names=var_names, indep_test='rcit', alpha=0.03, background_knowledge=knowledge)
        # cg = pc(all_data, node_names=var_names, indep_test='fisherz', alpha=0.03)
        
        # 保存因果图
        print("保存因果图...")
        pyd = GraphUtils.to_pydot(cg.G)
        pyd.write_png(os.path.join(save_dir, 'pc_causal_graph.png'))
        
        # 获取邻接矩阵
        adj_matrix = cg.G.graph
        np.save(os.path.join(save_dir, 'pc_adjacency_matrix.npy'), adj_matrix)
        
        # 保存图信息
        edge_list = []
        for i in range(len(var_names)):
            for j in range(len(var_names)):
                if adj_matrix[i, j] != 0:
                    edge_list.append((var_names[i], var_names[j], adj_matrix[i, j]))
        
        with open(os.path.join(save_dir, 'pc_causal_graph_edges.txt'), 'w') as f:
            f.write("Source\tTarget\tEdge_Type\n")
            for edge in edge_list:
                edge_type = "directed" if edge[2] == 1 else "undirected"
                f.write(f"{edge[0]}\t{edge[1]}\t{edge_type}\n")
        
        # 计算相关性矩阵作为参考
        corr_matrix = np.corrcoef(latents_scaled.T)
        np.save(os.path.join(save_dir, 'correlation_matrix.npy'), corr_matrix)
        
        # 打印统计信息
        
        return cg, adj_matrix, corr_matrix, edge_list
        
    except ImportError as e:
        print(f"警告: 无法导入causallearn库: {e}")
        print("将跳过CRL因果发现")
        return None, None, None, None

def perform_llm_inferred_causal_discovery(latents, save_dir, config=None):
    """基于LLM-inferred概念和CRL学到的未知概念进行因果发现"""
    try:
        # 导入因果发现库
        from causallearn.search.ConstraintBased.PC import pc
        from causallearn.utils.GraphUtils import GraphUtils
        import pandas as pd
        from sklearn.preprocessing import StandardScaler
        import numpy as np
        
        print("基于LLM-inferred概念和CRL未知概念进行因果发现...")

        # 标准化数据
        from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
        normalization_method = getattr(config, 'normalization_method', 'standard')
        
        # 加载LLM-inferred数据（前10维）和rating change（最后一维）
        llm_data = np.load("./labeled_data_labels_11features.npy")
        llm_inferred_concepts = llm_data[:, :10]  # 前10维：LLM-inferred概念
        rating_change_data = llm_data[:, -1:]  # 最后一维：rating change
        
        # 从CRL学到的潜在变量中提取未知概念
        # latents shape: [N, 12] - 包含7个作者概念+1个未知作者+3个审稿人概念+1个未知审稿人
        author_unknown = latents[:, 7:8]  # 第8个是未知作者概念
        reviewer_unknown = latents[:, 11:12]  # 第12个是未知审稿人概念
        
        # 标准化所有数据
        if normalization_method == 'none':
            llm_inferred_scaled = llm_inferred_concepts
            author_unknown_scaled = author_unknown.numpy()
            reviewer_unknown_scaled = reviewer_unknown.numpy()
            rating_change_scaled = rating_change_data
        elif normalization_method == 'standard':
            scaler = StandardScaler()
            llm_inferred_scaled = scaler.fit_transform(llm_inferred_concepts)
            author_unknown_scaled = scaler.fit_transform(author_unknown.numpy())
            reviewer_unknown_scaled = scaler.fit_transform(reviewer_unknown.numpy())
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        elif normalization_method == 'minmax':
            scaler = MinMaxScaler()
            llm_inferred_scaled = scaler.fit_transform(llm_inferred_concepts)
            author_unknown_scaled = scaler.fit_transform(author_unknown.numpy())
            reviewer_unknown_scaled = scaler.fit_transform(reviewer_unknown.numpy())
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        elif normalization_method == 'robust':
            scaler = RobustScaler()
            llm_inferred_scaled = scaler.fit_transform(llm_inferred_concepts)
            author_unknown_scaled = scaler.fit_transform(author_unknown.numpy())
            reviewer_unknown_scaled = scaler.fit_transform(reviewer_unknown.numpy())
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        else:
            scaler = StandardScaler()
            llm_inferred_scaled = scaler.fit_transform(llm_inferred_concepts)
            author_unknown_scaled = scaler.fit_transform(author_unknown.numpy())
            reviewer_unknown_scaled = scaler.fit_transform(reviewer_unknown.numpy())
            rating_change_scaled = scaler.fit_transform(rating_change_data)
        
        print(f"LLM-inferred concepts shape: {llm_inferred_scaled.shape}")
        print(f"Author unknown shape: {author_unknown_scaled.shape}")
        print(f"Reviewer unknown shape: {reviewer_unknown_scaled.shape}")
        print(f"Rating change shape: {rating_change_scaled.shape}")
        
        # 合并数据：10个LLM-inferred概念 + 2个CRL未知概念 + 1个rating change = 13个变量
        all_data = np.concatenate([llm_inferred_scaled, author_unknown_scaled, reviewer_unknown_scaled, rating_change_scaled], axis=1)
        print(f"Combined LLM+CRL data shape: {all_data.shape}")
        
        # 10个LLM-inferred概念的特征名称
        llm_feature_names = ['Clarity', 'Directness', 'Attitude', 'Openness', 'Evidence', 
                             'Rigor', 'De-escalation', 'ReviewQuality', 'ReviewerOpenness', 
                             'ConcernSeverity']
        
        # 创建变量名 - 总共13个变量（10个LLM-inferred + 2个CRL未知 + 1个RatingChange）
        var_names = []
        for i in range(10):  # 10个LLM-inferred概念
            var_names.append(f'{llm_feature_names[i]}')
        var_names.append('Author_Unknown_CRL')  # 第11个是CRL学到的未知作者概念
        var_names.append('Reviewer_Unknown_CRL')  # 第12个是CRL学到的未知审稿人概念
        var_names.append('RatingChange')  # 第13个是RatingChange
        
        # 使用PC算法进行因果发现
        print("运行PC算法（LLM+CRL）...")
        knowledge = BackgroundKnowledge()
        knowledge.add_forbidden_by_pattern('RatingChange', '.*')
        print("禁止: RatingChange -> 所有变量")
        cg = pc(all_data, node_names=var_names, indep_test='rcit', alpha=0.03, background_knowledge=knowledge)
        
        # 保存因果图
        print("保存LLM+CRL因果图...")
        pyd = GraphUtils.to_pydot(cg.G)
        pyd.write_png(os.path.join(save_dir, 'llm_crl_causal_graph.png'))
        
        # 获取邻接矩阵
        adj_matrix = cg.G.graph
        np.save(os.path.join(save_dir, 'llm_crl_adjacency_matrix.npy'), adj_matrix)
        
        # 保存图信息
        edge_list = []
        for i in range(len(var_names)):
            for j in range(len(var_names)):
                if adj_matrix[i, j] != 0:
                    edge_list.append((var_names[i], var_names[j], adj_matrix[i, j]))
        
        with open(os.path.join(save_dir, 'llm_crl_causal_graph_edges.txt'), 'w') as f:
            f.write("Source\tTarget\tEdge_Type\n")
            for edge in edge_list:
                edge_type = "directed" if edge[2] == 1 else "undirected"
                f.write(f"{edge[0]}\t{edge[1]}\t{edge_type}\n")
        
        # 计算相关性矩阵作为参考
        corr_matrix = np.corrcoef(all_data.T)
        np.save(os.path.join(save_dir, 'llm_crl_correlation_matrix.npy'), corr_matrix)
        
        # 打印统计信息
        print(f"LLM+CRL因果发现完成！")
        print(f"变量数量: {len(var_names)}")
        print(f"边数量: {len(edge_list)}")
        print(f"LLM概念: {llm_feature_names}")
        print(f"CRL未知概念: Author_Unknown_CRL, Reviewer_Unknown_CRL")
        print(f"保存的文件:")
        print(f"  - llm_crl_adjacency_matrix.npy")
        print(f"  - llm_crl_correlation_matrix.npy")
        print(f"  - llm_crl_causal_graph_edges.txt")
        print(f"  - llm_crl_causal_graph.png")
        
        return cg, adj_matrix, corr_matrix, edge_list
        
    except ImportError as e:
        print(f"警告: 无法导入causallearn库: {e}")
        print("将跳过LLM+CRL因果发现")
        return None, None, None, None

def calculate_shd(baseline_edges, crl_edges):
    """计算Structural Hamming Distance (SHD)"""
    # 创建11x11的邻接矩阵用于比较（只比较11个已知变量，不包括RatingChange）
    baseline_matrix = np.zeros((11, 11))
    crl_matrix = np.zeros((11, 11))
    
    # 11个已知变量的特征名称（用于SHD计算）
    feature_names = ['Clarity', 'Directness', 'Attitude', 'Openness', 'Evidence', 
                     'Rigor', 'De-escalation', 'ReviewQuality', 'ReviewerOpenness', 
                     'ConcernSeverity', 'RatingChange']
    
    # 变量名映射
    var_names = feature_names
    var_to_idx = {name: i for i, name in enumerate(var_names)}
    
    # 填充基线矩阵
    for edge in baseline_edges:
        source, target, edge_type = edge
        if source in var_to_idx and target in var_to_idx:
            i, j = var_to_idx[source], var_to_idx[target]
            baseline_matrix[i, j] = 1 if edge_type == 1 else -1
    
    # 填充CRL矩阵（只考虑10个已知变量）
    for edge in crl_edges:
        source, target, edge_type = edge
        if source in var_to_idx and target in var_to_idx:
            i, j = var_to_idx[source], var_to_idx[target]
            crl_matrix[i, j] = 1 if edge_type == 1 else -1
    
    # 计算SHD：不同边的数量
    shd = np.sum(baseline_matrix != crl_matrix)
    
    return shd, baseline_matrix, crl_matrix

def compare_causal_graphs(baseline_edges, crl_edges, save_dir, iteration=None):
    """比较基线因果图和CRL因果图的差异"""
    # 将边转换为集合以便比较
    baseline_edge_set = set()
    crl_edge_set = set()
    
    # 处理基线图的边（10个变量）
    for edge in baseline_edges:
        source, target, edge_type = edge
        baseline_edge_set.add((source, target, edge_type))
    
    # 处理CRL图的边（12个变量）
    for edge in crl_edges:
        source, target, edge_type = edge
        crl_edge_set.add((source, target, edge_type))
    
    # 计算差异
    common_edges = baseline_edge_set.intersection(crl_edge_set)
    baseline_only = baseline_edge_set - crl_edge_set
    crl_only = crl_edge_set - baseline_edge_set
    
    # 计算SHD
    shd, baseline_matrix, crl_matrix = calculate_shd(baseline_edges, crl_edges)
    
    # 计算相似度
    if len(baseline_edge_set) > 0:
        similarity = len(common_edges) / len(baseline_edge_set)
    else:
        similarity = 0.0
    
    # 分析新增的边（可能来自新发现的变量）
    new_variable_edges = []
    for edge in crl_only:
        source, target, edge_type = edge
        if 'unknown' in source or 'unknown' in target:
            new_variable_edges.append(edge)
    
    # 保存对比结果
    filename = 'graph_comparison.txt'
    if iteration is not None:
        filename = f'graph_comparison_iter_{iteration}.txt'
    
    with open(os.path.join(save_dir, filename), 'w') as f:
        f.write("因果图对比分析结果\n")
        if iteration is not None:
            f.write(f"迭代: {iteration}\n")
        f.write("="*50 + "\n\n")
        f.write(f"基线图（11个变量）边数: {len(baseline_edge_set)}\n")
        f.write(f"CRL图（13个变量）边数: {len(crl_edge_set)}\n")
        f.write(f"共同边数: {len(common_edges)}\n")
        f.write(f"仅在基线图中的边数: {len(baseline_only)}\n")
        f.write(f"仅在CRL图中的边数: {len(crl_only)}\n")
        f.write(f"Structural Hamming Distance (SHD): {shd}\n")
        f.write(f"图相似度: {similarity:.4f} ({similarity*100:.2f}%)\n")
        f.write(f"涉及新变量的边数: {len(new_variable_edges)}\n\n")
        
        f.write("共同边:\n")
        for edge in sorted(common_edges):
            f.write(f"  {edge[0]} -> {edge[1]} ({edge[2]})\n")
        
        f.write("\n仅在基线图中的边:\n")
        for edge in sorted(baseline_only):
            f.write(f"  {edge[0]} -> {edge[1]} ({edge[2]})\n")
        
        f.write("\n仅在CRL图中的边:\n")
        for edge in sorted(crl_only):
            f.write(f"  {edge[0]} -> {edge[1]} ({edge[2]})\n")
        
        f.write("\n涉及新变量的边:\n")
        for edge in sorted(new_variable_edges):
            f.write(f"  {edge[0]} -> {edge[1]} ({edge[2]})\n")
    
    # 计算准确率相关指标
    total_possible_edges = 11 * 10  # 11个变量，每个变量可以指向其他10个变量
    precision = len(common_edges) / len(crl_edge_set) if len(crl_edge_set) > 0 else 0.0
    recall = len(common_edges) / len(baseline_edge_set) if len(baseline_edge_set) > 0 else 0.0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        'baseline_edges': len(baseline_edge_set),
        'crl_edges': len(crl_edge_set),
        'common_edges': len(common_edges),
        'baseline_only': len(baseline_only),
        'crl_only': len(crl_only),
        'similarity': similarity,
        'new_variable_edges': len(new_variable_edges),
        'shd': shd,
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score,
        'total_possible_edges': total_possible_edges
    }

def run_baseline_causal_discovery_once(dataset_dict, save_dir, config=None):
    """运行一次基线因果发现（11个变量）- 直接使用原始数据文件"""
    # 创建保存目录
    os.makedirs(save_dir, exist_ok=True)
    
    # 直接加载原始数据文件（11维，包含RatingChange）
    data_file = "./labeled_data_labels_11features.npy"
    sublabels = np.load(data_file)
    
    print(f"基线因果发现使用原始数据文件: {sublabels.shape[0]} 个样本, {sublabels.shape[1]} 个特征")
    
    # 转换为torch tensor
    sublabels_tensor = torch.from_numpy(sublabels).float()
    
    # 运行基线因果发现
    baseline_cg, baseline_adj, baseline_corr, baseline_edges = perform_baseline_causal_discovery(sublabels_tensor, save_dir, config)
    
    return baseline_edges

def save_learned_latents(model, dataset_dict, device, run_dir):
    """保存学习到的潜在变量"""
    print("保存学习到的潜在变量...")
    
    model.eval()
    
    # 收集所有数据的潜在变量
    all_latents_reviewer = []
    all_latents_author = []
    all_combined_latents = []
    all_domain_labels = []
    all_sublabels = []
    
    # 处理训练数据
    print("处理训练数据...")
    
    # 处理标记的训练数据
    train_labeled_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(
            dataset_dict.train_labeled_data_reviewer,
            dataset_dict.train_labeled_data_author,
            dataset_dict.train_labeled_domain_labels,
            dataset_dict.train_labeled_sublabels
        ),
        batch_size=512, shuffle=False, drop_last=False
    )
    
    for batch in train_labeled_loader:
        x_reviewer, x_author, y, sublabels = batch
        x_reviewer, x_author, y, sublabels = x_reviewer.to(device), x_author.to(device), y.to(device), sublabels.to(device)
        
        with torch.no_grad():
            output_dict = model(x_reviewer, x_author, y, threshold=0.1, sublabels=sublabels)
            
        all_latents_reviewer.append(output_dict.z_reviewer.cpu())
        all_latents_author.append(output_dict.z_author.cpu())
        all_combined_latents.append(output_dict.all_concepts.cpu())  # 使用12个概念而不是10个
        all_domain_labels.append(y.cpu())
        all_sublabels.append(sublabels.cpu())
    
    # 处理未标记的训练数据
    train_unlabeled_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(
            dataset_dict.train_unlabeled_data_reviewer,
            dataset_dict.train_unlabeled_data_author,
            dataset_dict.train_unlabeled_domain_labels
        ),
        batch_size=512, shuffle=False, drop_last=False
    )
    
    for batch in train_unlabeled_loader:
        x_reviewer, x_author, y = batch
        x_reviewer, x_author, y = x_reviewer.to(device), x_author.to(device), y.to(device)
        
        with torch.no_grad():
            output_dict = model(x_reviewer, x_author, y, threshold=0.1)
            
        all_latents_reviewer.append(output_dict.z_reviewer.cpu())
        all_latents_author.append(output_dict.z_author.cpu())
        all_combined_latents.append(output_dict.all_concepts.cpu())  # 使用12个概念而不是10个
        all_domain_labels.append(y.cpu())
        all_sublabels.append(torch.zeros(len(y), 10))  # 未标记数据的监督标签为0
    
    # 处理验证数据
    print("处理验证数据...")
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(
            torch.cat([dataset_dict.val_labeled_data_reviewer, dataset_dict.val_unlabeled_data_reviewer], 0),
            torch.cat([dataset_dict.val_labeled_data_author, dataset_dict.val_unlabeled_data_author], 0),
            torch.cat([dataset_dict.val_labeled_domain_labels, dataset_dict.val_unlabeled_domain_labels], 0),
            torch.cat([dataset_dict.val_labeled_sublabels, torch.zeros(len(dataset_dict.val_unlabeled_data_reviewer), 10).to(dataset_dict.val_labeled_sublabels.device)], 0)
        ),
        batch_size=512, shuffle=False, drop_last=False
    )
    
    for batch in val_loader:
        x_reviewer, x_author, y, sublabels = batch
        x_reviewer, x_author, y, sublabels = x_reviewer.to(device), x_author.to(device), y.to(device), sublabels.to(device)
        
        with torch.no_grad():
            output_dict = model(x_reviewer, x_author, y, threshold=0.1, sublabels=sublabels)
            
        all_latents_reviewer.append(output_dict.z_reviewer.cpu())
        all_latents_author.append(output_dict.z_author.cpu())
        all_combined_latents.append(output_dict.all_concepts.cpu())  # 使用12个概念而不是10个
        all_domain_labels.append(y.cpu())
        all_sublabels.append(sublabels.cpu())
    
    # 连接所有数据
    all_latents_reviewer = torch.cat(all_latents_reviewer, dim=0)
    all_latents_author = torch.cat(all_latents_author, dim=0)
    all_combined_latents = torch.cat(all_combined_latents, dim=0)
    all_domain_labels = torch.cat(all_domain_labels, dim=0)
    all_sublabels = torch.cat(all_sublabels, dim=0)
    
    # 保存数据
    latents_dir = os.path.join(run_dir, 'learned_latents')
    os.makedirs(latents_dir, exist_ok=True)
    
    # 保存为numpy数组
    np.save(os.path.join(latents_dir, 'latents_reviewer.npy'), all_latents_reviewer.numpy())
    np.save(os.path.join(latents_dir, 'latents_author.npy'), all_latents_author.numpy())
    np.save(os.path.join(latents_dir, 'latents_combined.npy'), all_combined_latents.numpy())
    np.save(os.path.join(latents_dir, 'domain_labels.npy'), all_domain_labels.numpy())
    np.save(os.path.join(latents_dir, 'sublabels.npy'), all_sublabels.numpy())
    
    # 打印统计信息
    print(f"保存完成！")
    print(f"总样本数: {len(all_combined_latents)}")
    print(f"审稿人潜在变量形状: {all_latents_reviewer.shape}")  # [N, 4]
    print(f"作者潜在变量形状: {all_latents_author.shape}")  # [N, 8]
    print(f"组合潜在变量形状: {all_combined_latents.shape}")  # [N, 12] - 所有概念
    print(f"领域标签形状: {all_domain_labels.shape}")
    print(f"监督标签形状: {all_sublabels.shape}")  # [N, 10] - 已知概念
    print(f"保存路径: {latents_dir}")
    
    # 计算一些统计信息
    print(f"\n统计信息:")
    print(f"审稿人潜在变量均值: {all_latents_reviewer.mean(dim=0)}")
    print(f"审稿人潜在变量标准差: {all_latents_reviewer.std(dim=0)}")
    print(f"作者潜在变量均值: {all_latents_author.mean(dim=0)}")
    print(f"作者潜在变量标准差: {all_latents_author.std(dim=0)}")
    print(f"组合潜在变量均值: {all_combined_latents.mean(dim=0)}")
    print(f"组合潜在变量标准差: {all_combined_latents.std(dim=0)}")
    
    # 步骤1: 基于sublabels进行基线因果发现（图1）
    print(f"\n步骤1: 基于10个sublabels进行基线因果发现...")
    baseline_cg, baseline_adj, baseline_corr, baseline_edges = perform_baseline_causal_discovery(all_sublabels, latents_dir)
    
    # 步骤2: 基于学习到的真实概念进行CRL因果发现（图2）
    print(f"\n步骤2: 基于12个真实概念进行CRL因果发现...")
    crl_cg, crl_adj, crl_corr, crl_edges = perform_causal_discovery(all_combined_latents, latents_dir)
    
    # 步骤3: 对比两个因果图
    if baseline_edges is not None and crl_edges is not None:
        print(f"\n步骤3: 对比基线图和CRL图...")
        comparison_results = compare_causal_graphs(baseline_edges, crl_edges, latents_dir)
        
        # 输出关键指标
        print(f"\n" + "="*60)
        print("关键指标总结")
        print("="*60)
        print(f"图相似度: {comparison_results['similarity']:.4f} ({comparison_results['similarity']*100:.2f}%)")
        print(f"基线图边数: {comparison_results['baseline_edges']}")
        print(f"CRL图边数: {comparison_results['crl_edges']}")
        print(f"共同边数: {comparison_results['common_edges']}")
        print(f"新增边数: {comparison_results['crl_only']}")
        print(f"涉及新变量的边数: {comparison_results['new_variable_edges']}")
        
        # 保存关键指标
        with open(os.path.join(latents_dir, 'key_metrics.txt'), 'w') as f:
            f.write("关键指标总结\n")
            f.write("="*30 + "\n")
            f.write(f"图相似度: {comparison_results['similarity']:.4f} ({comparison_results['similarity']*100:.2f}%)\n")
            f.write(f"基线图边数: {comparison_results['baseline_edges']}\n")
            f.write(f"CRL图边数: {comparison_results['crl_edges']}\n")
            f.write(f"共同边数: {comparison_results['common_edges']}\n")
            f.write(f"新增边数: {comparison_results['crl_only']}\n")
            f.write(f"涉及新变量的边数: {comparison_results['new_variable_edges']}\n")
        
        print(f"\n关键指标已保存到: {os.path.join(latents_dir, 'key_metrics.txt')}")
    else:
        print("警告: 无法进行图对比，因为因果发现失败")
    
    model.train()

def weights_init(init_type='gaussian', gain=math.sqrt(2)):
    def init_fun(m):
        classname = m.__class__.__name__
        if (classname.find('Conv') == 0 or classname.find('Linear') == 0 or classname.find('Embedding')) and hasattr(m, 'weight') and ('Norm' not in classname):
            if init_type == 'gaussian':
                init.normal_(m.weight.data, 0.0,  gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=0.02)
            elif init_type == 'xavier_uniform':
                init.xavier_uniform_(m.weight.data)
            elif init_type == 'kaiming':
                init.kaiming_uniform_(m.weight.data)
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=math.sqrt(2))
            elif init_type == 'default':
                pass
            else:
                assert 0, "Unsupported initialization: {}".format(init_type)
        elif classname.find('Norm') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
    return init_fun

class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
    def __init__(self, *args, early_stopping=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.early_stopping = early_stopping
        self.early_stopping_counter = 0

    def step(self, metrics, epoch=None, callback_best=None, callback_reduce=None):
        current = metrics
        if epoch is None:
            epoch = self.last_epoch = self.last_epoch + 1
        self.last_epoch = epoch

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
            self.early_stopping_counter = 0
            if callback_best is not None:
                callback_best()
        else:
            self.num_bad_epochs += 1
            self.early_stopping_counter += 1

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.num_bad_epochs = 0

        if self.num_bad_epochs > self.patience:
            if callback_reduce is not None:
                callback_reduce()
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0

        return self.early_stopping_counter == self.early_stopping

# ==========================
# 内联 flow_blocks.py / networks.py 需要的类与函数
# ==========================
import torch.nn as nn
import torch.nn.functional as F
from munch import Munch
from nflows import transforms, distributions, flows

class XTanh(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.tanh(x)+0.1*x

class LinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
        super(LinearBlock, self).__init__()
        use_bias = True
        if norm == 'sn':
            self.fc = nn.utils.spectral_norm(nn.Linear(input_dim, output_dim, bias=use_bias))
        elif norm == 'wn':
            self.fc = nn.utils.weight_norm(nn.Linear(input_dim, output_dim, bias=use_bias))
        else:
            self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm1d(norm_dim)
        elif norm == 'in':
            self.norm = nn.InstanceNorm1d(norm_dim)
        elif norm == 'ln':
            self.norm = nn.LayerNorm(norm_dim)
        elif norm in ('none','sn','wn'):
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'xtanh':
            self.activation = XTanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)
    def forward(self, x):
        out = self.fc(x)
        if self.norm:
            out = self.norm(out)
        if self.activation:
            out = self.activation(out)
        return out

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
        super(MLP, self).__init__()
        self.model = []
        self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
        for i in range(n_blk - 2):
            self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
        self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')]
        self.model = nn.Sequential(*self.model)
    def forward(self, x):
        return self.model(x.view(x.size(0), -1))

def covariance_z_mean(z_mean):
    expectation_z_mean_z_mean_t = torch.mean(z_mean.unsqueeze(2) * z_mean.unsqueeze(1), dim=0)
    expectation_z_mean = torch.mean(z_mean, dim=0)
    cov_z_mean = expectation_z_mean_z_mean_t - (expectation_z_mean.unsqueeze(1) * expectation_z_mean.unsqueeze(0))
    return cov_z_mean

class LearnableAdjMat(nn.Module):
    def __init__(self, latent_dim):
        super(LearnableAdjMat, self).__init__()
        w = torch.nn.init.uniform_(torch.empty(latent_dim, latent_dim), a=0.5-1e-10, b=0.5+1e-10)
        self.trainable_parameters = nn.Parameter(w)
    def forward(self):
        return self.trainable_parameters

# BNAF components (from flow_blocks.py)
import numpy as np
class Sequential(torch.nn.Sequential):
    def forward(self, inputs: torch.Tensor):
        log_det_jacobian = 0.0
        for i, module in enumerate(self._modules.values()):
            inputs, log_det_jacobian_ = module(inputs)
            log_det_jacobian = log_det_jacobian + log_det_jacobian_
        return inputs, log_det_jacobian

class BNAF(torch.nn.Sequential):
    def __init__(self, *args, res: str = None):
        super(BNAF, self).__init__(*args)
        self.res = res
        if res == "gated":
            self.gate = torch.nn.Parameter(torch.nn.init.normal_(torch.Tensor(1)))
    def forward(self, inputs: torch.Tensor):
        outputs = inputs
        grad = None
        for module in self._modules.values():
            outputs, grad = module(outputs, grad)
            grad = grad if len(grad.shape) == 4 else grad.view(grad.shape + [1, 1])
        assert inputs.shape[-1] == outputs.shape[-1]
        if self.res == "normal":
            return inputs + outputs, torch.nn.functional.softplus(grad.squeeze()).sum(-1)
        elif self.res == "gated":
            return self.gate.sigmoid() * outputs + (1 - self.gate.sigmoid()) * inputs, (torch.nn.functional.softplus(grad.squeeze() + self.gate) - torch.nn.functional.softplus(self.gate)).sum(-1)
        else:
            return outputs, grad.squeeze().sum(-1)
    def _get_name(self):
        return "BNAF(res={})".format(self.res)

class Permutation(torch.nn.Module):
    def __init__(self, in_features: int, p: list = None):
        super(Permutation, self).__init__()
        self.in_features = in_features
        if p is None:
            self.p = np.random.permutation(in_features)
        elif p == "flip":
            self.p = list(reversed(range(in_features)))
        else:
            self.p = p
    def forward(self, inputs: torch.Tensor):
        return inputs[:, self.p], 0
    def __repr__(self):
        return "Permutation(in_features={}, p={})".format(self.in_features, self.p)

class MaskedWeight(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, dim: int, bias: bool = True):
        super(MaskedWeight, self).__init__()
        self.in_features, self.out_features, self.dim = in_features, out_features, dim
        weight = torch.zeros(out_features, in_features)
        for i in range(dim):
            weight[i * out_features // dim : (i + 1) * out_features // dim, 0 : (i + 1) * in_features // dim] = torch.nn.init.xavier_uniform_(torch.Tensor(out_features // dim, (i + 1) * in_features // dim))
        self._weight = torch.nn.Parameter(weight)
        self._diag_weight = torch.nn.Parameter(torch.nn.init.uniform_(torch.Tensor(out_features, 1)).log())
        self.bias = (torch.nn.Parameter(torch.nn.init.uniform_(torch.Tensor(out_features), -1 / math.sqrt(out_features), 1 / math.sqrt(out_features))) if bias else 0)
        mask_d = torch.zeros_like(weight)
        for i in range(dim):
            mask_d[i * (out_features // dim) : (i + 1) * (out_features // dim), i * (in_features // dim) : (i + 1) * (in_features // dim),] = 1
        self.register_buffer("mask_d", mask_d)
        mask_o = torch.ones_like(weight)
        for i in range(dim):
            mask_o[i * (out_features // dim) : (i + 1) * (out_features // dim), i * (in_features // dim) :,] = 0
        self.register_buffer("mask_o", mask_o)
    def get_weights(self):
        w = torch.exp(self._weight) * self.mask_d + self._weight * self.mask_o
        w_squared_norm = (w ** 2).sum(-1, keepdim=True)
        w = self._diag_weight.exp() * w / w_squared_norm.sqrt()
        wpl = self._diag_weight + self._weight - 0.5 * torch.log(w_squared_norm)
        return w.t(), wpl.t()[self.mask_d.bool().t()].view(self.dim, self.in_features // self.dim, self.out_features // self.dim)
    def forward(self, inputs, grad: torch.Tensor = None):
        w, wpl = self.get_weights()
        g = wpl.transpose(-2, -1).unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1)
        return (inputs.matmul(w) + self.bias, torch.logsumexp(g.unsqueeze(-2) + grad.transpose(-2, -1).unsqueeze(-3), -1) if grad is not None else g)
    def __repr__(self):
        return "MaskedWeight(in_features={}, out_features={}, dim={}, bias={})".format(self.in_features, self.out_features, self.dim, not isinstance(self.bias, int))

class Tanh(torch.nn.Tanh):
    def forward(self, inputs, grad: torch.Tensor = None):
        g = -2 * (inputs - math.log(2) + torch.nn.functional.softplus(-2 * inputs))
        return (torch.tanh(inputs), (g.view(grad.shape) + grad) if grad is not None else g)

class BNAFModel(torch.nn.Module):
    def __init__(self, n_dims=4, n_flows=5, n_layers=1, hidden_dim=10, residual='gated'):
        super().__init__()
        flows_list = []
        for f in range(n_flows):
            layers = []
            for _ in range(n_layers - 1):
                layers.append(MaskedWeight(n_dims * hidden_dim, n_dims * hidden_dim, dim=n_dims))
                layers.append(Tanh())
            flows_list.append(BNAF(*([MaskedWeight(n_dims, n_dims * hidden_dim, dim=n_dims), Tanh()] + layers + [MaskedWeight(n_dims * hidden_dim, n_dims, dim=n_dims)]), res=residual if f < n_flows - 1 else None))
            if f < n_flows - 1:
                flows_list.append(Permutation(n_dims, "flip"))
        self.model = Sequential(*flows_list)
        self.base_distribution = distributions.StandardNormal(shape=[n_dims])
    def forward(self, x):
        return self.model(x)

# Sigmoid flow utils (subset needed by MarkovFlowVAE)
from torch.autograd import Variable
delta = 1e-6
logsigmoid = lambda x: -F.softplus(-x)
log = lambda x: torch.log(x*1e2)-np.log(1e2)
softplus_ = nn.Softplus()
softplus = lambda x: softplus_(x) + delta
def softmax(x, dim=-1):
    e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0])
    out = e_x / e_x.sum(dim=dim, keepdim=True)
    return out

class DenseSigmoidFlow(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(DenseSigmoidFlow, self).__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.act_a = lambda x: softplus(x)
        self.act_b = lambda x: x
        self.act_w = lambda x: softmax(x, dim=3)
        self.act_u = lambda x: softmax(x, dim=3)
        self.u_ = torch.nn.Parameter(torch.Tensor(hidden_dim, in_dim))
        self.w_ = torch.nn.Parameter(torch.Tensor(out_dim, hidden_dim))
        self.num_params = 3 * hidden_dim + in_dim
        self.reset_parameters()
    def reset_parameters(self):
        self.u_.data.uniform_(-0.001, 0.001)
        self.w_.data.uniform_(-0.001, 0.001)
    def forward(self, x, dsparams, logdet):
        if len(x.size()) == 2:
            x = x.unsqueeze(-1)
        inv = np.log(np.exp(1 - delta) - 1)
        ndim = self.hidden_dim
        pre_u = self.u_[None, None, :, :] + dsparams[:, :, -self.in_dim:][:, :, None, :]
        pre_w = self.w_[None, None, :, :] + dsparams[:, :, 2 * ndim:3 * ndim][:, :, None, :]
        a = self.act_a(dsparams[:, :, 0 * ndim:1 * ndim] + inv)
        b = self.act_b(dsparams[:, :, 1 * ndim:2 * ndim])
        w = self.act_w(pre_w)
        u = self.act_u(pre_u)
        pre_sigm = torch.sum(u * a[:, :, :, None] * x[:, :, None, :], 3) + b
        sigm = torch.sigmoid(pre_sigm)
        x_pre = torch.sum(w * sigm[:, :, None, :], dim=3)
        x_pre_clipped = x_pre * (1 - delta) + delta * 0.5
        x_ = log(x_pre_clipped) - log(1 - x_pre_clipped)
        xnew = x_
        logj = F.log_softmax(pre_w, dim=3) + logsigmoid(pre_sigm[:, :, None, :]) + logsigmoid(-pre_sigm[:, :, None, :]) + log(a[:, :, None, :])
        logj = logj[:, :, :, :, None] + F.log_softmax(pre_u, dim=3)[:, :, None, :, :]
        logj = log_sum_exp(logj, 3).sum(3)
        logdet_ = logj + np.log(1 - delta) - (log(x_pre_clipped) + log(-x_pre_clipped + 1))[:, :, :, None]
        logdet = log_sum_exp(logdet_[:, :, :, :, None] + logdet[:, :, None, :, :], 3).sum(3)
        return xnew, logdet

def log_sum_exp(A, axis=-1, sum_op=torch.sum):
    maximum = lambda x: x.max(axis)[0]
    A_max = oper(A,maximum,axis,True)
    summation = lambda x: sum_op(torch.exp(x-A_max), axis)
    B = torch.log(oper(A,summation,axis,True)) + A_max
    return B

def oper(array,oper,axis=-1,keepdims=False):
    a_oper = oper(array)
    if keepdims:
        shape = []
        for j,s in enumerate(array.size()):
            shape.append(s)
        shape[axis] = -1
        a_oper = a_oper.view(*shape)
    return a_oper

class DDSF(nn.Module):
    def __init__(self, n_blocks=1, hidden_dim=16):
        super().__init__()
        self.num_params = 0
        if n_blocks == 1:
            model = [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=1, out_dim=1)]
        else:
            model = [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=1, out_dim=hidden_dim)]
            for _ in range(n_blocks-2):
                model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=hidden_dim)]
            model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=1)]
        self.model = nn.Sequential(*model)
        for block in self.model:
            self.num_params += block.num_params
    def forward(self, x, dsparams):
        x = x.unsqueeze(2)
        start = 0
        logdet = torch.zeros([len(x),1,1,1]).to(x.device)
        for block in self.model:
            block_dsparams = dsparams[:,:,start:start+block.num_params]
            x, logdet = block(x, block_dsparams, logdet)
            start += block.num_params
        return x.squeeze(2), logdet.view(len(logdet), -1)

# ==========================
# 内联 networks.py 两个模型
# ==========================
class MarkovMLPVAE(nn.Module):
    def __init__(self, config, data_config):
        super(MarkovMLPVAE, self).__init__()
        self.domain_embedding = nn.Embedding(data_config.n_domains, config.embed_dim)
        self.encoder_mu = MLP(config.data_dim, config.latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.encoder_logvar = MLP(config.data_dim, config.latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.decoder = MLP(config.latent_dim, config.data_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.domain_bias_net = MLP(config.embed_dim, config.bias_dim, config.bias_hidden_dim, config.bias_n_layers, config.bias_norm, config.bias_activ)
        self.domain_scale_net = MLP(config.embed_dim, config.latent_dim**2, config.bias_hidden_dim, config.bias_n_layers, config.bias_norm, config.bias_activ)
        self.domain_self_scale_net = MLP(config.embed_dim, config.latent_dim, config.bias_hidden_dim, config.bias_n_layers, config.bias_norm, config.bias_activ)
        self.alpha_net = LearnableAdjMat(config.latent_dim)
        self.alpha_threshold = config.alpha_threshold
        self.x_std = config.x_std
        self.flow = BNAFModel(4,5,0,10)
        self.domain_embedding.apply(weights_init(config.init))
        self.encoder_mu.apply(weights_init(config.init))
        self.encoder_logvar.apply(weights_init(config.init))
        self.decoder.apply(weights_init(config.init))
        self.domain_bias_net.apply(weights_init(config.init))
        self.domain_scale_net.apply(weights_init(config.init))
        self.domain_self_scale_net.apply(weights_init(config.init))
        self.n_domains = data_config.n_domains
        self.noise = config.noise
        torch.nn.init.zeros_(self.domain_scale_net.model[-1].fc.weight.data)
        torch.nn.init.zeros_(self.domain_scale_net.model[-1].fc.bias.data)
        torch.nn.init.zeros_(self.domain_bias_net.model[-1].fc.weight.data)
        torch.nn.init.zeros_(self.domain_bias_net.model[-1].fc.bias.data)
    def encode(self, x, y):
        xy = x
        mu = self.encoder_mu(xy)
        logvar = self.encoder_logvar(xy)
        return mu, logvar
    def decode(self, x):
        return self.decoder(x)
    def reparameterize(self, mu, logvar, noise):
        if noise == 'gaussian':
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps*std + mu
        else:
            scale = torch.exp(logvar/2)
            u = Variable(torch.rand(mu.size()).type_as(mu.data)) - 0.5
            sample = mu - scale * torch.sign(u) * torch.log(1 - 2 * torch.abs(u) + 1e-8)
            return sample
    def forward(self, x, y, threshold):
        y2 = F.one_hot(y.long(), self.n_domains).float()
        mu, logvar = self.encode(x, y2)
        y = self.domain_embedding(y.long().squeeze())
        hidden = self.reparameterize(mu, logvar, 'laplace') if self.training else mu
        z, logdet_flow = hidden, 0
        xhat = self.decode(z)
        unmasked_alpha = self.alpha_net()
        domain_embeddings = y
        bias = self.domain_bias_net(domain_embeddings)
        scale = (self.domain_scale_net(domain_embeddings))
        self_scale = (self.domain_self_scale_net(domain_embeddings))
        latent_dim = z.size()[-1]
        B = len(z)
        bias = bias.unsqueeze(1)
        scale = scale.view(len(scale), latent_dim, latent_dim)
        scale = (torch.triu(scale+1, 1))
        y_all = torch.arange(self.n_domains).long().to(x.device)
        y_all_emb = self.domain_embedding(y_all)
        scale_all = self.domain_scale_net(y_all_emb).view(self.n_domains, latent_dim, latent_dim)
        scale_all = scale_all + 1.
        diag_scale_all = torch.triu(scale_all, 1)
        mean_scale_all = torch.mean(torch.abs(diag_scale_all), 0, keepdim=True)
        alpha = (mean_scale_all>threshold).float()
        epsilons = (z - torch.bmm(z.unsqueeze(1), scale*alpha).squeeze(1) -bias.squeeze(1))
        if self.noise == 'gaussian':
            q_dist = D.Normal(mu, torch.exp(logvar / 2))
            base_dist = D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
        elif self.noise == 'laplace':
            q_dist = D.Laplace(mu, torch.exp(logvar/2))
            base_dist = D.Laplace(0, 1)
        log_qz = q_dist.log_prob(hidden)
        batch_log_qz = torch.sum(log_qz, 1)
        log_p_eps = base_dist.log_prob(epsilons.squeeze())
        log_pz = log_p_eps - 0
        batch_log_pz = torch.sum(log_pz, 1)
        ele_kld = torch.mean(log_qz-log_pz, 0)
        loss_dip = ((ele_kld-ele_kld.min().detach())**2).sum()
        loss_kld = torch.sum((ele_kld))
        self.x_dist = D.Normal(xhat, self.x_std)
        log_px = self.x_dist.log_prob(x).sum(-1)
        elbo = (log_px + batch_log_pz - batch_log_qz).mean()
        loss_sparsity = torch.mean(torch.abs(scale_all))
        loss_rec = -log_px.mean()
        output_dict = Munch(elbo=elbo, loss_sparsity=loss_sparsity, mu=mu, z=z, epsilons=epsilons, adj_mat=alpha, loss_kld=loss_kld, loss_rec=loss_rec, scale=mean_scale_all, self_scale=self_scale, bias=bias, loss_dip=loss_dip, ele_kld=ele_kld)
        return output_dict

class MarkovFlowVAE(nn.Module):
    def __init__(self, config, data_config):
        super(MarkovFlowVAE, self).__init__()
        self.domain_embedding = nn.Embedding(data_config.n_domains, config.embed_dim)
        self.encoder_mu = MLP(config.data_dim+data_config.n_domains, config.latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.encoder_logvar = MLP(config.data_dim+data_config.n_domains, config.latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.decoder = MLP(config.latent_dim, config.data_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.alpha_net = LearnableAdjMat(config.latent_dim)
        self.alpha_threshold = config.alpha_threshold
        self.x_std = config.x_std
        self.flow = DDSF(1, 64)
        # transformer should take latent-size vectors (cond_input has shape [B*latent_dim, latent_dim])
        self.transformer = MLP(config.latent_dim, config.embed_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.flow_param_net = MLP(config.embed_dim+data_config.n_domains, self.flow.num_params, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.domain_embedding.apply(weights_init(config.init))
        self.encoder_mu.apply(weights_init(config.init))
        self.encoder_logvar.apply(weights_init(config.init))
        self.decoder.apply(weights_init(config.init))
        self.transformer.apply(weights_init(config.init))
        self.flow_param_net.apply(weights_init(config.init))
        self.n_domains = data_config.n_domains
        self.noise = config.noise
    def encode(self, x, y):
        x = torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6)
        xy = torch.cat([x,y], 1)
        mu = self.encoder_mu(xy)
        logvar = self.encoder_logvar(xy)
        return mu, logvar
    def decode(self, x):
        return self.decoder(x)
    def reparameterize(self, mu, logvar, noise):
        logvar = torch.clamp(logvar, -10.0, 10.0)
        if noise == 'gaussian':
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps*std + mu
        else:
            scale = torch.exp(logvar/2)
            u = Variable(torch.rand(mu.size()).type_as(mu.data)) - 0.5
            sample = mu - scale * torch.sign(u) * torch.log(1 - 2 * torch.abs(u) + 1e-8)
            return sample
    def forward(self, x, y, threshold):
        y2 = F.one_hot(y.long(), self.n_domains).float()
        mu, logvar = self.encode(x, y2)
        y = self.domain_embedding(y.long().squeeze())
        hidden = self.reparameterize(mu, logvar, self.noise) if self.training else mu
        z, logdet_flow = hidden, 0
        xhat = self.decode(z)
        unmasked_alpha = self.alpha_net()
        domain_embeddings = y
        latent_dim = z.size()[-1]
        B = len(z)
        alpha = unmasked_alpha
        alpha = torch.tril(alpha, -1)
        alpha = alpha * (alpha>self.alpha_threshold).float()
        reshape_z = z.view(-1, 1)
        cond_input = torch.repeat_interleave(z, latent_dim, dim=0).view(B, latent_dim, latent_dim)
        cond_input = alpha.unsqueeze(0) * cond_input
        cond_input = cond_input.view(B*latent_dim, latent_dim)
        cond_input = torch.nan_to_num(cond_input, nan=0.0, posinf=1e6, neginf=-1e6)
        cond_input = self.transformer(cond_input)
        cond_embedding = torch.repeat_interleave(y2, latent_dim, dim=0)
        cond_input = torch.cat([cond_input, cond_embedding], 1)
        flow_params = self.flow_param_net(cond_input).unsqueeze(1)
        epsilons, neg_logdet = self.flow(reshape_z, flow_params)
        logdet = -neg_logdet
        epsilons = epsilons.view(B, latent_dim)
        logdet = logdet.view(B, latent_dim)
        if self.noise == 'gaussian':
            q_dist = D.Normal(mu, torch.exp(logvar / 2))
            base_dist = D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
        elif self.noise == 'laplace':
            q_dist = D.Laplace(mu, torch.exp(logvar/2))
            base_dist = D.Laplace(0, 1)
        log_qz = q_dist.log_prob(hidden)
        batch_log_qz = torch.sum(log_qz, 1)
        log_p_eps = base_dist.log_prob(epsilons.squeeze())
        log_pz = log_p_eps - logdet
        batch_log_pz = torch.sum(log_pz, 1)
        ele_kld = torch.mean(log_qz-log_pz, 0)
        loss_dip = ((ele_kld-ele_kld.min().detach())**2).sum()
        loss_kld = torch.sum((ele_kld))
        self.x_dist = D.Normal(xhat, self.x_std)
        log_px = self.x_dist.log_prob(x).sum(-1)
        elbo = (log_px + batch_log_pz - batch_log_qz).mean()
        loss_sparsity = torch.mean(torch.abs(unmasked_alpha))
        loss_rec = -log_px.mean()
        output_dict = Munch(elbo=elbo, loss_sparsity=loss_sparsity, mu=mu, z=z, epsilons=epsilons, adj_mat=alpha, loss_kld=loss_kld, loss_rec=loss_rec, loss_dip=loss_dip, ele_kld=ele_kld)
        return output_dict

class MultiModalMarkovFlowVAE(nn.Module):
    def __init__(self, config, data_config):
        super(MultiModalMarkovFlowVAE, self).__init__()
        self.domain_embedding = nn.Embedding(data_config.n_domains, config.embed_dim)
        
        # Define latent dimensions for each modality
        self.author_latent_dim = 8  # Author modality gets 8 latent variables (7 known + 1 unknown)
        self.reviewer_latent_dim = 4  # Reviewer modality gets 4 latent variables (3 known + 1 unknown)
        self.total_latent_dim = self.author_latent_dim + self.reviewer_latent_dim  # Total 12
        
        # Two encoders for two modalities with different latent dimensions
        self.encoder_mu_reviewer = MLP(config.data_dim+data_config.n_domains, self.reviewer_latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.encoder_logvar_reviewer = MLP(config.data_dim+data_config.n_domains, self.reviewer_latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.encoder_mu_author = MLP(config.data_dim+data_config.n_domains, self.author_latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.encoder_logvar_author = MLP(config.data_dim+data_config.n_domains, self.author_latent_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        
        # Two decoders for two modalities
        self.decoder_reviewer = MLP(self.reviewer_latent_dim, config.data_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.decoder_author = MLP(self.author_latent_dim, config.data_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        
        # Weak supervision layers: linear transformations from latents to 10 features
        # Note: We have 8 author latents but only supervise 7 known concepts, 1 unknown
        # We have 4 reviewer latents but only supervise 3 known concepts, 1 unknown
        self.author_supervision_layer = nn.Linear(self.author_latent_dim, 7)  # 7 known features from author
        self.reviewer_supervision_layer = nn.Linear(self.reviewer_latent_dim, 3)  # 3 known features from reviewer
        
        # Use the larger latent dimension for alpha_net
        self.alpha_net = LearnableAdjMat(max(self.author_latent_dim, self.reviewer_latent_dim))
        self.alpha_threshold = config.alpha_threshold
        self.x_std = config.x_std
        self.flow = DDSF(1, 64)
        
        # transformer should take latent-size vectors (cond_input has shape [B*latent_dim, latent_dim])
        # For multi-modal, we need to handle total_latent_dim input
        self.transformer = MLP(self.total_latent_dim, config.embed_dim, config.hidden_dim, config.n_layers, config.norm, config.activ)
        self.flow_param_net = MLP(config.embed_dim+data_config.n_domains, self.flow.num_params, config.hidden_dim, config.n_layers, config.norm, config.activ)
        
        # Initialize weights
        self.domain_embedding.apply(weights_init(config.init))
        self.encoder_mu_reviewer.apply(weights_init(config.init))
        self.encoder_logvar_reviewer.apply(weights_init(config.init))
        self.encoder_mu_author.apply(weights_init(config.init))
        self.encoder_logvar_author.apply(weights_init(config.init))
        self.decoder_reviewer.apply(weights_init(config.init))
        self.decoder_author.apply(weights_init(config.init))
        self.author_supervision_layer.apply(weights_init(config.init))
        self.reviewer_supervision_layer.apply(weights_init(config.init))
        self.transformer.apply(weights_init(config.init))
        self.flow_param_net.apply(weights_init(config.init))
        
        self.n_domains = data_config.n_domains
        self.noise = config.noise
        
    def encode(self, x_reviewer, x_author, y):
        x_reviewer = torch.nan_to_num(x_reviewer, nan=0.0, posinf=1e6, neginf=-1e6)
        x_author = torch.nan_to_num(x_author, nan=0.0, posinf=1e6, neginf=-1e6)
        
        xy_reviewer = torch.cat([x_reviewer, y], 1)
        xy_author = torch.cat([x_author, y], 1)
        
        mu_reviewer = self.encoder_mu_reviewer(xy_reviewer)
        logvar_reviewer = self.encoder_logvar_reviewer(xy_reviewer)
        mu_author = self.encoder_mu_author(xy_author)
        logvar_author = self.encoder_logvar_author(xy_author)
        
        # 确保logvar在合理范围内
        logvar_reviewer = torch.clamp(logvar_reviewer, -10.0, 10.0)
        logvar_author = torch.clamp(logvar_author, -10.0, 10.0)
        
        return mu_reviewer, logvar_reviewer, mu_author, logvar_author
        
    def decode(self, z_reviewer, z_author):
        xhat_reviewer = self.decoder_reviewer(z_reviewer)
        xhat_author = self.decoder_author(z_author)
        return xhat_reviewer, xhat_author
        
    def reparameterize(self, mu, logvar, noise):
        logvar = torch.clamp(logvar, -10.0, 10.0)
        if noise == 'gaussian':
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps*std + mu
        else:
            scale = torch.exp(logvar/2)
            u = Variable(torch.rand(mu.size()).type_as(mu.data)) - 0.5
            sample = mu - scale * torch.sign(u) * torch.log(1 - 2 * torch.abs(u) + 1e-8)
            return sample
            
    def forward(self, x_reviewer, x_author, y, threshold, sublabels=None):
        y2 = F.one_hot(y.long(), self.n_domains).float()
        mu_reviewer, logvar_reviewer, mu_author, logvar_author = self.encode(x_reviewer, x_author, y2)
        y = self.domain_embedding(y.long().squeeze())
        
        # Reparameterize both modalities
        hidden_reviewer = self.reparameterize(mu_reviewer, logvar_reviewer, self.noise) if self.training else mu_reviewer
        hidden_author = self.reparameterize(mu_author, logvar_author, self.noise) if self.training else mu_author
        
        # Keep separate latents for decoding
        z_reviewer, z_author = hidden_reviewer, hidden_author
        
        # Decode both modalities
        xhat_reviewer, xhat_author = self.decode(z_reviewer, z_author)
        
        # Weak supervision: transform latents to true concepts c, then add noise to get noisy estimates c̃
        author_concepts = self.author_supervision_layer(z_author)  # [B, 7] - true author concepts c
        reviewer_concepts = self.reviewer_supervision_layer(z_reviewer)  # [B, 3] - true reviewer concepts c
        
        # Add noise to get noisy estimates c̃ (for supervision)
        # Note: In practice, the noise might come from the model's uncertainty or be learned
        # Here we use a small amount of noise to simulate the noisy nature of estimates
        noise_std = 0.1  # You can make this learnable or configurable
        author_noise = torch.randn_like(author_concepts) * noise_std
        reviewer_noise = torch.randn_like(reviewer_concepts) * noise_std
        
        # Noisy estimates c̃ for supervision
        author_features_noisy = author_concepts + author_noise  # c̃_author = c_author + η_author
        reviewer_features_noisy = reviewer_concepts + reviewer_noise  # c̃_reviewer = c_reviewer + η_reviewer
        
        # Extract unknown variables (last dimension of each modality)
        author_unknown = z_author[:, -1:]  # [B, 1] - unknown author concept
        reviewer_unknown = z_reviewer[:, -1:]  # [B, 1] - unknown reviewer concept
        
        # Combine all 12 concepts: 7 true author + 1 unknown author + 3 true reviewer + 1 unknown reviewer
        all_concepts = torch.cat([author_concepts, author_unknown, reviewer_concepts, reviewer_unknown], dim=1)  # [B, 12]
        
        # For supervision loss, use the noisy estimates c̃ (only 10 features for training)
        predicted_features = torch.cat([author_features_noisy, reviewer_features_noisy], dim=1)  # [B, 10]
        
        # Weak supervision loss (MSE)
        loss_supervision = 0.0
        if sublabels is not None:
            loss_supervision = F.mse_loss(predicted_features, sublabels)
        
        unmasked_alpha = self.alpha_net()
        domain_embeddings = y
        B = len(z_reviewer)
        
        # Process through flow using combined latents
        alpha = unmasked_alpha
        alpha = torch.tril(alpha, -1)
        alpha = alpha * (alpha>self.alpha_threshold).float()
        
        # Concatenate latents for flow processing
        z_combined = torch.cat([z_reviewer, z_author], dim=1)  # [B, total_latent_dim]
        
        # Reshape combined latents for flow
        reshape_z_combined = z_combined.view(-1, 1)
        
        # Create conditional input for flow
        cond_input = torch.repeat_interleave(z_combined, self.total_latent_dim, dim=0).view(B, self.total_latent_dim, self.total_latent_dim)
        # Create a larger alpha matrix for combined latents
        alpha_combined = torch.zeros(self.total_latent_dim, self.total_latent_dim).to(alpha.device)
        # Copy the original alpha to appropriate positions
        alpha_combined[:self.reviewer_latent_dim, :self.reviewer_latent_dim] = alpha[:self.reviewer_latent_dim, :self.reviewer_latent_dim]
        alpha_combined[self.reviewer_latent_dim:, self.reviewer_latent_dim:] = alpha[:self.author_latent_dim, :self.author_latent_dim]
        alpha_expanded = alpha_combined.unsqueeze(0).expand(B, self.total_latent_dim, self.total_latent_dim)
        cond_input = alpha_expanded * cond_input
        cond_input = cond_input.view(B*self.total_latent_dim, self.total_latent_dim)
        cond_input = torch.nan_to_num(cond_input, nan=0.0, posinf=1e6, neginf=-1e6)
        cond_input = self.transformer(cond_input)
        cond_embedding = torch.repeat_interleave(y2, self.total_latent_dim, dim=0)
        cond_input = torch.cat([cond_input, cond_embedding], 1)
        flow_params = self.flow_param_net(cond_input).unsqueeze(1)
        epsilons_combined, neg_logdet = self.flow(reshape_z_combined, flow_params)
        logdet = -neg_logdet
        epsilons_combined = epsilons_combined.view(B, self.total_latent_dim)
        logdet = logdet.view(B, self.total_latent_dim)
        
        # Split epsilons back to individual modalities
        epsilons_reviewer = epsilons_combined[:, :self.reviewer_latent_dim]
        epsilons_author = epsilons_combined[:, self.reviewer_latent_dim:]
        
        # Compute losses for both modalities
        if self.noise == 'gaussian':
            q_dist_reviewer = D.Normal(mu_reviewer, torch.exp(logvar_reviewer / 2))
            q_dist_author = D.Normal(mu_author, torch.exp(logvar_author / 2))
            base_dist_reviewer = D.Normal(torch.zeros_like(mu_reviewer), torch.ones_like(logvar_reviewer))
            base_dist_author = D.Normal(torch.zeros_like(mu_author), torch.ones_like(logvar_author))
        elif self.noise == 'laplace':
            q_dist_reviewer = D.Laplace(mu_reviewer, torch.exp(logvar_reviewer/2))
            q_dist_author = D.Laplace(mu_author, torch.exp(logvar_author/2))
            base_dist_reviewer = D.Laplace(0, 1)
            base_dist_author = D.Laplace(0, 1)
            
        log_qz_reviewer = q_dist_reviewer.log_prob(hidden_reviewer)
        log_qz_author = q_dist_author.log_prob(hidden_author)
        batch_log_qz_reviewer = torch.sum(log_qz_reviewer, 1)
        batch_log_qz_author = torch.sum(log_qz_author, 1)
        
        log_p_eps_reviewer = base_dist_reviewer.log_prob(epsilons_reviewer.squeeze())
        log_p_eps_author = base_dist_author.log_prob(epsilons_author.squeeze())
        log_pz_reviewer = log_p_eps_reviewer - logdet[:, :self.reviewer_latent_dim]
        log_pz_author = log_p_eps_author - logdet[:, self.reviewer_latent_dim:]
        
        batch_log_pz_reviewer = torch.sum(log_pz_reviewer, 1)
        batch_log_pz_author = torch.sum(log_pz_author, 1)
        
        # Combined KLD for both modalities
        ele_kld_reviewer = torch.mean(log_qz_reviewer-log_pz_reviewer, 0)
        ele_kld_author = torch.mean(log_qz_author-log_pz_author, 0)
        # Pad the smaller one to match the larger one
        if ele_kld_reviewer.size(0) < ele_kld_author.size(0):
            ele_kld_reviewer_padded = torch.cat([ele_kld_reviewer, torch.zeros(ele_kld_author.size(0) - ele_kld_reviewer.size(0)).to(ele_kld_reviewer.device)])
            ele_kld = (ele_kld_reviewer_padded + ele_kld_author) / 2
        else:
            ele_kld_author_padded = torch.cat([ele_kld_author, torch.zeros(ele_kld_reviewer.size(0) - ele_kld_author.size(0)).to(ele_kld_author.device)])
            ele_kld = (ele_kld_reviewer + ele_kld_author_padded) / 2
        
        loss_dip = ((ele_kld-ele_kld.min().detach())**2).sum()
        loss_kld = torch.sum(ele_kld)
        
        # Reconstruction losses for both modalities
        x_dist_reviewer = D.Normal(xhat_reviewer, self.x_std)
        x_dist_author = D.Normal(xhat_author, self.x_std)
        log_px_reviewer = x_dist_reviewer.log_prob(x_reviewer).sum(-1)
        log_px_author = x_dist_author.log_prob(x_author).sum(-1)
        
        # Combined ELBO
        elbo = (log_px_reviewer + log_px_author + batch_log_pz_reviewer + batch_log_pz_author - batch_log_qz_reviewer - batch_log_qz_author).mean()
        loss_sparsity = torch.mean(torch.abs(unmasked_alpha))
        loss_rec = -(log_px_reviewer + log_px_author).mean()
        
        output_dict = Munch(
            elbo=elbo, 
            loss_sparsity=loss_sparsity, 
            mu_reviewer=mu_reviewer, 
            mu_author=mu_author,
            z_reviewer=z_reviewer, 
            z_author=z_author,
            epsilons_reviewer=epsilons_reviewer, 
            epsilons_author=epsilons_author,
            adj_mat=alpha, 
            loss_kld=loss_kld, 
            loss_rec=loss_rec, 
            loss_dip=loss_dip, 
            loss_supervision=loss_supervision,
            ele_kld=ele_kld,
            xhat_reviewer=xhat_reviewer,
            xhat_author=xhat_author,
            # True concepts c (clean, no noise)
            author_concepts=author_concepts,  # [B, 7] - true author concepts c
            reviewer_concepts=reviewer_concepts,  # [B, 3] - true reviewer concepts c
            # Noisy estimates c̃ (for supervision)
            predicted_features=predicted_features,  # [B, 10] - noisy estimates c̃ for supervision
            # All concepts including unknown
            all_concepts=all_concepts,  # [B, 12] - all concepts including unknown
            author_unknown=author_unknown,  # [B, 1] - unknown author concept
            reviewer_unknown=reviewer_unknown  # [B, 1] - unknown reviewer concept
        )
        return output_dict

# ==========================_
# 内联 trainer.py
# ==========================
class ToyTrainer(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        if config.assumption == 'flow':
            self.model = MarkovFlowVAE(config.model, config.data).to(device)
        else:
            self.model = MarkovMLPVAE(config.model, config.data).to(device)
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=config.base_lr*config.batch_size)
        self.config = config
        def lambda_rule(epoch):
            return 1.
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda=lambda_rule)
    def train_step(self, x, y, iterations):
        threshold = self.config.model.alpha_threshold
        output_dict = self.model(x, y, threshold)
        num_nonzero = torch.sum((output_dict.adj_mat)>0)
        if self.config.sparsity_policy == 'slow':
            lambda_sparsity = self.config.lambda_sparsity * min(1., num_nonzero/6)
        else:
            lambda_sparsity = self.config.lambda_sparsity
        lambda_kld = self.config.lambda_kld
        lambda_rec = self.config.lambda_rec * (0.5+0.5*min(1, iterations/5000))
        lambda_dip = self.config.lambda_dip
        self.optim.zero_grad()
        loss = lambda_sparsity * output_dict.loss_sparsity + lambda_rec * output_dict.loss_rec + lambda_kld * output_dict.loss_kld + lambda_dip * output_dict.loss_dip
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
        self.optim.step()
        output_dict.lambda_sparsity = lambda_sparsity
        output_dict.lambda_rec = lambda_rec
        output_dict.lambda_kld = lambda_kld
        output_dict.lambda_dip = lambda_dip
        return output_dict
    def save(self, checkpoint_dir):
        filename = os.path.join(checkpoint_dir,'best.pt')
        tmp_dict = {'model': self.model.state_dict()}
        torch.save(tmp_dict, filename)
    def resume(self, checkpoint_dir):
        filename = os.path.join(checkpoint_dir, 'best.pt')
        tmp_dict = torch.load(filename)
        self.model.load_state_dict(tmp_dict['model'])
    @torch.no_grad()
    def evaluate(self, model, val_dataloader, iterations, plot=True, filename=None):
        threshold = self.config.model.alpha_threshold
        total_elbo = []
        model.eval()
        estimated_latents = []
        estimated_epsilons = []
        for it, batch in enumerate(val_dataloader):
            if len(batch) == 2:
                x, y = batch
            else:
                x, y = batch[0], batch[1]
            output_dict = model(x,y, threshold)
            total_elbo.append(output_dict.elbo)
            estimated_latents.append(output_dict.z.squeeze().cpu())
            estimated_epsilons.append(output_dict.epsilons.cpu())
        total_elbo = torch.stack(total_elbo).mean()
        metric_dict = {'elbo': total_elbo, 'adj_mat': output_dict.adj_mat}
        model.train()
        # No GT: skip MCC plots against true latents/eps
        from munch import Munch as _M
        metric_dict = _M.fromDict(metric_dict)
        return metric_dict

class MultiModalToyTrainer(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        if config.assumption == 'flow':
            self.model = MultiModalMarkovFlowVAE(config.model, config.data).to(device)
        else:
            raise NotImplementedError("MultiModalMarkovMLPVAE not implemented yet")
        self.optim = torch.optim.AdamW(self.model.parameters(), lr=config.base_lr*config.batch_size)
        self.config = config
        def lambda_rule(epoch):
            return 1.
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda=lambda_rule)
        
    def train_step(self, x_reviewer, x_author, y, iterations, sublabels=None):
        threshold = self.config.model.alpha_threshold
        output_dict = self.model(x_reviewer, x_author, y, threshold, sublabels)
        num_nonzero = torch.sum((output_dict.adj_mat)>0)
        if self.config.sparsity_policy == 'slow':
            lambda_sparsity = self.config.lambda_sparsity * min(1., num_nonzero/6)
        else:
            lambda_sparsity = self.config.lambda_sparsity
        lambda_kld = self.config.lambda_kld
        lambda_rec = self.config.lambda_rec * (0.5+0.5*min(1, iterations/5000))
        lambda_dip = self.config.lambda_dip
        lambda_supervision = getattr(self.config, 'lambda_supervision', 1.0)  # Default to 1.0 if not specified
        self.optim.zero_grad()
        loss = lambda_sparsity * output_dict.loss_sparsity + lambda_rec * output_dict.loss_rec + lambda_kld * output_dict.loss_kld + lambda_dip * output_dict.loss_dip + lambda_supervision * output_dict.loss_supervision
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
        self.optim.step()
        output_dict.lambda_sparsity = lambda_sparsity
        output_dict.lambda_rec = lambda_rec
        output_dict.lambda_kld = lambda_kld
        output_dict.lambda_dip = lambda_dip
        output_dict.lambda_supervision = lambda_supervision
        return output_dict
        
    def save(self, checkpoint_dir):
        filename = os.path.join(checkpoint_dir,'best.pt')
        tmp_dict = {'model': self.model.state_dict()}
        torch.save(tmp_dict, filename)
        
    def resume(self, checkpoint_dir):
        filename = os.path.join(checkpoint_dir, 'best.pt')
        tmp_dict = torch.load(filename)
        self.model.load_state_dict(tmp_dict['model'])
    
    @torch.no_grad()
    def get_labeled_concepts(self, model, threshold, dataset_dict):
        """获取labeled数据的恢复概念（2393样本，12特征）"""
        model.eval()
        labeled_concepts_list = []
        
        # 使用labeled数据进行推理
        labeled_dataset = torch.utils.data.TensorDataset(
            dataset_dict.val_labeled_data_reviewer,
            dataset_dict.val_labeled_data_author,
            dataset_dict.val_labeled_domain_labels
        )
        labeled_dataloader = torch.utils.data.DataLoader(labeled_dataset, batch_size=32, shuffle=False)
        
        for batch in labeled_dataloader:
            x_reviewer, x_author, y = batch
            output_dict = model(x_reviewer, x_author, y, threshold)
            labeled_concepts_list.append(output_dict.all_concepts.cpu())
        
        return torch.cat(labeled_concepts_list, dim=0)
        
    @torch.no_grad()
    def evaluate(self, model, val_dataloader, iterations, plot=True, filename=None, baseline_edges=None, run_dir=None, dataset_dict=None, config=None):
        threshold = self.config.model.alpha_threshold
        total_elbo = []
        model.eval()
        estimated_latents_reviewer = []
        estimated_latents_author = []
        estimated_epsilons_reviewer = []
        estimated_epsilons_author = []
        all_concepts_list = []
        
        for it, batch in enumerate(val_dataloader):
            if len(batch) == 3:
                x_reviewer, x_author, y = batch
            else:
                x_reviewer, x_author, y = batch[0], batch[1], batch[2]
            output_dict = model(x_reviewer, x_author, y, threshold)
            total_elbo.append(output_dict.elbo)
            estimated_latents_reviewer.append(output_dict.z_reviewer.squeeze().cpu())
            estimated_latents_author.append(output_dict.z_author.squeeze().cpu())
            estimated_epsilons_reviewer.append(output_dict.epsilons_reviewer.cpu())
            estimated_epsilons_author.append(output_dict.epsilons_author.cpu())
            all_concepts_list.append(output_dict.all_concepts.cpu())
        
        total_elbo = torch.stack(total_elbo).mean()
        
        # 使用labeled数据进行因果发现（2393样本，12特征）
        labeled_concepts = self.get_labeled_concepts(model, threshold, dataset_dict)
        
        # 创建永久保存目录
        causal_save_dir = os.path.join(run_dir, 'causal_discovery', f'iter_{iterations}')
        os.makedirs(causal_save_dir, exist_ok=True)
        
        try:
            # 运行因果发现（使用labeled数据：2393样本，12特征）
            cg, adj_matrix, corr_matrix, edge_list = perform_causal_discovery(labeled_concepts, causal_save_dir, config)
            
            # 运行LLM+CRL因果发现
            print("运行LLM+CRL因果发现...")
            llm_crl_cg, llm_crl_adj_matrix, llm_crl_corr_matrix, llm_crl_edge_list = perform_llm_inferred_causal_discovery(labeled_concepts, causal_save_dir, config)
            
            if edge_list is not None:
                # 统计因果发现结果
                num_edges = len(edge_list)
                num_directed = sum(1 for edge in edge_list if edge[2] == 1)
                num_undirected = sum(1 for edge in edge_list if edge[2] == -1)
                
                # 统计涉及新变量的边
                new_var_edges = sum(1 for edge in edge_list if 'unknown' in edge[0] or 'unknown' in edge[1])
                
                # 统计LLM+CRL因果发现结果
                llm_crl_edges = 0
                llm_crl_directed = 0
                llm_crl_undirected = 0
                llm_crl_new_var_edges = 0
                
                if llm_crl_edge_list is not None:
                    llm_crl_edges = len(llm_crl_edge_list)
                    llm_crl_directed = sum(1 for edge in llm_crl_edge_list if edge[2] == 1)
                    llm_crl_undirected = sum(1 for edge in llm_crl_edge_list if edge[2] == -1)
                    llm_crl_new_var_edges = sum(1 for edge in llm_crl_edge_list if 'CRL' in edge[0] or 'CRL' in edge[1])
                
                # 与基线图对比
                if baseline_edges is not None:
                    comparison_results = compare_causal_graphs(baseline_edges, edge_list, causal_save_dir, iteration=iterations)
                    
                    # 将对比结果添加到metric_dict
                    metric_dict = {
                        'elbo': total_elbo, 
                        'adj_mat': output_dict.adj_mat,
                        'causal_edges': num_edges,
                        'causal_directed': num_directed,
                        'causal_undirected': num_undirected,
                        'causal_new_var_edges': new_var_edges,
                        'llm_crl_edges': llm_crl_edges,
                        'llm_crl_directed': llm_crl_directed,
                        'llm_crl_undirected': llm_crl_undirected,
                        'llm_crl_new_var_edges': llm_crl_new_var_edges,
                        'shd': comparison_results['shd'],
                        'similarity': comparison_results['similarity'],
                        'baseline_edges': comparison_results['baseline_edges'],
                        'common_edges': comparison_results['common_edges'],
                        'precision': comparison_results['precision'],
                        'recall': comparison_results['recall'],
                        'f1_score': comparison_results['f1_score'],
                        'baseline_only': comparison_results['baseline_only'],
                        'crl_only': comparison_results['crl_only']
                    }
                else:
                    metric_dict = {
                        'elbo': total_elbo, 
                        'adj_mat': output_dict.adj_mat,
                        'causal_edges': num_edges,
                        'causal_directed': num_directed,
                        'causal_undirected': num_undirected,
                        'causal_new_var_edges': new_var_edges,
                        'llm_crl_edges': llm_crl_edges,
                        'llm_crl_directed': llm_crl_directed,
                        'llm_crl_undirected': llm_crl_undirected,
                        'llm_crl_new_var_edges': llm_crl_new_var_edges
                    }
            else:
                metric_dict = {'elbo': total_elbo, 'adj_mat': output_dict.adj_mat}
                
        except Exception as e:
            print(f"因果发现失败: {e}")
            import traceback
            traceback.print_exc()
            metric_dict = {'elbo': total_elbo, 'adj_mat': output_dict.adj_mat}
            
        model.train()
        # No GT: skip MCC plots against true latents/eps
        from munch import Munch as _M
        metric_dict = _M.fromDict(metric_dict)
        return metric_dict

# ==========================
# 内联 data.py（只保留 get_data/generate_toy）
# ==========================
from scipy import linalg as la
from scipy.stats import ortho_group
from munch import Munch as _M

def sample_weight(start, end, shape):
    pos = np.random.uniform(start, end, shape)
    neg = np.random.uniform(-end, -start, shape)
    mask = np.random.randint(0, 2, shape)
    return pos*mask+neg*(1-mask)

def get_data(config, device):
    data_config = config.data
    if 'toy' in config.name:
        case = int(config.name.split('toy')[-1])
        return generate_toy(data_config.n_samples, data_config.n_domains, data_config.n_layers, noise=data_config.noise, case=case, device=device, val_ratio=data_config.val_ratio)
    else:
        return load_rebuttal_data(device=device, val_ratio=data_config.val_ratio)

def generate_toy(n_samples, n_domains, n_layers, case, device, noise, val_ratio=0.3):
    data = []
    ground_truth_latents = []
    ground_truth_epsilons = []
    domain_labels = []
    total_domains = max(200, n_domains)
    if noise == 'gaussian':
        eps1 = (np.random.normal(0,1, [total_domains, n_samples, 1]))
        eps2 = (np.random.normal(0,1, [total_domains, n_samples, 1]))
        eps3 = (np.random.normal(0,1, [total_domains, n_samples, 1]))
        eps4 = (np.random.normal(0,1, [total_domains, n_samples, 1]))
    elif noise == 'laplace':
        dist = torch.distributions.Laplace(0, 1)
        eps1 = dist.sample([total_domains, n_samples, 1]).numpy()
        eps2 = dist.sample([total_domains, n_samples, 1]).numpy()
        eps3 = dist.sample([total_domains, n_samples, 1]).numpy()
        eps4 = dist.sample([total_domains, n_samples, 1]).numpy()
    inv_mats = [ortho_group.rvs(4) for _ in range(n_layers)]
    scale = sample_weight(0.5, 2, [n_domains, 4, 4])
    bias = sample_weight(0.5, 2, [n_domains, 4])
    for i in range(n_domains):
        if case == 1:
            z1 = eps1[i] + bias[i, 0]
            z2 = eps2[i] + bias[i, 1]
            z3 = eps3[i] + bias[i, 2] + scale[i, 2, 0] * z1 + scale[i, 2, 1] * z2
            z4 = eps4[i] + bias[i, 3] + scale[i, 3, 2] * z3
            ground_truth = np.array([[0, 0, 1, 0],[0, 0, 1, 0],[0, 0, 0, 1],[0, 0, 0, 0]])
        elif case == 5:
            z1 = eps1[i] + bias[i, 0]
            z2 = eps2[i] + bias[i, 1]
            z3 = eps3[i] + bias[i, 2]
            z4 = eps4[i] + bias[i, 3] + scale[i, 3, 2] * z3
            ground_truth = np.array([[0, 0, 0, 0],[0, 0, 0, 0],[0, 0, 0, 1],[0, 0, 0, 0]])
        else:
            raise NotImplementedError('Only case 1 and 5 are wired in the single file for brevity.')
        z = np.concatenate([z1,z2,z3,z4], 1)
        x = z
        for inl in range(n_layers):
            x = np.where(x > 0, x, x * 0.2)
            x = x @ inv_mats[inl]
        x = torch.from_numpy(x).float()
        data.append(x.to(device).detach())
        domain_labels.append(torch.ones([len(z)]).to(device) * i)
        ground_truth_latents.append(torch.from_numpy(z).float())
        ground_truth_epsilons.append(torch.from_numpy(np.concatenate([eps1[i], eps2[i], eps3[i], eps4[i]], 1)))
    all_data = torch.cat(data[:n_domains], 0)
    all_domain_labels = torch.cat(domain_labels[:n_domains], 0)
    all_latents = torch.cat(ground_truth_latents[:n_domains], 0)
    all_epsilons = torch.cat(ground_truth_epsilons[:n_domains], 0)
    n_val_samples = int(len(all_data) * val_ratio)
    n_train_samples = len(all_data) - n_val_samples
    perm_indices = np.random.permutation(len(all_data))
    train_indices = perm_indices[:n_train_samples]
    val_indices = perm_indices[n_train_samples:]
    dataset_dict = _M(train_data=all_data[train_indices], train_domain_labels=all_domain_labels[train_indices], train_latents=all_latents[train_indices], train_epsilons=all_epsilons[train_indices], val_data=all_data[val_indices], val_domain_labels=all_domain_labels[val_indices], val_latents=all_latents[val_indices], val_epsilons=all_epsilons[val_indices], true_adj=ground_truth)
    return dataset_dict

def load_rebuttal_data(device, val_ratio=0.1):
    unlabeled_primary_areas = np.load(os.path.join('dataset', 'unlabeled_primary_areas.npy'))
    unlabeled_data_reviewer = np.load(os.path.join('dataset', 'unlabeled_data_reviewer.npy'))
    labeled_data_reviewer = np.load(os.path.join('dataset', 'labeled_data_reviewer.npy'))
    labeled_primary_areas = np.load(os.path.join('dataset', 'labeled_primary_areas.npy'))

    unlabeled_data_author = np.load(os.path.join('dataset', 'unlabeled_data_author.npy'))
    labeled_data_author = np.load(os.path.join('dataset', 'labeled_data_author.npy'))

    sublabel_supervision = np.load(os.path.join('dataset', 'labeled_data_labels_10features.npy'))

    # Clean NaN/Inf before torch conversion
    unlabeled_data_reviewer = np.nan_to_num(unlabeled_data_reviewer, nan=0.0, posinf=1e6, neginf=-1e6)
    labeled_data_reviewer = np.nan_to_num(labeled_data_reviewer, nan=0.0, posinf=1e6, neginf=-1e6)
    unlabeled_data_author = np.nan_to_num(unlabeled_data_author, nan=0.0, posinf=1e6, neginf=-1e6)
    labeled_data_author = np.nan_to_num(labeled_data_author, nan=0.0, posinf=1e6, neginf=-1e6)
    sublabel_supervision = np.nan_to_num(sublabel_supervision, nan=0.0, posinf=1e6, neginf=-1e6)

    all_primary = np.concatenate([unlabeled_primary_areas, labeled_primary_areas], 0)
    unique_vals, mapped = np.unique(all_primary, return_inverse=True)
    n_domains = len(unique_vals)
    mapped_unlabeled = mapped[: len(unlabeled_primary_areas)]
    mapped_labeled = mapped[len(unlabeled_primary_areas) :]

    unlabeled_x_reviewer = torch.from_numpy(unlabeled_data_reviewer).float().to(device)
    unlabeled_x_author = torch.from_numpy(unlabeled_data_author).float().to(device)
    unlabeled_y = torch.from_numpy(mapped_unlabeled).long().to(device)
    labeled_x_reviewer = torch.from_numpy(labeled_data_reviewer).float().to(device)
    labeled_x_author = torch.from_numpy(labeled_data_author).float().to(device)
    labeled_y = torch.from_numpy(mapped_labeled).long().to(device)
    labeled_sublabels = torch.from_numpy(sublabel_supervision).float().to(device)

    # 使用全部数据进行训练，不进行分割
    # Standardize using all data statistics (per feature) for both modalities
    all_x_reviewer = torch.cat([unlabeled_x_reviewer, labeled_x_reviewer], 0)
    all_x_author = torch.cat([unlabeled_x_author, labeled_x_author], 0)
    
    # Reviewer modality standardization
    reviewer_feat_mean = all_x_reviewer.mean(0, keepdim=True)
    reviewer_feat_std = all_x_reviewer.std(0, keepdim=True).clamp_min(1e-6)
    
    # Author modality standardization
    author_feat_mean = all_x_author.mean(0, keepdim=True)
    author_feat_std = all_x_author.std(0, keepdim=True).clamp_min(1e-6)
    
    def norm_reviewer(x):
        x = (x - reviewer_feat_mean) / reviewer_feat_std
        x = torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6)
        return x
    
    def norm_author(x):
        x = (x - author_feat_mean) / author_feat_std
        x = torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6)
        return x

    dataset_dict = _M(
        train_unlabeled_data_reviewer=norm_reviewer(unlabeled_x_reviewer),
        train_unlabeled_data_author=norm_author(unlabeled_x_author),
        train_unlabeled_domain_labels=unlabeled_y,
        train_labeled_data_reviewer=norm_reviewer(labeled_x_reviewer),
        train_labeled_data_author=norm_author(labeled_x_author),
        train_labeled_domain_labels=labeled_y,
        train_labeled_sublabels=labeled_sublabels,
        # 验证集使用全部数据（用于因果发现）
        val_unlabeled_data_reviewer=norm_reviewer(unlabeled_x_reviewer),
        val_unlabeled_data_author=norm_author(unlabeled_x_author),
        val_unlabeled_domain_labels=unlabeled_y,
        val_labeled_data_reviewer=norm_reviewer(labeled_x_reviewer),
        val_labeled_data_author=norm_author(labeled_x_author),
        val_labeled_domain_labels=labeled_y,
        val_labeled_sublabels=labeled_sublabels,
        n_domains=n_domains,
    )
    return dataset_dict

# ==========================
# 合并入口（等价 train.py）
# ==========================
def main_single():
    from munch import Munch
    import tensorboardX
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default="supervision_test/supervision_5.0_config.yaml") #config_rebuttal.yaml
    args = parser.parse_args()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print("="*80)
    print("开始训练 - 因果发现实验")
    print("="*80)
    config = Munch.fromDict(get_config(args.config))
    set_seed(config.seed)
    dataset_dict = get_data(config, args.device)
    
    # 创建tag和运行目录
    mcfg = config.model
    if hasattr(mcfg, 'scale_norm'):
        tag = '%s-BS%d-NORM_%s_%s_%s-DIM_%d_%d_%d-Layers_%d_%d_%d' %(config.assumption, config.batch_size, mcfg.norm, getattr(mcfg, 'scale_norm', 'na'), getattr(mcfg, 'bias_norm', 'na'), mcfg.hidden_dim, getattr(mcfg, 'scale_hidden_dim', mcfg.hidden_dim), getattr(mcfg, 'bias_hidden_dim', mcfg.hidden_dim), mcfg.n_layers, getattr(mcfg, 'scale_n_layers', mcfg.n_layers), getattr(mcfg, 'bias_n_layers', mcfg.n_layers))
    else:
        tag = '%s-BS%d-NORM_%s-DIM_%d-Layers_%d' %(config.assumption, config.batch_size, mcfg.norm, mcfg.hidden_dim, mcfg.n_layers)
    
    # 添加时间步到运行目录
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(config.result_dir, f"{config.name}-{tag}-{timestamp}")
    os.makedirs(run_dir, exist_ok=True)
    
    print(f"实验配置: {config.name}")
    print(f"模型标签: {tag}")
    print(f"运行目录: {run_dir}")
    print(f"结果将保存到以下路径:")
    print(f"  基线因果发现: {run_dir}/baseline_causal_discovery/")
    print(f"  因果发现结果: {run_dir}/causal_discovery/")
    print(f"  模型检查点: {run_dir}/checkpoints/")
    print(f"  日志文件: {run_dir}/logs/")
    print("="*80)
    # If real dataset, we will build two loaders with 1:9 labeled:unlabeled per batch
    is_real = hasattr(dataset_dict, 'train_unlabeled_data_reviewer')
    if not is_real:
        train_dataset = torch.utils.data.TensorDataset(dataset_dict.train_data, dataset_dict.train_domain_labels)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)
        val_dataset = torch.utils.data.TensorDataset(dataset_dict.val_data, dataset_dict.val_domain_labels)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, drop_last=False)
        n_domains = int(torch.max(dataset_dict.train_domain_labels).item()) + 1
    else:
        # Update n_domains in config for model construction
        n_domains = dataset_dict.n_domains
    
    # ensure config carries n_domains
    config.data.n_domains = n_domains
    
    # Choose trainer based on whether we have multi-modal data
    if is_real:
        trainer = MultiModalToyTrainer(config=config, device=args.device)
    else:
        trainer = ToyTrainer(config=config, device=args.device)
    
    trainer.to(args.device)
    
    # 运行基线因果发现（只运行一次）
    baseline_edges = None
    if is_real:
        print("运行基线因果发现（11个变量）...")
        baseline_save_dir = os.path.join(run_dir, 'baseline_causal_discovery')
        baseline_edges = run_baseline_causal_discovery_once(dataset_dict, baseline_save_dir, config)
        if baseline_edges is not None:
            print(f"基线因果发现完成: {len(baseline_edges)} 条边")
        else:
            print("基线因果发现失败")
    
    # 跟踪最佳相似度的因果图
    best_similarity = -1.0
    best_iteration = -1
    best_metric_dict = None
    
    checkpoint_dir, image_dir, log_dir = prepare_sub_folder(run_dir)
    shutil.copy(args.config, os.path.join(run_dir, 'config.yaml'))
    train_writer = tensorboardX.SummaryWriter(log_dir)
    iterations = 0
    def cycle(loader):
        while True:
            for batch in loader:
                yield batch

    if is_real:
        # Build separate loaders for multi-modal data with supervision
        labeled_train_ds = torch.utils.data.TensorDataset(
            dataset_dict.train_labeled_data_reviewer, 
            dataset_dict.train_labeled_data_author, 
            dataset_dict.train_labeled_domain_labels,
            dataset_dict.train_labeled_sublabels
        )
        unlabeled_train_ds = torch.utils.data.TensorDataset(
            dataset_dict.train_unlabeled_data_reviewer, 
            dataset_dict.train_unlabeled_data_author, 
            dataset_dict.train_unlabeled_domain_labels
        )
        labeled_loader = torch.utils.data.DataLoader(labeled_train_ds, batch_size=max(1, config.batch_size // 10), shuffle=True, drop_last=True)
        unlabeled_loader = torch.utils.data.DataLoader(unlabeled_train_ds, batch_size=config.batch_size - max(1, config.batch_size // 10), shuffle=True, drop_last=True)
        labeled_iter = cycle(labeled_loader)
        unlabeled_iter = cycle(unlabeled_loader)

        # Validation on merged val sets
        val_x_reviewer = torch.cat([dataset_dict.val_labeled_data_reviewer, dataset_dict.val_unlabeled_data_reviewer], 0)
        val_x_author = torch.cat([dataset_dict.val_labeled_data_author, dataset_dict.val_unlabeled_data_author], 0)
        val_y = torch.cat([dataset_dict.val_labeled_domain_labels, dataset_dict.val_unlabeled_domain_labels], 0)
        val_sublabels = torch.cat([dataset_dict.val_labeled_sublabels, torch.zeros(len(dataset_dict.val_unlabeled_data_reviewer), 10).to(val_x_reviewer.device)], 0)
        val_dataset = torch.utils.data.TensorDataset(val_x_reviewer, val_x_author, val_y, val_sublabels)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, drop_last=False)

    while True:
        if is_real:
            x_l_reviewer, x_l_author, y_l, sublabels_l = next(labeled_iter)
            x_u_reviewer, x_u_author, y_u = next(unlabeled_iter)
            x_reviewer = torch.cat([x_l_reviewer, x_u_reviewer], 0)
            x_author = torch.cat([x_l_author, x_u_author], 0)
            y = torch.cat([y_l, y_u], 0)
            # Create sublabels for unlabeled data (zeros)
            sublabels_u = torch.zeros(len(x_u_reviewer), 10).to(x_reviewer.device)
            sublabels = torch.cat([sublabels_l, sublabels_u], 0)
            # shuffle within batch to avoid ordering bias
            perm = torch.randperm(len(x_reviewer))
            x_reviewer = x_reviewer[perm]
            x_author = x_author[perm]
            y = y[perm]
            sublabels = sublabels[perm]
            batches = [(x_reviewer, x_author, y, sublabels)]
        else:
            batches = train_dataloader
        for it, batch in enumerate(batches):
            if is_real:
                x_reviewer, x_author, y, sublabels = batch
                output_dict = trainer.train_step(x_reviewer, x_author, y, iterations, sublabels)
            else:
                x, y = batch
                output_dict = trainer.train_step(x, y, iterations)
            if iterations % config.log_iter == 0:
                trainer.scheduler.step()
                print('Train: %08d/%08d elbo: %.4f, sparsity: %.4f, kld: %.4f, rec: %.4f, dip: %.4f, supervision: %.4f' % (iterations, config.max_iter, output_dict.elbo.item(), output_dict.loss_sparsity.item(), output_dict.loss_kld.item(), output_dict.loss_rec.item(), output_dict.loss_dip.item(), output_dict.loss_supervision.item()))
            if iterations % config.eval_iter == 0:
                filename = os.path.join(image_dir, '%08d.jpg' % iterations)
                metric_dict = trainer.evaluate(trainer.model, val_dataloader, iterations, filename=filename, baseline_edges=baseline_edges, run_dir=run_dir, dataset_dict=dataset_dict, config=config)
                write_metric(iterations, metric_dict, train_writer)
                
                # 显示因果发现结果
                if hasattr(metric_dict, 'shd'):
                    # 更新最佳相似度
                    if metric_dict.similarity > best_similarity:
                        best_similarity = metric_dict.similarity
                        best_iteration = iterations
                        best_metric_dict = metric_dict
                        print('Eval: %08d/%08d | SHD: %d, 相似度: %.2f%% ⭐ (新最佳), 边数: %d/%d, 新变量边: %d' % (
                            iterations, config.max_iter, 
                            metric_dict.shd, metric_dict.similarity*100, 
                            metric_dict.causal_edges, metric_dict.baseline_edges, 
                            metric_dict.causal_new_var_edges))
                    else:
                        print('Eval: %08d/%08d | SHD: %d, 相似度: %.2f%%, 边数: %d/%d, 新变量边: %d' % (
                            iterations, config.max_iter, 
                            metric_dict.shd, metric_dict.similarity*100, 
                            metric_dict.causal_edges, metric_dict.baseline_edges, 
                            metric_dict.causal_new_var_edges))
                    print('    准确率: P=%.2f%%, R=%.2f%%, F1=%.2f%% | 共同边: %d, 基线独有: %d, CRL独有: %d' % (
                        metric_dict.precision*100, metric_dict.recall*100, metric_dict.f1_score*100,
                        metric_dict.common_edges, metric_dict.baseline_only, metric_dict.crl_only))
                    
                    # 显示LLM+CRL因果发现结果
                    if hasattr(metric_dict, 'llm_crl_edges'):
                        print('    LLM+CRL因果发现: 边数=%d (有向=%d, 无向=%d), CRL新变量边=%d' % (
                            metric_dict.llm_crl_edges, metric_dict.llm_crl_directed, 
                            metric_dict.llm_crl_undirected, metric_dict.llm_crl_new_var_edges))
                    
                    print(f'    因果图已保存到: causal_discovery/iter_{iterations}/')
                else:
                    print('Eval: %08d/%08d' % (iterations, config.max_iter))
            iterations += 1
            if iterations >= config.max_iter:
            #     # Save learned latent variables
            #     if is_real:
            #         save_learned_latents(trainer.model, dataset_dict, args.device, run_dir)
                print("="*80)
                print("训练完成！")
                print("="*80)
                
                # 输出最佳相似度信息
                if best_metric_dict is not None:
                    print("🏆 最佳因果图相似度结果:")
                    print(f"  最佳迭代步数: {best_iteration}")
                    print(f"  最佳相似度: {best_similarity*100:.2f}%")
                    print(f"  SHD: {best_metric_dict.shd}")
                    print(f"  边数: {best_metric_dict.causal_edges}/{best_metric_dict.baseline_edges}")
                    print(f"  新变量边数: {best_metric_dict.causal_new_var_edges}")
                    print(f"  准确率: P={best_metric_dict.precision*100:.2f}%, R={best_metric_dict.recall*100:.2f}%, F1={best_metric_dict.f1_score*100:.2f}%")
                    print(f"  共同边: {best_metric_dict.common_edges}, 基线独有: {best_metric_dict.baseline_only}, CRL独有: {best_metric_dict.crl_only}")
                    print(f"  最佳因果图保存路径: {run_dir}/causal_discovery/iter_{best_iteration}/")
                    print("="*80)
                else:
                    print("⚠️  未找到有效的因果发现结果")
                    print("="*80)
                
                print(f"结果保存路径:")
                print(f"  主目录: {run_dir}")
                print(f"  基线因果发现: {run_dir}/baseline_causal_discovery/")
                print(f"  因果发现结果: {run_dir}/causal_discovery/")
                print(f"  模型检查点: {run_dir}/checkpoints/")
                print(f"  日志文件: {run_dir}/logs/")
                print("="*80)
                sys.exit('Finished Training')

if __name__ == '__main__':
    main_single()


