from tqdm import tqdm
import math
from dataset import get_dataset
import numpy as np
from models import *
from config import mprint
from utils import *
from learn import *
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import TensorDataset, DataLoader
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_auc_score
import scipy.stats as stats
import pdb
import sys
from torch.autograd import Function

# Gradient reversal layer for adversarial learning
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambd):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambd, None

def grad_reverse(x, lambd=0.2):
    return GradReverse.apply(x, lambd)

def condition_info_nce_for_embeddings(x, z, s, p, alpha=0.1, tau=0.5):
    N = x.shape[0]

    x_norm = F.normalize(x)
    y_norm = F.normalize(z + alpha * s)

    pos_score = torch.sum(x_norm * y_norm, dim=1)
    pos_score = torch.exp(pos_score / tau)

    neg_score = torch.zeros(N, dtype=torch.float32).to(x.device)

    # for cat in set(p.tolist()):
    #     x_given = x_norm[p == cat]
    #     y_given = y_norm[p == cat]
    #     t = x_given @ y_given.T
    #     t = torch.sum(torch.exp(t / tau), dim=1)
    #     neg_score[p == cat] = t

    for cat in set(p.tolist()):
        idx = (p == cat).nonzero(as_tuple=True)[0]
        x_given = x_norm[idx]
        y_given = y_norm[idx]

        batch_size = 1024
        n_given = x_given.size(0)
        negs = []

        for start in range(0, n_given, batch_size):
            end = min(start + batch_size, n_given)
            xb = x_given[start:end]    # [b, d]
            t = xb @ y_given.T         # [b, n_given]
            t = torch.sum(torch.exp(t / tau), dim=1)  # [b]
            negs.append(t.detach().cpu())  # 必须detach防爆显存

        t_cat = torch.cat(negs, dim=0).to(x.device)
        neg_score[idx] = t_cat

    cl_loss = -torch.log(pos_score / neg_score).mean()
    return cl_loss

