import torch

def get_m_td(meta_vocab_size, true_adj, true_reach):
    Start_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            Start_Check[i+2][j+2][i+2] = 1

    Adj_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            if(i == j):
                Adj_Check[j+2][i+2][1] = 1
                Adj_Check[j+2][i+2][0] = 1
            else:
                for k in range(meta_vocab_size-2):
                    if(true_adj[i][k] == 1):
                        Adj_Check[j+2][i+2][k+2] = 1
    for j in range(meta_vocab_size):
        Adj_Check[j][1][0] = 1
        Adj_Check[j][0][0] = 1
        Adj_Check[j][1][1] = 1

    Path_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            if(i == j):
                Path_Check[j+2][i+2][1] = 1 # 2
                Path_Check[j+2][i+2][0] = 1 # 2
            elif(true_reach[j][i] == 1):
                for k in range(meta_vocab_size-2):
                    if(true_adj[i][k] == 1 and true_reach[j][k] == 1):
                        Path_Check[j+2][i+2][k+2] = 1
                        if j == k:
                            Path_Check[j+2][i+2][k+2] = 2 # added
        for j in range(meta_vocab_size): 
            Path_Check[j][1][0] = 1
            Path_Check[j][0][0] = 1
            Path_Check[j][1][1] = 1

    Reach_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            if(i == j):
                Reach_Check[j+2][i+2][1] = 1
                Reach_Check[j+2][i+2][0] = 1
            else:
                for k in range(meta_vocab_size-2):
                    if(true_reach[j][k] == 1):
                        Reach_Check[j+2][i+2][k+2] = 1 
        for j in range(meta_vocab_size): 
            Reach_Check[j][1][0] = 1
            Reach_Check[j][0][0] = 1
            Reach_Check[j][1][1] = 1
    
    Pad_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size):
        Pad_Check[i][1][0] = 1
        Pad_Check[i][0][0] = 1
        Pad_Check[i][1][1] = 1
        Pad_Check[i][i][1] = 1
        Pad_Check[i][i][0] = 1
               
    return Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check  

def get_m(meta_vocab_size, true_adj, true_reach):
    Start_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            Start_Check[i+2][j+2][i+2] = 1

    Adj_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            if(i == j):
                Adj_Check[j+2][i+2][1] = 1
                Adj_Check[j+2][i+2][0] = 1
            else:
                for k in range(meta_vocab_size-2):
                    if(true_adj[i][k] == 1):
                        Adj_Check[j+2][i+2][k+2] = 1
    for j in range(meta_vocab_size):
        Adj_Check[j][1][0] = 1
        Adj_Check[j][0][0] = 1
        Adj_Check[j][1][1] = 1

    Path_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            if(i == j):
                Path_Check[j+2][i+2][1] = 2 # 2
                Path_Check[j+2][i+2][0] = 2 # 2
            elif(true_reach[j][i] == 1):
                for k in range(meta_vocab_size-2):
                    if(true_adj[i][k] == 1 and true_reach[j][k] == 1):
                        Path_Check[j+2][i+2][k+2] = 1
                        # if j == k:
                        #     Path_Check[j+2][i+2][k+2] = 2 # added
        for j in range(meta_vocab_size): 
            Path_Check[j][1][0] = 1
            Path_Check[j][0][0] = 1
            Path_Check[j][1][1] = 1

    Reach_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size-2):
        for j in range(meta_vocab_size-2):
            if(i == j):
                Reach_Check[j+2][i+2][1] = 1
                Reach_Check[j+2][i+2][0] = 1
            else:
                for k in range(meta_vocab_size-2):
                    if(true_reach[j][k] == 1):
                        Reach_Check[j+2][i+2][k+2] = 1 
        for j in range(meta_vocab_size): 
            Reach_Check[j][1][0] = 1
            Reach_Check[j][0][0] = 1
            Reach_Check[j][1][1] = 1
    
    Pad_Check = torch.zeros(meta_vocab_size, meta_vocab_size, meta_vocab_size)
    for i in range(meta_vocab_size):
        Pad_Check[i][1][0] = 1
        Pad_Check[i][0][0] = 1
        Pad_Check[i][1][1] = 1
        Pad_Check[i][i][1] = 1
        Pad_Check[i][i][0] = 1
               
    return Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check  
    
