import torch
import torch.nn as nn
import torch.nn.functional as FF
from pdb import set_trace as stx
import inspect
from einops import rearrange
from torchsummary import summary
from typing import Optional
from collections import OrderedDict
from typing import Sequence, Tuple, Union, Dict, Optional, Tuple
from .lip_encoder import LipEncoderClassifier
import math
from typing import Dict, List, Optional, Tuple
from torch.nn import init
from torch.nn.parameter import Parameter
from abc import ABC, abstractmethod
import difflib
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor


def is_torch_complex_tensor(c):
    return not isinstance(c, ComplexTensor) and torch.is_complex(c)
def is_complex(c):
    return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
EPS = torch.finfo(torch.double).eps

def new_complex_like(
    ref: Union[torch.Tensor, ComplexTensor],
    real_imag: Tuple[torch.Tensor, torch.Tensor],
):
    if isinstance(ref, ComplexTensor):
        return ComplexTensor(*real_imag)
    elif is_torch_complex_tensor(ref):
        return torch.complex(*real_imag)
    else:
        raise ValueError(
            "Please update your PyTorch version to 1.9+ for complex support."
        )

class AbsSeparator(torch.nn.Module, ABC):
    @abstractmethod
    def forward(
        self,
        input: torch.Tensor,
        ilens: torch.Tensor,
        additional: Optional[Dict] = None,
    ) -> Tuple[Tuple[torch.Tensor], torch.Tensor, OrderedDict]:
        raise NotImplementedError

    def forward_streaming(
        self,
        input_frame: torch.Tensor,
        buffer=None,
    ):
        raise NotImplementedError

    @property
    @abstractmethod
    def num_spk(self):
        raise NotImplementedError

from packaging.version import parse as V

def get_layer(l_name, library=torch.nn):
    """Return layer object handler from library e.g. from torch.nn

    E.g. if l_name=="elu", returns torch.nn.ELU.

    Args:
        l_name (string): Case insensitive name for layer in library (e.g. .'elu').
        library (module): Name of library/module where to search for object handler
        with l_name e.g. "torch.nn".

    Returns:
        layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU)

    """

    all_torch_layers = [x for x in dir(torch.nn)]
    match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
    if len(match) == 0:
        close_matches = difflib.get_close_matches(
            l_name, [x.lower() for x in all_torch_layers]
        )
        raise NotImplementedError(
            "Layer with name {} not found in {}.\n Closest matches: {}".format(
                l_name, str(library), close_matches
            )
        )
    elif len(match) > 1:
        close_matches = difflib.get_close_matches(
            l_name, [x.lower() for x in all_torch_layers]
        )
        raise NotImplementedError(
            "Multiple matchs for layer with name {} not found in {}.\n "
            "All matches: {}".format(l_name, str(library), close_matches)
        )
    else:
        # valid
        layer_handler = getattr(library, match[0])
        return layer_handler
    