@torch.no_grad()
def build_mi_cross_graph(
    args, source_data, target_data, 
    inter_enc, yhat_t, conf_t,
    sens_enc, sens_hat_t, conf_sens_t, 
    chunk_size=1024

):
    """
    构建“跨域 MI 大图”：
      - 节点顺序： [target 全部] + [被选中的 source 子集]
      - 跨域边：目标每点从源域取 top-k 相似（双向加边）
      - 任务伪标签：target 用 (yhat_t, conf_t)；source 用真值
      - 敏感伪标签：target 用 sens_enc 预测（含置信度）；source 用真值
      - 训练掩码：target 仅当 (任务置信度>=阈值) 且 (敏感置信度>=阈值) 时为 True；source 延用其 train_mask

    返回：
      new_x, edge_index, new_sens_labels, new_y, new_train_mask, id_map
    """
    device = args.device
    top_k = args.top_k_mi
    thr_y = args.y_pseudo_threshold
    thr_s = args.s_pseudo_threshold
    inter_enc.eval()
    sens_enc.eval()

    N_source, N_target = source_data.x.size(0), target_data.x.size(0)

    # 全局 id（仅用于构跨域边；新图 id 用 id_map）
    source_ids = torch.arange(N_source, device=device) + N_target
    target_ids = torch.arange(N_target, device=device)

    # === interest encoder：取表征并做相似度 ===
    src_emb, _ = inter_enc(source_data.x.to(device), source_data.edge_index.to(device))     # [Ns, d]
    tgt_emb, _ = inter_enc(target_data.x.to(device), target_data.edge_index.to(device))     # [Nt, d]
    src_emb = F.normalize(src_emb, dim=1)
    tgt_emb = F.normalize(tgt_emb, dim=1)

    yhat_t = yhat_t.to(device).view(-1).long()
    conf_t = conf_t.to(device).view(-1).float()
    sens_hat_t = sens_hat_t.to(device).view(-1).float()
    conf_sens_t = conf_sens_t.to(device).view(-1).float()

    # === 分块计算相似度并取 top-k ===
    cross_edges_list = []
    selected_source_set = set()

    for t_start in tqdm(range(0, N_target, chunk_size), desc='MI cross-domain block'):
        t_end = min(N_target, t_start + chunk_size)
        sim_block = torch.matmul(tgt_emb[t_start:t_end], src_emb.t())    # [chunk, Ns]
        # 逐块 top-k（对每个 target）
        topk_val, topk_idx = torch.topk(sim_block, k=min(top_k, sim_block.size(1)), dim=1)  # [chunk,k]
        tgt_ids_blk = target_ids[t_start:t_end].unsqueeze(1).expand_as(topk_idx)             # [chunk,k]
        src_ids_blk = source_ids[topk_idx]                                                   # [chunk,k]

        # 记录边与被选中的 source
        cross_edges_list.extend(
            [(int(s), int(t)) for s, t in zip(src_ids_blk.reshape(-1).tolist(), tgt_ids_blk.reshape(-1).tolist())]
        )
        cross_edges_list.extend(
            [(int(t), int(s)) for s, t in zip(src_ids_blk.reshape(-1).tolist(), tgt_ids_blk.reshape(-1).tolist())]
        )
        selected_source_set.update(topk_idx.reshape(-1).detach().cpu().tolist())

    print(f"build {len(cross_edges_list)} cross edges")
    selected_source_indices = sorted(list(selected_source_set))
    selected_source_ids = [int(source_ids[i].item()) for i in selected_source_indices]

    # === 新图 id 映射（全局 id -> 新图 id）===
    all_node_ids = list(range(N_target)) + selected_source_ids  # target(0..Nt-1) + 选中source全局id
    id_map = {nid: i for i, nid in enumerate(all_node_ids)}

    new_x = torch.cat(
        [target_data.x.to(device), source_data.x[selected_source_indices].to(device)],
        dim=0
    )
    new_sens_labels = torch.cat(
        [sens_hat_t.to(device), source_data.sens_labels[selected_source_indices].to(device)],
        dim=0
    )

    # === 构建边（target原边 + 跨域边 + source子图边）===
    tgt_edge_index = target_data.edge_index.clone().to(device)

    if len(cross_edges_list) > 0:
        cross_edge_index = torch.tensor(
            [[id_map[src], id_map[tgt]] for src, tgt in cross_edges_list],
            dtype=torch.long, device=device
        ).t()
    else:
        cross_edge_index = torch.empty((2, 0), dtype=torch.long, device=device)

    # 仅保留两端都在“选中 source 子集”的源域内部边
    edge_s = source_data.edge_index.to(device)
    if edge_s.numel() > 0 and len(selected_source_indices) > 0:
        sel_mask = torch.zeros(N_source, dtype=torch.bool, device=device)
        sel_mask[torch.tensor(selected_source_indices, device=device)] = True
        es0, es1 = edge_s[0], edge_s[1]
        keep = sel_mask[es0] & sel_mask[es1]
        if keep.any():
            # 构 N_source -> 子集位置 的映射
            pos_src = torch.full((N_source,), -1, dtype=torch.long, device=device)
            pos_src[torch.tensor(selected_source_indices, device=device)] = torch.arange(len(selected_source_indices), device=device)
            mapped_src = N_target + pos_src[es0[keep]]
            mapped_dst = N_target + pos_src[es1[keep]]
            source_sub_edge_index = torch.stack([mapped_src, mapped_dst], dim=0)
        else:
            source_sub_edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
    else:
        source_sub_edge_index = torch.empty((2, 0), dtype=torch.long, device=device)

    edge_index = torch.cat([tgt_edge_index, cross_edge_index, source_sub_edge_index], dim=1).to(device)

    # === 任务伪标签与训练掩码（目标域需同时满足两个置信度阈值）===
    pseudo_y = yhat_t.clone().detach()
    mask_y   = (conf_t >= thr_y)
    mask_s   = (conf_sens_t >= thr_s)
    mask_tgt = (mask_y & mask_s)                # 目标域样本：两个都达阈值才训练
    pseudo_y[~mask_tgt] = -1                    # -1 表示不作为监督样本

    new_y = torch.cat([pseudo_y.to(device), source_data.y[selected_source_indices].to(device)], dim=0)

    src_train_mask = source_data.train_mask.to(device) if hasattr(source_data, "train_mask") \
                     else torch.ones(N_source, dtype=torch.bool, device=device)
    new_train_mask = torch.cat([mask_tgt.to(device), src_train_mask[selected_source_indices]], dim=0)

    print(f"target_pseudo_num (both-conf): {int(mask_tgt.sum().item())} | "
          f"only_y_conf: {int(mask_y.sum().item())} | only_s_conf: {int(mask_s.sum().item())}")

    return new_x, edge_index, new_sens_labels, new_y, new_train_mask, id_map

@torch.no_grad()
def pseudo_from_logits(logits, method="prob", thr=0.5):
    """
    logits: [N] or [N,1]
    return:
      y_hat: LongTensor [N] in {0,1}
      conf : FloatTensor [N] in [0,1], 置信度（用于过滤/加权）
    """
    probs = torch.sigmoid(logits.view(-1))
    if method == "prob":
        y_hat = (probs >= thr).long()
        conf  = torch.maximum(probs, 1.0 - probs)  # 离0.5越远越高
    else:
        raise NotImplementedError
    return y_hat, conf

