from [anonymous] import MultiTrack, Bar
import os
import sys
dirof = os.path.dirname
sys.path.append(dirof(dirof(dirof(os.path.abspath(__file__)))))

import random
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from utils_midi import remi_utils
from typing import List
from tqdm import tqdm
from utils_instrument.inst_map import InstMapUtil
from utils_chord.chord_detect_from_remi import chord_to_id
from utils_common.utils import jpath, ls, read_yaml, save_json, read_json
import matplotlib.pyplot as plt


def get_dataloader(config, split):
    bs = config['bs'] if split != 'test' else config['bs_test']

    # data_root='/data2/[anonymous]/Datasets/slakh2100_flac_redux',
    dataset_class_name = 'PianoReductionDataset'
    dataset_class = eval(dataset_class_name)

    dataset = dataset_class(config=config, split=split)
    dataloader = DataLoader(
        dataset=dataset, 
        shuffle=True if split == 'train' else False,
        batch_size=int(bs),
        num_workers=config['num_workers'],
        collate_fn=dataset.collate_fn,
        # collate_fn=lambda x: x,
    )
    return dataloader


class PianoReductionDataset(Dataset):
    '''
    The dataset class for Moyu's ICASSP 2023 paper
    piano roll-based parallel dataset: (content of mixture, content of piano)

    Conditions:
        - (8-bar length)
        - Time signature
        - Tempo
        - Instrument (with voice control)
    '''
    def __init__(self, config, split):
        self.config = config

        dataset_root = config['data_root']
        meta_fp = jpath(dataset_root, 'metadata', 'segment_dataset_1bar_q16_norm.json')

        # Read the dataset
        print('Loading the dataset...')
        dataset = read_json(meta_fp)
        dataset = dataset[split] if split != 'valid' else dataset['validation']

        ''' 
        Filtering: Remove 
        - non-4/4 time signature
        - no piano
        '''
        print('Filtering the dataset...')
        data_filtered = {}
        piano_ids = set([0, 1, 2, 3, 4, 5, 6, 7])
        piano_ids_list = list(piano_ids)
        self.piano_ids = piano_ids_list
        for bar_name, bar_data in tqdm(dataset.items()):
            # Filter out non-4/4 time signature
            if bar_data['meta']['time_signature'] != '(4, 4)':
                continue

            # Filter out blank bars
            insts_in_sample = bar_data['meta']['insts']
            if len(insts_in_sample) == 0: # Blank bar with no instruments,
                continue

            # Filter out bars with no piano
            insts_in_sample = insts_in_sample.split(' ')
            insts_in_sample = [int(i) for i in insts_in_sample]
            if len(piano_ids.intersection(insts_in_sample)) == 0:
                continue

            # Filter out range of piano < bar range * 0.4
            piano_range = bar_data['meta']['piano_pitch_range'] + 1
            entire_range = bar_data['meta']['pitch_range'] + 1
            if piano_range / entire_range < 0.4:
                continue

            data_filtered[bar_name] = bar_data

        valid_rate = len(data_filtered) / len(dataset) * 100
        print('Filtered data:', len(data_filtered), 'out of', len(dataset), f'valid rate: {valid_rate:.2f}%')

        # Re-index with integer starting from 0
        data_filtered = {i: data_filtered[k] for i, k in enumerate(data_filtered)}
        self.data = data_filtered

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        mt = MultiTrack.from_remiz_str(self.data[index]['content']) # This has already been quantized and normalized before (1-bar dataset)
        assert len(mt) == 1, 'Only support 1-bar'
        bar = mt[0]

        # Create piano roll for the mixture
        mix_proll = bar.to_piano_roll() # [16, 128]

        # Octave shift for mixture piano roll
        if self.config.get('octave_shift', False):
            mix_proll = shift_octave(mix_proll)

        # Create piano roll for the piano
        piano_proll = bar.to_piano_roll(of_insts=self.piano_ids)

        return mix_proll, piano_proll

    def collate_fn(self, batch):
        mix_prolls = [torch.from_numpy(b[0]) for b in batch]
        piano_prolls = [torch.from_numpy(b[1]) for b in batch]

        # Convert to float tensor
        mix_prolls = torch.stack(mix_prolls).float()
        piano_prolls = torch.stack(piano_prolls).float()

        ret = {
            'mix_prolls': mix_prolls,
            'piano_prolls': piano_prolls
        }

        return ret
    

def shift_octave(proll:np.ndarray) -> np.ndarray:
    '''
    Shift the proll from octave 0 to (-12, 0, 12)
    '''
    shifts = [-12, 0, 12]
    non_zero_idx = proll.nonzero()
    for shift in shifts:
        for x, y in zip(non_zero_idx[0], non_zero_idx[1]):
            if 0 <= y+shift < proll.shape[1]:
                proll[x, y+shift] = max(proll[x, y], proll[x, y+shift])
    return proll
        

def draw_proll(proll, save_fp):
    fig, ax = plt.subplots(figsize=(8, 6))  # 设置宽 8 英寸，高 6 英寸
    cax = ax.matshow(proll)  # 使用 ax.matshow 而不是 plt.matshow
    # fig.colorbar(cax)  # 添加 colorbar
    ax.set_title("Matrix Visualization", pad=20)  # 设置标题
    plt.show()
    plt.savefig(save_fp)