from midiprocessor import midi_utils, MidiEncoder, enc_remigen_utils, enc_remigen2_utils


def get_midi_pos_info(
    encoder,
    midi_path=None,
    midi_obj=None,
    remove_empty_bars=True,
):
    if midi_obj is None:
        midi_obj = midi_utils.load_midi(midi_path)

    pos_info = encoder.collect_pos_info(midi_obj, trunc_pos=None, tracks=None, remove_same_notes=False, end_offset=0)
    del midi_obj

    # encode and decode to ensure the info consistency, i.e., let the following chord and cadence detection
    # happen on exactly the same info as the resulting token sequences
    pos_info = encoder.convert_pos_info_to_pos_info_id(pos_info)
    pos_info = encoder.convert_pos_info_id_to_pos_info(pos_info)
    # remove the beginning and ending empty bars
    if remove_empty_bars:
        pos_info = encoder.remove_empty_bars_for_pos_info(pos_info)

    return pos_info


def convert_pos_info_to_tokens(encoder, pos_info, **kwargs):
    pos_info_id = encoder.convert_pos_info_to_pos_info_id(pos_info)
    if encoder.encoding_method == 'REMIGEN':
        enc_utils = enc_remigen_utils
    elif encoder.encoding_method == 'REMIGEN2':
        enc_utils = enc_remigen2_utils
    else:
        raise ValueError(encoder.encoding_method)
    tokens = enc_utils.convert_pos_info_to_token_lists(
        pos_info_id, ignore_ts=False, sort_insts='id', sort_notes=None, **kwargs
    )[0]
    tokens = enc_utils.convert_remigen_token_list_to_token_str_list(tokens)
    return tokens


if __name__ == '__main__':
    midi_path = 'test.mid'

    enc = MidiEncoder("REMIGEN")
    
    pi = get_midi_pos_info(enc, midi_path)

    tokens = convert_pos_info_to_tokens(enc, pi)

    print(tokens[:100])
