"""tokenizer for chord-based transformer

    Author: Joey.Zhu
    Email: joey8273@qq.com
    Date: 2024/12/11
    
"""
import miditoolkit
import math
from miditoolkit.midi import parser as mid_parser  
from miditoolkit.midi import containers as ct
from miditok import TokenizerConfig, REMI, Event, TokSequence
from preprocess.extract_chord_progression import generate_chords_progression # if train need preprocess pkg
from collections import defaultdict
from preprocess.chord_map import chord_map # if train need preprocess pkg
import random
from typing import Union
from tools.oss import upload

class RemiPlus(object):
    def __init__(self):
        """initialize the remi_plus tokenizer and its vocab
        """        
        remi_plus_config = TokenizerConfig(
            use_chords=True,
            chord_tokens_with_root_note=True,
            use_programs=True,
            # use_time_signatures=True,
        )
        self.min_progression_len = 4
        self.max_progression_len = 8

        self.tokenizer = REMI(
            tokenizer_config=remi_plus_config
        )
        self.event2id = self.tokenizer.vocab
        self.id2event =  {v: k for k, v in self.event2id.items()}
        self.vocab_size = len(self.event2id)
        self.chord_map = chord_map()
        self.velocity_map = self.get_velocity_map()
        # there is no need to annotate the chord note
        # cur_id = len(self.vocab)
        # self.vocab["ChordPgs_Start"] = cur_id
        # cur_id += 1
        # self.vocab["ChordPgs_End"] = cur_id

    def get_velocity_map(self):
        velocities = []
        for k, _ in self.event2id.items():
            if k.startswith('Velocity_'):
                velocities.append(int(k.split("_")[1]))
        split_len = len(velocities) // 3
        velocity_map = {}
        velocity_map['low'] = velocities[:split_len]
        velocity_map['mid'] = velocities[split_len:2 * split_len]
        velocity_map['high'] = velocities[split_len * 2:]
        return velocity_map

    def getPitch(self, input_event):
        """Return corresponding note pitch
        if input_event is not a note, it returns -1
        we just offer the Pitch, PitchDrum will not return 
        Args:
            input_event (str or int): REMI Event Name or vocab ID
        """
        if isinstance(input_event,int):
            input_event = self.id2event[input_event]
        elif isinstance(input_event,str):
            pass
        else:
            try:
                input_event = int(input_event)
                input_event = self.id2event[input_event]
            except:
                raise TypeError("input_event should be int or str, input_event={}, type={}".format(input_event,type(input_event)))
        
        if not input_event.startswith("Pitch") or input_event.startswith("PitchDrum"):
            return -1

        assert int(input_event.split("_")[1]) >=21 and int(input_event.split("_")[1]) <=108
        return int(input_event.split("_")[1])
    
    def getChordProgression(self, chords):
        """ read chord progression from the sqlite db(which is achived in another project)

        """
        chords, _ = generate_chords_progression(chords)
        if len(chords) < self.min_progression_len or len(chords) > self.max_progression_len:
            return None
        return chords

    def getChords(self, events: Union[list[Event], list[str]]):
        chords = []
        if isinstance(events, list) and all(isinstance(event, Event) for event in events):
            for e in events:
                if e.type_ == "Chord":
                    chords.append(e.value)
        elif isinstance(events, list) and all(isinstance(event, str) for event in events):
            for s in events:
                if s.startswith("Chord"):
                    chords.append(s)
        else:
            print("not supported type of input")
        return chords

    def midi2RemiPlus(self, midi_path):
        """convert midi file to token representation

        Args:
            midi_path (str): the path of input midi file
        Returns:
            list: sequence of tokens
        """  
        try:
            tokens = self.tokenizer.encode(midi_path)
        except Exception as e:
            print(e.__str__())
            return None
        return tokens
    
    def generateChordProgressionToken(self, chordProgression, programs = '0-0-0', velocity = 'mid'):
        """
        this function get the encoder input token according to the input
        """
        if chordProgression is None:
            return None
        src = []
        chordProgression = chordProgression.split("-")
        programs = programs.split("-")
        for chord in chordProgression:
            src.extend(['Bar_None', "Chord_" + chord])
            for program in programs:
                src.append('Program_' + str(program))
                pitch = self.chord_map.get_random_pitch(chord)
                src.append('Pitch_' + pitch)
                src.append('Velocity_' + str(random.choice(self.velocity_map[velocity]) ))
        return  list(map(lambda x: self.event2id[x], src))

    def preprocessRemiPlus(self, remi_sequence: TokSequence, max_seq_len=1024, verbose=True):
        """Preprocess token sequence

        slicing the sequence for training our models

        Args:
            remi_sequence (List): the music token seqeunce(using remi+)
            max_seq_len (Int): maximum sequence length for each data

        Return:
            {
                "src" : <corressponding chord progression condition>,
                "tgt_segments" : <list of target sequences>,
                "tgt_segments_chord_binary_msk" : <list of target sequences theme msk>
            }
        """
        if verbose:
            print(f"the sequence length is:{len(remi_sequence.ids)}")
        # 1. genertae chord progression
        chord_progression = self.getChordProgression(self.getChords(remi_sequence.events))
        if chord_progression is None:
            return None
        chord_binary_msk = []
        in_chord, mask_index = False, 0
        cur_chord = ''
        cur_attr_map = {}
        chord_map = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) # chord : {program_pitch_velocity: cnt}
        for _, e in enumerate(remi_sequence.events):
            if e.type_ == 'Chord' and not in_chord and e.value in chord_progression:
                in_chord = True
                mask_index = 1
                chord_binary_msk.append(mask_index)
                cur_chord = e.value
            elif in_chord:
                if e.type_ == 'Position':
                    mask_index = 0
                    in_chord = False
                    cur_chord = ''
                else:
                    mask_index += 1
                    if e.type_ == 'Program' or e.type_ == 'Pitch' or e.type_ == 'Velocity':
                        cur_attr_map[e.type_] = str(e.value)
                    elif len(cur_attr_map) == 3:
                        chord_map[cur_chord][cur_attr_map['Pitch']][cur_attr_map['Program'] + '_' + cur_attr_map['Velocity']] += 1
                        cur_attr_map = {}
                    else:
                        cur_attr_map = {}
                chord_binary_msk.append(mask_index)
            else:
                chord_binary_msk.append(mask_index)
        
        src = []
        # generate src according to chord_map
        # we use the most frequency value
        for chord in chord_progression:
            src.extend(['Bar_None', "Chord_" + chord])
            picth_to_program_velocity = chord_map[chord]
            for pitch, program_velocity in picth_to_program_velocity.items():
                program_velocity = max(program_velocity, key=program_velocity.get)
                program_velocity = program_velocity.split('_')
                program, velocity = program_velocity[0], program_velocity[1]
                src.extend(['Program_' + program, 'Pitch_' + pitch, 'Velocity_' + velocity])
                # src.append('Duration_4.0.0') this may not be nessacery
        
        ids = remi_sequence.ids
        src = list(map(lambda x: self.event2id[x], src))
        tgt_segments, tgt_segments_theme_msk = [], []
        # TODO: 是否按照边界严格划分
        for x in range(0,len(ids),max_seq_len):
            tgt_segments.append(ids[x:x+max_seq_len+1])
            tgt_segments_theme_msk.append(chord_binary_msk[x:x+max_seq_len+1])

        return {
            "src" : src,
            "tgt_segments" : tgt_segments,
            "tgt_segments_chord_binary_msk" : tgt_segments_theme_msk
        }

    def REMIID2midi(self,event_ids, out_path, upload_to_oss=False, oss_config=None):
        """convert tokens to midi file
        
        """
        # this is not right, dkw
        # return self.tokenizer.decode(event_ids)

        # tokens = list(map(lambda x: self.id2event[x], event_ids))
        generated_midi = self.tokenizer(event_ids)
        if upload_to_oss:
            if not oss_config:
                raise ValueError("上传到 OSS 时需要提供 oss_config")
            # 创建 OSS 认证对象
            mid_url, wav_url = upload(oss_config, generated_midi, out_path)
            # 创建mp3文件

            return mid_url, wav_url
        else:
            # 保存到本地文件
            generated_midi.dump_midi(out_path)
            return out_path, None

        

    def __str__(self):
        # TODO: 修饰你的tokenizer
        pass

if __name__ == '__main__':
    # print all tokens
    myvocab = RemiPlus()
    
    print(myvocab)

    
