from miditoolkit import MidiFile, Note, Instrument, TempoChange, ControlChange
import bisect
import numpy as np
import os
from copy import copy
import random
from collections import defaultdict

"""
def normalize_midi(midi_obj, target_ticks_per_beat = 500, target_tempo = 120):
    ticks_per_beat = midi_obj.ticks_per_beat
    merged_events = []
    for i in range(len(midi_obj.instruments)):
        filter_control_changes = []
        for cc in midi_obj.instruments[i].control_changes:
            if cc.number == 64:
                filter_control_changes.append(cc)
        merged_events.extend(midi_obj.instruments[i].notes + filter_control_changes)
    merged_events.sort(key=lambda x: (x.start, x.pitch) if isinstance(x, Note) else (x.time, x.number))
    
    time_interval = []
    last_time = 0
    for note in merged_events:
        if isinstance(note, Note):
            time_interval.append(note.start - last_time)
            last_time = note.start
        else:
            time_interval.append(note.time - last_time)
            last_time = note.time

    output_notes = []
    output_cc = []
    ind = -1
    now_tempo = 120
    now_time = 0
    for i, note in enumerate(merged_events):
        if isinstance(note, Note):
            time = note.start
        else:
            time = note.time
        while ind + 1 < len(midi_obj.tempo_changes) and time >= midi_obj.tempo_changes[ind+1].time:
            now_tempo = midi_obj.tempo_changes[ind+1].tempo
            ind += 1
        ratio = target_ticks_per_beat * target_tempo / now_tempo / ticks_per_beat
        start_time = time_interval[i] * ratio + now_time
        if isinstance(note, Note):
            end_time = (note.end - note.start) * ratio + start_time
            output_notes.append(Note(note.velocity, note.pitch, round(start_time), round(end_time)))
        else:
            output_cc.append(ControlChange(64, note.value, round(start_time)))
        now_time = round(start_time)
    
    output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_cc))
    output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
    for note in output_notes:
        output_midi_obj.max_tick = max(output_midi_obj.max_tick, note.end)
    for cc in output_cc:
        output_midi_obj.max_tick = max(output_midi_obj.max_tick, cc.time)
    return output_midi_obj
"""

"""
def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):    
    # 创建一个新的、干净的MidiFile对象用于输出
    output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
    
    # 获取原始MIDI的tick到秒的精确映射
    # 这是最关键的一步，partitura和miditoolkit都有类似功能
    # miditoolkit的get_tick_to_time_mapping()可以处理所有tempo变化
    tick_to_time_map = midi_obj.get_tick_to_time_mapping()
    
    # 计算从秒转换回目标tick的比例因子
    # 目标MIDI中，每秒对应的tick数 = target_ticks_per_beat * (target_tempo / 60)
    seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)

    merged_notes = []
    merged_cc = []

    # 遍历所有乐器轨道
    for instrument in midi_obj.instruments:
        # 只处理非鼓组的乐器
        if not instrument.is_drum:
            # --- 处理音符 (Notes) ---
            for note in instrument.notes:
                # 1. 将原始tick转换为绝对秒数
                start_time_sec = tick_to_time_map[note.start]
                end_time_sec = tick_to_time_map[note.end]
                
                # 2. 将绝对秒数转换为目标tick
                new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
                new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
                
                # 避免duration为0的音符
                if new_start_tick == new_end_tick:
                    new_end_tick += 1

                merged_notes.append(Note(velocity=note.velocity, 
                                         pitch=note.pitch, 
                                         start=new_start_tick, 
                                         end=new_end_tick))
            
            # --- 处理延音踏板 (CC #64) ---
            for cc in instrument.control_changes:
                if cc.number == 64:
                    # 1. 将原始tick转换为绝对秒数
                    time_sec = tick_to_time_map[cc.time]
                    
                    # 2. 将绝对秒数转换为目标tick
                    new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
                    
                    merged_cc.append(ControlChange(number=64, 
                                                   value=cc.value, 
                                                   time=new_time_tick))

    # --- 排序并创建新乐器 ---
    # 按开始时间排序，对于同时开始的事件，CC优先于Note
    merged_notes.sort(key=lambda x: (x.start, x.pitch))
    merged_cc.sort(key=lambda x: (x.time, x.number))
    
    output_instrument = Instrument(program=0, is_drum=False, name="Piano")
    output_instrument.notes = merged_notes
    output_instrument.control_changes = merged_cc
    output_midi_obj.instruments.append(output_instrument)
    
    # --- 正确计算 max_tick ---
    max_tick = 0
    if output_instrument.notes:
        max_tick = max(max_tick, max(n.end for n in output_instrument.notes))
    if output_instrument.control_changes:
        max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes))
    
    output_midi_obj.max_tick = max_tick

    return output_midi_obj
"""