class GridNetBlock(nn.Module):
    def __getitem__(self, key):
        return getattr(self, key)

    def __init__(
        self,
        emb_dim,
        emb_ks,
        emb_hs,
        n_freqs,
        hidden_channels,
        n_head=4,
        approx_qk_dim=512,
        activation="prelu",
        eps=1e-5,
    ):
        super().__init__()

        # self.att_freq = nn.MultiheadAttention(emb_dim, n_head, batch_first=True)
        # self.norm_att_freq = nn.LayerNorm(emb_dim, eps=eps)
        # self.att_time = nn.MultiheadAttention(emb_dim, n_head, batch_first=True)
        # self.norm_att_time = nn.LayerNorm(emb_dim, eps=eps)

        in_channels = emb_dim * emb_ks

        self.intra_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_rnn = nn.LSTM(
            in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
        )

        self.inter_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_rnn = nn.LSTM(
            in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
        )

        E = math.ceil(
            approx_qk_dim * 1.0 / n_freqs
        )  # approx_qk_dim is only approximate
        assert emb_dim % n_head == 0
        for ii in range(n_head):
            self.add_module(
                "attn_conv_Q_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_K_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_V_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, emb_dim // n_head, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps),
                ),
            )
        self.add_module(
            "attn_concat_proj",
            nn.Sequential(
                nn.Conv2d(emb_dim, emb_dim, 1),
                get_layer(activation)(),
                LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
            ),
        )
        # self.FFN = FeedForward(emb_dim, 4, bias=True)  # 64*4  256

        self.emb_dim = emb_dim
        self.emb_ks = emb_ks
        self.emb_hs = emb_hs
        self.n_head = n_head

    def forward(self, x):
        """GridNetBlock Forward.

        Args:
            x: [B, C, T, Q]
            out: [B, C, T, Q]
        """
        B, C, old_T, old_Q = x.shape
        T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        x = FF.pad(x, (0, Q - old_Q, 0, T - old_T))

        # intra RNN
        input_ = x
        intra_rnn = self.intra_norm(input_)  # [B, C, T, Q]
        intra_rnn = (
            intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q)
        )  # [BT, C, Q]
        # branch 1
        intra_rnn = FF.unfold(
            intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
        )  # [BT, C*emb_ks, -1]
        intra_rnn = intra_rnn.transpose(1, 2)  # [BT, -1, C*emb_ks]
        print(intra_rnn.shape)
        intra_rnn, _ = self.intra_rnn(intra_rnn)  # [BT, -1, H]
        print(intra_rnn.shape)
        intra_rnn = intra_rnn.transpose(1, 2)  # [BT, H, -1]
        intra_rnn = self.intra_linear(intra_rnn)  # [BT, C, Q]
        intra_rnn = intra_rnn.view([B, T, C, Q])
        intra_rnn = intra_rnn.transpose(1, 2).contiguous()  # [B, C, T, Q]

        intra_rnn = intra_rnn + input_  # [B, C, T, Q]

        # inter RNN
        input_ = intra_rnn
        inter_rnn = self.inter_norm(input_)  # [B, C, T, F]
        inter_rnn = (
            inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        )  # [BF, C, T]

        # inter_rnn = inter_rnn.transpose(1, 2)
        # att_out_time, _ = self.att_time(inter_rnn, inter_rnn, inter_rnn)
        # att_out_time = att_out_time + inter_rnn
        # att_out_time = self.norm_att_time(att_out_time)
        # inter_rnn = att_out_time.transpose(1, 2)

        inter_rnn = FF.unfold(
            inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
        )  # [BF, C*emb_ks, -1]
        inter_rnn = inter_rnn.transpose(1, 2)  # [BF, -1, C*emb_ks]
        inter_rnn, _ = self.inter_rnn(inter_rnn)  # [BF, -1, H]
        inter_rnn = inter_rnn.transpose(1, 2)  # [BF, H, -1]
        inter_rnn = self.inter_linear(inter_rnn)  # [BF, C, T]
        inter_rnn = inter_rnn.view([B, Q, C, T])
        inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous()  # [B, C, T, Q]
        inter_rnn = inter_rnn + input_  # [B, C, T, Q]

        # attention
        inter_rnn = inter_rnn[..., :old_T, :old_Q]
        batch = inter_rnn

        all_Q, all_K, all_V = [], [], []
        for ii in range(self.n_head):
            all_Q.append(self["attn_conv_Q_%d" % ii](batch))  # [B, C, T, Q]
            all_K.append(self["attn_conv_K_%d" % ii](batch))  # [B, C, T, Q]
            all_V.append(self["attn_conv_V_%d" % ii](batch))  # [B, C, T, Q]

        Q = torch.cat(all_Q, dim=0)  # [B', C, T, Q]
        K = torch.cat(all_K, dim=0)  # [B', C, T, Q]
        V = torch.cat(all_V, dim=0)  # [B', C, T, Q]

        Q = Q.transpose(1, 2)
        Q = Q.flatten(start_dim=2)  # [B', T, C*Q]
        K = K.transpose(1, 2)
        K = K.flatten(start_dim=2)  # [B', T, C*Q]
        V = V.transpose(1, 2)  # [B', T, C, Q]
        old_shape = V.shape
        V = V.flatten(start_dim=2)  # [B', T, C*Q]
        emb_dim = Q.shape[-1]

        attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5)  # [B', T, T]
        attn_mat = FF.softmax(attn_mat, dim=2)  # [B', T, T]
        V = torch.matmul(attn_mat, V)  # [B', T, C*Q]

        V = V.reshape(old_shape)  # [B', T, C, Q]
        V = V.transpose(1, 2)  # [B', C, T, Q]
        emb_dim = V.shape[1]

        batch = V.view([self.n_head, B, emb_dim, old_T, -1])  # [n_head, B, C, T, Q])
        batch = batch.transpose(0, 1)  # [B, n_head, C, T, Q])
        batch = batch.contiguous().view(
            [B, self.n_head * emb_dim, old_T, -1]
        )  # [B, C, T, Q])
        batch = self["attn_concat_proj"](batch)  # [B, C, T, Q])

        out = batch + inter_rnn
        return out


