'''
A simple rule-based band arrange algorithm that 
Split the notes on pitch range axis evenly
Allocate the notes to different instruments
'''
import os
import sys

sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')
print(os.getcwd())

from utils_midi.utils_midi import RemiTokenizer
from utils_midi import remi_utils
from utils_common.utils import *
import pretty_midi

def main():
    arrange_song = 'q_fly_me'
    inst_set = 'jband'
    # instruments = ['i-40', 'i-41', 'i-42'] # string trio
    # instruments = ['i-80', 'i-26', 'i-29', 'i-33'] # rock band
    instruments = ['i-64', 'i-40', 'i-61', 'i-26', 'i-0', 'i-44', 'i-33'] # jazz band

    midi_fp_dict = read_yaml('../../utils_arrange/song_path.yaml')
    assert arrange_song in midi_fp_dict, 'Song not found in the song_path.yaml'
    midi_fp = midi_fp_dict[arrange_song]

    save_dir = './outputs/'
    
    song_name = arrange_song
    save_fn = 'rule_{}_{}.mid'.format(song_name, inst_set)

    save_fp = jpath(save_dir, save_fn)
    rule_based_band_arrange(midi_fp, instruments, save_fp)


def rule_based_band_arrange(midi_fp, insts, out_fp):
    '''
    A simple rule-based piano reduction algorithm that plays all the notes in the original score by piano.
    '''
    # Get remi from midi
    tk = RemiTokenizer()
    remi_seq = tk.midi_to_remi(midi_fp, 
                               normalize_pitch=False, 
                               return_pitch_shift=False, 
                               return_key=False, 
                               reorder_by_inst=True, 
                               include_ts=False, 
                               include_tempo=False, 
                               include_velocity=False)

    n_insts = len(insts)
    res = []
    bar_indices = remi_utils.from_remi_get_bar_idx(remi_seq)
    for bar_id in bar_indices:
        bar_start_idx, bar_end_idx = bar_indices[bar_id]
        bar_seq = remi_seq[bar_start_idx:bar_end_idx]
        opd_seq = remi_utils.from_remi_get_global_opd_seq(bar_seq)

        # Convert remi seq to note seq, contains (position, pitch, duration)
        note_seq = []
        for tok in opd_seq:
            if tok.startswith('o-'):
                cur_pos = tok
            elif tok.startswith('p-'):
                cur_pitch = tok
            elif tok.startswith('d-'):
                cur_duration = tok

                note = [cur_pos, cur_pitch, cur_duration]
                note_seq.append(note)

        # Sort the notes by pitch
        note_seq.sort(key=lambda x: int(x[1].split('-')[1]), reverse=True)

        # Allocate notes according to n_insts, first instruments get first 1/n notes, etc
        # Add the instrument token to the note seq
        notes_per_inst = len(note_seq) // n_insts
        for i in range(n_insts):
            inst = insts[i]
            for note in note_seq[i*notes_per_inst : (i+1)*notes_per_inst]:
                note.insert(0, inst)
        # allocate inst to the remaining notes
        for i in range(n_insts*notes_per_inst, len(note_seq)):
            note = note_seq[i]
            note.insert(0, inst)

        # Convert to opd_seq for each instrument
        opd_seqs = {}
        for inst in insts:
            opd_seqs[inst] = []
        for note in note_seq:
            inst = note[0]
            opd_seqs[inst].extend(note[1:])

        # Convert back to remi
        remi_new = []
        for inst in insts:
            remi_new.append(inst)
            remi_new.extend(opd_seqs[inst])
        remi_new.append('b-1')
        
        res.extend(remi_new)

    # Save remi to file
    tk.remi_to_midi(res, out_fp)


def merge_all_notes_to_piano(remi_seq):
    ret = []
    bar_indices = remi_utils.from_remi_get_bar_idx(remi_seq)
    for bar_id in bar_indices:
        bar_start_idx, bar_end_idx = bar_indices[bar_id]
        bar_seq = remi_seq[bar_start_idx:bar_end_idx]

        opd_seq = remi_utils.from_remi_get_global_opd_seq(bar_seq)
        ret.extend(opd_seq)
        ret.append('b-1')
    return ret


if __name__ == '__main__':
    main()