# pylint: disable=C0114, C0301, R0913, C0303, C0115, C0116, R0914, E0402
from itertools import accumulate
from bisect import bisect
import math
from abc import ABC
import random
import datasets
from torch.utils.data import Dataset
import torch
import librosa
import torch.nn.functional as F
# import zhconv
import numpy as np

from .Normalizer.utils import text_norm, remove_punctuation


LANG_MULTI = {
    'en': 1, 
    'zh': 1,
    'de': 2,
    'fr': 2,
    'it': 8,
    'es': 2,
}
LANGCODEMAP = {
    'english': 'en',
    'chinese': 'zh',
    'french': 'fr',
    'german': 'de',
    'spanish': 'es',
    'italian': 'it'
}

NORMAL_MAP = {
    '’': "'",
    "!": '.',
    "、": ',',
    '‘': "'",
    '？': '?',
    '！': '。',
    '＇': "'",
    '；': ',',
    '，': ',',
    '：': ',',
    ':': ',',
    '…': '.',
}

REMOVE_TOKEN = [
    '"', '’', '”', '“', '!', '、', '，', '？', '！', '；', '：', '（', '）', '《', '》', '【', '】', '「', '」', '‘', '’', '“', '”', '…', '—', '–', '﹏', '～', '·', '•', '、', '＂', '＇', '｀', '＃', '＄', '％', '＆', '＊', '＋', '－', '／', '＝', '＠', '＜', '＞', '［', '］', '｛', '｜', '｝', '～', '！', '＼', '＂', '＇', '｀', '＃', '＄', '％', '＆', '＊', '＋', '－', '／', '＝', '＠', '＜', '＞', '［', '］', '｛', '｜', '｝', '～', '！', '＼', 
]

NON_SPEECH_TOKENS = [-100, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 25, 26, 27, 28, 29, 30, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 98, 120, 171, 223, 230, 231, 234, 253, 353, 485, 893, 918, 1231, 1260, 1432, 1453, 2751, 2945, 3246, 3253, 4246, 4589, 4852, 5341, 5990, 9383, 10852, 10922, 13556, 13684, 15025, 16259, 18097, 18150, 18807, 20075, 20111, 20387, 20409, 22781, 23757, 24396, 25472, 25629, 28493, 29026, 29305, 29464, 30932, 35746, 39107, 40698, 43807, 50199]

def cal_speech_token(ys, max_id=50256):
    mask = ~torch.isin(ys, torch.tensor(NON_SPEECH_TOKENS).long())
    mask = torch.logical_and(mask, ys <= max_id)
    return mask.float().sum(-1)



