import numpy as np

def lower_triangle(text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech, value=1, top=5, low=1):
    # mask0.png
    # 下三角，常规的causal attention
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    # Step 1: 计算每个序列的起止位置
    offsets = {}
    current = 0
    for name, length in seq_list:
        offsets[name] = (current, current + length)
        current += length
    total_len = current
    # Step 2: 初始化为 low，并加下三角 causal mask
    mask = np.tril(np.ones((total_len, total_len), dtype=np.float32)) * value
    return mask

def tgt_st_paired_emo_and_all_tt(text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech, top=5, low=1):
    # mask1.png
    # text and prompt causal, tgt speech only look at paired emo prompt and all tgt text
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    prompt_keys = [item[0].split("_")[1] for item in speech_side_prompt]
    dependency_map = {}
    for prompt_item, tgt_item in zip(speech_side_prompt, speech_side_speech):
        dependency_map[tgt_item[0]] = [f"s_{item}" for item in prompt_keys if item != prompt_item[0].split("_")[1]] \
            + [f"t_{item}" for item in prompt_keys if item != prompt_item[0].split("_")[1]]

    # Step 1: 计算每个序列的起止位置
    names = [name for name, _ in seq_list]

    offsets = {}
    current = 0
    for name, length in seq_list:
        offsets[name] = (current, current + length)
        current += length
    total_len = current

    # Step 2: 初始化为 low，并加下三角 causal mask
    mask = np.tril(np.ones((total_len, total_len), dtype=np.float32)) * top

    # Step 3: 动态填充 top 区域
    for query_name, key_names in dependency_map.items():
        if query_name not in offsets:
            continue  # 跳过不在 seq_list 里的名字
        q_start, q_end = offsets[query_name]
        for key_name in key_names:
            if key_name not in offsets:
                continue
            k_start, k_end = offsets[key_name]
            mask[q_start:q_end, k_start:k_end] = low
    return mask