class MultiRangeGridNetBlock(nn.Module):
    def __getitem__(self, key):
        return getattr(self, key)

    def __init__(
        self,
        emb_dim,
        emb_ks,
        emb_hs,
        n_freqs,
        hidden_channels,
        n_head=4,
        approx_qk_dim=512,
        activation="prelu",
        eps=1e-5,
    ):
        super().__init__()
        
        # -------------------- intra --------------------
        # Branch 1: Do not use unfold, directly process the original sequence with LSTM
        self.intra_branch1_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_branch1_rnn = nn.LSTM(
            emb_dim, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_branch1_linear = nn.Conv1d(
            hidden_channels * 2, emb_dim, kernel_size=1
        )
        
        # Branch 2: Use unfold, ks=4, hs=1
        self.intra_branch2_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_branch2_rnn = nn.LSTM(
            emb_dim * 4, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_branch2_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 4, stride=1
        )
        
        # 3：unfold，ks=8, hs=1
        self.intra_branch3_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_branch3_rnn = nn.LSTM(
            emb_dim * 8, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_branch3_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 8, stride=1
        )
        

        self.intra_fusion_conv = nn.Conv2d(emb_dim * 3, emb_dim, kernel_size=1)
        self.intra_fusion_norm = LayerNormalization4D(emb_dim, eps=eps)
        
        # -------------------- inter --------------------
        # Branch 1: Do not use unfold, directly process the original sequence with LSTM
        self.inter_branch1_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_branch1_rnn = nn.LSTM(
            emb_dim, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_branch1_linear = nn.Conv1d(
            hidden_channels * 2, emb_dim, kernel_size=1
        )
        
        # Branch 2: Use unfold, ks=4, hs=1
        self.inter_branch2_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_branch2_rnn = nn.LSTM(
            emb_dim * 4, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_branch2_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 4, stride=1
        )
        
        # 3：unfold，ks=8, hs=1
        self.inter_branch3_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_branch3_rnn = nn.LSTM(
            emb_dim * 8, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_branch3_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 8, stride=1
        )

        

        self.inter_fusion_conv = nn.Conv2d(emb_dim * 3, emb_dim, kernel_size=1)
        self.inter_fusion_norm = LayerNormalization4D(emb_dim, eps=eps)
        

        E = math.ceil(approx_qk_dim * 1.0 / n_freqs)
        assert emb_dim % n_head == 0
        for ii in range(n_head):
            self.add_module(
                "attn_conv_Q_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_K_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_V_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, emb_dim // n_head, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps),
                ),
            )
        self.add_module(
            "attn_concat_proj",
            nn.Sequential(
                nn.Conv2d(emb_dim, emb_dim, 1),
                get_layer(activation)(),
                LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
            ),
        )
        
        self.emb_dim = emb_dim
        self.n_head = n_head
        self.emb_ks = 8
        self.emb_hs = 1

    def forward(self, x):
        B, C, old_T, old_Q = x.shape
        

        T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        x = FF.pad(x, (0, Q - old_Q, 0, T - old_T))
        
    
        residual = x
        
        # -------------------- intra --------------------
        intra_input = x
        
 
        intra_b1 = self.intra_branch1_norm(intra_input)
        intra_b1 = intra_b1.permute(0, 2, 3, 1).contiguous().view(B * T, Q, C)
        intra_b1, _ = self.intra_branch1_rnn(intra_b1)  # [BT, Q, H*2]
        intra_b1 = intra_b1.transpose(1, 2)  # [BT, H*2, Q]
        intra_b1 = self.intra_branch1_linear(intra_b1)  # [BT, C, Q]
        intra_b1 = intra_b1.view(B, T, C, Q).permute(0, 2, 1, 3).contiguous()
        
        
        intra_b2 = self.intra_branch2_norm(intra_input)
        intra_b2 = intra_b2.transpose(1, 2).contiguous().view(B * T, C, Q)
        intra_b2 = FF.unfold(intra_b2[..., None], (4, 1), stride=(1, 1))
        intra_b2 = intra_b2.transpose(1, 2)  # [BT, -1, C*4]
        intra_b2, _ = self.intra_branch2_rnn(intra_b2)
        intra_b2 = intra_b2.transpose(1, 2)
        intra_b2 = self.intra_branch2_linear(intra_b2)  # [BT, C, Q]
        intra_b2 = intra_b2.view(B, T, C, Q).transpose(1, 2).contiguous()
        
        
        intra_b3 = self.intra_branch3_norm(intra_input)
        intra_b3 = intra_b3.transpose(1, 2).contiguous().view(B * T, C, Q)
        intra_b3 = FF.unfold(intra_b3[..., None], (8, 1), stride=(1, 1))
        intra_b3 = intra_b3.transpose(1, 2)  # [BT, -1, C*8]
        intra_b3, _ = self.intra_branch3_rnn(intra_b3)
        intra_b3 = intra_b3.transpose(1, 2)
        intra_b3 = self.intra_branch3_linear(intra_b3)  # [BT, C, Q]
        intra_b3 = intra_b3.view(B, T, C, Q).transpose(1, 2).contiguous()
        

        intra_concat = torch.cat([intra_b1, intra_b2, intra_b3], dim=1)
        intra_fused = self.intra_fusion_conv(intra_concat)
        intra_fused = self.intra_fusion_norm(intra_fused)
        
        
        intra_output = intra_fused + intra_input
        
        # -------------------- inter --------------------
        inter_input = intra_output
        
        # 
        inter_b1 = self.inter_branch1_norm(inter_input)
        inter_b1 = inter_b1.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        inter_b1 = inter_b1.transpose(1, 2)
        inter_b1, _ = self.inter_branch1_rnn(inter_b1)  # [BQ, T, H*2]
        inter_b1 = inter_b1.transpose(1, 2)  # [BQ, H*2, T]
        inter_b1 = self.inter_branch1_linear(inter_b1)  # [BQ, C, T]
        inter_b1 = inter_b1.view(B, Q, C, T).permute(0, 2, 3, 1).contiguous()
        
        # 
        inter_b2 = self.inter_branch2_norm(inter_input)
        inter_b2 = inter_b2.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        inter_b2 = FF.unfold(inter_b2[..., None], (4, 1), stride=(1, 1))
        inter_b2 = inter_b2.transpose(1, 2)  # [BQ, -1, C*4]
        inter_b2, _ = self.inter_branch2_rnn(inter_b2)
        inter_b2 = inter_b2.transpose(1, 2)
        inter_b2 = self.inter_branch2_linear(inter_b2)  # [BQ, C, T]
        inter_b2 = inter_b2.view(B, Q, C, T).permute(0, 2, 3, 1).contiguous()
        
        # 
        inter_b3 = self.inter_branch3_norm(inter_input)
        inter_b3 = inter_b3.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        inter_b3 = FF.unfold(inter_b3[..., None], (8, 1), stride=(1, 1))
        inter_b3 = inter_b3.transpose(1, 2)  # [BQ, -1, C*8]
        inter_b3, _ = self.inter_branch3_rnn(inter_b3)
        inter_b3 = inter_b3.transpose(1, 2)
        inter_b3 = self.inter_branch3_linear(inter_b3)  # [BQ, C, T]
        inter_b3 = inter_b3.view(B, Q, C, T).permute(0, 2, 3, 1).contiguous()
        
        inter_concat = torch.cat([inter_b1, inter_b2, inter_b3], dim=1)
        inter_fused = self.inter_fusion_conv(inter_concat)
        inter_fused = self.inter_fusion_norm(inter_fused)
        
        # 
        inter_output = inter_fused + inter_input
        
        # 
        inter_output = inter_output[..., :old_T, :old_Q]
        
        # 
        batch = inter_output
        
        all_Q, all_K, all_V = [], [], []
        for ii in range(self.n_head):
            all_Q.append(self["attn_conv_Q_%d" % ii](batch))
            all_K.append(self["attn_conv_K_%d" % ii](batch))
            all_V.append(self["attn_conv_V_%d" % ii](batch))

        Q = torch.cat(all_Q, dim=0)
        K = torch.cat(all_K, dim=0)
        V = torch.cat(all_V, dim=0)

        Q = Q.transpose(1, 2)
        Q = Q.flatten(start_dim=2)
        K = K.transpose(1, 2)
        K = K.flatten(start_dim=2)
        V = V.transpose(1, 2)
        old_shape = V.shape
        V = V.flatten(start_dim=2)
        emb_dim = Q.shape[-1]

        attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5)
        attn_mat = FF.softmax(attn_mat, dim=2)
        V = torch.matmul(attn_mat, V)

        V = V.reshape(old_shape)
        V = V.transpose(1, 2)
        emb_dim = V.shape[1]

        batch = V.view([self.n_head, B, emb_dim, old_T, -1])
        batch = batch.transpose(0, 1)
        batch = batch.contiguous().view([B, self.n_head * emb_dim, old_T, -1])
        batch = self["attn_concat_proj"](batch)

        # 
        out = batch + residual[..., :old_T, :old_Q]
        return out

