from src.utils.midi import ids_to_midi, midi_to_ids
from src.model.pianoformer import PianoT5Gemma
from miditoolkit import MidiFile
from torch.nn.utils.rnn import pad_sequence
from transformers import LogitsProcessorList, LogitsProcessor
from tqdm import tqdm
import torch
from src.utils.midi import normalize_midi
from miditoolkit import MidiFile, Note, TempoChange, Instrument, ControlChange
import bisect

class BatchSparseForcedTokenProcessor(LogitsProcessor):
    def __init__(self, input_ids, config, target_len, origin_len, already, weight, progress_callback):
        self.batch_map = [{j: input_ids[i][j] for j in range(0, len(input_ids[i]), 8)} for i in range(len(input_ids))]
        self.valid_id_range = config.valid_id_range
        self.target_len = target_len
        self.origin_len = origin_len
        self.already = already
        self.weight = weight
        self.progress_callback = progress_callback

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.progress_callback:
            self.progress_callback(
                (input_ids.shape[1] - self.origin_len) / (self.target_len - self.origin_len) * self.weight + self.already
            )
        step = input_ids.shape[1] - 1
        batch_size = scores.shape[0]
        for i in range(batch_size):
            sample_map = self.batch_map[i]
            if step in sample_map:
                forced_token_id = sample_map[step]
                scores[i] = float('-inf')
                scores[i, forced_token_id] = 0.0
            else:
                step = step % 8
                scores[i, :self.valid_id_range[step][0]] = float('-inf')
                scores[i, self.valid_id_range[step][1]:] = float('-inf')
                #if step % 8 > 3:
                #    scores = scores / 0.95
        return scores

@torch.no_grad()
def batch_performance_render(
        model, 
        score_midi_objs, 
        max_context_length=4096, 
        overlap_ratio=0.5, 
        temperature=1.0,
        top_p=0.95,
        device="cpu",
        progress_callback=None
    ):
    def slide_window(total_len, window_len):
        if total_len <= window_len:
            return [(0, total_len)]
        window_len = window_len // 8 * 8
        out = []
        start = 0
        while start + window_len <= total_len:
            out.append((start, start + window_len))
            start += int(window_len * (1 - overlap_ratio)) // 8 * 8
        if out[-1][1] != total_len:
            out.append((start, total_len))
        return out
    if max_context_length > 4096:
        raise ValueError("You should set max_context_length <= 4096!")
    batch_ids = [torch.tensor(midi_to_ids(model.config, score_midi_obj), dtype=torch.long).to(device) for score_midi_obj in score_midi_objs]
    len_list = [len(batch_ids[i]) for i in range(len(batch_ids))]
    
    input_ids = pad_sequence(batch_ids, batch_first=True, padding_value=model.config.pad_token_id)
    windows = slide_window(input_ids.shape[1], max_context_length)
    #print(windows)
    output_list = []
    res_tensor = None
    for i in tqdm(range(len(windows))):
        start, end = windows[i]
        logits_processor = LogitsProcessorList([
            BatchSparseForcedTokenProcessor(
                input_ids[:,start:end], 
                model.config, 
                end,
                start,
                i / len(windows),
                1 / len(windows),
                progress_callback,
            )
        ])
        if i == 0:
            output = model.generate(
                input_ids[:,start:end], 
                do_sample=True, 
                max_new_tokens=end-start, 
                logits_processor=logits_processor,
                temperature=temperature,
                top_p=top_p,
            )
            res_tensor = output[:,1:]
        else:
            last_start, last_end = windows[i-1]
            length = int(((last_end-last_start) - (start-last_start)) * 0.2)
            decoder_input_ids = output_list[i-1][:, start-last_start:last_end-last_start - length]
            start_tensor = torch.tensor([[model.config.bos_token_id] for _ in range(input_ids.shape[0])], dtype=torch.long).to(device)
            decoder_input_ids = torch.cat([start_tensor, decoder_input_ids], dim=1)
            #print(decoder_input_ids.shape)
            output = model.generate(
                input_ids[:,start:end], 
                decoder_input_ids=decoder_input_ids,
                do_sample=True, 
                max_new_tokens=end-last_end+length, 
                logits_processor=logits_processor,
                temperature=temperature,
                top_p=top_p,
            )
            res_tensor = torch.cat([res_tensor[:,:-length], output[:,-(end-last_end+length):]], dim=1)
        output_list.append(output)
    res_tensor = res_tensor.cpu().numpy().tolist()
    #print(res_tensor)
    res = []
    for i in range(len(res_tensor)):
        #print(res_tensor[i][:len_list[i]])
        res.append(ids_to_midi(model.config, res_tensor[i][:len_list[i]]))
    return res


