import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Dict, Iterable, Optional, List


def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """Make mask tensor containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0,
                             max_len,
                             dtype=torch.int64,
                             device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask.to(lengths.device)


class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        # return super().forward(x.float()).type(x.dtype)
        return super().forward(x).type(x.dtype)

class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )

class GELUActivation(nn.Module):
    """
    Original Implementation of the GELU activation function in Google BERT repo when initially created. For
    information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
    torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
    Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, use_gelu_python: bool = False):
        super().__init__()
        if use_gelu_python:
            self.act = self._gelu_python
        else:
            self.act = nn.functional.gelu

    def _gelu_python(self, input: Tensor) -> Tensor:
        return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

    def forward(self, input: Tensor) -> Tensor:
        return self.act(input)

class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )


def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


def subsequent_chunk_mask(
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Create mask for subsequent steps (size, size) with chunk size,
       this is for streaming encoder

    Args:
        size (int): size of mask
        chunk_size (int): size of chunk
        num_left_chunks (int): number of left chunks
            <0: use full chunk
            >=0: use num_left_chunks
        device (torch.device): "cpu" or "cuda" or torch.Tensor.device

    Returns:
        torch.Tensor: mask

    Examples:
        >>> subsequent_chunk_mask(4, 2)
        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]
    """
    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if num_left_chunks < 0:
            start = 0
        else:
            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
        ending = min((i // chunk_size + 1) * chunk_size, size)
        ret[i, start:ending] = True
    return ret



def add_optional_chunk_mask(xs: torch.Tensor, 
                            use_dynamic_chunk: bool = True,
                            use_dynamic_left_chunk: bool = False,
                            decoding_chunk_size: int = 0, 
                            static_chunk_size: int = 16,
                            num_decoding_left_chunks: int = 10):
    """ Apply optional mask for encoder.

    Args:
        xs (torch.Tensor): padded input, (B, L, D), L for max length
        mask (torch.Tensor): mask for xs, (B, 1, L)
        use_dynamic_chunk (bool): whether to use dynamic chunk or not
        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
            training.
        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        static_chunk_size (int): chunk size for static chunk training/decoding
            if it's greater than 0, if use_dynamic_chunk is true,
            this parameter will be ignored
        num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks

    Returns:
        torch.Tensor: chunk mask of the input xs.
    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(-1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            normal_chunk = False
            if normal_chunk:
                chunk_size = int(torch.normal(16.0 , 4.0 ,size=(1,)).item())   #gaussian, 1sigma > 68%, 2 sigma 95%, 3sigma 99%
                if chunk_size <  2 or chunk_size > 24:
                    chunk_size = max_len
            else:
                chunk_size = torch.randint(1, max_len, (1, )).item()
            
            num_left_chunks = -1
            if chunk_size > max_len//1:#
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 64 + 32
                #chunk_size = 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,(1, )).item()

        chunk_masks = subsequent_chunk_mask(xs.size(-1), chunk_size,num_left_chunks,xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(-1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
    else:
        raise ValueError("no chunk setting is specified")
    return ~chunk_masks


def cross_mask_process3(
    rets,  # no_causal attention
    mels_len, # lengths of speech_feature
    all_tokens_num, # nums of all token
    cif_pres, # cif
    prompt_token_num:int = 4, # nums of prompt token
    wait_k: int = 2,
    ):
    b = rets.size(0)
    rets_clone = rets.clone().to(rets.device)
    wait_k = max(wait_k,1)
    
    def make_stream_cross_mask(all_token_num,mel_len,cif_pre,prompt_token_num,wait_k):
        pre_token_num = all_token_num - prompt_token_num
        cro_matrix = torch.clip(torch.arange(1,all_token_num+1).to(rets.device)-prompt_token_num,min=0).unsqueeze(1).expand(-1,mel_len) + wait_k
        cif_matrix = torch.cumsum(cif_pre*(pre_token_num/torch.sum(cif_pre)),dim=0).unsqueeze(0).expand(all_token_num,-1).to(rets.device)      
        mask = torch.zeros((all_token_num,mel_len)).bool().to(rets.device)
        mask[:,1:] = (cro_matrix<=cif_matrix)[:,:-1]
        return mask

    for i in range(b):
        mel_len = mels_len[i]
        all_token_num = all_tokens_num[i]
        cif_pre = cif_pres[i][:mel_len]
        rets_clone[i,:all_token_num,:mel_len] = make_stream_cross_mask(all_token_num,mel_len,cif_pre,prompt_token_num,wait_k)               
    return rets_clone