class LayerNormalization4D(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        param_size = [1, input_dimension, 1, 1]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        if x.ndim == 4:
            _, C, _, _ = x.shape
            stat_dim = (1,)
        else:
            raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,F]
        std_ = torch.sqrt(
            x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
        )  # [B,1,T,F]
        x_hat = ((x - mu_) / std_) * self.gamma + self.beta
        return x_hat


class LayerNormalization4DCF(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        assert len(input_dimension) == 2
        param_size = [1, input_dimension[0], 1, input_dimension[1]]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        if x.ndim == 4:
            stat_dim = (1, 3)
        else:
            raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,1]
        std_ = torch.sqrt(
            x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
        )  # [B,1,T,F]
        x_hat = ((x - mu_) / std_) * self.gamma + self.beta
        return x_hat


def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class BaseEncoder(nn.Module):
    def unsqueeze_to_3D(self, x: torch.Tensor):
        if x.ndim == 1:
            return x.reshape(1, 1, -1)
        elif x.ndim == 2:
            return x.unsqueeze(1)
        else:
            return x

    def unsqueeze_to_2D(self, x: torch.Tensor):
        if x.ndim == 1:
            return x.reshape(1, -1)
        elif len(s := x.shape) == 3:
            assert s[1] == 1
            return x.reshape(s[0], -1)
        else:
            return x

    def pad(self, x: torch.Tensor, lcm: int):
        values_to_pad = int(x.shape[-1]) % lcm
        if values_to_pad:
            appropriate_shape = x.shape
            padding = torch.zeros(
                list(appropriate_shape[:-1]) + [lcm - values_to_pad],
                dtype=x.dtype,
                device=x.device,
            )
            padded_x = torch.cat([x, padding], dim=-1)
            return padded_x
        else:
            return x

    def get_out_chan(self):
        return self.out_chan

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def get_config(self):
        encoder_args = {}

        for k, v in (self.__dict__).items():
            if not k.startswith("_") and k != "training":
                if not inspect.ismethod(v):
                    encoder_args[k] = v

        return encoder_args

