# pylint: disable=C0114, C0301, R0913, C0303, C0115, C0116, R0914, E0402
from itertools import accumulate
from bisect import bisect
import math
from abc import ABC
import random
import datasets
from torch.utils.data import Dataset
import torch
import librosa
import torch.nn.functional as F
from transformers import AutoTokenizer
import numpy as np
import onnxruntime

from .Normalizer.utils import text_norm

LANG_MULTI = {
    'en': 1, 
    'zh': 1,
}
LANGCODEMAP = {
    'english': 'en',
    'chinese': 'cn',

}

SPECIAL_TOKENS = {
    'eos_token': '<|endoftext|>',
    'pad_token': '<|endoftext|>',
    'additional_special_tokens': [
        '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
        '[breath]', '<strong>', '</strong>', '[noise]',
        '[laughter]', '[cough]', '[clucking]', '[accent]',
        '[quick_breath]',
        "<laughter>", "</laughter>",
        "[hissing]", "[sigh]", "[vocalized-noise]",
        "[lipsmack]", "[mn]"
    ]
}
    
def log_mel_spectrogram(
        wav,
        sr=16000,
        num_mel_bins=128,
        n_fft=400,
        hop_length=160,                                    
        padding=0,
        pad_or_trim: bool = False,
        max_duration: int = 30
    ):
    sample_rate = sr
    # waveform = wav.squeeze(0)  # (channel=1, sample) -> (sample,)
    waveform = torch.from_numpy(wav).float()
    #print(waveform.size())
    if padding > 0:
        waveform = F.pad(waveform, (0, padding))
    mel_length = np.array([math.ceil(waveform.size(0)/sr*100)]).astype('long')
        
    if pad_or_trim:
        length = max_duration * sample_rate
        if waveform.size(0) >= length:
            waveform = waveform[:length]
            mel_length = np.array([math.ceil(waveform.size(0)/sr*100)]).astype('long')
        else:
            waveform = F.pad(waveform, (0, length - waveform.size(0)))               
    window = torch.hann_window(n_fft)
    stft = torch.stft(
        waveform,
        n_fft,
        hop_length,
        window=window,
        return_complex=True)
    magnitudes = stft[..., :-1].abs()**2        
    filters = torch.from_numpy(
        librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins))
    mel_spec = filters @ magnitudes
    
    # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    
    # NOTE: for wav streaming process,replace log_spec.max() - 8.0 by min=--8.0
    # log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = torch.clamp(log_spec, min=-8.0, max=0.0)
    
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec, mel_length
    
    



def tokenize(tokenizer, text, allowed_special='all', max_length=None):
    tokenizer.model_input_names = ["input_ids", "attention_mask"]
    inputs = tokenizer(
        text, return_tensors="pt", padding=True,
        truncation=True, max_length=max_length)
    text_lens = inputs['attention_mask'].sum(dim=-1)
    # input_ids = inputs['input_ids']
    eos = torch.full([len(text), 1], tokenizer.eos_token_id).long()
    input_ids = torch.cat([inputs['input_ids'], eos], dim=1)
    text_lens += 1
    return input_ids, text_lens 



