'''
Calculate metrics for the reconstructed data
'''

import os
import sys
dirof = os.path.dirname
sys.path.append(dirof(dirof(dirof(os.path.abspath(__file__)))))

from torch.utils.data import Dataset, DataLoader
from utils_common.utils import *
from utils_midi.utils_midi import RemiTokenizer
from utils_midi import remi_utils
from m2m.evaluate import Metric
from tqdm import tqdm

def main():
    evaluate_qna_recon()


def procedures():
    evaluate_qna_recon()


def evaluate_qna_recon():
    dataset = PairedDataset()
    dataloader = DataLoader(
        dataset, 
        batch_size=1, 
        shuffle=False,
        collate_fn=lambda x: x,
        num_workers=12,
    )
    metric = Metric()
    for i, batch in enumerate(tqdm(dataloader)):
        gt_remi, recon_remi = batch[0]

        # Pad output to 2 bars
        out_bar_cnt = recon_remi.count('b-1')
        if out_bar_cnt < 2:
            recon_remi += ['b-1'] * (2 - out_bar_cnt)

        gt_bar1_idx = gt_remi.index('b-1')
        out_bar1_idx = recon_remi.index('b-1')

        gt_bar1_seq = gt_remi[:gt_bar1_idx+1]
        gt_bar2_seq = gt_remi[gt_bar1_idx+1:]
        out_bar1_seq = recon_remi[:out_bar1_idx+1]
        out_bar2_seq = recon_remi[out_bar1_idx+1:]

        # Segment-level evaluation
        tgt_insts = remi_utils.from_remi_get_inst_and_voice(gt_remi)
        out_insts = remi_utils.from_remi_get_inst_and_voice(recon_remi)
        
        # Calculate inst iou
        inst_iou = metric.calculate_inst_iou_from_inst(out_insts, tgt_insts)
        metric.update('inst_iou', inst_iou)
        
        # Calculate voice WER
        voice_wer = metric.calculate_wer(out_insts, tgt_insts)
        metric.update('voice_wer', voice_wer)

        # Bar-level evaluation
        for tgt_seq, out_seq in [(gt_bar1_seq, out_bar1_seq), (gt_bar2_seq, out_bar2_seq)]:

            # Melody recall
            melody_recall = metric.calculate_melody_recall_mbar(out_seq, tgt_seq)
            metric.update('melody_recall', melody_recall)

            # Pitch sequence similarity
            pitch_wer = metric.calculate_pitch_wer(out_seq, tgt_seq)
            metric.update('pitch_wer', pitch_wer)

            pitch_iou = metric.calculate_pitch_iou(out_seq, tgt_seq)
            metric.update('pitch_iou', pitch_iou)

            # Groove similarity
            pos_wer, pos_sor = metric.calculate_groove_wer_sor_mbar(out_seq, tgt_seq)
            metric.update('pos_wer', pos_wer)

            pos_iou = metric.calculate_groove_iou_mbar(out_seq, tgt_seq)
            metric.update('pos_iou', pos_iou)

            ''' Track-wise metrics '''
            # Track-wise pitch sequence wer
            track_pitch_wer = metric.calculate_avg_track_pitch_wer(out_seq, tgt_seq)
            metric.update('track_pitch_wer', track_pitch_wer)
            
            # Track-wise pitch sequence wer
            track_pitch_iou = metric.calculate_avg_track_pitch_iou(out_seq, tgt_seq)
            metric.update('track_pitch_iou', track_pitch_iou)

            # # Track-wise groove similarity
            # track_pos_wer = metric.calculate_avg_track_pos_wer(out_seq, tgt_seq)
            # ret['track_pos_wer'] = track_pos_wer

            # Track-wise groove similarity
            track_pos_iou = metric.calculate_avg_track_pos_iou(out_seq, tgt_seq)
            metric.update('track_pos_iou', track_pos_iou)

            # Duration difference
            dur_diff = metric.calculate_dur_dif_per_track(out_seq, tgt_seq)
            metric.update('dur_diff', dur_diff)

            if len(tgt_seq) > 1:

                # Note F1
                note_f1 = metric.calculate_note_f1_q16(out_seq, tgt_seq)
                metric.update('note_f1', note_f1)

                # Note_i F1
                note_i_f1 = metric.calculate_note_i_f1_q16(out_seq, tgt_seq)
                metric.update('note_i_f1', note_i_f1)

                # Melody F1
                melody_f1 = metric.calculate_melody_f1_q16(out_seq, tgt_seq)
                metric.update('melody_f1', melody_f1)

        # # Debug
        # if i == 100:
        #     break


    scores = metric.average()
    save_dir = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/baseline_results'
    save_fn = 'qna_recon_metrics.json'
    save_fp = jpath(save_dir, save_fn)
    save_json(scores, save_fp)





class PairedDataset(Dataset):
    def __init__(self):
        self.data_dir = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/baseline_results/QnA_demo'
        song_dir_names = ls(self.data_dir)

        indices = []
        for song_dir_name in song_dir_names:
            song_dir_path = jpath(self.data_dir, song_dir_name)
            midi_fns = ls(song_dir_path)
            gt_fns = [fn for fn in midi_fns if 'gt' in fn]
            recon_fns = [fn for fn in midi_fns if 'recon' in fn]
            assert len(gt_fns) == len(recon_fns)

            # Sort gt_fns and recon_fns by id (e.g., gt_079.mid -> 79)
            gt_fns = sorted(gt_fns, key=lambda x: int(x.split('_')[1].split('.')[0]))
            recon_fns = sorted(recon_fns, key=lambda x: int(x.split('_')[1].split('.')[0]))

            for gt_fn, recon_fn in zip(gt_fns, recon_fns):
                indices.append((song_dir_name, gt_fn, recon_fn))
        self.indices = indices
        self.tk = RemiTokenizer()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        song_dir_name, gt_fn, recon_fn = self.indices[index]
        song_dir_path = jpath(self.data_dir, song_dir_name)
        gt_fp = jpath(song_dir_path, gt_fn)
        recon_fp = jpath(song_dir_path, recon_fn)

        gt_remi = self.tk.midi_to_remi(
            midi_path=gt_fp,
            normalize_pitch=True,
            return_pitch_shift=False,
            return_key=False,
            reorder_by_inst=True,
            include_ts=False,
            include_tempo=False,
            include_velocity=False
        )

        recon_remi = self.tk.midi_to_remi(
            midi_path=recon_fp,
            normalize_pitch=True,
            return_pitch_shift=False,
            return_key=False,
            reorder_by_inst=True,
            include_ts=False,
            include_tempo=False,
            include_velocity=False
        )

        return gt_remi, recon_remi


if __name__ == '__main__':
    main()