class STFTEncoder(BaseEncoder):
    def __init__(
        self,
        win_length: int = 512,         # 32ms @16kHz
        hop_length: int = 128,         # 8ms @16kHz
        n_fft: int = 256,              # FFT 
        bias: bool = False,
        *args,
        **kwargs,
    ):
        super(STFTEncoder, self).__init__()
        self.win_length = win_length
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.bias = bias

        hann = torch.hann_window(self.win_length)
        sqrt_hann = torch.sqrt(hann)
        self.register_buffer("window", sqrt_hann, persistent=False)

    def forward(self, x: torch.Tensor):
        x = self.unsqueeze_to_2D(x)

        spec = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window.to(x.device),
            return_complex=True,
        )

        # B, 2, T, F
        spec = torch.stack([spec.real, spec.imag], 1).transpose(2, 3).contiguous()
        return spec


class BaseDecoder(nn.Module):
    def pad_to_input_length(self, separated_audio, input_frames):
        output_frames = separated_audio.shape[-1]
        return nn.functional.pad(separated_audio, [0, input_frames - output_frames])

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def get_config(self):
        encoder_args = {}

        for k, v in (self.__dict__).items():
            if not k.startswith("_") and k != "training":
                if not inspect.ismethod(v):
                    encoder_args[k] = v

        return encoder_args

class STFTDecoder_1(BaseDecoder):
    def __init__(
        self,
        win_length: int = 512,
        hop_length: int = 128,
        n_fft: int = 256,
        in_chan: int = 2,
        n_src: int = 2,
        kernel_size: int = -1,
        stride: int = 1,
        bias: bool = False,
        *args,
        **kwargs,
    ):
        super(STFTDecoder_1, self).__init__()
        self.win_length = win_length
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.in_chan = in_chan
        self.n_src = n_src
        self.kernel_size = kernel_size
        self.padding = (self.kernel_size - 1) // 2 if self.kernel_size > 0 else 0
        self.stride = stride
        self.bias = bias

        hann = torch.hann_window(self.win_length)
        sqrt_hann = torch.sqrt(hann)
        self.register_buffer("window", sqrt_hann, persistent=False)

    def forward(self, x, input_shape):
        B, n_src, N, T, F = x.shape
        batch_size, length = input_shape.shape[0], input_shape.shape[-1]
        
        x = x.view(B * n_src, N, T, F)  # (B*n_src, 2, T, F)
        spec = torch.complex(x[:, 0], x[:, 1])  # complex spectrum
        spec = spec.transpose(1, 2).contiguous()  # (B*n_src, F, T)

        output = torch.istft(
            spec,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window.to(x.device),
            length=length,
        )  # (B*n_src, L)

        output = output.view(batch_size, self.n_src, length)  # (B, n_src, L)

        return output


def _get_activation_fn(activation):
    if activation == "relu":
        return FF.relu
    elif activation == "gelu":
        return FF.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))


class multi_OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=2, embed_dim=64, bias=False):
        super(multi_OverlapPatchEmbed, self).__init__()

        # self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
        self.conv_1x1 = nn.Conv2d(in_c, embed_dim, kernel_size=1, dilation=1, padding=0)
        self.conv_3x3_d1 = nn.Conv2d(in_c, embed_dim, kernel_size=3, dilation=1, padding=1)
        self.conv_3x3_d2 = nn.Conv2d(in_c, embed_dim, kernel_size=3, dilation=2, padding=2)
        self.conv_3x3_d3 = nn.Conv2d(in_c, embed_dim, kernel_size=3, dilation=3, padding=3)

    def forward(self, x):
        out1 = self.conv_1x1(x)
        out2 = self.conv_3x3_d1(x)
        out3 = self.conv_3x3_d2(x)
        out4 = self.conv_3x3_d3(x)
        out = torch.cat([out1, out2, out3, out4], dim=1)  # 256
        return out



class multi_OverlapPatchEmbed_v(nn.Module):
    def __init__(self, in_c=1024, embed_dim=64, bias=False):
        super(multi_OverlapPatchEmbed_v, self).__init__()

        # self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
        self.conv_1x1 = nn.Conv1d(in_c, embed_dim, kernel_size=1, dilation=1, padding=0)
        self.conv_3x3_d1 = nn.Conv1d(in_c, embed_dim, kernel_size=3, dilation=1, padding=1)
        self.conv_3x3_d2 = nn.Conv1d(in_c, embed_dim, kernel_size=3, dilation=2, padding=2)
        self.conv_3x3_d3 = nn.Conv1d(in_c, embed_dim, kernel_size=3, dilation=3, padding=3)

    def forward(self, x):
        out1 = self.conv_1x1(x)
        out2 = self.conv_3x3_d1(x)
        out3 = self.conv_3x3_d2(x)
        out4 = self.conv_3x3_d3(x)
        out = torch.cat([out1, out2, out3, out4], dim=1)  # 256
        return out

