import os
import PIL.Image as Image
from torch.utils.data import Dataset
import numpy
import librosa
import librosa.core
import librosa.feature
import glob
import sys
#下载对应的音频文件
import logging

logging.basicConfig(level=logging.DEBUG, filename="baseline.log")
logger = logging.getLogger(' ')
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
def file_load(wav_name, mono=False):
    """
    load .wav file.

    wav_name : str
        target .wav file
    sampling_rate : int
        audio file sampling_rate
    mono : boolean
        When load a multi channels file and this param True, the returned data will be merged for mono data

    return : numpy.array( float )
    """
    try:
        return librosa.load(wav_name, sr=None, mono=mono)
    except:
        logger.error("file_broken or not exists!! : {}".format(wav_name))

                                                                                      

#convert the wave files to the vectors
def _img_loader(file_name,
                         n_mels=64,
                         frames=1,
                         n_fft=4096,
                         hop_length=2048,
                         power=2.0):
    """
    convert file_name to a vector array.

    file_name : str
        target .wav file

    return : numpy.array( numpy.array( float ) )
        vector array
        * dataset.shape = (dataset_size, feature_vector_length)
    """
    # 01 calculate the number of dimensions
    dims = n_mels * frames

    # 02 generate melspectrogram using librosa
    y, sr = file_load(file_name)
    mel_spectrogram = librosa.feature.melspectrogram(y=y,
                                                     sr=sr,
                                                     n_fft=n_fft,
                                                     hop_length=hop_length,  
                                                     n_mels=n_mels,
                                                     power=power)

    # 03 convert melspectrogram to log mel energy
    log_mel_spectrogram = 20.0 / power * numpy.log10(mel_spectrogram + sys.float_info.epsilon)

    # 04 calculate total vector size
    vector_array_size = 64                                                                                                                                 

    # 05 skip too short clips
    if vector_array_size < 1:
        return numpy.empty((0, dims))

    # 06 generate feature vectors by concatenating multiframes
    vector_array = numpy.zeros((vector_array_size, dims))
    for t in range(frames):
        vector_array[:, n_mels * t: n_mels * (t + 1)] = log_mel_spectrogram[:, t: t + vector_array_size].T
    vector_array=vector_array.reshape(1,vector_array.shape[0],vector_array.shape[1])
    
    return numpy.array(vector_array, dtype='float32')




#因为根目录中，有很多的类别，返回的是类别的名称和类别的编号。
def _find_classes(root):
    class_names = [d.name for d in os.scandir(root) if d.is_dir()]
    class_names.sort()
    classes_indices = {class_names[i]: i for i in range(len(class_names))}
    # print(classes_indices)
    return class_names, classes_indices  # 'class_name':index

#返回samples，为list类型，主要内容为文件的地址和类别编号 
#image_dir代表的是train_dir
def _make_dataset(image_dir):
    samples = []  # image_path, class_idx

    class_names, class_indices = _find_classes(image_dir)

    for class_name in sorted(class_names):
        class_idx = class_indices[class_name]
        target_dir = os.path.join(image_dir, class_name)

        if not os.path.isdir(target_dir):
            continue
        training_list_path=os.path.abspath("{dir}/*.{ext}".format(dir=target_dir,  ext="wav"))
        files=sorted(glob.glob(training_list_path))
        for file in files:
            item=file,class_idx
            samples.append(item)

    return samples


class ImageDataset(Dataset):
    def __init__(self, image_dir, mode='RGB', transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.samples = _make_dataset(self.image_dir)
        self.targets = [s[1] for s in self.samples]
        self.mode = mode

    def __getitem__(self, index):
        image_path, target = self.samples[index]
        image = _img_loader(image_path)
        name = os.path.split(image_path)[1]

        #if self.transform is not None:
           # image = self.transform(image)

        return image, target, name

    def __len__(self):
        return len(self.samples)