@torch.no_grad()
def propagate_labels_from_source(
    encoder, source_data, target_data,
    K=5, tau=0.07, chunk_size=2048, use_half=False
):
    """
    返回：
      P_t   : [N_t, 2] 目标域的类别分布（单步传播 W_ts @ Y_s）
      y_hat : [N_t]    伪标签 argmax
      conf  : [N_t]    置信度 max prob
    说明：
      - Y_s 用源域真标签 one-hot
      - 相似度稀疏：对每个 target 只保留 Top-K 的源邻居，并用 softmax(sim/tau) 归一化
    """
    encoder.eval()
    z_s, _ = encoder(source_data.x, source_data.edge_index)
    z_t, _ = encoder(target_data.x, target_data.edge_index)
    z_s = F.normalize(z_s, dim=1)
    z_t = F.normalize(z_t, dim=1)

    if use_half:
        z_s_sim, z_t_sim = z_s.half(), z_t.half()
    else:
        z_s_sim, z_t_sim = z_s, z_t

    Ns, Nt = z_s.size(0), z_t.size(0)
    Y_s = torch.nn.functional.one_hot(source_data.y.long(), num_classes=2).float().to(z_s.device)

    # 稀疏 Top-K 二部图的行归一化权重：W_ts
    rows, cols, vals = [], [], []
    for t0 in range(0, Nt, chunk_size):
        t1 = min(Nt, t0 + chunk_size)
        # 相似度分块构建
        S = z_t_sim[t0:t1] @ z_s_sim.T                     # [B, Ns]
        # 选择出topk的标签进行标签传播
        topv, topi = torch.topk(S, k=min(K, Ns), dim=1)    # [B, K]
        # 每个标签传播的权重
        w = torch.softmax(topv / tau, dim=1)               # [B, K]
        r = torch.arange(t0, t1, device=topi.device).unsqueeze(1).expand_as(topi)

        rows.append(r.reshape(-1).cpu())
        cols.append(topi.reshape(-1).cpu())
        vals.append(w.reshape(-1).cpu())
        del S, topv, topi, w, r

    rows = torch.cat(rows)
    cols = torch.cat(cols)
    vals = torch.cat(vals)
    W = torch.sparse_coo_tensor(
        torch.stack([rows, cols]), vals, size=(Nt, Ns)
    ).coalesce().to(z_s.device)

    P_t = torch.sparse.mm(W, Y_s)          # [Nt, 2]
    conf, y_hat = P_t.max(dim=1)           # 置信度/伪标签
    return P_t, y_hat, conf

def initialize_models(args):
    # sensitive encoder
    sens_enc = Encoder(args, encoder_type = args.sens_encoder).to(args.device)
    optimizer_sens = torch.optim.Adam(sens_enc.parameters(), lr=args.lr, weight_decay=args.lr2_reg)

    # interest encoder
    inter_enc = Encoder(args, encoder_type = args.inter_encoder).to(args.device)
    optimizer_inter = torch.optim.Adam(inter_enc.parameters(), lr=args.lr, weight_decay=args.lr2_reg)

    club = CLUBSample(args)
    optimizer_club = torch.optim.Adam(club.parameters(), lr=args.lr, weight_decay=args.lr2_reg)

    scheduler_sens = torch.optim.lr_scheduler.StepLR(optimizer_sens, step_size=50, gamma=0.9)
    scheduler_inter = torch.optim.lr_scheduler.StepLR(optimizer_inter, step_size=50, gamma=0.9)
    scheduler_club = torch.optim.lr_scheduler.StepLR(optimizer_club, step_size=50, gamma=0.9)

    return sens_enc, optimizer_sens, scheduler_sens, inter_enc, optimizer_inter, scheduler_inter, club, optimizer_club, scheduler_club

def train_sensitive_encoder(args, source_data, sens_enc, optimizer_sens, scheduler_sens, sens_labels, criterion):
    for _ in tqdm(range(args.train_epochs), desc='Train Sensitive Encoder'):
        embeddings, sens_logits = sens_enc(source_data.x, source_data.edge_index)
        sens_loss = criterion(sens_logits.view(-1), sens_labels)
        optimizer_sens.zero_grad()
        sens_loss.backward()
        optimizer_sens.step()
        scheduler_sens.step()
    sens_train_path = args.sens_train_path.format(args.sens_encoder, args.dataset, args.inid)
    torch.save({'encoder': sens_enc.state_dict(),}, sens_train_path)
    return sens_enc

def load_pretrained_encoder(pre_train_path, encoder_type):
    pre_train_path = args.pre_train_path.format(args.pre_train_encoder, args.dataset, args.inid)
    ex_enc = Encoder(args, args.pre_train_encoder).to(args.device)
    checkpoint = torch.load(pre_train_path)
    ex_enc.load_state_dict(checkpoint['encoder'])
    ex_enc = ex_enc.to(args.device)
    return ex_enc

def load_encoder(args, pre_train_path, encoder_type, encoder = None):
    if encoder is None:
        encoder = Encoder(args, encoder_type).to(args.device)
    checkpoint = torch.load(pre_train_path)
    encoder.load_state_dict(checkpoint['encoder'])
    encoder = encoder.to(args.device)
    return encoder