class Processor():
    def __init__(self, num_mel_bins=128, augment=True, prob=None):
        self.num_mel_bins = num_mel_bins
        self.sr = 16000
        self.augment = augment
        self.prob = 0.15 if prob is None else prob
                
    def speed_augment(self,waveform,speeds=None):
        if random.random() < self.prob*2:
            if speeds is None:
                speeds = (0.8,1.1)
            speed = random.uniform(speeds[0],speeds[1])
            new_length = min(int(speed*waveform.size(-1)),self.sr*30)
            waveform = torch.nn.functional.interpolate(waveform.unsqueeze(0),size=new_length,mode='linear',align_corners=True).squeeze(0)        
        return waveform
    
    def volume_augment(self,waveform,volumes=None):
        if random.random() < self.prob:
            if volumes is None:
                volumes = (-5,5)
            gain = random.uniform(volumes[0],volumes[1])
            gain = 10. ** (gain / 20.)
            waveform = waveform*gain    
        return waveform
    
    def wav_trim(self,waveform,max_t=8000):
        if random.random() < self.prob:
            max_frames = waveform.size(-1)
            length = random.randint(1, max_t)
            if length < max_frames / 2:
                waveform = waveform.clone().detach()[:,:max_frames - length]
        return waveform
    
    def spec_aug(self,mel,num_t_mask=45, num_f_mask=2, max_t=20, max_f=10, max_w=80):
        mel_clone = mel.clone().detach().transpose(0, 1)
        max_frames = mel_clone.size(0)
        max_freq = mel_clone.size(1)
        # time mask
        if random.random() < self.prob:
            for i in range(num_t_mask):
                start = random.randint(0, max_frames - 1)
                length = random.randint(1, max_t)
                end = min(max_frames, start + length)
                mel_clone[start:end, :] = 0
        # freq mask
        if random.random() < self.prob:
            for _ in range(num_f_mask):
                start = random.randint(0, max_freq - 1)
                length = random.randint(1, max_f)
                end = min(max_freq, start + length)
                mel_clone[:, start:end] = 0
        return mel_clone.transpose(0, 1)
    
    def compute_log_mel_spectrogram(
            self,
            wav,
            num_mel_bins=128,
            n_fft=400,
            hop_length=160,                                    
            padding=0,
            pad_or_trim: bool = True,
            max_duration: int = 30
        ):
        sample_rate = self.sr
        waveform = wav.squeeze(0)  # (channel=1, sample) -> (sample,)
        #print(waveform.size())
        if padding > 0:
            waveform = F.pad(waveform, (0, padding))
        mel_length = np.array([math.ceil(waveform.size(0)/self.sr*100)]).astype('long')
            
        if pad_or_trim:
            length = max_duration * sample_rate
            if waveform.size(0) >= length:
                waveform = waveform[:length]
                mel_length = np.array([math.ceil(waveform.size(0)/self.sr*100)]).astype('long')
            else:
                waveform = F.pad(waveform, (0, length - waveform.size(0)))               
        window = torch.hann_window(n_fft)
        stft = torch.stft(
            waveform,
            n_fft,
            hop_length,
            window=window,
            return_complex=True)
        magnitudes = stft[..., :-1].abs()**2        
        filters = torch.from_numpy(
            librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins))
        mel_spec = filters @ magnitudes
        
        # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
        
        # NOTE: for wav streaming process,replace log_spec.max() - 8.0 by min=--8.0
        # log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
        log_spec = torch.clamp(log_spec, min=-8.0, max=0.0)
        
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec,mel_length
    
    
    def wav2mel(self, wav, training=True):
        wav = torch.from_numpy(wav).unsqueeze(0).float()
        if self.augment and training:
            wav = self.speed_augment(wav)
            wav = self.volume_augment(wav)
            mel, mel_length = self.compute_log_mel_spectrogram(wav,self.num_mel_bins)
            mel = self.spec_aug(mel)
        else:
            mel, mel_length = self.compute_log_mel_spectrogram(wav,self.num_mel_bins)
        return {
            'mel': mel.numpy(),
            'mel_lens': mel_length
        }


def tokenize(tokenizer, text, src_ids, text_target=None, tgt_ids=None, max_length=None, ignore_id=-100,):
    inputs = tokenizer(
        text, text_target = text_target, return_tensors="pt", padding=True,
        truncation=True, max_length=max_length)
    inputs['input_ids'][:, 0] = src_ids
    if text_target is not None:
        inputs['labels'][:, 0] = tgt_ids
        inputs['labels'][inputs['labels']==tokenizer.pad_token_id] = ignore_id
    return inputs

def determine_language(path):
    for lang in ['ZH','EN', 'CN', 'FR', 'IT', 'DE', 'ES']:
        if lang in path:
            return lang.lower().replace('cn', 'zh')
    raise ValueError(f'{path}')


class MultiHFDataset(Dataset, ABC):
    def __init__(self, datasets_paths, train=True, target_split=None):
        self.datasets = [self._load_from_path(path) for path in datasets_paths]
        # self.languages = [LANGCODEMAP[dataset[0]['language']] for dataset in self.datasets]
        self.languages = [determine_language(path) for path in datasets_paths]
        if train:
            self.lengths = [x.num_rows * LANG_MULTI[lang] for x, lang in zip(self.datasets, self.languages)]
        else:
            self.lengths = [x.num_rows for x in self.datasets]
        self.cum_lengths = [0] + list(accumulate(self.lengths))
        self.total_length = sum(self.lengths)
        self.train = train
        if isinstance(target_split, str):
            target_split = [target_split for _ in datasets_paths]
        self.target_split=target_split
        self.processor = Processor(80, True)

    def _search_idx(self, i):
        idx = i % self.total_length
        dataset_idx = bisect(self.cum_lengths, idx) - 1
        sample_idx = idx - self.cum_lengths[dataset_idx]
        sample_idx = sample_idx % self.datasets[dataset_idx].num_rows
        return dataset_idx, sample_idx

    def _load_from_path(self, path):
        dataset = datasets.load_from_disk(path)
        Dataset_list = []
        if type(dataset).__name__ == 'DatasetDict':
            for sub_name in dataset:
                sub_dataset = dataset[sub_name]
                Dataset_list.append(sub_dataset)
        else:
            Dataset_list.append(dataset)
        dataset = datasets.concatenate_datasets(Dataset_list)
        return dataset


