'''
Do the song-level inference with Moyu's model
'''
import os
import sys
dirof = os.path.dirname
sys.path.insert(0, (dirof(__file__)))
sys.path.insert(0, dirof(dirof(dirof(__file__))))

import torch
import mlconfig
from lightning_model_moyu import load_lit_model
from [anonymous] import MultiTrack, Bar
from utils_common.utils import create_dir_if_not_exist, ls, jpath
from tqdm import tqdm
import numpy as np


def main():
    test_single_demo()


def procedures():
    test_single_demo()
    infer_all_test_set()


def infer_all_test_set():
    config_fp = '/home/[anonymous]/work/[anonymous]/[anonymous]/baselines/moyu_piano/hparams/baseline.yaml'
    save_dir = '/data2/[anonymous]/[anonymous]_data/out_moyu/infer_test_set'
    create_dir_if_not_exist(save_dir)

    test_set_dir = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/original/test'
    song_names = ls(test_set_dir)
    pbar = tqdm(song_names)
    for song_name in pbar:
        pbar.set_description(f'Testing {song_name}...')
        midi_fp = f'{test_set_dir}/{song_name}/all_src.mid'
        save_fp = jpath(save_dir, f'{song_name}.mid')
        infer_for_song(config_fp, midi_fp, save_fp)


def test_all_demos():
    config_fp = '/home/[anonymous]/work/[anonymous]/[anonymous]/baselines/moyu_piano/hparams/note_loss_only.yaml'
    save_dir = '/data2/[anonymous]/[anonymous]_data/out_moyu/note_loss_only'

    create_dir_if_not_exist(save_dir)
    for song_name in ['Track01876', 'Track01877', 'Track01880', 'Track01884', 'Track01889']:
        midi_fp = f'/data2/[anonymous]/Datasets/slakh2100_flac_redux/original/test/{song_name}/all_src.mid'
        print('Testing', song_name, '...')
        test_single_demo(config_fp, midi_fp, song_name, save_dir)
    

def test_single_demo():
    # Specify paths
    config_fp = '/home/[anonymous]/work/[anonymous]/[anonymous]/baselines/moyu_piano/hparams/note_loss_only_octave_shift.yaml'
    midi_fp = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/original/test/Track01889/all_src.mid'
    song_name = 'Track01889'
    model_type = 'note_loss_octave_shift'
    save_dir = '/home/[anonymous]/work/[anonymous]/[anonymous]/_misc'

    config = mlconfig.load(config_fp)

    save_fn = f'moyu_{song_name}_{model_type}.mid'
    save_fp = os.path.join(save_dir, save_fn)
    infer_for_song(config_fp, midi_fp, save_fp)


def infer_for_song(config_fp, midi_fp, save_fp):
    '''
    Do the inference for a single song
    '''
    config = mlconfig.load(config_fp)

    # Load model
    lit_model = load_lit_model(config)
    model = lit_model.model
    model.config = config

    # Load data
    song_fp = midi_fp
    mt = MultiTrack.from_midi(song_fp)
    mt.quantize_to_16th()
    mt.normalize_pitch()

    # Do nothing if time signature is not 4/4
    if len(mt.time_signatures) > 1 or mt.time_signatures[0] != (4,4):
        print('Time signature is not 4/4, skip...')
        return

    out_bars = []
    for bar in mt.bars:
        mix_proll = bar.to_piano_roll() # [16, 128] if time signature is 4/4
        mix_proll = torch.tensor(mix_proll).unsqueeze(0).float().cuda()
        inp = {
            'mix_prolls': mix_proll,
        }
        out = model(inp)['pred'][0]

        out_bar = Bar.from_piano_roll(out.cpu().numpy())
        out_bars.append(out_bar)

    # Save output
    out_mt = MultiTrack.from_bars(out_bars)
    out_mt.to_midi(save_fp)


if __name__ == '__main__':
    main()