def train_interest_encoder(
    args, 
    source_data, 
    target_data, 
    inter_enc, 
    optimizer_inter, 
    scheduler_inter, 
    sens_enc, 
    club, 
    optimizer_club, 
    scheduler_club, 
    criterion
):
    e_s, _ = sens_enc.forward(source_data.x, source_data.edge_index)
    e_s = e_s.detach().to(args.device)

    p_s = conditional_samples(e_s.detach().cpu().numpy())
    p_s = torch.tensor(p_s).to(args.device)

    pre_train_path = args.pre_train_path.format(args.pre_train_encoder, args.dataset, args.inid)
    ex_enc = load_encoder(args, pre_train_path, args.pre_train_encoder)
    e_x, _ = ex_enc.forward(source_data.x, source_data.edge_index)
    e_x = e_x.detach().to(args.device)
    best_perf = 0.0

    s0_mask = source_data.train_mask & (source_data.sens_labels == 0)
    s1_mask = source_data.train_mask & (source_data.sens_labels == 1)
    y0_mask = source_data.train_mask & (source_data.y == 0)
    y1_mask = source_data.train_mask & (source_data.y == 1)

    y0s0_mask = (source_data.y==0) & (source_data.sens_labels == 0)
    y0s1_mask = (source_data.y==0) & (source_data.sens_labels == 1)
    y1s0_mask = (source_data.y==1) & (source_data.sens_labels == 0)
    y1s1_mask = (source_data.y==1) & (source_data.sens_labels == 1)
    
    for epoch in tqdm(range(args.train_epochs), desc='Train Interest Encoder'):
        train_res = {'cls': 0.0, 'lb': 0.0, 'ub': 0.0, 'mi': 0.0}
        e_z, logits = inter_enc.forward(source_data.x, source_data.edge_index)
        # cls_loss = criterion(logits[source_data.train_mask].view(-1), source_data.y[source_data.train_mask].float())
        # cls_loss = criterion(logits[s0_mask].view(-1), source_data.y[s0_mask].float()) + criterion(logits[s1_mask].view(-1), source_data.y[s1_mask].float())
        # cls_loss = criterion(logits[y0_mask].view(-1), source_data.y[y0_mask].float()) + criterion(logits[y1_mask].view(-1), source_data.y[y1_mask].float())
        cls_loss = criterion(logits[source_data.train_mask & y0s0_mask].view(-1), source_data.y[source_data.train_mask & y0s0_mask].float()) + \
            criterion(logits[source_data.train_mask & y0s1_mask].view(-1), source_data.y[source_data.train_mask & y0s1_mask].float()) + \
            criterion(logits[source_data.train_mask & y1s0_mask].view(-1), source_data.y[source_data.train_mask & y1s0_mask].float()) + \
            criterion(logits[source_data.train_mask & y1s1_mask].view(-1), source_data.y[source_data.train_mask & y1s1_mask].float())

        # loss = cls_loss
        # optimizer_inter.zero_grad()
        # loss.backward()
        # optimizer_inter.step()

        # if epoch % 50 == 0:
        #     accs, auc_rocs, parity, equality = evaluate_per_class(args, source_data, inter_enc)
        #     print(f"source_ epoch {epoch}: acc: {accs}; auc: {auc_rocs}; parity: {parity}; equality: {equality}")
        #     accs, auc_rocs, parity, equality = evaluate_per_class(args, target_data, inter_enc)
        #     print(f"target: epoch {epoch}: acc: {accs}; auc: {auc_rocs}; parity: {parity}; equality: {equality}")
        # equation 12
        # for information between fairness aware embedding and pre train embedding
        lb = condition_info_nce_for_embeddings(e_x, e_z, e_s, p_s)
        # equation 9
        # for information between fairness aware embedding and sensitive embedding
        up = club.forward(e_z, e_s)
        
        loss = cls_loss + args.lreg * lb + args.ureg * up
        optimizer_inter.zero_grad()
        loss.backward()
        optimizer_inter.step()
        scheduler_inter.step()

        train_res['cls'] += cls_loss.item()
        train_res['lb'] += lb.item()
        train_res['ub'] += up.item()

        e_z, _ = inter_enc.forward(source_data.x, source_data.edge_index)
        x_samples = e_z.detach()
        y_samples = e_s.detach()
        for _ in range(args.train_step):
            mi_loss = club.learning_loss(x_samples, y_samples)
            optimizer_club.zero_grad()
            mi_loss.backward()
            optimizer_club.step()
            scheduler_club.step()
            train_res['mi'] += mi_loss.item()
        train_res['mi'] /= args.train_step
        training_logs = f'epoch: {epoch}, ' + ' '.join([f'{k}:{v:.3f}' for k, v in train_res.items()])
        # print(training_logs)

    inter_train_path = args.inter_train_path.format(args.inter_encoder, args.dataset, args.inid)
    torch.save({'encoder': inter_enc.state_dict(),}, inter_train_path)
    return inter_enc