def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
    """
    将一个MidiFile对象标准化：
    1. 合并所有轨道的钢琴音符和延音踏板事件。
    2. 将所有时间信息（包括tempo变化）统一转换为一个固定的ticks_per_beat和tempo。
    3. 清理重叠音符以避免解析错误。
    4. 正确计算并设置max_tick。

    Args:
        midi_obj (MidiFile): 原始的MidiFile对象。
        target_ticks_per_beat (int): 目标ticks_per_beat.
        target_tempo (float): 目标tempo (BPM).

    Returns:
        MidiFile: 标准化后的新MidiFile对象。
    """
    
    # 创建一个新的、干净的MidiFile对象用于输出
    output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
    
    tick_to_time_map = midi_obj.get_tick_to_time_mapping()
    seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)

    # --- 1. 收集并转换所有音符 ---
    all_converted_notes = []
    for instrument in midi_obj.instruments:
        if not instrument.is_drum:
            for note in instrument.notes:
                start_time_sec = tick_to_time_map[note.start]
                end_time_sec = tick_to_time_map[note.end]
                
                new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
                new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
                
                if new_start_tick >= new_end_tick:
                    # 确保音符至少有1 tick的长度
                    new_end_tick = new_start_tick + 1

                all_converted_notes.append(Note(velocity=note.velocity, 
                                                pitch=note.pitch, 
                                                start=new_start_tick, 
                                                end=new_end_tick))

    # --- 2. 清理重叠音符 (关键新增部分) ---
    # 首先按音高分组，然后按开始时间排序
    notes_by_pitch = defaultdict(list)
    for note in all_converted_notes:
        notes_by_pitch[note.pitch].append(note)

    merged_notes = []
    for pitch in sorted(notes_by_pitch.keys()):
        # 对每个音高的音符列表按开始时间排序
        sorted_notes = sorted(notes_by_pitch[pitch], key=lambda n: n.start)
        
        # 迭代并修复重叠
        if len(sorted_notes) > 1:
            for i in range(len(sorted_notes) - 1):
                current_note = sorted_notes[i]
                next_note = sorted_notes[i+1]
                
                # 如果当前音符的结束时间晚于或等于下一个音符的开始时间
                if current_note.end >= next_note.start:
                    # 修正当前音符的结束时间，让它在下一个音符开始前结束
                    # 我们可以让它在下一个音符开始时就结束
                    current_note.end = next_note.start
                    # 如果修复后导致时长为0，则丢弃该音符（或者设置为1 tick，这里选择前者更干净）
                    if current_note.start >= current_note.end:
                         # 标记为待删除，而不是直接删除，以避免迭代问题
                         current_note.pitch = -1 # 用一个无效音高作为标记

        # 将处理过的（且未被标记删除的）音符添加到最终列表
        merged_notes.extend([n for n in sorted_notes if n.pitch != -1])

    # --- 3. 收集并转换CC事件 ---
    merged_cc = []
    for instrument in midi_obj.instruments:
        if not instrument.is_drum:
            for cc in instrument.control_changes:
                if cc.number == 64:
                    time_sec = tick_to_time_map[cc.time]
                    new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
                    merged_cc.append(ControlChange(number=64, 
                                                   value=cc.value, 
                                                   time=new_time_tick))

    # --- 4. 排序并创建新乐器 ---
    merged_notes.sort(key=lambda x: (x.start, x.pitch))
    merged_cc.sort(key=lambda x: (x.time, x.number))
    
    output_instrument = Instrument(program=0, is_drum=False, name="Piano")
    output_instrument.notes = merged_notes
    output_instrument.control_changes = merged_cc
    output_midi_obj.instruments.append(output_instrument)
    
    # --- 5. 正确计算 max_tick ---
    max_tick = 0
    if output_instrument.notes:
        max_tick = max(max_tick, max(n.end for n in output_instrument.notes if n.end is not None))
    if output_instrument.control_changes:
        max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes if c.time is not None))
    
    # 添加一个小的buffer，确保最后一个事件不会被截断
    output_midi_obj.max_tick = max_tick + target_ticks_per_beat 

    return output_midi_obj