def all_paired_emo(text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech, top=5, low=1):
    # mask2.png
    # all paired emo (最原始的cosyvoice2)
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    dependency_map = {}
    iter_num = len(text_side_prompt)
    
    for item in seq_list:
        dependency_map[item[0]] = []
        if str(item[0]).startswith("t_0"):
            cur_idx = int(item[0].replace("t_0", ""))
            for i in range(int(cur_idx)):
                dependency_map[item[0]] += [f"t_0{i}"]
        elif str(item[0]).startswith("t_1"):
            cur_idx = int(item[0].replace("t_1", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"t_1{i}" for i in range(cur_idx)]
        elif str(item[0]).startswith("s_0"):
            cur_idx = int(item[0].replace("s_0", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"s_0{i}" for i in range(cur_idx)] \
                    + [f"t_1{i}" for i in range(iter_num) if i != cur_idx] 
        elif str(item[0]).startswith("s_1"):
            cur_idx = int(item[0].replace("s_1", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"s_0{i}" for i in range(iter_num) if i != cur_idx] \
                    + [f"t_1{i}" for i in range(iter_num) if i != cur_idx] \
                        + [f"s_1{i}" for i in range(cur_idx)]
            
    # Step 1: 计算每个序列的起止位置
    names = [name for name, _ in seq_list]

    offsets = {}
    current = 0
    for name, length in seq_list:
        offsets[name] = (current, current + length)
        current += length
    total_len = current

    # Step 2: 初始化为 low，并加下三角 causal mask
    mask = np.tril(np.ones((total_len, total_len), dtype=np.float32)) * top

    # Step 3: 动态填充 top 区域
    for query_name, key_names in dependency_map.items():
        if query_name not in offsets:
            continue  # 跳过不在 seq_list 里的名字
        q_start, q_end = offsets[query_name]
        for key_name in key_names:
            if key_name not in offsets:
                continue
            k_start, k_end = offsets[key_name]
            mask[q_start:q_end, k_start:k_end] = low
    return mask


def all_st_paired_emo(text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech, top=5, low=1):
    # mask3.png
    # all st paired emo
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    dependency_map = {}
    iter_num = len(text_side_prompt)
    
    for item in seq_list:
        dependency_map[item[0]] = []
        if str(item[0]).startswith("s_0"):
            cur_idx = int(item[0].replace("s_0", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"s_0{i}" for i in range(cur_idx)] \
                    + [f"t_1{i}" for i in range(iter_num) if i != cur_idx] 
        elif str(item[0]).startswith("s_1"):
            cur_idx = int(item[0].replace("s_1", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"s_0{i}" for i in range(iter_num) if i != cur_idx] \
                    + [f"t_1{i}" for i in range(iter_num) if i != cur_idx] + \
                        [f"s_1{i}" for i in range(cur_idx)]
            
    # Step 1: 计算每个序列的起止位置
    names = [name for name, _ in seq_list]

    offsets = {}
    current = 0
    for name, length in seq_list:
        offsets[name] = (current, current + length)
        current += length
    total_len = current

    # Step 2: 初始化为 low，并加下三角 causal mask
    mask = np.tril(np.ones((total_len, total_len), dtype=np.float32)) * top

    # Step 3: 动态填充 top 区域
    for query_name, key_names in dependency_map.items():
        if query_name not in offsets:
            continue  # 跳过不在 seq_list 里的名字
        q_start, q_end = offsets[query_name]
        for key_name in key_names:
            if key_name not in offsets:
                continue
            k_start, k_end = offsets[key_name]
            mask[q_start:q_end, k_start:k_end] = low

    return mask


def tgt_st_paired_emo(text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech, top=5, low=1):
    # mask6.png
    # tgt st look at only paired emo
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    dependency_map = {}
    iter_num = len(text_side_prompt)
    
    for item in seq_list:
        dependency_map[item[0]] = []
        if str(item[0]).startswith("s_1"):
            cur_idx = int(item[0].replace("s_1", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"s_0{i}" for i in range(iter_num) if i != cur_idx] \
                    + [f"t_1{i}" for i in range(iter_num) if i != cur_idx] + \
                        [f"s_1{i}" for i in range(cur_idx)]
            
    # Step 1: 计算每个序列的起止位置
    names = [name for name, _ in seq_list]

    offsets = {}
    current = 0
    for name, length in seq_list:
        offsets[name] = (current, current + length)
        current += length
    total_len = current

    # Step 2: 初始化为 low，并加下三角 causal mask
    mask = np.tril(np.ones((total_len, total_len), dtype=np.float32)) * top

    # Step 3: 动态填充 top 区域
    for query_name, key_names in dependency_map.items():
        if query_name not in offsets:
            continue  # 跳过不在 seq_list 里的名字
        q_start, q_end = offsets[query_name]
        for key_name in key_names:
            if key_name not in offsets:
                continue
            k_start, k_end = offsets[key_name]
            mask[q_start:q_end, k_start:k_end] = low

    return mask

def st_paired_emo_tgt_st_all_tt(text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech, top=5, low=1):
    # mask9.png
    # st look at paired emo, tgt st also look at all tgt text
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    dependency_map = {}
    iter_num = len(text_side_prompt)
    
    for item in seq_list:
        dependency_map[item[0]] = []
        if str(item[0]).startswith("s_0"):
            cur_idx = int(item[0].replace("s_0", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"s_0{i}" for i in range(cur_idx)] \
                    + [f"t_1{i}" for i in range(iter_num) if i != cur_idx] 
    
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    prompt_keys = [item[0].split("_")[1] for item in speech_side_prompt]
    for prompt_item, tgt_item in zip(speech_side_prompt, speech_side_speech):
        dependency_map[tgt_item[0]] = [f"s_{item}" for item in prompt_keys if item != prompt_item[0].split("_")[1]] \
            + [f"t_{item}" for item in prompt_keys if item != prompt_item[0].split("_")[1]]

    # Step 1: 计算每个序列的起止位置
    names = [name for name, _ in seq_list]

    offsets = {}
    current = 0
    for name, length in seq_list:
        offsets[name] = (current, current + length)
        current += length
    total_len = current

    # Step 2: 初始化为 low，并加下三角 causal mask
    mask = np.tril(np.ones((total_len, total_len), dtype=np.float32)) * top

    # Step 3: 动态填充 top 区域
    for query_name, key_names in dependency_map.items():
        if query_name not in offsets:
            continue  # 跳过不在 seq_list 里的名字
        q_start, q_end = offsets[query_name]
        for key_name in key_names:
            if key_name not in offsets:
                continue
            k_start, k_end = offsets[key_name]
            mask[q_start:q_end, k_start:k_end] = low

    return mask

def all_st_paired_emo_and_all_tt(text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech, top=5, low=1):
    # mask10.png
    # all speech only look at paired emo prompt and all text
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    dependency_map = {}
    iter_num = len(text_side_prompt)
    
    for item in seq_list:
        dependency_map[item[0]] = []
        if str(item[0]).startswith("s_0"):
            cur_idx = int(item[0].replace("s_0", ""))
            dependency_map[item[0]] += [f"t_0{i}" for i in range(iter_num) if i != cur_idx] \
                + [f"s_0{i}" for i in range(cur_idx)] 
    
    seq_list = text_side_prompt + text_side_tgt + speech_side_prompt + speech_side_speech
    prompt_keys = [item[0].split("_")[1] for item in speech_side_prompt]
    for prompt_item, tgt_item in zip(speech_side_prompt, speech_side_speech):
        dependency_map[tgt_item[0]] = [f"s_{item}" for item in prompt_keys if item != prompt_item[0].split("_")[1]] \
            + [f"t_{item}" for item in prompt_keys if item != prompt_item[0].split("_")[1]]

    # Step 1: 计算每个序列的起止位置
    names = [name for name, _ in seq_list]

    offsets = {}
    current = 0
    for name, length in seq_list:
        offsets[name] = (current, current + length)
        current += length
    total_len = current

    # Step 2: 初始化为 low，并加下三角 causal mask
    mask = np.tril(np.ones((total_len, total_len), dtype=np.float32)) * top

    # Step 3: 动态填充 top 区域
    for query_name, key_names in dependency_map.items():
        if query_name not in offsets:
            continue  # 跳过不在 seq_list 里的名字
        q_start, q_end = offsets[query_name]
        for key_name in key_names:
            if key_name not in offsets:
                continue
            k_start, k_end = offsets[key_name]
            mask[q_start:q_end, k_start:k_end] = low
    return mask

