"""Music Dataset for chord-based transformer

    Author: Joey.Zhu
    Email: joey8273@qq.com
    Date: 2024/12/09
    
"""
import torch
from torch.utils.data.dataset import Dataset
import sys, pickle
import numpy as np
from glob import glob
from copy import deepcopy
import random
from preprocess import tokenizer
import os

class MusicDataset(Dataset):
    def __init__(self,data,max_seq_len, aug=None):
        # create vocab
        self.remi_plus = tokenizer.RemiPlus()
        self.event2id = self.remi_plus.event2id
        self.id2event =  self.remi_plus.id2event
        self.vocab_size = self.remi_plus.vocab_size
        self.data = data

        self.max_seq_len = max_seq_len # prediction length

        self.ep_start_pitchaug = 0 # pitch augutation

        self.pitchaug_range = aug
        self.final_training_data = []

        self.constants = {
            "max_src_len" : max([ len(x["src"]) for x in self.data]),
            "max_tgt_len" : min(max_seq_len,max([ len(x["tgt"]) for x in self.data])) + 1
        }
        print(self.constants)
    
    def data_pitch_augment(self,src_seq,tgt_seq):
        """pitch shift for data augmentation

        Args:
            src_seq (list): src sequence
            tgt_seq (list): tgt sequence

        Returns:
            tuple: augmented src sequence, augmented target sequence
        """
        # pitch augement

        # get all pitch tokens(pitch durms == 27 ~ 87 we just ignore it)
        all_pitches = [ self.remi_plus.getPitch(x) for x in src_seq] + [self.remi_plus.getPitch(x) for x in tgt_seq]
        all_pitches = [x for x in all_pitches if x > 0]
        if len(all_pitches) == 0 :
            return src_seq,tgt_seq
        l, r = 21-min(all_pitches),108 - max(all_pitches)
        # idk why this would happen, just ignore it for now
        if l >= r:
            return src_seq, tgt_seq
        pitch_offsets = np.random.randint(l, r, size=1)
        pitch_offset = pitch_offsets[0]

        aug_src_phrase = deepcopy(src_seq)
        aug_tgt_phrase = deepcopy(tgt_seq)
        for t in range(len(aug_src_phrase)):
            if self.remi_plus.getPitch(aug_src_phrase[t]) > 0:
                aug_src_phrase[t] += pitch_offset
                assert self.remi_plus.getPitch(aug_src_phrase[t]) > 0
        
        for t in range(len(aug_tgt_phrase)):
            if self.remi_plus.getPitch(aug_tgt_phrase[t]) > 0:
                aug_tgt_phrase[t] += pitch_offset
                assert self.remi_plus.getPitch(aug_tgt_phrase[t]) > 0

        return aug_src_phrase, aug_tgt_phrase

    def __getitem__(self, index):
        """return data given index

        Args:
            index (int): the index fo data

        Returns:
            obj: {
                "src"           : <src sequence>
                "src_msk"       : <src sequence padding mask>,
                "tgt"           : <target sequence>,
                "tgt_msk"       : <target sequence padding mask>,
                "tgt_chord_msk" : <target sequence chord mask>,
            }
        """

        # pitch augment
        src, tgt = self.data_pitch_augment(src_seq=self.data[index]["src"],tgt_seq=self.data[index]["tgt"])
        
        tgt_chord_msk = self.data[index]["tgt_chord_msk"]
        
        tgt_chord_msk = list(map(int,tgt_chord_msk))
        # # get all labels
        # arousal = self.data[index]["arousal"]
        # valence = self.data[index]["valence"]
        # danceability = self.data[index]["danceability"]
        # energy = self.data[index]["energy"]
        # instrumentalness = self.data[index]["instrumentalness"]
        # liveness = self.data[index]["liveness"]
        
        # padding
        src_msk = [0]* len(src) + [1] * (self.constants["max_src_len"] - len(src))
        src.extend([0]*(self.constants["max_src_len"] - len(src)))
        tgt_msk = []
        if len(tgt) > self.constants["max_tgt_len"]:
            print("Should not be here")
            assert True
            tgt = tgt[:self.constants["max_tgt_len"]]
            tgt_chord_msk = tgt_chord_msk[:self.constants["max_tgt_len"]]
            tgt_msk = [0] * len(tgt)
        else:
            tgt_msk = [0]* len(tgt) + [1] * (self.constants["max_tgt_len"] - len(tgt))
            tgt_chord_msk.extend([0]*(self.constants["max_tgt_len"] - len(tgt)))
            tgt.extend([0]*(self.constants["max_tgt_len"] - len(tgt)))
        
        current_entry = {
            "src"           : torch.tensor(src, dtype=torch.long),
            "src_msk"       : torch.tensor(src_msk, dtype=torch.long),
            "tgt"           : torch.tensor(tgt, dtype=torch.long),
            "tgt_msk"       : torch.tensor(tgt_msk, dtype=torch.long),
            "tgt_chord_msk" : torch.tensor(tgt_chord_msk, dtype=torch.long),
            # "arousal"       : torch.tensor(arousal, dtype=torch.float32),
            # "valence"       : torch.tensor(valence, dtype=torch.float32),
            # "danceability"  : torch.tensor(danceability, dtype=torch.float32),
            # "energy"        : torch.tensor(energy, dtype=torch.float32),
            # "instrumentalness" : torch.tensor(instrumentalness, dtype=torch.float32),
            # "liveness"      : torch.tensor(liveness, dtype=torch.float32),
        }
        assert(len(src) == self.constants["max_src_len"])
        assert(len(tgt) == self.constants["max_tgt_len"])
        assert(len(src_msk) == self.constants["max_src_len"])
        assert(len(tgt_msk) == self.constants["max_tgt_len"])
        assert(len(tgt_chord_msk) == self.constants["max_tgt_len"])
        return current_entry

    def __len__(self):
        return len(self.data)

    # def set_epoch(self,ep):
        

def getMusicDataset(args, i=0, cut=False, training=True):
    """load data from pkl file and return torch dataset

    Args:
        training(boolean): wheather the training dataset
        args (obj): all args from argparser

    Returns:
        torch.util.data.dataset: the dataset insrance
    """
    if training:
        if not cut:
            with open(args.training_data_path,"rb") as f:
                data = pickle.load(f)
        else:
            file_path = os.path.join(args.training_data_path, f'train_lmd_full_1024_part_{i + 1}.pkl')
            with open(file_path,"rb") as f:
                data = pickle.load(f)
    else:
        with open(args.validating_data_path,"rb") as f:
            data = pickle.load(f)
    
    dataset = MusicDataset(data=data,max_seq_len=args.max_len)
    return dataset

if __name__ == '__main__':
    # test
    with open('./generated/data_pkl/train_pop909_1024.pkl',"rb") as f:
        data = pickle.load(f)
    dataset = MusicDataset(data=data,max_seq_len=4096)
    