class MultiRangeConv2d(nn.Module):
    def __init__(self, in_channels, out_channels_each):
        super(MultiRangeConv2d, self).__init__()
        
        self.conv_1x1 = nn.Conv2d(in_channels, out_channels_each, kernel_size=1, dilation=1, padding=0)
        self.conv_3x3_d1 = nn.Conv2d(in_channels, out_channels_each, kernel_size=3, dilation=1, padding=1)
        self.conv_3x3_d2 = nn.Conv2d(in_channels, out_channels_each, kernel_size=3, dilation=2, padding=2)
        self.conv_3x3_d3 = nn.Conv2d(in_channels, out_channels_each, kernel_size=3, dilation=3, padding=3)

    def forward(self, x):
        out1 = self.conv_1x1(x)
        out2 = self.conv_3x3_d1(x)
        out3 = self.conv_3x3_d2(x)
        out4 = self.conv_3x3_d3(x)

        out = torch.cat([out1, out2, out3, out4], dim=1)  # concatenate on channel dim
        return out

class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps

        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_tensors
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)

        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
            dim=0), None

class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
         #  weight  bias  x 
        # weight = self.weight.to(x.device)
        # bias = self.bias.to(x.device)
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


from typing import *
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Parameter, init
from einops import rearrange
from einops.layers.torch import Rearrange
import math 
import difflib
from torch.nn import *
import math
import numpy as np

class LayerNorm(nn.LayerNorm):

    def __init__(self, seq_last: bool, **kwargs) -> None:
        # """
        # Arg s:
        #     seq_last (bool): whether the sequence dim is the last dim
        # """
        super().__init__(**kwargs)
        self.seq_last = seq_last

    def forward(self, input: Tensor) -> Tensor:
        if self.seq_last:
            input = input.transpose(-1, 1)  # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H]
        o = super().forward(input)
        if self.seq_last:
            o = o.transpose(-1, 1)
        return o


class GlobalLayerNorm(nn.Module):

    def __init__(self, dim_hidden: int, seq_last: bool, eps: float = 1e-5) -> None:
        super().__init__()
        self.dim_hidden = dim_hidden
        self.seq_last = seq_last
        self.eps = eps

        if seq_last:
            self.weight = Parameter(torch.empty([dim_hidden, 1]))
            self.bias = Parameter(torch.empty([dim_hidden, 1]))
        else:
            self.weight = Parameter(torch.empty([dim_hidden]))
            self.bias = Parameter(torch.empty([dim_hidden]))
        init.ones_(self.weight)
        init.zeros_(self.bias)

    def forward(self, input: Tensor) -> Tensor:
        # """
        # Args:
        #     input (Tensor): shape [B, Seq, H] or [B, H, Seq]
        # """
        var, mean = torch.var_mean(input, dim=(1, 2), unbiased=False, keepdim=True)

        output = (input - mean) / torch.sqrt(var + self.eps)
        output = output * self.weight + self.bias
        return output

    def extra_repr(self) -> str:
        return '{dim_hidden}, seq_last={seq_last}, eps={eps}'.format(**self.__dict__)


class LayerNormalization4D(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        param_size = [1, input_dimension, 1, 1]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        with torch.autocast(device_type = "cuda", enabled = False):
            if x.ndim == 4:
                _, C, _, _ = x.shape
                stat_dim = (1,)
            else:
                raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
            mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,F]
            std_ = torch.sqrt(torch.clamp(x.var(dim=stat_dim, unbiased=False, keepdim=True), self.eps))  # [B,1,T,F]
            x_hat = (x - mu_) / (std_ )
                
            x_hat = x_hat * self.gamma + self.beta

            return x_hat


class LayerNormalization4DCF(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        assert len(input_dimension) == 2
        param_size = [1, input_dimension[0], 1, input_dimension[1]]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        with torch.autocast(device_type = "cuda", enabled = False):
            
            if x.ndim == 4:
                stat_dim = (1, 3)
            else:
                raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
            mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,1]
            std_ = torch.sqrt(torch.clamp(x.var(dim=stat_dim, unbiased=False, keepdim=True), self.eps))  # [B,1,T,F]
            x_hat = (x - mu_) / (std_)
            
            x_hat = x_hat * self.gamma + self.beta
            
            
            return x_hat

import torch as th
import torch.nn as nn


