"""
this generates the chord progression and extract the chord progression model, with the picture
"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from collections import Counter
import itertools, copy
from more_itertools import split_before
import os, traceback, time, warnings, sys
import multiprocessing
from miditoolkit.midi.parser import MidiFile
from miditoolkit.midi.containers import Instrument
from miditoolkit.midi.containers import Note as mtkNote
from chorder import Dechorder

def merge_drums(p_midi):  # merge all percussions
    drum_0_lst = []
    new_instruments = []
    for instrument in p_midi.instruments:
        if not len(instrument.notes) == 0:
            # --------------------
            if instrument.is_drum:
                for note in instrument.notes:
                    drum_0_lst.append(note)
            else:
                new_instruments.append(instrument)
    if len(drum_0_lst) > 0:
        drum_0_lst.sort(key=lambda x: x.start)
        # remove duplicate
        drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst))

        drum_0_instrument = Instrument(program=0, is_drum=True, name="drum")
        drum_0_instrument.notes = drum_0_lst
        new_instruments.append(drum_0_instrument)

    p_midi.instruments = new_instruments


def merge_sparse_track(p_midi, CANDI_THRES=50, MIN_THRES=5):  # merge track has too less notes
    good_instruments = []
    bad_instruments = []
    good_instruments_idx = []
    for instrument in p_midi.instruments:
        if len(instrument.notes) < CANDI_THRES:
            bad_instruments.append(instrument)
        else:
            good_instruments.append(instrument)
            good_instruments_idx.append((instrument.program, instrument.is_drum))

    for bad_instrument in bad_instruments:
        if (bad_instrument.program, bad_instrument.is_drum) in good_instruments_idx:
            # find one track to merge
            for instrument in good_instruments:
                if bad_instrument.program == instrument.program and \
                        bad_instrument.is_drum == instrument.is_drum:
                    instrument.notes.extend(bad_instrument.notes)
                    break
        # no track to merge
        else:
            if len(bad_instrument.notes) > MIN_THRES:
                good_instruments.append(bad_instrument)
    p_midi.instruments = good_instruments


def limit_max_track(p_midi, MAX_TRACK=40):  # merge track with least notes and limit the maximum amount of track to 40

    good_instruments = p_midi.instruments
    good_instruments.sort(
        key=lambda x: (not x.is_drum, -len(x.notes)))  # place drum track or the most note track at first
    assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len(
        good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3])
    # assert good_instruments[0].is_drum == False, (, len(good_instruments[2]))
    track_idx_lst = list(range(len(good_instruments)))

    if len(good_instruments) > MAX_TRACK:
        new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK]) 

        # print(midi_file_path)
        for id in track_idx_lst[MAX_TRACK:]:
            cur_ins = good_instruments[id]
            merged = False
            new_good_instruments.sort(key=lambda x: len(x.notes))
            for nid, ins in enumerate(new_good_instruments):
                if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum:
                    new_good_instruments[nid].notes.extend(cur_ins.notes)
                    merged = True
                    break
            if not merged:
                pass  # print('Track {:d} deprecated, program {:d}, note count {:d}'.format(id, cur_ins.program, len(cur_ins.notes)))
        good_instruments = new_good_instruments
        # print(trks, probs, chosen)

def get_init_note_events(p_midi):  # extract all notes in midi file

    note_events, note_on_ticks, note_dur_lst = [], [], []
    for track_idx, instrument in enumerate(p_midi.instruments):
        # track_idx_lst.append(track_idx)
        for note in instrument.notes:
            note_dur = note.end - note.start

            # special case: note_dur too long
            max_dur = 4 * p_midi.ticks_per_beat # one bar
            if note_dur / max_dur > 1:

                total_dur = note_dur
                start = note.start
                while total_dur != 0:
                    if total_dur > max_dur:
                        note_events.extend([[start, "ON", note.pitch, instrument.program,
                                             instrument.is_drum, track_idx, max_dur]])

                        note_on_ticks.append(start)
                        note_dur_lst.append(max_dur)

                        start += max_dur
                        total_dur -= max_dur
                    else:
                        note_events.extend([[start, "ON", note.pitch, instrument.program,
                                             instrument.is_drum, track_idx, total_dur]])
                        note_on_ticks.append(start)
                        note_dur_lst.append(total_dur)

                        total_dur = 0

            else:
                note_events.extend(
                    [[note.start, "ON", note.pitch, instrument.program, instrument.is_drum, track_idx, note_dur]])

                # for score analysis and beat estimating when score has no time signature
                note_on_ticks.append(note.start)
                note_dur_lst.append(note.end - note.start)

    note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[5], x[4], x[3], x[2], x[-1]))
    note_events = list(k for k, _ in itertools.groupby(note_events))
    return note_events, note_on_ticks, note_dur_lst


def calculate_measure(p_midi: MidiFile, first_event_tick,
                      last_event_tick):  # calculate measures and append measure symbol to event_seq

    measure_events = []
    time_signature_changes = p_midi.time_signature_changes  # 拍号
    
    if not time_signature_changes:  # no time_signature_changes, estimate it
        # get_time_signature(p_midi)
        # print("No time_signature_changes")
        raise AssertionError("No time_signature_changes")
    else:
        if time_signature_changes[0].time != 0 and \
                time_signature_changes[0].time > first_event_tick:
            raise AssertionError("First time signature start with None zero tick")

        # clean duplicate time_signature_changes
        temp_sig = []
        for idx, time_sig in enumerate(time_signature_changes):
            if idx == 0:
                temp_sig.append(time_sig)
            else:
                previous_timg_sig = time_signature_changes[idx - 1]
                if not (previous_timg_sig.numerator == time_sig.numerator
                        and previous_timg_sig.denominator == time_sig.denominator):
                    temp_sig.append(time_sig)
        time_signature_changes = temp_sig
        # print("time_signature_changes", time_signature_changes)
        for idx in range(len(time_signature_changes)):
            # calculate measures, eg: how many ticks per measure
            numerator = time_signature_changes[idx].numerator
            denominator = time_signature_changes[idx].denominator
            
            ticks_per_measure = p_midi.ticks_per_beat * (4 / denominator) * numerator

            cur_tick = time_signature_changes[idx].time

            if idx < len(time_signature_changes) - 1:
                next_tick = time_signature_changes[idx + 1].time
            else:
                next_tick = last_event_tick + int(ticks_per_measure)

            if ticks_per_measure.is_integer():
                for measure_start_tick in range(cur_tick, next_tick, int(ticks_per_measure)):
                    if measure_start_tick + int(ticks_per_measure) > next_tick:
                        measure_events.append([measure_start_tick, "BOM", None, None, None, None, 0])
                        measure_events.append([next_tick, "EOM", None, None, None, None, 0])
                    else:
                        measure_events.append([measure_start_tick, "BOM", None, None, None, None, 0])
                        measure_events.append(
                            [measure_start_tick + int(ticks_per_measure), "EOM", None, None, None, None, 0])
            else:
                assert False, "ticks_per_measure Error"
    return measure_events


def prettify(note_events, ticks_per_beat):
    fist_event_idx = next(i for i in (range(len(note_events))) if note_events[i][1] == "ON")
    last_event_idx = next(i for i in reversed(range(len(note_events))) if note_events[i][1] == "ON")

    assert note_events[fist_event_idx - 1][1] == "BOM", "measure_start Error"
    assert note_events[last_event_idx + 1][1] == "EOM", "measure_end Error"

    # remove invalid measures on both sides
    note_events = note_events[fist_event_idx - 1: last_event_idx + 2]  # between measure

    # check again
    assert note_events[0][1] == "BOM", "measure_start Error"
    assert note_events[-1][1] == "EOM", "measure_end Error"

    # -------------- zero start tick -----------------
    start_tick = note_events[0][0]
    if start_tick != 0:
        for event in note_events:
            event[0] -= start_tick

    from fractions import Fraction
    ticks_32th = Fraction(ticks_per_beat, 8)

    note_events = quantize_by_nth(ticks_32th, note_events)

    note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[1] == "BOM", x[1] == "EOM",
                                    x[5], x[4], x[3], x[2], x[-1]))
    note_events = list(k for k, _ in itertools.groupby(note_events))

    # -------------------------check measure duration----------------------------------------------
    note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[1] == "BOM", x[1] == "EOM",
                                    x[5], x[4], x[3], x[2], x[-1]))
    split_score = list(split_before(note_events, lambda x: x[1] == "BOM"))

    check_measure_dur = [0]

    for measure_idx, measure in enumerate(split_score):
        first_tick = measure[0][0]
        last_tick = measure[-1][0]
        measure_dur = last_tick - first_tick
        if measure_dur > 100:
            raise AssertionError("Measure duration error")
        split_score[measure_idx][0][-1] = measure_dur

        if measure_dur in check_measure_dur:
            # print(measure_dur)
            raise AssertionError("Measure duration error")
    return split_score

def measure_calc_chord(evt_seq):
    assert evt_seq[0][1] == 'BOM', "wrong measure for chord"
    bom_tick = evt_seq[0][0]
    ts = min(evt_seq[0][-1], 8)
    chroma = Counter()
    mtknotes = []
    # get some notes from it
    for evt in evt_seq[1:-1]:
        assert evt[1] == 'ON', "wrong measure for chord: " + evt[1] + evt_seq[-1][1]
        if evt[3] == 128:  # exclude drums
            continue
        o, p, d = evt[0] - bom_tick, evt[2], evt[-1]
        if p < 21 or p > 108:  # exclude unusual pitch
            continue
        if o < 8:
            note = mtkNote(60, p, o, o + d if o > 0 else 8) # velocity, pitch, start, end
            mtknotes.append(note)
        else:
            break

    chord, score = Dechorder.get_chord_quality(mtknotes, start=0, end=ts)
    if score < 0:
        return [bom_tick, 'CHR', None, None, None, None, 'NA']
    return [bom_tick, 'CHR', None, None, None, None,
            pit2alphabet[chord.root_pc] + (chord.quality if chord.quality != '7' else 'D7')]


def get_pos_and_cc(split_score):
    new_event_seq = []
    for measure_idx, measure in enumerate(split_score):
        measure.sort(key=lambda x: (x[1] == "EOM", x[1] == "ON", x[1] == 'CHR', x[1] == "BOM", x[-2]))
        bom_tick = measure[0][0]

        # split measure by track
        track_nmb = set(map(lambda x: x[-2], measure[2:-1]))
        tracks = [[y for y in measure if y[-2] == x] for x in track_nmb]

        # ---------- calculate POS for each track / add CC
        new_measure = []
        for track_idx, track in enumerate(tracks):
            pos_lst = []
            trk_abs_num = -1
            for event in track:
                if event[1] == "ON":
                    assert trk_abs_num == -1 or trk_abs_num == event[
                        -2], "Error: found inconsistent trackid within same track"
                    trk_abs_num = event[-2]
                    mypos = event[0] - bom_tick
                    pos_lst.append(mypos)
                    pos_lst = list(set(pos_lst))

            for pos in pos_lst:
                tracks[track_idx].append([pos + bom_tick, "POS", None, None, None, None, pos])
            tracks[track_idx].insert(0, [bom_tick, "CC", None, None, None, None, trk_abs_num])
            tracks[track_idx].sort(
                key=lambda x: (x[0], x[1] == "ON", x[1] == "POS", x[1] == "CC", x[5], x[4], x[3], x[2]))

        new_measure.append(measure[0])
        new_measure.append(measure[1])
        for track in tracks:
            for idx, event in enumerate(track):
                new_measure.append(event)

        new_event_seq.extend(new_measure)

    return new_event_seq


def event_seq_to_str(new_event_seq):
    char_events = []

    for evt in new_event_seq:
        if evt[1] == 'ON':
            char_events.append(pit2str(evt[2]))  # pitch
            char_events.append(dur2str(evt[-1]))  # duration
            char_events.append(trk2str(evt[-2]))  # track
            char_events.append(ins2str(evt[3]))  # instrument
        elif evt[1] == 'POS':
            char_events.append(pos2str(evt[-1]))  # type (time position)
            char_events.append('RZ')
            char_events.append('TZ')
            char_events.append('YZ')
        elif evt[1] == 'BOM':
            char_events.append(bom2str(evt[-1]))
            char_events.append('RZ')
            char_events.append('TZ')
            char_events.append('YZ')
        elif evt[1] == 'CC':
            char_events.append('NT')
            char_events.append('RZ')
            char_events.append('TZ')
            char_events.append('YZ')
        elif evt[1] == 'CHR':
            char_events.append('H' + evt[-1])
            char_events.append('RZ')
            char_events.append('TZ')
            char_events.append('YZ')
        else:
            assert False, ("evt type error", evt[1])
    return char_events

def midi_to_event_seq_str(midi_file_path, readonly=False):
    
    p_midi = MidiFile(midi_file_path)
    
    for ins in p_midi.instruments:
        ins.remove_invalid_notes(verbose=False)
    # merge durms into a ins(program=0)
    merge_drums(p_midi)

    if not readonly:
        # not sort yet
        merge_sparse_track(p_midi)

    limit_max_track(p_midi)

    note_events, note_on_ticks, _ = get_init_note_events(p_midi)

    measure_events = calculate_measure(p_midi, min(note_on_ticks), max(note_on_ticks))

    note_events.extend(measure_events)
    note_events.sort(key=lambda x: (x[0], x[1] == "ON", x[1] == "BOM", x[1] == "EOM",
                                    x[5], x[4], x[3], x[2]))
    
    chords = [] 
    
    split_score = prettify(note_events, p_midi.ticks_per_beat)
    for measure_idx, measure in enumerate(split_score):  # calculate chord for every measure
        # get the chord from the measure
        chord_evt = measure_calc_chord(measure)
        if chord_evt[-1] != 'NA':
            chords.append(chord_evt[-1])
        split_score[measure_idx].insert(1, chord_evt)

    new_event_seq = get_pos_and_cc(split_score)

    char_events = event_seq_to_str(new_event_seq)

    return char_events, chords

def mp_worker(file_path):
    try:
        event_seq, chords = midi_to_event_seq_str(file_path)
        file_name = file_path.split('/')[-1].split('.')[0]
        return event_seq, chords, file_name 
    except (OSError, EOFError, ValueError, KeyError) as e:
        print(file_path)              
        traceback.print_exc(limit=0)
        return "error"

    except AssertionError as e:
        print(file_path)     
        if str(e) == "No time_signature_changes":
            return "error"
        elif str(e) == "Measure duration error":
            print("Measure duration error", file_path)
            return "error"
        else:
            print("Other Assertion Error", str(e), file_path)
            return "error"

    except Exception as e:
        print(file_path)     
        traceback.print_exc(limit=0)
        return "error"
file_path = './generated/midis/pop909/great/3/D:min-G:maj-A:min-F:maj_109.mid'
analyze_chords_in_midi(file_path)