'''
A simple rule-based piano reduction algorithm that plays all the notes in the original score by piano.
'''
import os
import sys

sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')
print(os.getcwd())

from utils_midi.utils_midi import RemiTokenizer
from utils_midi import remi_utils
from utils_common.utils import *
import pretty_midi

def main():
    # midi_fp_dict = read_yaml('utils_arrange/song_path.yaml')
    # assert arrange_song in midi_fp_dict, 'Song not found in the song_path.yaml'
    # midi_fp = midi_fp_dict[arrange_song]

    arrange_song = 'caihong'
    midi_fp = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/[anonymous]_data/infer_input/full_song/caihong/caihong.mid'
    exclude = []

    # arrange_song = 'q_piano_5'
    # midi_fp = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/[anonymous]_data/infer_input/full_song/slakh/demo piano/all_src1889_norm.mid'
    # exclude = ['i-73']
    save_dir = './outputs_acc'

    inp_fp = midi_fp
    song_name = arrange_song
    save_fn = 'rule_{}.mid'.format(song_name)

    save_fp = jpath(save_dir, save_fn)
    rule_based_piano_arrange(inp_fp, save_fp, exclude_insts=exclude)


def rule_based_piano_arrange(midi_fp, out_fp, exclude_insts=None):
    '''
    A simple rule-based piano reduction algorithm that plays all the notes in the original score by piano.
    '''
    # Get remi from midi
    tk = RemiTokenizer()
    remi_seq = tk.midi_to_remi(midi_fp, 
                               normalize_pitch=True, 
                               return_pitch_shift=False, 
                               return_key=False, 
                               reorder_by_inst=True, 
                               include_ts=False, 
                               include_tempo=False, 
                               include_velocity=False)

    # Exclude the melody instruments
    if exclude_insts is not None:
        
        remi_new = []
        bar_indices = remi_utils.from_remi_get_bar_idx(remi_seq)
        for bar_id in bar_indices:
            bar_start_idx, bar_end_idx = bar_indices[bar_id]
            bar_seq = remi_seq[bar_start_idx:bar_end_idx]
            opd_seqs = remi_utils.from_remi_get_opd_seq_per_track(bar_seq)
            
            insts_with_voice = remi_utils.from_remi_get_inst_and_voice(bar_seq)

            for exclude_inst in exclude_insts:
                if exclude_inst in insts_with_voice:
                    insts_with_voice.remove(exclude_inst)

            for inst in insts_with_voice:
                remi_new.append(inst)
                remi_new.extend(opd_seqs[inst])
            remi_new.append('b-1')
        remi_seq = remi_new
        

    # Merge all notes into piano
    out_remi = merge_all_notes_to_piano(remi_seq)
        
    # Save remi to file
    tk.remi_to_midi(out_remi, out_fp)


def merge_all_notes_to_piano(remi_seq):
    ret = []
    bar_indices = remi_utils.from_remi_get_bar_idx(remi_seq)
    for bar_id in bar_indices:
        bar_start_idx, bar_end_idx = bar_indices[bar_id]
        bar_seq = remi_seq[bar_start_idx:bar_end_idx]

        opd_seq = remi_utils.from_remi_get_global_opd_seq(bar_seq)
        ret.extend(opd_seq)
        ret.append('b-1')
    return ret


if __name__ == '__main__':
    main()