def sensitive_consistency_loss(args, feat_t, feat_s, y, s, train_mask):
    feat_t = F.normalize(feat_t, dim=1)
    feat_s = F.normalize(feat_s, dim=1)

    train_labels = y.view(-1, 1)[train_mask]
    sens_label = s.view(-1, 1)[train_mask]
    batch_size = train_labels.size(0)

    idx_s0 = s[train_mask] == 0
    idx_s1 = s[train_mask] == 1
    idx_s0_y1 = torch.logical_and(idx_s0, y[train_mask] == 1)
    idx_s1_y1 = torch.logical_and(idx_s1, y[train_mask] == 1)
    idx_s0_y0 = torch.logical_and(idx_s0, y[train_mask] == 0)
    idx_s1_y0 = torch.logical_and(idx_s1, y[train_mask] == 0)
    # print(idx_s0_y1.size())
    num_s0_y1, num_s1_y1, num_s0_y0, num_s1_y0 = sum(idx_s0_y1), sum(idx_s1_y1), sum(idx_s0_y0), sum(idx_s1_y0)


    mask = torch.eq(train_labels, train_labels.T).to(args.device)  # 类别标签掩码
    sens_mask = torch.eq(sens_label, sens_label.T).to(args.device)  # 敏感属性掩码
    logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size).view(-1, 1).to(args.device), 0)  # 对角线为0，排除自身，用于排除自对比样本的掩码 

    intra_group = (mask & sens_mask).float() # [b, b] 同类同敏 # 分子 1 
    target_inter_group = ((~mask) & sens_mask).float() # 不同类同敏  # 作为分母
    sensitve_inter_group = (mask & (~sens_mask)).float() # 同类不同敏 # 分子 2
    tar_sens_inter_group = ((~mask) & (~sens_mask)).float() # 不同类不同敏
    numerator_mask = (logits_mask & mask).float() # 同类

    # disentanglement
    logits = torch.div(torch.matmul(feat_t[train_mask], feat_t[train_mask].T), args.tau) #[b, b]
    anchor_dot_contrast = torch.div(torch.matmul(feat_t[train_mask], feat_s[train_mask].T), args.tau)
    exp_logits_sum = (torch.exp(anchor_dot_contrast) * intra_group).sum(1, keepdim=True) # 同类同敏求和作分母 [b,1]
    log_prob = logits - torch.log(exp_logits_sum + (exp_logits_sum == 0) * 1) # [b, b]
    mean_log_prob = (log_prob * intra_group).sum(1) / intra_group.sum(1) #[b]
    loss_dis = - mean_log_prob.mean()

    # group acquired
    anchor_dot_contrast_2 = torch.div(torch.matmul(feat_t[train_mask], feat_t[train_mask].T), args.tau)
    # logits_max, _ = torch.max(anchor_dot_contrast_2, dim=1, keepdim=True)
    # logits_2 = anchor_dot_contrast_2 - logits_max.detach() # [b, b]
    # exp_logits_sum_2 = (torch.exp(logits_2) * target_inter_group).sum(1, keepdim=True) # 其实这里不会有自身的，因为不同类同敏已经将自身排除了
    # log_prob_2 = logits_2 - torch.log(exp_logits_sum_2 + (exp_logits_sum_2 == 0) * 1)

    exp_logits_sum_2 = (torch.exp(anchor_dot_contrast_2) * target_inter_group).sum(1, keepdim=True) # 其实这里不会有自身的，因为不同类同敏已经将自身排除了
    log_prob_2 = anchor_dot_contrast_2 - torch.log(exp_logits_sum_2 + (exp_logits_sum_2 == 0) * 1)
    mean_log_prob_2 = (mask * log_prob_2).sum(1) / mask.sum(1)
    # print(mean_log_prob_2.size())
    loss_con = -((idx_s0_y1 * mean_log_prob_2).sum() / num_s0_y1 + (idx_s1_y1 * mean_log_prob_2).sum() / num_s1_y1 + \
        (idx_s0_y0 * mean_log_prob_2).sum() / num_s0_y0 + (idx_s1_y0 * mean_log_prob_2).sum()/ num_s1_y0)
    # loss_con = loss_con.mean()
    # print(num_s0_y1.size())

    return loss_dis, loss_con