def midi_to_ids(config, midi_obj, normalize=True):
    def get_pedal(time_list, ccs, time):
        i = bisect.bisect_right(time_list, time)
        if i == 0:
            return 0
        else:
            return ccs[i-1].value
    if normalize:
        norm_midi_obj = normalize_midi(midi_obj)
    else:
        norm_midi_obj = midi_obj
    time_list = [cc.time for cc in norm_midi_obj.instruments[0].control_changes]
    #print(time_list)
    intervals = []
    last_time = 0
    for note in norm_midi_obj.instruments[0].notes:
        intervals.append(note.start - last_time)
        last_time = note.start
    intervals.append(4990)

    ids = []
    last_time = 0
    for i, note in enumerate(norm_midi_obj.instruments[0].notes):
        interval = config.timing_start + intervals[i]
        #print(interval - interval_start)

        pitch = config.pitch_start + note.pitch
        velocity = config.velocity_start + note.velocity
        duration = config.timing_start + note.duration
        last_time = last_time + intervals[i]

        pedal1 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time)
        pedal2 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 1 / 4)
        pedal3 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 2 / 4)
        pedal4 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 3 / 4)
        
        pitch = min(config.valid_id_range[0][1] - 1, max(config.valid_id_range[0][0], pitch))
        interval = min(config.valid_id_range[1][1] - 1, max(config.valid_id_range[1][0], interval))
        velocity = min(config.valid_id_range[2][1] - 1, max(config.valid_id_range[2][0], velocity))
        duration = min(config.valid_id_range[3][1] - 1, max(config.valid_id_range[3][0], duration))
        pedal1 = min(config.valid_id_range[4][1] - 1, max(config.valid_id_range[4][0], pedal1))
        pedal2 = min(config.valid_id_range[5][1] - 1, max(config.valid_id_range[5][0], pedal2))
        pedal3 = min(config.valid_id_range[6][1] - 1, max(config.valid_id_range[6][0], pedal3))
        pedal4 = min(config.valid_id_range[7][1] - 1, max(config.valid_id_range[7][0], pedal4))

        ids.extend([pitch, interval, velocity, duration, pedal1, pedal2, pedal3, pedal4])
    return ids