class MTDataset(MultiHFDataset):
    def __len__(self):
        if self.train:
            return self.total_length * 2
        return self.total_length

    def __getitem__(self, i):
        dataset_idx, sample_idx = self._search_idx(i)
        revsere = i >= self.total_length
        record = self.datasets[dataset_idx][sample_idx]['translation']
        keys = list(record.keys())
        if self.target_split is None:
            src_lang = keys[1] if revsere else keys[0]
            tgt_lang = keys[0] if revsere else keys[1]
        else:
            src_lang, tgt_lang = self.target_split[dataset_idx].split("-")
            assert src_lang in keys and tgt_lang in keys, f"Specified target split {self.target_split}, but only {keys} found in data"
        return {
            "input": record[src_lang],
            "label": record[tgt_lang],
            "src_lang": src_lang,
            "tgt_lang": tgt_lang
        }

class STDataset(MultiHFDataset):
    def __len__(self):
        return self.total_length
    
    def task2prompt(self, language:str = 'en'):
        assert language in ['en','cn','de','fr','es','it','haw'] #'de','fr','es','it'
        whisper_language = language.replace('cn','zh')
        return f'<|{whisper_language}|>'

    def __getitem__(self, i):
        dataset_idx, sample_idx = self._search_idx(i)
        record = self.datasets[dataset_idx][sample_idx]
        valid_target = []
        for key in record.keys():
            if key.startswith('trans') and record[key] is not None:
                valid_target.append(key.split('_')[1])
        if len(valid_target) == 0:
            return self.__getitem__(random.randint(0, self.total_length-1))

        if self.target_split is None:
            # src_lang = LANGCODEMAP[record['language']]
            src_lang = self.languages[dataset_idx]
            tgt_lang = random.choice(valid_target)
        else:
            src_lang, tgt_lang = self.target_split[dataset_idx].split("-")
            tgt_lang = tgt_lang.replace('zh', 'cn')
            assert tgt_lang in valid_target
            # assert src_lang.replace('cn', 'zh') == LANGCODEMAP[record['language']]

        wav = record['audio']['array']
        sr = record['audio']['sampling_rate']
        if sr != 16000:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
        if not isinstance(wav, np.ndarray):
            wav = np.array(wav)

        if len(wav) < 8000:
            return self.__getitem__(random.randint(0, self.total_length-1))

        if self.train:
            if random.random() < 0.7 and len(wav) >= 24000:
                end_frame = random.randint(8000, len(wav) - 8000)
                wav_trunc = wav[:end_frame]
                trunc = True
                score_target = (min(end_frame, 24000)*0.5 + (len(wav) - end_frame) * 1) / len(wav)
            else:
                wav_trunc = wav
                trunc = False
                score_target = 0.0
        else:
            if len(wav) >= 24000:
                end_frame = len(wav) // 2 + 12000
                score_target = 0.5
                trunc=True
            else:
                end_frame = len(wav)
                score_target = 0.0
                trunc=False
            wav_trunc = wav[:end_frame]

        try:
            mel = self.processor.wav2mel(wav, self.train)
            mel_trunc = self.processor.wav2mel(wav_trunc, self.train)
        except:
            return self.__getitem__(random.randint(0, self.total_length-1))

        

        if 'text_clean_ner' in record.keys():
            asr_label = record['text_clean_ner'].replace("<organization>>", "").\
                        replace("<organization>", "").\
                        replace("<person>>", "").\
                        replace("<person>", "").\
                        replace("<location>>", "").\
                        replace("<location>", "")
        else:
            asr_label = record['text']

        input_text = asr_label
        label = record[f'trans_{tgt_lang}']

        # if random.random() < 0.05 and self.target_split is None:
        #     label = input_text
        #     tgt_lang = src_lang

        return {
            "wav": wav,
            "mel_full": mel,
            "mel_trunc": mel_trunc,
            "src_text": text_norm(input_text, src_lang, filt_num=False),
            "tgt_text": text_norm(label, tgt_lang, filt_num=False),
            "src_lang": src_lang.replace('cn', 'zh'),
            "tgt_lang": tgt_lang.replace('cn', 'zh'),
        }