@torch.no_grad()
def build_bma_pairs(
    z_s, y_s,              # 源域 embedding / 标签
    z_t, y_t_hat, conf_t,  # 目标域 embedding / 伪标签 / 置信度
    conf_th, topk=20, tau=0.5, chunk=2048
):
    """
    构建 BMA 跨域配对（同标签，不再区分敏感属性）。
    对每个目标节点 t（且 conf_t >= conf_th），
    只在“源域同标签”的集合里取 Top-K 最近邻，并对该 t 的 k 个邻居做 softmax(sim/tau) 作为权重。

    返回：
      idx_t: [M] 目标域索引（t）
      idx_s: [M] 源域索引（s）
      w    : [M]  每条 (t,s) 对的权重（每个 t 的 k 个权重和为 1）
    说明：
      - z_* 会先做 L2 归一化再算余弦相似；
      - 只会使用 conf_t >= conf_th 的目标节点；
      - 为节省显存，对目标端用 chunk 分块做 top-k；
    """

    device = z_t.device

    # 归一化表征
    z_s = F.normalize(z_s, dim=1)
    z_t = F.normalize(z_t, dim=1)

    # 仅保留高置信的 target 节点
    keep_t = conf_t >= conf_th
    T_all = torch.arange(z_t.size(0), device=device)[keep_t]
    if T_all.numel() == 0:
        empty_l = torch.empty(0, dtype=torch.long, device=device)
        empty_f = torch.empty(0, dtype=torch.float, device=device)
        return empty_l, empty_l, empty_f

    idx_t_list, idx_s_list, w_list = [], [], []

    # 仅按“同标签”分桶
    # （二分类时 y in {0,1}；若多类可用 torch.unique(y_s)）
    labels = [0, 1] if y_s.max().item() <= 1 else torch.unique(y_s).tolist()

    for y in labels:
        # 源域：该标签的全集
        S = torch.where(y_s == y)[0]
        if S.numel() == 0:
            continue
        zS = z_s[S]  # [|S|, d]

        # 目标域：该标签且高置信
        T = T_all[(y_t_hat[T_all] == y)]
        if T.numel() == 0:
            continue

        # 对目标端分块，避免 OOM
        for t0 in range(0, T.numel(), chunk):
            t1   = min(T.numel(), t0 + chunk)
            t_idx = T[t0:t1]                 # [b]
            # 相似度 [b, |S|]
            sim   = z_t[t_idx] @ zS.T
            k     = min(topk, sim.size(1))
            if k <= 0:
                continue
            vals, nbr = torch.topk(sim, k=k, dim=1)     # [b,k]
            w = F.softmax(vals / tau, dim=1)            # [b,k] 每个 t 的权重和=1

            # 记录 (t, s) 与权重
            idx_t_list.append(t_idx.unsqueeze(1).expand(-1, k).reshape(-1))  # 展平 [b*k]
            idx_s_list.append(S[nbr.reshape(-1)])
            w_list.append(w.reshape(-1))

    if len(idx_t_list) == 0:
        empty_l = torch.empty(0, dtype=torch.long, device=device)
        empty_f = torch.empty(0, dtype=torch.float, device=device)
        return empty_l, empty_l, empty_f

    idx_t = torch.cat(idx_t_list, dim=0)
    idx_s = torch.cat(idx_s_list, dim=0)
    w     = torch.cat(w_list,   dim=0)
    return idx_t, idx_s, w


def bma_loss_bipartite(
    P_t, P_s,
    idx_t, idx_s,
    w=None,
    use_weighted_degree=False,
    eps=1e-8,
    norm_edge="E",     # "E"=按边均值；也可设为 "pair" 用 Nt*Ns
    norm_gram="pair"   # "pair"=按 Nt*Ns；也可设为 "E"
):
    """
    DREAM 风格的跨域 BMA（矩形版、只对齐跨域块），数值稳定。
    目标：  L =  - 2 * <P_t P_s^T,  Ẃ>  +  ||P_t^T P_s||_F^2   （常数 ||Ẃ||_F^2 省略）
    其中 Ẃ = D_t^{-1/2} W D_s^{-1/2} 是二部图的左右度归一化邻接（可按需用无权度）。

    Args:
      P_t: [Nt, C]  目标域的类别概率（每行和为1；二分类时形如 [p, 1-p]）
      P_s: [Ns, C]  源域的类别概率
      idx_t: [M]    边 (t_i -- s_j) 中 target 端的索引 ∈ [0, Nt)
      idx_s: [M]    边 (t_i -- s_j) 中 source 端的索引 ∈ [0, Ns)
      w: [M] or None  每条边的权重（相似/置信），缺省为 1
      use_weighted_degree: True 则用加权度做归一化；False 则用“无权度”（更适合把 w 当置信度）
      norm_edge: 交叉项的归一化因子；"E"=按边数，"pair"=按 Nt*Ns
      norm_gram: Gram 项的归一化因子；"pair"=按 Nt*Ns，"E"=按边数

    Returns:
      标量 loss（可反传到 P_t/P_s，对 encoder 反传）
    """
    device = P_t.device
    Nt, C = P_t.shape
    Ns = P_s.shape[0]

    if idx_t is None or idx_t.numel() == 0:
        return P_t.new_zeros((), requires_grad=True)

    # 行归一 & clamp，避免极端数值
    P_t = (P_t.clamp(1e-6, 1-1e-6) / P_t.sum(dim=1, keepdim=True))
    P_s = (P_s.clamp(1e-6, 1-1e-6) / P_s.sum(dim=1, keepdim=True))

    if w is None:
        w = torch.ones(idx_t.numel(), device=device)
    else:
        w = w.to(device)

    # --- 度与左右归一化（矩形的 D_t^{-1/2} W D_s^{-1/2}） ---
    # 先统计“度”：t侧和s侧
    if use_weighted_degree:
        deg_t = torch.zeros(Nt, device=device).index_add(0, idx_t, w)  # 加权度
        deg_s = torch.zeros(Ns, device=device).index_add(0, idx_s, w)
    else:
        one = torch.ones_like(w)
        deg_t = torch.zeros(Nt, device=device).index_add(0, idx_t, one)  # 无权度（推荐）
        deg_s = torch.zeros(Ns, device=device).index_add(0, idx_s, one)

    # 归一化边权： w_norm = w / sqrt(d_t * d_s)
    dt = deg_t[idx_t].clamp_min(eps).sqrt()
    ds = deg_s[idx_s].clamp_min(eps).sqrt()
    w_norm = w / (dt * ds + eps)  # [M]

    # --- 交叉项：-2 * sum_edges w_norm * <p_t[i], p_s[j]>  （按边均值）
    sim_ts = (P_t[idx_t] * P_s[idx_s]).sum(dim=1)         # [M]
    cross_term = -2.0 * (w_norm * sim_ts).sum()
    if norm_edge == "E":
        E = max(1, w_norm.numel())
        cross_term = cross_term / E
    else:  # "pair"
        cross_term = cross_term / max(1.0, float(Nt) * float(Ns))

    # --- Gram 项： ||P_t^T P_s||_F^2 = tr(P_t^T P_t · P_s^T P_s)  （按对均值）
    Gt = P_t.t() @ P_t      # [C, C]
    Gs = P_s.t() @ P_s      # [C, C]
    gram_term = torch.sum(Gt * Gs)  # Frobenius 内积
    if norm_gram == "pair":
        gram_term = gram_term / max(1.0, float(Nt) * float(Ns))
    else:  # "E"
        gram_term = gram_term / max(1, w_norm.numel())

    # （可选）常数项：||Ẃ||_F^2/E，仅用于日志，不影响梯度
    # const_term = (w_norm.pow(2).sum() / max(1, w_norm.numel()))

    return cross_term + gram_term