def get_scores_0(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Path + Adj {0, 0.5, 1}
    b,t = train.shape
    device = train.device
    Start_Check, Adj_Check, Path_Check, Pad_Check = Start_Check.to(device), Adj_Check.to(device), Path_Check.to(device), Pad_Check.to(device)
    new_score_vector = torch.zeros((b, t-2)).to(device)
    new_score_vector[:, 0] = Start_Check[train[:, 0], train[:, 1], train[:,2]]
    path_check_b = torch.zeros((b, t-3)).to(device)
    
    source_indices = train[:, 2:-1]
    target_indices = train[:, 1].unsqueeze(-1)
    next_indices = train[:, 3:]

    path_check_b = Path_Check[target_indices, source_indices, next_indices]
    is_zero = (path_check_b == 0).float()
    cumsum_zero = torch.cumsum(is_zero, dim=1)
    first_zero_mask = (cumsum_zero > 0).float()

    path_check_b_masked = torch.where(first_zero_mask > 0, torch.ones_like(path_check_b), path_check_b)
    path_check_b_masked[:,0] = 1
    path_check_b_masked = torch.prod(path_check_b_masked, dim = 1)
    
    loss2 = torch.zeros((b, t-3)).to(device)

    mask = (path_check_b_masked == 2).unsqueeze(1).expand(-1, t-3)
    loss2[mask] = 1

    loss1 = torch.ones((b, t-3)).to(device)
    loss1 = Adj_Check[target_indices, source_indices, next_indices]

    loss3 = Pad_Check[target_indices, source_indices, next_indices]
    # print(torch.sum(loss2))
    # print(torch.sum(loss1))
    new_score_vector[:, 1:] = torch.maximum((loss1 + loss2) / 2, loss3) 
    new_score_vector[:, 0] = new_score_vector[:, 0]
    return new_score_vector


def get_scores_1(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {0, 1}
    b,t = train.shape
    device = train.device
    Start_Check, Adj_Check, Path_Check, Reach_Check = Start_Check.to(device), Adj_Check.to(device), Path_Check.to(device), Reach_Check.to(device)
    new_score_vector = torch.zeros((b, t-2)).to(device)
    
    new_score_vector[:, 0] = Start_Check[train[:, 0], train[:, 1], train[:,2]]

    source_indices = train[:, 2:-1]  # shape: (batch_size, seq_len-1)
    target_indices = train[:, 1].unsqueeze(-1)   # shape: (batch_size, seq_len-1)
    next_indices = train[:, 3:]

    new_score_vector[:, 1:] = (Reach_Check[target_indices, source_indices, next_indices] * Adj_Check[target_indices, source_indices, next_indices])
    
    return new_score_vector


def get_scores_2(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # (Adj + Reach) / 2 {0, 0.5, 1}
    b,t = train.shape
    device = train.device
    Start_Check, Adj_Check, Path_Check, Reach_Check = Start_Check.to(device), Adj_Check.to(device), Path_Check.to(device), Reach_Check.to(device)
    new_score_vector = torch.zeros((b, t-2)).to(device)
    
    new_score_vector[:, 0] = Start_Check[train[:, 0], train[:, 1], train[:,2]]

    source_indices = train[:, 2:-1]  # shape: (batch_size, seq_len-1)
    target_indices = train[:, 1].unsqueeze(-1)   # shape: (batch_size, seq_len-1)
    next_indices = train[:, 3:]

    new_score_vector[:, 1:] = (Reach_Check[target_indices, source_indices, next_indices] + Adj_Check[target_indices, source_indices, next_indices]) / 2
    
    return new_score_vector


def get_scores_3(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {-0.5, 0, 0.5}
    new_score_vector = get_scores_2(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    new_score_vector = new_score_vector - 0.5
    return new_score_vector


def get_scores_4(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {-1, 0}
    score = get_scores_6(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score - 1
    return score

def get_scores_5(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {-0.2, 0.8}
    score = get_scores_6(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score - 0.2
    return score

def get_scores_7(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {0.2, 1.2}
    score = get_scores_6(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score + 0.2
    return score

def get_scores_8(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {1, 2}
    score = get_scores_6(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score + 1
    return score

def get_scores_6(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {0, 1}
    b,t = train.shape
    device = train.device
    Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check = Start_Check.to(device), Adj_Check.to(device), Path_Check.to(device), Reach_Check.to(device), Pad_Check.to(device)
    new_score_vector = torch.zeros((b, t-2)).to(device)
    new_score_vector[:, 0] = Start_Check[train[:, 0], train[:, 1], train[:,2]]

    path_check_b = torch.zeros((b, t-3)).to(device)

    source_indices = train[:, 2:-1]  # shape: (batch_size, seq_len-1)
    target_indices = train[:, 1].unsqueeze(-1)   # shape: (batch_size, seq_len-1)
    next_indices = train[:, 3:]

    # 利用 idx[:, 1] 作为 batch 级索引
    path_check_b = Path_Check[target_indices, source_indices, next_indices]
    is_zero = (path_check_b == 0).float()
    cumsum_zero = torch.cumsum(is_zero, dim=1)  # 累积和用于定位第一个零
    first_zero_mask = (cumsum_zero > 0).float()  # 标记第一个零及其之后的位置

    # 2. 将掩码位置设为 1
    path_check_b_masked = torch.where(first_zero_mask > 0, torch.ones_like(path_check_b), path_check_b)
    path_check_b_masked[:,0] = 1
    path_check_b_masked = torch.prod(path_check_b_masked, dim = 1)
    
    loss2 = torch.zeros((b, t-3)).to(device) # 默认初始化为 1

    mask = (path_check_b_masked == 2).unsqueeze(1).expand(-1, t-3)  # 将 mask 扩展为 a*c 的形状
    loss2[mask] = 1
    
    new_score_vector[:, 1:] = loss2 * new_score_vector[:, [0]]
    new_score_vector[:, 0] = new_score_vector[:, 1]
    # new_score_vector[:, 1:] = loss2
    return new_score_vector

def get_scores_9(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {-1, 1}
    score = get_scores_1(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score * 2 - 1
    return score

def get_scores_10(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj & Reach {1, 2}
    score = get_scores_1(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score + 1
    return score

def get_scores_11(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj + Path {-1, 0, 1}
    score = get_scores_0(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score * 2 - 1
    return score

def get_scores_12(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Adj + Path {1, 3/2, 2}
    score = get_scores_0(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
    score = score + 1
    return score

def get_scores_13(train, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check): # Path + Adj {0, 0.5, 1}
    b,t = train.shape
    device = train.device
    Start_Check, Adj_Check, Path_Check, Pad_Check = Start_Check.to(device), Adj_Check.to(device), Path_Check.to(device), Pad_Check.to(device)
    new_score_vector = torch.zeros((b, t-2)).to(device)
    new_score_vector[:, 0] = Start_Check[train[:, 0], train[:, 1], train[:,2]]
    path_check_b = torch.zeros((b, t-3)).to(device)
    
    source_indices = train[:, 2:-1]
    target_indices = train[:, 1].unsqueeze(-1)
    next_indices = train[:, 3:]

    path_check_b = Path_Check[target_indices, source_indices, next_indices]
    is_zero = (path_check_b == 0).float()
    cumsum_zero = torch.cumsum(is_zero, dim=1)
    first_zero_mask = (cumsum_zero > 0).float()

    path_check_b_masked = torch.where(first_zero_mask > 0, torch.ones_like(path_check_b), path_check_b)
    # path_check_b_masked[:, 0] = 1
    path_check_b_masked = torch.prod(path_check_b_masked, dim = 1)
    
    path_check_b = path_check_b * (path_check_b > 1)
    mask = (path_check_b_masked == 2).unsqueeze(1).expand(-1, t-3)
    loss2 = mask * path_check_b * 50

    loss1 = Adj_Check[target_indices, source_indices, next_indices]
    loss1 = loss1 * 10 - 10

    new_score_vector[:, 1:] = loss1 + loss2 
    new_score_vector[:, 0] = new_score_vector[:, 0] * 10 - 10
    return new_score_vector