def map_midi(score_midi_obj, performance_midi_obj):
    def compute_duration(start_time, target_duration, tempo_list):
        if target_duration <= 0:
            return 0
        if not tempo_list:
            # 如果没有提供tempo信息，则假定为默认的120 BPM
            tempo_list = [TempoChange(120, 0)]

        # --- 步骤1: 定位start_time所在的BPM区间 ---
        # 提取所有tempo变化的时间点
        tempo_times = [t.time for t in tempo_list]
        # 使用二分查找找到start_time应该插入的位置
        # bisect_right返回的是插入点索引，因此当前生效的tempo在索引-1的位置
        start_tempo_idx = bisect.bisect_right(tempo_times, start_time) - 1
        # 如果start_time在第一个tempo变化之前，索引会是-1，修正为0
        if start_tempo_idx < 0:
            start_tempo_idx = 0

        # --- 步骤2: 初始化循环变量 ---
        total_ticks_duration = 0.0
        time_remaining_ms = float(target_duration)
        current_tick = start_time
        current_tempo_idx = start_tempo_idx

        # --- 步骤3: 循环处理每个BPM区间，直到消耗完target_duration ---
        # 使用一个极小值(epsilon)来处理浮点数精度问题
        while time_remaining_ms > 1e-9:
            current_tempo_event = tempo_list[current_tempo_idx]
            current_bpm = current_tempo_event.tempo

            # 计算在当前BPM下，每个tick持续多少毫秒
            # 1分钟 = 60,000毫秒
            # 每分钟节拍数 = bpm
            # 每拍tick数 = TICK_PER_BEAT
            # ms_per_tick = (毫秒/分钟) / (节拍/分钟) / (tick/节拍) = (60000 / bpm) / TICK_PER_BEAT
            ms_per_tick = (60 * 1000.0 / current_bpm) / 500

            # 确定当前BPM区间的结束点
            # 如果是最后一个tempo，则它会一直持续下去
            end_of_segment_tick = float('inf')
            if current_tempo_idx + 1 < len(tempo_list):
                end_of_segment_tick = tempo_list[current_tempo_idx + 1].time

            # 计算从当前位置到本BPM区间结束，有多少tick
            ticks_in_segment = end_of_segment_tick - current_tick
            # 这些tick总共持续多少毫秒
            ms_in_segment = ticks_in_segment * ms_per_tick

            # --- 步骤4: 判断与更新 ---
            if time_remaining_ms <= ms_in_segment:
                # 如果剩余需要的时间，在本BPM区间内就能满足
                # 计算还需要多少tick来凑够剩余的毫秒数
                ticks_needed = time_remaining_ms / ms_per_tick
                total_ticks_duration += ticks_needed
                # 时间已全部消耗完毕，跳出循环
                time_remaining_ms = 0
            else:
                # 如果本BPM区间的时间不够用
                # 消耗掉整个区间的tick和毫秒数
                total_ticks_duration += ticks_in_segment
                time_remaining_ms -= ms_in_segment
                
                # 更新“指针”，移动到下一个BPM区间的起点
                current_tick = end_of_segment_tick
                current_tempo_idx += 1

        # 返回四舍五入后的总tick数
        return round(total_ticks_duration)

    def ms_to_tick(target_ms, tempo_list):
        # --- 边缘情况处理 ---
        if target_ms <= 0:
            return 0
        if not tempo_list:
            # 如果没有提供tempo信息，则假定为默认的120 BPM
            tempo_list = [TempoChange(120, 0)]

        # --- 步骤1: 初始化累加器 ---
        accumulated_ms = 0.0

        # --- 步骤2: 遍历所有“有终点”的BPM区间 ---
        # 我们遍历到倒数第二个元素，因为每个循环处理的是 tempo[i] 到 tempo[i+1] 的区间
        for i in range(len(tempo_list) - 1):
            current_tempo_event = tempo_list[i]
            next_tempo_event = tempo_list[i+1]

            current_bpm = current_tempo_event.tempo
            
            # 计算当前区间的tick数和对应的毫秒数
            ticks_in_segment = next_tempo_event.time - current_tempo_event.time
            
            # 如果区间长度为0，直接跳过，避免除零错误
            if ticks_in_segment == 0:
                continue
                
            ms_per_tick = (60 * 1000.0 / current_bpm) / 500
            ms_in_segment = ticks_in_segment * ms_per_tick

            # --- 步骤3: 判断目标是否在本区间内 ---
            if target_ms <= accumulated_ms + ms_in_segment:
                # 目标在本区间内！
                ms_into_segment = target_ms - accumulated_ms
                ticks_needed = ms_into_segment / ms_per_tick
                
                # 最终tick = 本区间起始tick + 在本区间内转换出的tick
                final_tick = current_tempo_event.time + ticks_needed
                return round(final_tick)
            
            # 如果目标不在本区间，则累加本区间的总毫秒数，继续下一个循环
            accumulated_ms += ms_in_segment

        # --- 步骤4: 如果循环结束仍未返回，说明目标在最后一个BPM区间内 ---
        last_tempo_event = tempo_list[-1]
        last_bpm = last_tempo_event.tempo
        
        ms_per_tick = (60 * 1000.0 / last_bpm) / 500

        # 计算进入最后一个区间后，还需要多少毫秒
        ms_into_segment = target_ms - accumulated_ms
        ticks_needed = ms_into_segment / ms_per_tick

        # 最终tick = 最后一个区间的起始tick + 剩余毫秒转换的tick
        final_tick = last_tempo_event.time + ticks_needed
        return round(final_tick)

    norm_score = normalize_midi(score_midi_obj)
    norm_performance = normalize_midi(performance_midi_obj)

    score_notes = norm_score.instruments[0].notes
    performance_notes = norm_performance.instruments[0].notes
    performance_ccs = norm_performance.instruments[0].control_changes

    start_list = []
    last = -1
    score_start = score_notes[0].start
    performance_start = performance_notes[0].start
    for i in range(len(score_notes)):
        performance_notes[i].end -= performance_start
        performance_notes[i].start -= performance_start
        score_notes[i].end -= score_start
        score_notes[i].start -= score_start
        if score_notes[i].start != last:
            start_list.append((score_notes[i].start, performance_notes[i].start, i))
            last = score_notes[i].start
    
    for i in range(len(performance_ccs)):
        performance_ccs[i].time -= performance_start

    score_interval_list = []
    performance_interval_list = []

    for i in range(len(start_list)-1):
        score_interval_list.append(start_list[i+1][0] - start_list[i][0])
        performance_interval_list.append(start_list[i+1][1] - start_list[i][1])
    #print(score_interval_list)
    #print(performance_interval_list)

    tempo_list = []
    start_note_offset = []
    for i in range(len(score_interval_list)):
        if performance_interval_list[i] != 0:
            bpm = 120.0 / performance_interval_list[i] * score_interval_list[i]
        else:
            bpm = 300

        if bpm > 300:
            start_note_offset.append(300 / 120.0 * performance_interval_list[i] - score_interval_list[i])
        elif bpm < 10:
            start_note_offset.append(10 / 120.0 * performance_interval_list[i] - score_interval_list[i])
        else:
            start_note_offset.append(0)
        tempo_list.append(max(min(bpm, 300), 10))
        #tempo_list.append(120.0 / performance_interval_list[i] * score_interval_list[i])
    #print(tempo_list)

    for i in range(1, len(start_note_offset)):
        start_note_offset[i] += start_note_offset[i-1]
    #print(start_note_offset)

    #print(len(tempo_list))
    #print(len(start_list))
    note_tempo_list = []
    note_performance_align = []
    note_start_offset = [0]
    cnt = 0
    for i in range(len(score_notes)):
        if cnt < len(start_list) - 2 and i >= start_list[cnt + 1][2]:
            cnt += 1
        note_tempo_list.append(tempo_list[cnt])
        note_performance_align.append(start_list[cnt][1])
        note_start_offset.append(start_note_offset[cnt])
    #print(note_start_offset)

    for i in range(len(score_notes)):
        score_notes[i].start += note_start_offset[i]
    note_interval_list = [0]
    for i in range(len(score_notes)-1):
        note_interval_list.append(score_notes[i+1].start - score_notes[i].start)

    #print(note_tempo_list)
    #print(note_performance_align)

    #for i in range(len(performance_notes)):
        #print(performance_notes[i].start)

    micro_shift_list = [0]
    cnt = 1
    last_time = 0
    for i in range(1, len(score_notes)):
        last_time += note_interval_list[i] / note_tempo_list[i-1] * 120
        micro_shift_list.append((performance_notes[i].start - last_time) / 120 * note_tempo_list[i-1])
        #last_time = note_performance_align[i]
        print(last_time)

    #print(micro_shift_list)
    #plt.plot(tempo_list)

    res = MidiFile(ticks_per_beat=500)
    res_notes = []
    start_time_list = []
    tempo_list_filter = []
    cc_list = []
    last = -1
    for i in range(len(score_notes)):
        start_time_list.append(round(score_notes[i].start + micro_shift_list[i]))
        #res_notes.append(Note(performance_notes[i].velocity, score_notes[i].pitch, round(score_notes[i].start + micro_shift_list[i]), round(score_notes[i].start + micro_shift_list[i]) + 100))
        #res.tempo_changes.append(TempoChange(round(note_tempo_list[i]), round(score_notes[i].start + micro_shift_list[i])))
        #print(last , round(note_tempo_list[i]))
        if last != round(note_tempo_list[i]):
            tempo_list_filter.append(TempoChange(round(note_tempo_list[i]), round(score_notes[i].start + micro_shift_list[i])))
            last = round(note_tempo_list[i])
    for i in range(len(score_notes)):
        res_notes.append(
            Note(
                performance_notes[i].velocity, 
                score_notes[i].pitch, 
                start_time_list[i], 
                start_time_list[i]+compute_duration(start_time_list[i], performance_notes[i].duration, tempo_list_filter)
            )
        )

    for cc in performance_ccs:
        cc_list.append(ControlChange(64, cc.value, ms_to_tick(cc.time, tempo_list_filter)))
        
    #print(tempo_list_filter)
    res.tempo_changes = tempo_list_filter
    res.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=res_notes, control_changes=cc_list))
    return res


if __name__ == "__main__":
    pass