def train(args, source_data, target_data):
    # process data
    pbar = tqdm(range(args.runs), unit='run')
    acc, auc_roc, parity, equality = np.zeros([args.runs, 1]), np.zeros([args.runs, 1]), np.zeros([args.runs, 1]), np.zeros([args.runs, 1])
    source_data = source_data.to(args.device)
    target_data = target_data.to(args.device)
    sens_labels = torch.tensor(source_data.sens_labels).to(torch.float).to(args.device)
    criterion = nn.BCEWithLogitsLoss()

    for count in pbar:
        # initialize models and optimizer
        sens_enc, optimizer_sens, scheduler_sens, inter_enc, optimizer_inter, scheduler_inter, club, optimizer_club, scheduler_club = initialize_models(args)
        # train sensitive encoder
        sens_train_path = args.sens_train_path.format(args.sens_encoder, args.dataset, args.inid)
        if os.path.exists(sens_train_path) and not args.overwrite:
            print(f"Loading encoder from {sens_train_path}")
            sens_enc = load_encoder(args, sens_train_path, args.sens_encoder, sens_enc)
        else:
            print(f"{sens_train_path} not found. Training sensitive encoder...")
            sens_enc = train_sensitive_encoder(
                args, source_data, sens_enc, optimizer_sens, scheduler_sens, sens_labels, criterion
            )

        # train interest encoder
        inter_train_path = args.inter_train_path.format(args.inter_encoder, args.dataset, args.inid)
        if os.path.exists(inter_train_path) and not args.overwrite:
            print(f"Loading encoder from {inter_train_path}")
            inter_enc = load_encoder(args, inter_train_path, args.inter_encoder, inter_enc)
        else:
            print(f"{inter_train_path} not found. Training interest encoder...")
            inter_enc = train_interest_encoder(
                args, source_data, target_data, inter_enc, optimizer_inter, scheduler_inter, sens_enc, club, optimizer_club, scheduler_club, criterion
            )

        t1, t2, t3, t4 = evaluate_per_class(args, source_data, inter_enc)
        print(f"test on source data: {t1}, {t2}, {t3}, {t4}")
        accs, auc_rocs, tmp_parity, tmp_equality = evaluate_per_class(args, target_data, inter_enc)
        print(f"test on target data: {accs}, {auc_rocs}, {tmp_parity}, {tmp_equality}")
        
        with torch.no_grad():
            inter_enc.eval()
            sens_enc.eval()
            z_s, _ = inter_enc(source_data.x, source_data.edge_index)
            z_t, logit_t = inter_enc(target_data.x, target_data.edge_index)
            yhat_t, conf_t = pseudo_from_logits(logit_t, thr=0.5)
            z_sens_t, logit_sens_t = sens_enc(target_data.x, target_data.edge_index)
            sens_hat_t, conf_sens_t = pseudo_from_logits(logit_sens_t, thr=0.5)

            new_x, edge_index, new_sens, new_y, new_train_mask, id_map = build_mi_cross_graph(
                args, source_data, target_data, 
                inter_enc, yhat_t, conf_t,
                sens_enc, sens_hat_t, conf_sens_t, 
                chunk_size=1024
            )

            idx_t, idx_s, w_bma = build_bma_pairs(
                z_s=z_s, y_s=source_data.y,
                z_t=z_t, y_t_hat=yhat_t, conf_t=conf_t, 
                conf_th = args.y_pseudo_threshold
            )

        y0s0_mask = (new_y==0) & (new_sens == 0)
        y0s1_mask = (new_y==0) & (new_sens == 1)
        y1s0_mask = (new_y==1) & (new_sens == 0)
        y1s1_mask = (new_y==1) & (new_sens == 1)

        inter_enc.train()
        sens_enc.train()
        e_s, _ = sens_enc.forward(source_data.x, source_data.edge_index)
        e_s = e_s.detach().to(args.device)

        p_s = conditional_samples(e_s.detach().cpu().numpy())
        p_s = torch.tensor(p_s).to(args.device)

        pre_train_path = args.pre_train_path.format(args.pre_train_encoder, args.dataset, args.inid)
        ex_enc = load_encoder(args, pre_train_path, args.pre_train_encoder)
        e_x, _ = ex_enc.forward(source_data.x, source_data.edge_index)
        e_x = e_x.detach().to(args.device)
        best_perf = 0.0
        
        for epoch in range(1, args.adaption_epochs+1):
            print(f"================={epoch}=================")
            e_z, logits = inter_enc.forward(source_data.x, source_data.edge_index)
            # equation 12
            # for information between fairness aware embedding and pre train embedding
            lb = condition_info_nce_for_embeddings(e_x, e_z, e_s, p_s)
            # equation 9
            # for information between fairness aware embedding and sensitive embedding
            up = club.forward(e_z, e_s)

            new_z, new_logits = inter_enc(new_x.to(args.device), edge_index.to(args.device))  # logits: [N_new, 1]
            new_s, _ = sens_enc(new_x.to(args.device), edge_index.to(args.device))  # logits: [N_new, 1]
            # L_sup = criterion(new_logits[new_train_mask].view(-1), new_y[new_train_mask].float().to(args.device))
            L_sup = criterion(new_logits[new_train_mask & y0s0_mask].view(-1), new_y[new_train_mask & y0s0_mask].float()) + \
                criterion(new_logits[new_train_mask & y0s1_mask].view(-1), new_y[new_train_mask & y0s1_mask].float()) + \
                criterion(new_logits[new_train_mask & y1s0_mask].view(-1), new_y[new_train_mask & y1s0_mask].float()) + \
                criterion(new_logits[new_train_mask & y1s1_mask].view(-1), new_y[new_train_mask & y1s1_mask].float())

            loss_dis, loss_cons = sensitive_consistency_loss(args, new_z, new_s, new_y.to(args.device), new_sens.to(args.device), new_train_mask.to(args.device))

            z_s, logit_s_full = inter_enc(source_data.x, source_data.edge_index)
            z_t, logit_t_full = inter_enc(target_data.x, target_data.edge_index)
            p1_s = torch.sigmoid(logit_s_full.view(-1))
            p1_t = torch.sigmoid(logit_t_full.view(-1))
            P_s  = torch.stack([p1_s, 1 - p1_s], dim=1)
            P_t  = torch.stack([p1_t, 1 - p1_t], dim=1)

            # -- bigraph-aware modality alignment (same-y, diff-s) --
            # bma loss get guarantee the alignment between source and target domain
            # behave like adversarial function
            if idx_t is None or idx_t.numel() == 0:
                loss_bma = torch.zeros((), device=args.device)
            else:
                loss_bma = bma_loss_bipartite(
                    P_t, P_s, idx_t, idx_s,
                    w=w_bma,
                    use_weighted_degree=False,  # w当“置信/相似”时建议False
                    norm_edge="E",
                    norm_gram="pair"
                )

            loss = L_sup + args.lambda_bma * loss_bma +  args.lambda_cons * loss_cons + args.lambda_dis * loss_dis + args.lreg * lb + args.ureg * up
            # loss = L_sup + args.lambda_bma * loss_bma +  args.lambda_cons * (loss_cons + loss_dis) + args.lreg * lb + args.ureg * up
            
            optimizer_inter.zero_grad()
            optimizer_sens.zero_grad()
            loss.backward()
            optimizer_inter.step()
            optimizer_sens.step()
            scheduler_inter.step()
            scheduler_sens.step()

            e_z, _ = inter_enc.forward(source_data.x, source_data.edge_index)
            x_samples = e_z.detach()
            y_samples = e_s.detach()
            for _ in range(args.train_step):
                mi_loss = club.learning_loss(x_samples, y_samples)
                optimizer_club.zero_grad()
                mi_loss.backward()
                optimizer_club.step()
                scheduler_club.step()

            if epoch % 1 == 0:
                print(f"[Epoch {epoch:03d}] "
                    f"L_sup={L_sup.item():.4f}  "
                    f"L_bma={loss_bma.item():.4f}"
                    f"L_fair_cons={loss_cons.item():.4f}"
                    f"L_fair_dis={loss_dis.item():.4f}"
                    f"Total={loss.item():.4f}")

            t1, t2, t3, t4 = evaluate_per_class(args, source_data, inter_enc)
            print(f"test on source data: {t1}, {t2}, {t3}, {t4}")
            accs, auc_rocs, tmp_parity, tmp_equality = evaluate_per_class(args, target_data, inter_enc)
            print(f"test on target data: {accs}, {auc_rocs}, {tmp_parity}, {tmp_equality}")
            # if t1["all"] + t2["all"] - t3["all"] - t4["all"] > acc[count] + auc_roc[count] - parity[count] - equality[count]:
            acc[count], auc_roc[count] = accs['all'], auc_rocs['all']
            parity[count], equality[count] = tmp_parity['all'], tmp_equality['all']

    return acc, auc_roc, parity, equality