import torch
import random
import numpy as np
import re
import evaluate


def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """Make mask tensor containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask


def subsequent_chunk_mask(
    size: int,
    chunk_size: int,
    num_left_chunks: int = -1,
    device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Create mask for subsequent steps (size, size) with chunk size,
       this is for streaming encoder

    Args:
        size (int): size of mask
        chunk_size (int): size of chunk
        num_left_chunks (int): number of left chunks
            <0: use full chunk
            >=0: use num_left_chunks
        device (torch.device): "cpu" or "cuda" or torch.Tensor.device

    Returns:
        torch.Tensor: mask

    Examples:
        >>> subsequent_chunk_mask(4, 2)
        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]
    """
    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if num_left_chunks < 0:
            start = 0
        else:
            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
        ending = min((i // chunk_size + 1) * chunk_size, size)
        ret[i, start:ending] = True
    return ret


def add_optional_chunk_mask(
    xs: torch.Tensor,
    use_dynamic_chunk: bool = True,
    use_dynamic_left_chunk: bool = False,
    decoding_chunk_size: int = 0,
    static_chunk_size: int = 32,
    num_decoding_left_chunks: int = -1,
):
    """Apply optional mask for encoder.

    Args:
        xs (torch.Tensor): padded input, (B, L, D), L for max length
        mask (torch.Tensor): mask for xs, (B, 1, L)
        use_dynamic_chunk (bool): whether to use dynamic chunk or not
        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        static_chunk_size (int): chunk size for static chunk training/decoding
            if it's greater than 0, if use_dynamic_chunk is true,
            this parameter will be ignored
        num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks
        enable_full_context (bool):
            True: chunk size is either [1, max_chunk_size] or full context(max_len)
            False: chunk size ~ U[1, max_chunk_size]

    Returns:
        torch.Tensor: chunk mask of the input xs.
    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            chunk_size = torch.randint(1, max_len, (1,)).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:  #
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 51 + 20  # chunk_size 1000
                # chunk_size = 1
                if use_dynamic_left_chunk:
                    try:
                        max_left_chunks = (max_len - 1) // chunk_size
                        num_left_chunks = torch.randint(9, max_left_chunks, (1,)).item()
                    except:
                        pass
    else:
        num_left_chunks = num_decoding_left_chunks
        chunk_size = static_chunk_size
    chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, num_left_chunks, xs.device)  # (L, L)

    chunk_masks = 1.0 - chunk_masks[None, None, :, :].to(dtype=xs.dtype)
    chunk_masks = chunk_masks * torch.finfo(xs.dtype).min
    # chunk_masks = chunk_masks.expand(xs.size(0), 1, -1, -1)  # (B, 1, L, L)

    return chunk_masks


def speed_augment(waveform, prob=1.0, speeds=None):
    if random.random() < prob:
        if speeds is None:
            speeds = (0.9, 1.1)
        speed = random.uniform(speeds[0], speeds[1])
        new_length = int(speed * waveform.shape[-1])

        # 使用 NumPy 进行插值
        x_old = np.linspace(0, 1, waveform.shape[-1])
        x_new = np.linspace(0, 1, new_length)
        waveform = np.interp(x_new, x_old, waveform)

    return waveform


def volume_augment(waveform, prob=1.0, volumes=None):
    if random.random() < prob:
        if volumes is None:
            volumes = (-5, 5)
        gain = random.uniform(volumes[0], volumes[1])
        gain = 10.0 ** (gain / 20.0)
        waveform = waveform * gain
    return waveform


PATTERN_PUNCTUATION = re.compile("[^\w\s]")
PATTERN_ENGLISH = re.compile("[a-zA-Z]+")
PATTERN_CHINESE = re.compile("([\u4e00-\u9fff])")


class CalculateBLEU(object):
    def __init__(self):
        pass

    def rewrite(self, source_txt, target_txt):
        with open(source_txt, "r", encoding="utf-8") as f:
            lines = f.readlines()
            text = []
            for line in lines:
                line = line.strip()
                if bool(re.search("\t", line)):
                    line_ls = line.split("\t")
                elif bool(re.search(" ", line)):
                    line = "\t".join(line.split(" ", 1))
                    line_ls = line.split("\t")
                else:
                    print(line)
                    continue
                txt = re.sub("\\(.*?\\)|\\{.*?}|\\[.*?]|\\【.*?】|\\<.*?>", "", line_ls[-1])
                txt = re.sub("▁", " ", txt)
                txt = re.sub(PATTERN_PUNCTUATION, " ", txt)
                txt = txt.lower()
                txt = re.sub(" +", " ", txt)
                chars = PATTERN_CHINESE.split(txt)
                mix_chars = [w for w in chars if len(w.strip()) > 0]
                txt = line_ls[0] + "\t" + " ".join(mix_chars)
                text.append(txt)
        with open(target_txt, "w", encoding="utf-8") as fw:
            fw.write("\n".join(text))

    def cn_norm(self, txt):
        # PATTERN_CHINESE = re.compile("([\u4e00-\u9fff])")
        chars = PATTERN_CHINESE.split(txt)
        mix_chars = [w for w in chars if len(w.strip()) > 0]
        new_txt = " ".join(mix_chars)
        return new_txt

    def evaluate_bleu(self, asr_path, ref_path, hyp_path, bleu_path):
        asr_txt = open(asr_path, "r", encoding="UTF-8").readlines()
        ref_txt = open(ref_path, "r", encoding="UTF-8").readlines()
        hyp_txt = open(hyp_path, "r", encoding="UTF-8").readlines()
        bleu = evaluate.load("/opt/nas/n/xiayinfeng/Whisper/metrics/bleu")
        dicts = {}
        for line in asr_txt:
            ids, label = line.strip("\n").split("\t")
            label = self.cn_norm(label)
            dicts[ids] = [label]

        for line in ref_txt:
            ids, label = line.strip("\n").split("\t")
            label = self.cn_norm(label)
            if len(label) > 0:
                try:
                    dicts[ids].append(self.cn_norm(label))
                except:
                    print(f"{ids} not in asr text")

        for line in hyp_txt:
            ids, label, al = line.strip("\n").split("\t")
            label = self.cn_norm(label)
            if len(label) > 0:
                try:
                    dicts[ids].append(self.cn_norm(label))
                    dicts[ids].append(float(al))
                except:
                    print(f"{ids} not in asr text")

        key_lists = []
        for key in dicts.keys():
            if len(dicts[key]) == 4:
                key_lists.append(key)

        ref_lists = []
        hyp_lists = []
        all_als = []

        with open(bleu_path, "w", encoding="utf-8") as fout:
            for key in key_lists:
                asr, ref, hyp, al = dicts[key]
                ref_lists.append([ref])
                hyp_lists.append(hyp)
                all_als.append(al)
                # bleu_score = bleu.compute(references=[[ref]], predictions=[hyp], smooth=True)['bleu']
                fout.writelines(f"utt: {key} \n")
                fout.writelines(f"BLEU: " + str(format(0.0, ".4f")) + " AL: " + str(format(al, ".2f")) + "\n")
                fout.writelines(f"asr: {asr} \n")
                fout.writelines(f"ref: {ref} \n")
                fout.writelines(f"hyp: {hyp} \n")
                fout.writelines("\n")
                fout.writelines("\n")
            results = bleu.compute(predictions=hyp_lists, references=ref_lists)
            ave_al = sum(all_als) / len(all_als)
            fout.writelines(
                f"BLEUs: " + str(format(results["bleu"], ".4f")) + " AL: " + str(format(ave_al, ".2f")) + "\n"
            )
        return results["bleu"], ave_al