class TTSDataset(Dataset):
    def __init__(self, datasets_paths, onnx_path, train=True):
        self.datasets = [self._load_from_path(path) for path in datasets_paths]
        # self.languages = [LANGCODEMAP[dataset[0]['language']] for dataset in self.datasets]

        self.lengths = [x.num_rows for x in self.datasets]
        self.cum_lengths = [0] + list(accumulate(self.lengths))
        self.total_length = sum(self.lengths)
        self.train = train

        option = onnxruntime.SessionOptions()
        providers = ['CPUExecutionProvider']
        ort_session = onnxruntime.InferenceSession(onnx_path, sess_options=option, providers=providers)
        self.onnx_extractor = ort_session

    def _search_idx(self, i):
        idx = i % self.total_length
        dataset_idx = bisect(self.cum_lengths, idx) - 1
        sample_idx = idx - self.cum_lengths[dataset_idx]
        sample_idx = sample_idx % self.datasets[dataset_idx].num_rows
        return dataset_idx, sample_idx

    def _load_from_path(self, path):
        dataset = datasets.load_from_disk(path)
        Dataset_list = []
        if type(dataset).__name__ == 'DatasetDict':
            for sub_name in dataset:
                sub_dataset = dataset[sub_name]
                Dataset_list.append(sub_dataset)
        else:
            Dataset_list.append(dataset)
        dataset = datasets.concatenate_datasets(Dataset_list)
        return dataset

    def __len__(self):
        return self.total_length

    def __getitem__(self, i):
        dataset_idx, sample_idx = self._search_idx(i)
        record = self.datasets[dataset_idx][sample_idx]


        # src_lang = LANGCODEMAP[record['language']]


        wav = record['audio']['array']
        sr = record['audio']['sampling_rate']
        if sr != 16000:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
        if not isinstance(wav, np.ndarray):
            wav = np.array(wav)

        if len(wav) < 8000 or len(wav) >= 30 * 16000:
            return self.__getitem__(random.randint(0, self.total_length-1))

        if 'text_clean_ner' in record.keys():
            asr_label = record['text_clean_ner'].replace("<organization>>", "").\
                        replace("<organization>", "").\
                        replace("<person>>", "").\
                        replace("<person>", "").\
                        replace("<location>>", "").\
                        replace("<location>", "")
        else:
            asr_label = record['text']

        input_text = asr_label
        mel, mel_len = log_mel_spectrogram(wav)
        mel = mel.unsqueeze(0)
        token = self.onnx_extractor.run(
            None, {
                self.onnx_extractor.get_inputs()[0].name: mel.detach().cpu().numpy(),
                self.onnx_extractor.get_inputs()[1].name: np.array([mel.shape[2]], dtype=np.int32)
                })[0].flatten()


        return {
            'train': self.train,
            "token": token,
            # "text": text_norm(input_text, src_lang, filt_num=False),
            "text": input_text
        }

class Collator:
    def __init__(self, pretrain_path):
        self.tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS)

    def __call__(self, batch):
        train = batch[0]['train']
        text = [b["text"] for b in batch]
        token = [{ "speech": b["token"], "speech_len": len(b["token"])} for b in batch]
        # token_lens = [len(b["token"]) for b in batch]
        input_text, text_len = tokenize(self.tokenizer, text)
        self.tokenizer.model_input_names = ["speech", "speech_len"]
        inputs_spch = self.tokenizer.pad(token, return_tensors="pt", padding='longest',max_length=750)
        trunc_text, norm_target = [], []
        for t, t_l in zip(input_text, text_len):
            if t_l < 6:
                trunc_text.append({'text': t, 'text_len': t_l})
                norm_target.append(0)
            else:
                if train:
                    end = random.randint(4, t_l-1)
                else:
                    end = t_l // 2 + 2
                trunc_text.append({'text': t[:end], 'text_len': end})
                norm_target.append((min(end, 2)*0.5 + (t_l - end) * 1) / t_l)
                

        self.tokenizer.model_input_names = ["text", "text_len"]
        inputs_text_trunc = self.tokenizer.pad(trunc_text, return_tensors="pt", padding='longest',max_length=200)
        norm_target = torch.tensor(norm_target)
        inputs_spch['speech'].masked_fill_(inputs_spch['speech']==self.tokenizer.pad_token_id, 0)

        return {
            'text': input_text,
            'raw_text': text,
            'text_len': text_len,
            'speech': inputs_spch['speech'],
            'speech_len': inputs_spch['speech_len'],
            'tunc_text': inputs_text_trunc['text'],
            'tunc_text_len': inputs_text_trunc['text_len'],
            'norm_target': norm_target,
        }


