import os
import librosa
import pandas as pd
import torch

from torch.utils.data import Dataset
from numpy.lib.stride_tricks import as_strided
from tqdm import tqdm
from glob import glob
import numpy as np
from .config import *


class MyDataset(Dataset):
    def __init__(self, path, hop_length, pitch_th=0.5, sr=16000, sequence_length=None, groups=None, num_class=360):
        self.path = path
        self.sample_rate = sr
        self.HOP_LENGTH = hop_length
        self.th = pitch_th
        self.seq_len = int(sequence_length * sr) if sequence_length is not None else None
        self.num_class = num_class
        self.data = []

        print(f"Loading {len(groups)} group{'s' if len(groups) > 1 else ''} "
              f"of {self.__class__.__name__} at {path}")
        for group in groups:
            for input_files in tqdm(self.files(group), desc='Loading group %s' % group):
                self.data.extend(self.load(*input_files))

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

    @staticmethod
    def availabe_groups():
        return ['test']

    def files(self, group):
        audio_m_files = glob(os.path.join(self.path, group, '*_m.wav'))
        audio_v_files = [f.replace('_m.wav', '_v.wav') for f in audio_m_files]
        label_files = [f.replace('_m.wav', '.csv') for f in audio_m_files]

        assert (all(os.path.isfile(audio_v_file) for audio_v_file in audio_v_files))
        assert (all(os.path.isfile(label_file) for label_file in label_files))

        return sorted(zip(audio_m_files, audio_v_files, label_files))

    def load(self, audio_m_path, audio_v_path, label_path):
        data = []
        audio_m, _ = librosa.load(audio_m_path, sr=self.sample_rate)
        audio_steps = len(audio_m) // self.HOP_LENGTH + 1
        audio_m = torch.from_numpy(audio_m)

        audio_v, _ = librosa.load(audio_v_path, sr=self.sample_rate)
        audio_v = torch.from_numpy(audio_v)

        t_audio_v = np.pad(audio_v.numpy(), 512)
        audio_v_frames = as_strided(t_audio_v, shape=(1024, audio_steps),
                                    strides=(t_audio_v.itemsize, self.HOP_LENGTH * t_audio_v.itemsize)).transpose()
        audio_v_frames = torch.from_numpy(audio_v_frames)

        df_pitch = pd.read_csv(label_path)
        pitch = torch.zeros((audio_steps, self.num_class))
        assert len(df_pitch) == audio_steps
        if self.num_class == 360:
            for i in range(len(df_pitch)):
                freq = df_pitch['freqs'][i]
                confi = df_pitch['pe_confi'][i]
                if freq != 0:
                    cent = 1200 * np.log2(freq / 10)
                    index = int(round((cent - CONST) / 20))
                    pitch[i, index] = float(confi > self.th)
        else:
            for i in range(len(df_pitch)):
                freq = df_pitch['freqs'][i]
                confi = df_pitch['pe_confi'][i]
                if F_MIN <= freq <= F_MAX:
                    index = int(round(Q * np.log2(freq / F_MIN)))
                    pitch[i, index] = float(confi > self.th)
        pe_confi = df_pitch['pe_confi'].values

        if sum(pe_confi) != len(pe_confi):
            pe_confi = pe_confi * (pe_confi >= self.th) + (1 - pe_confi) * (pe_confi < self.th)

        mss_confi = df_pitch['mss_confi'].values
        if sum(mss_confi) != len(mss_confi):
            mss_confi = mss_confi * (mss_confi >= self.th) + (1 - mss_confi) * (mss_confi < self.th)

        if self.seq_len is not None:
            audio_len = len(audio_m)
            n_steps = self.seq_len // self.HOP_LENGTH
            for i in range(audio_len // self.seq_len):
                begin_t = i * self.seq_len
                end_t = begin_t + self.seq_len
                begin_step = begin_t // self.HOP_LENGTH
                end_step = begin_step + n_steps
                data.append(dict(audio_m=audio_m[begin_t:end_t], audio_v=audio_v[begin_t:end_t],
                                 audio_v_frames=audio_v_frames[begin_step:end_step],
                                 pe_confi=pe_confi[begin_step:end_step], mss_confi=mss_confi[begin_step:end_step],
                                 pitch=pitch[begin_step:end_step], file=os.path.basename(audio_v_path)))
            data.append(dict(audio_m=audio_m[-self.seq_len:], audio_v=audio_v[-self.seq_len:],
                             audio_v_frames=audio_v_frames[-n_steps:],
                             pe_confi=pe_confi[-n_steps:], mss_confi=mss_confi[-n_steps:],
                             pitch=pitch[-n_steps:], file=os.path.basename(audio_v_path)))

        else:
            data.append(dict(audio_m=audio_m, audio_v=audio_v, audio_v_frames=audio_v_frames, pe_confi=pe_confi,
                             mss_confi=mss_confi, pitch=pitch, file=os.path.basename(audio_v_path)))
        return data