class ResBlock(nn.Module):
    """
    Resnet block for speaker encoder to obtain speaker embedding
    ref to 
        https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py
        and
        https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py
    """
    def __init__(self, in_dims, out_dims):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)
        self.batch_norm1 = nn.BatchNorm1d(out_dims)
        self.batch_norm2 = nn.BatchNorm1d(out_dims)
        self.prelu1 = nn.PReLU()
        self.prelu2 = nn.PReLU()
        self.maxpool = nn.MaxPool1d(3)
        if in_dims != out_dims:
            self.downsample = True
            self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
        else:
            self.downsample = False

    def forward(self, x):
        y = self.conv1(x)
        y = self.batch_norm1(y)
        y = self.prelu1(y)
        y = self.conv2(y)
        y = self.batch_norm2(y)
        if self.downsample:
            y += self.conv_downsample(x)
        else:
            y += x
        y = self.prelu2(y)
        return self.maxpool(y)

class ChannelwiseLayerNorm(nn.LayerNorm):
    """
    Channel-wise layer normalization based on nn.LayerNorm
    Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
    Output: 3D tensor with same shape
    """

    def __init__(self, *args, **kwargs):
        super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs)

    def forward(self, x):
        if x.dim() != 3:
            raise RuntimeError("{} requires a 3D tensor input".format(
                self.__name__))
        x = th.transpose(x, 1, 2)
        x = super().forward(x)
        x = th.transpose(x, 1, 2)
        return x


class tfgridnet_v2(nn.Module):
    def __init__(self, 
        num_layers = 6,  # 12
        win = 256,
        hop_length = 128,
        n_fft = 256,
        inp_channels=2, 
        out_channels=2, 
        dim = 48,
        bias = False,
        vpre_channels=768,
        num_source = 2,
        lstm_hidden_units=128,
        attn_n_head=4,
        attn_approx_qk_dim=512,
        emb_dim=48,
        emb_ks=4,
        emb_hs=1,
        activation="prelu",
        eps=1.0e-5,
        
    ):

        super(tfgridnet_v2, self).__init__()
        self.num_source = num_source
        self.win = win
        self.hop_length = hop_length
        self.num_layers = num_layers
        assert n_fft % 2 == 0
        n_freqs = n_fft // 2 + 1

        self.pre_v = multi_OverlapPatchEmbed_v(vpre_channels, dim)  # 192
        self.stft_encoder = STFTEncoder(win, hop_length, n_fft, bias)
        # self.patch_embed = multi_OverlapPatchEmbed(inp_channels, dim)  # 192  #
        t_ksize = 3
        ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
        self.conv = nn.Sequential(
            # nn.Conv2d(inp_channels, emb_dim, ks, padding=padding),  # 48
            multi_OverlapPatchEmbed(inp_channels, emb_dim),
            nn.GroupNorm(1, emb_dim*4, eps=eps),
            nn.PReLU()  # 48
        ) 
        self.conv1 = nn.Sequential(
            nn.Conv2d(emb_dim*4+dim*8, emb_dim, ks, padding=padding),  # 48
            nn.GroupNorm(1, emb_dim, eps=eps),
            nn.PReLU()  # 48
        )
        # self.conv_v0 = nn.Conv2d(dim*8, dim*4, kernel_size=1)

        self.lipencoder = LipEncoderClassifier(num_speakers=2936, tcn_channels=256, lip_emb_dim=192)

        #  fusion conv
        self.fusion_conv0 = nn.Sequential(
            nn.Conv2d(dim*12, dim*4, kernel_size=1),
            nn.GroupNorm(1, dim*4, eps=eps),
            nn.PReLU()
        )

        self.fusion_conv1 = nn.Sequential(
            nn.Conv2d(dim*12, dim*4, kernel_size=1),
            nn.GroupNorm(1, dim*4, eps=eps),
            nn.PReLU()
        )
        self.blocks = nn.ModuleList([])
        for _ in range(num_layers):
            self.blocks.append(
                MultiRangeGridNetBlock(
                    emb_dim,
                    emb_ks,
                    emb_hs,
                    n_freqs,
                    lstm_hidden_units,
                    n_head=attn_n_head,
                    approx_qk_dim=attn_approx_qk_dim,
                    activation=activation,
                    eps=eps,
                )
            )
        self.deconv = nn.ConvTranspose2d(emb_dim, num_source * out_channels, ks, padding=padding)
        self.stft_decoder = STFTDecoder_1(win, hop_length, n_fft, in_chan=dim, n_src=num_source, kernel_size=3, stride=1, bias=bias)

    def forward(self, x, v, face=None):  # 
        # print(x.shape)
        # print(v[:, 0].shape)
        x= x.unsqueeze(1)
        mix_std_ = torch.std(x, dim=(1, 2), keepdim=True)  # [B, 1, 1]
        x = x / mix_std_

        # print(face.shape)  # torch.Size([B, 2, 50, 1, 112, 112])
        # print(face[:, 0].shape)
        logits0, lip_emb0 = self.lipencoder(face[:, 0].float()) # torch.Size([4, 50, 192]) 
        logits1, lip_emb1 = self.lipencoder(face[:, 1].float()) # torch.Size([4, 50, 192])
        # utt_emb0 = lip_emb0.mean(dim=1)  # [B, D]  ——
        # utt_emb1 = lip_emb1.mean(dim=1)  # [B, D]  ——
        # logits0 = self.classifier(utt_emb0)
        # logits1 = self.classifier(utt_emb1)
        logits = [logits0, logits1]

        feature_map = self.stft_encoder(x)  # torch.Size([4, 2, 251, 129])
        assert not torch.isnan(feature_map).any(), "NaN in stft_encoder output"
        B, C, T, F = feature_map.size()
        inp_enc_level1 = self.conv(feature_map) # torch.Size([4, 256, 251, 129])
        assert not torch.isnan(inp_enc_level1).any(), "NaN in conv"


        v00 = self.pre_v(v[:, 0])  # torch.Size([4, 192, 251, 129])   torch.Size([4, 192, 50])
        v01 = self.pre_v(v[:, 1])

        lip_emb0 = lip_emb0.transpose(1, 2)
        lip_emb0 = FF.interpolate(lip_emb0, size=T, mode='linear', align_corners=True)
        lip_emb0 = lip_emb0.unsqueeze(-1)
        lip_emb0 = lip_emb0.repeat(1, 1,1, F)

        lip_emb1 = lip_emb1.transpose(1, 2)
        lip_emb1 = FF.interpolate(lip_emb1, size=T, mode='linear', align_corners=True) 
        lip_emb1 = lip_emb1.unsqueeze(-1)
        lip_emb1 = lip_emb1.repeat(1, 1, 1, F)  # torch.Size([4, 256, 50])

        v00 = FF.interpolate(v00, size=T, mode='linear', align_corners=True)  # [B, C, 251]
        v00 = v00.unsqueeze(-1)  # [B, C, 251, 1]
        v00 = v00.repeat(1, 1, 1, F)  # [B, 256, 251, 129]  

        v01 = FF.interpolate(v01, size=T, mode='linear', align_corners=True)  # [B, C, 251]
        v01 = v01.unsqueeze(-1)  # [B, C, 251, 1]
        v01 = v01.repeat(1, 1, 1, F)  # [B, 256, 251, 129]  
        fusion0 = torch.cat([inp_enc_level1, v00, lip_emb0], dim=1)  # [B, 512, T, F] 192+192+192
        fusion1 = torch.cat([inp_enc_level1, v01, lip_emb1], dim=1)  # [B, 512, T, F]

        fusion0 = self.fusion_conv0(fusion0)  # nn.Conv2d(768, 192, 1)
        fusion1 = self.fusion_conv1(fusion1)  # [B, 192, T, F]
        # v0_all = torch.cat([v00, v01], dim=1)
        # for i in self.inter:
        #     inp_enc_level1, v0_all = i(inp_enc_level1, v0_all)

        batch = self.conv1(torch.cat([inp_enc_level1, fusion0, fusion1], dim=1))
        for ii in range(self.num_layers):
            batch = self.blocks[ii](batch)  # [B, -1, T, F]
        y = self.deconv(batch) # [B, n_srcs*2, T, F]
        assert not torch.isnan(y).any(), "NaN in est_sources"
        y = y.view([B, self.num_source, 2, T, F])
        
        source = self.stft_decoder(y, x)
        # print(source.shape)
        source = mix_std_ * source

        stft_out0 = self.stft_encoder(source[:,0])
        stft_out0 = torch.complex(stft_out0[:, 0], stft_out0[:, 1])
        # print(stft_out0.shape)
        stft_out1 = self.stft_encoder(source[:,1])
        stft_out1 = torch.complex(stft_out1[:, 0], stft_out1[:, 1])
        stft_out_spec = torch.stack([stft_out0, stft_out1], dim=1)


        


        return stft_out_spec, source, logits