def ids_to_midi(config, ids, target_ticks_per_beat = 500, target_tempo = 120):
    note_list = []
    cc_list = []
    intervals = []
    for i in range(0, len(ids), 8):
        intervals.append(ids[i+1] - config.timing_start)
    intervals.append(4990)
    
    last_time = 0
    for i in range(0, len(ids), 8):
        interval = intervals[i // 8]
        pitch = ids[i] - config.pitch_start
        velocity = ids[i+2] - config.velocity_start
        duration = ids[i+3] - config.timing_start
        pedal1 = ids[i+4] - config.pedal_start
        pedal2 = ids[i+5] - config.pedal_start
        pedal3 = ids[i+6] - config.pedal_start
        pedal4 = ids[i+7] - config.pedal_start
        note_list.append(Note(velocity, pitch, last_time + interval, last_time + interval + duration))
        last_time += interval
        #cc_list.append(ControlChange(64, pedal1, last_time))
        #cc_list.append(ControlChange(64, pedal2, round(last_time + min(intervals[i // 8 + 1] * 1 / 10, 5))))
        #cc_list.append(ControlChange(64, pedal3, round(last_time + max(intervals[i // 8 + 1] * 8 / 10, intervals[i // 8 + 1] * 8 / 10 - 10))))
        #cc_list.append(ControlChange(64, pedal4, round(last_time + max(intervals[i // 8 + 1] * 9 / 10, intervals[i // 8 + 1] * 9 / 10 - 5))))
        cc_list.append(ControlChange(64, pedal1, last_time))
        cc_list.append(ControlChange(64, pedal2, round(last_time + intervals[i // 8 + 1] * 1 / 4)))
        cc_list.append(ControlChange(64, pedal3, round(last_time + intervals[i // 8 + 1] * 2 / 4)))
        cc_list.append(ControlChange(64, pedal4, round(last_time + intervals[i // 8 + 1] * 3 / 4)))

    max_tick = 0
    for note in note_list:
        max_tick = max(max_tick, note.end)
    for cc in cc_list:
        max_tick = max(max_tick, cc.time)
    max_tick = max_tick + 1

    output = MidiFile(ticks_per_beat=target_ticks_per_beat)
    output.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=note_list, control_changes=cc_list))
    output.tempo_changes.append(TempoChange(target_tempo, 0))
    output.max_tick = max_tick
    
    return output

def read_corresp(corresp_path):
    out = []
    performacne_id_list = []
    with open(corresp_path, "r") as f:
        align_txt = f.readlines()

    score_ids_map = {}
    performance_ids_map = {}
    score_temp_list = []
    performance_temp_list = set()
    for line in align_txt[1:]:
        informs = line.split("\t")
        if informs[0] != '*':
            score_temp_list.append((float(informs[1]), int(informs[3]), int(informs[0])))
        if informs[5] != '*':
            performance_temp_list.add((float(informs[6]), int(informs[8]), int(informs[5])))
    performance_temp_list = list(performance_temp_list)
    score_temp_list.sort()
    performance_temp_list.sort()
    for i, inform in enumerate(score_temp_list):
        score_ids_map[inform[2]] = i
    for i, inform in enumerate(performance_temp_list):
        performance_ids_map[inform[2]] = i

    for line in align_txt[1:]:
        informs = line.split("\t")
        if informs[0] == '*':
            break
        if informs[5] != '*':
            out.append((score_ids_map[int(informs[0])], performance_ids_map[int(informs[5])]))
        else:
            out.append((score_ids_map[int(informs[0])], -1))
    
    for line in align_txt[1:]:
        informs = line.split("\t")
        if informs[5] != '*':
            performacne_id_list.append(performance_ids_map[int(informs[5])])
    if out[0][1] == -1:
        out[0] = (out[0][0], min(performacne_id_list))
    if out[-1][1] == -1:
        out[-1] = (out[-1][0], max(performacne_id_list)) 
    out.sort()
    return out

def interpolate(a, b):
    a = np.array(a) + np.linspace(0, 1e-5, len(a))
    b = np.array(b)
    known_inds = np.where(~np.isnan(b))[0]
    x_known = a[known_inds]
    y_known = b[known_inds]
    res = np.interp(a, x_known, y_known)
    res[known_inds] = b[known_inds]
    return [round(i) for i in res.tolist()]

def align_score_and_performance(config, score_midi_obj, performance_midi_obj):
    norm_score_midi_obj = normalize_midi(score_midi_obj)
    norm_performance_midi_obj = normalize_midi(performance_midi_obj)
    
    norm_score_midi_obj.dump("temp/score.mid")
    norm_performance_midi_obj.dump("temp/performance.mid")

    os.chdir("./tools/AlignmentTool")
    os.system(f"timeout 120s ./MIDIToMIDIAlign.sh ../../temp/performance ../../temp/score")
    os.chdir("./../../") 

    corresp_list = read_corresp("temp/score_corresp.txt")
    aligned_midi_obj = MidiFile(ticks_per_beat=500)
    score_notes = norm_score_midi_obj.instruments[0].notes
    performance_notes = norm_performance_midi_obj.instruments[0].notes
    score_start_list = []
    output_notes = []
    output_ccs = []
    vel_list = []
    start_list = []
    duration_list = []
    unknown_ids = []
    for i, ids in enumerate(corresp_list):
        if ids[1] != -1:
            vel_list.append(performance_notes[ids[1]].velocity)
            start_list.append(performance_notes[ids[1]].start)
            duration_list.append(performance_notes[ids[1]].end - performance_notes[ids[1]].start)
        else:
            vel_list.append(np.nan)
            duration_list.append(np.nan)
            unknown_ids.append(i)
        score_start_list.append(score_notes[ids[0]].start)
    start_list.sort()
    temp = []
    cnt = 0
    for i in range(len(corresp_list)):
        if i not in unknown_ids:
            temp.append(start_list[cnt])
            cnt += 1
        else:
            temp.append(np.nan)
    start_list = interpolate(score_start_list, temp)
    vel_list = interpolate(start_list, vel_list)
    duration_list = interpolate(start_list, duration_list)

    end_list = []
    for i, ids in enumerate(corresp_list):
        end = start_list[i]+duration_list[i]
        end_list.append(end)
        output_notes.append(Note(vel_list[i], score_notes[ids[0]].pitch, start_list[i], end))
    max_tick = max(end_list) + 4999
    for cc in norm_performance_midi_obj.instruments[0].control_changes:
        if cc.time <= max_tick:
            output_ccs.append(cc)
        else:
            break

    aligned_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_ccs))
    x = midi_to_ids(config, norm_score_midi_obj)
    label = midi_to_ids(config, aligned_midi_obj, normalize=False)
    assert(len(x) == len(label))
    for i in range(len(x)):
        if i % 8 == 0:
            assert(x[i] == label[i])
    return x, label

def enhanced_ids(config, ids):
    res = copy(ids)
    retry = 10
    for i in range(len(res)):
        j = i % 8
        if j == 3:
            value = res[i] - config.valid_id_range[j][0]
            if value == 10:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 5)
                    if n >= -9 and n <= 5:
                        noise = n
                        break
            else:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 5)
                    if n >= -4 and n <= 5:
                        noise = n
                        break
            value = min(max(value + noise, 0), 4999)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 2:
            value = res[i] - config.valid_id_range[j][0]
            if value == 5:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 2.5)
                    if n >= -4 and n <= 2:
                        noise = n
                        break
            elif value == 120:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 2.5)
                    if n >= -2 and n <= 7:
                        noise = n
                        break
            else:
                noise = 0
                for _ in range(retry):
                    n = round(np.random.randn() * 2.5)
                    if n >= -2 and n <= 2:
                        noise = n
                        break
            value = min(max(value + noise, 0), 127)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 1:
            value = res[i] - config.valid_id_range[j][0]
            noise = 0
            for _ in range(retry):
                n = round(np.random.randn() * 5)
                if n >= -4 and n <= 5:
                    noise = n
                    break
            value = min(max(value + noise, 0), 4990)
            res[i] = config.valid_id_range[j][0] + value
    return res

def enhanced_ids_uniform(config, ids):
    res = copy(ids)
    for i in range(len(res)):
        j = i % 8
        if j == 3:
            value = res[i] - config.valid_id_range[j][0]
            if value == 10:
                noise = random.randint(-9, 5)
            else:
                noise = random.randint(-4, 5)
            value = min(max(value + noise, 0), 4999)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 2:
            value = res[i] - config.valid_id_range[j][0]
            if value == 5:
                noise = random.randint(-4, 2)
            elif value == 120:
                noise = random.randint(-2, 7)
            else:
                noise = random.randint(-2, 2)
            value = min(max(value + noise, 0), 127)
            res[i] = config.valid_id_range[j][0] + value
        elif j == 1:
            value = res[i] - config.valid_id_range[j][0]
            noise = random.randint(-4, 5)
            value = min(max(value + noise, 0), 4990)
            res[i] = config.valid_id_range[j][0] + value
    return res

#if __name__ == "__main__":
#    midi_obj = MidiFile("data/midi/test/2.mid")
#    ids = midi_to_ids(midi_obj)
#    midi = ids_to_midi(ids)
#    midi.dump("data/rebuild/2.mid")