class STCollator:
    def __init__(self, tokenizer, n_mels=128, max_length=400, ignore_id=-100):
        self.tokenizer = tokenizer
        self.processor = Processor(n_mels, True)
        self.max_length = max_length
        self.ignore_id = ignore_id
        self.tokenizer.model_input_names = ["mel", "mel_lens"]
        
    def lid2label(self, language:str = 'en'):
        assert language in ['en','cn','de','fr','es','it', 'zh']
        lid_label_dicts = {'cn':0,'en':1,'de':2,'fr':3,'es':4,'it':5, 'zh':0}
        return lid_label_dicts[language]
    
    def tokenize_text(self, text, langs):
        self.tokenizer.model_input_names = ["input_ids", "attention_mask"]
        lang_ids = [self.tokenizer.added_tokens_encoder[f'<|{l}|>'] for l in langs]
        text = self.tokenizer(text, return_tensors="pt", padding=True,
            truncation=True, max_length=self.max_length)
        text['input_ids'][:, 1] = torch.tensor(lang_ids).long()
        text['input_ids'].masked_fill_(text['attention_mask'] == 0, 50258)
        text_lens = text['attention_mask'].sum(dim=-1) - 1 # exclude eos
        return text['input_ids'], text_lens

    def __call__(self, batch):
        wavs = [b["wav"] for b in batch]
        mels = [b["mel_trunc"] for b in batch]
        f_mels = [b["mel_full"] for b in batch]
        src_txt = [b["src_text"] for b in batch]
        tgt_txt = [b["tgt_text"] for b in batch]
        src_lang = [b["src_lang"] for b in batch]
        tgt_lang = [b["tgt_lang"] for b in batch]

        splits = [f"{s}-{t}" for s, t in zip(src_lang, tgt_lang)]

        src_text, src_text_lens = self.tokenize_text(src_txt, src_lang)
        tgt_text, tgt_text_lens = self.tokenize_text(tgt_txt, tgt_lang)

        self.tokenizer.model_input_names = ["mel", "mel_lens"]
        inputs = self.tokenizer.pad(mels, return_tensors="pt", padding='longest',max_length=1500)

        inputs['ast_ids'] = tgt_text
        inputs['ast_lens'] = tgt_text_lens
        inputs['mel_lens'] = inputs['mel_lens'].squeeze(-1)
        inputs['mel'] = inputs['mel'][:, :, :inputs['mel_lens'].max()]
        inputs = inputs.data
        inputs['text'] = src_text[:, 1:] # exclude bos
        inputs['text_lens'] = src_text_lens

        f_mels = self.tokenizer.pad(f_mels, return_tensors="pt", padding='longest',max_length=1500)
        inputs['f_mel'] = f_mels['mel'][:, :, :f_mels['mel_lens'].max()]
        inputs['f_mel_lens'] = f_mels['mel_lens'].squeeze(-1)



        return {
            'inputs': inputs,
            'src_txt': src_txt,
            'tgt_txt': tgt_txt,
            'splits': splits,
            'wavs': wavs,
        }



# class MTCollator:
#     def __init__(self, tokenizer, max_length=400, ignore_id=-100):
#         self.tokenizer = tokenizer
#         self.max_length = max_length
#         self.ignore_id = ignore_id

#     def __call__(self, batch):
#         src_txt = [b["input"] for b in batch]
#         tgt_txt = [b["label"] for b in batch]
#         src_lang = [b["src_lang"] for b in batch]
#         tgt_lang = [b["tgt_lang"] for b in batch]
#         src_ids = torch.tensor([LANGUAGES2ID[l] for l in src_lang]).long()
#         tgt_ids = torch.tensor([LANGUAGES2ID[l] for l in tgt_lang]).long()
#         splits = [f"{s}-{t}" for s, t in zip(src_lang, tgt_lang)]
#         inputs = tokenize(
#             self.tokenizer,
#             src_txt,
#             src_ids,
#             text_target=tgt_txt,
#             tgt_ids=tgt_ids,
#             max_length=self.max_length,
#             ignore_id=self.ignore_id
#         )

#         return {
#             'inputs': inputs,
#             'src_txt': src_txt,
#             'tgt_txt': tgt_txt,
#             'splits': splits
#         }