if __name__ == '__main__':
    import os
    import glob
    from PIL import Image
    from torchvision import transforms
    transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.485], std=[0.229])  
        ])

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = tfgridnet_v2().to(device)
    # print(model)
    input_signal = torch.randn(1,32000).to(device)  
    v = torch.randn(1,2,768,50).to(device)
    mouth = torch.randn(1,2,50,1,112,112).to(device)
    # frames1 = frames1.unsqueeze(0).to(device)
    # summary(model, input_signal, device='cuda')
    import time
    start_time = time.time()
    # out1,out2, a,b,c, d,e,_,_,_ = model(input_signal, v, mouth)
    out1, out2, logits = model(input_signal, v, mouth)  # [1, T, 1, 112, 112]
    print(logits[0].shape)
    print(logits[1].shape)
    # criterion = nn.CrossEntropyLoss()
    # labels = torch.tensor([1]).to(device)
    # loss = criterion(logits, labels)
    # print(loss)
    # print(b.shape)  # torch.Size([2, 256, 251, 129])
    # print(c.shape) # torch.Size([2, 1, 251, 1]) mask
    end_time = time.time()
    print(out1.shape)
    print(out2.shape)

    # 输出信息
    print(":", sum(p.numel() for p in model.parameters() if p.requires_grad))
    print(":", out1.shape)
    print(f": {end_time - start_time:.4f} 秒")

