import torch
import torch.nn as nn
import torch.nn.functional as FF
from pdb import set_trace as stx
import inspect
from einops import rearrange
from torchsummary import summary
from typing import Optional
from collections import OrderedDict
from typing import Sequence, Tuple, Union, Dict, Optional, Tuple
from .lip_encoder import LipEncoderClassifier
import math
from typing import Dict, List, Optional, Tuple
from torch.nn import init
from torch.nn.parameter import Parameter
from abc import ABC, abstractmethod
import difflib
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor


def is_torch_complex_tensor(c):
    return not isinstance(c, ComplexTensor) and torch.is_complex(c)
def is_complex(c):
    return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
EPS = torch.finfo(torch.double).eps

def new_complex_like(
    ref: Union[torch.Tensor, ComplexTensor],
    real_imag: Tuple[torch.Tensor, torch.Tensor],
):
    if isinstance(ref, ComplexTensor):
        return ComplexTensor(*real_imag)
    elif is_torch_complex_tensor(ref):
        return torch.complex(*real_imag)
    else:
        raise ValueError(
            "Please update your PyTorch version to 1.9+ for complex support."
        )

class AbsSeparator(torch.nn.Module, ABC):
    @abstractmethod
    def forward(
        self,
        input: torch.Tensor,
        ilens: torch.Tensor,
        additional: Optional[Dict] = None,
    ) -> Tuple[Tuple[torch.Tensor], torch.Tensor, OrderedDict]:
        raise NotImplementedError

    def forward_streaming(
        self,
        input_frame: torch.Tensor,
        buffer=None,
    ):
        raise NotImplementedError

    @property
    @abstractmethod
    def num_spk(self):
        raise NotImplementedError

from packaging.version import parse as V

def get_layer(l_name, library=torch.nn):
    """Return layer object handler from library e.g. from torch.nn

    E.g. if l_name=="elu", returns torch.nn.ELU.

    Args:
        l_name (string): Case insensitive name for layer in library (e.g. .'elu').
        library (module): Name of library/module where to search for object handler
        with l_name e.g. "torch.nn".

    Returns:
        layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU)

    """

    all_torch_layers = [x for x in dir(torch.nn)]
    match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
    if len(match) == 0:
        close_matches = difflib.get_close_matches(
            l_name, [x.lower() for x in all_torch_layers]
        )
        raise NotImplementedError(
            "Layer with name {} not found in {}.\n Closest matches: {}".format(
                l_name, str(library), close_matches
            )
        )
    elif len(match) > 1:
        close_matches = difflib.get_close_matches(
            l_name, [x.lower() for x in all_torch_layers]
        )
        raise NotImplementedError(
            "Multiple matchs for layer with name {} not found in {}.\n "
            "All matches: {}".format(l_name, str(library), close_matches)
        )
    else:
        # valid
        layer_handler = getattr(library, match[0])
        return layer_handler

    

class GridNetBlock(nn.Module):
    def __getitem__(self, key):
        return getattr(self, key)

    def __init__(
        self,
        emb_dim,
        emb_ks,
        emb_hs,
        n_freqs,
        hidden_channels,
        n_head=4,
        approx_qk_dim=512,
        activation="prelu",
        eps=1e-5,
    ):
        super().__init__()

        # self.att_freq = nn.MultiheadAttention(emb_dim, n_head, batch_first=True)
        # self.norm_att_freq = nn.LayerNorm(emb_dim, eps=eps)
        # self.att_time = nn.MultiheadAttention(emb_dim, n_head, batch_first=True)
        # self.norm_att_time = nn.LayerNorm(emb_dim, eps=eps)

        in_channels = emb_dim * emb_ks

        self.intra_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_rnn = nn.LSTM(
            in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
        )

        self.inter_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_rnn = nn.LSTM(
            in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
        )

        E = math.ceil(
            approx_qk_dim * 1.0 / n_freqs
        )  # approx_qk_dim is only approximate
        assert emb_dim % n_head == 0
        for ii in range(n_head):
            self.add_module(
                "attn_conv_Q_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_K_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_V_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, emb_dim // n_head, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps),
                ),
            )
        self.add_module(
            "attn_concat_proj",
            nn.Sequential(
                nn.Conv2d(emb_dim, emb_dim, 1),
                get_layer(activation)(),
                LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
            ),
        )
        # self.FFN = FeedForward(emb_dim, 4, bias=True)  # 64*4  256

        self.emb_dim = emb_dim
        self.emb_ks = emb_ks
        self.emb_hs = emb_hs
        self.n_head = n_head

    def forward(self, x):
        """GridNetBlock Forward.

        Args:
            x: [B, C, T, Q]
            out: [B, C, T, Q]
        """
        B, C, old_T, old_Q = x.shape
        T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        x = FF.pad(x, (0, Q - old_Q, 0, T - old_T))

        # intra RNN
        input_ = x
        intra_rnn = self.intra_norm(input_)  # [B, C, T, Q]
        intra_rnn = (
            intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q)
        )  # [BT, C, Q]
        # branch 1
        intra_rnn = FF.unfold(
            intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
        )  # [BT, C*emb_ks, -1]
        intra_rnn = intra_rnn.transpose(1, 2)  # [BT, -1, C*emb_ks]
        print(intra_rnn.shape)
        intra_rnn, _ = self.intra_rnn(intra_rnn)  # [BT, -1, H]
        print(intra_rnn.shape)
        intra_rnn = intra_rnn.transpose(1, 2)  # [BT, H, -1]
        intra_rnn = self.intra_linear(intra_rnn)  # [BT, C, Q]
        intra_rnn = intra_rnn.view([B, T, C, Q])
        intra_rnn = intra_rnn.transpose(1, 2).contiguous()  # [B, C, T, Q]

        intra_rnn = intra_rnn + input_  # [B, C, T, Q]

        # inter RNN
        input_ = intra_rnn
        inter_rnn = self.inter_norm(input_)  # [B, C, T, F]
        inter_rnn = (
            inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        )  # [BF, C, T]

        # inter_rnn = inter_rnn.transpose(1, 2)
        # att_out_time, _ = self.att_time(inter_rnn, inter_rnn, inter_rnn)
        # att_out_time = att_out_time + inter_rnn
        # att_out_time = self.norm_att_time(att_out_time)
        # inter_rnn = att_out_time.transpose(1, 2)

        inter_rnn = FF.unfold(
            inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
        )  # [BF, C*emb_ks, -1]
        inter_rnn = inter_rnn.transpose(1, 2)  # [BF, -1, C*emb_ks]
        inter_rnn, _ = self.inter_rnn(inter_rnn)  # [BF, -1, H]
        inter_rnn = inter_rnn.transpose(1, 2)  # [BF, H, -1]
        inter_rnn = self.inter_linear(inter_rnn)  # [BF, C, T]
        inter_rnn = inter_rnn.view([B, Q, C, T])
        inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous()  # [B, C, T, Q]
        inter_rnn = inter_rnn + input_  # [B, C, T, Q]

        # attention
        inter_rnn = inter_rnn[..., :old_T, :old_Q]
        batch = inter_rnn

        all_Q, all_K, all_V = [], [], []
        for ii in range(self.n_head):
            all_Q.append(self["attn_conv_Q_%d" % ii](batch))  # [B, C, T, Q]
            all_K.append(self["attn_conv_K_%d" % ii](batch))  # [B, C, T, Q]
            all_V.append(self["attn_conv_V_%d" % ii](batch))  # [B, C, T, Q]

        Q = torch.cat(all_Q, dim=0)  # [B', C, T, Q]
        K = torch.cat(all_K, dim=0)  # [B', C, T, Q]
        V = torch.cat(all_V, dim=0)  # [B', C, T, Q]

        Q = Q.transpose(1, 2)
        Q = Q.flatten(start_dim=2)  # [B', T, C*Q]
        K = K.transpose(1, 2)
        K = K.flatten(start_dim=2)  # [B', T, C*Q]
        V = V.transpose(1, 2)  # [B', T, C, Q]
        old_shape = V.shape
        V = V.flatten(start_dim=2)  # [B', T, C*Q]
        emb_dim = Q.shape[-1]

        attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5)  # [B', T, T]
        attn_mat = FF.softmax(attn_mat, dim=2)  # [B', T, T]
        V = torch.matmul(attn_mat, V)  # [B', T, C*Q]

        V = V.reshape(old_shape)  # [B', T, C, Q]
        V = V.transpose(1, 2)  # [B', C, T, Q]
        emb_dim = V.shape[1]

        batch = V.view([self.n_head, B, emb_dim, old_T, -1])  # [n_head, B, C, T, Q])
        batch = batch.transpose(0, 1)  # [B, n_head, C, T, Q])
        batch = batch.contiguous().view(
            [B, self.n_head * emb_dim, old_T, -1]
        )  # [B, C, T, Q])
        batch = self["attn_concat_proj"](batch)  # [B, C, T, Q])

        out = batch + inter_rnn

        # out1 = self.FFN(out)
        # assert not torch.isnan(out1).any(), "NaN detected in FFN output"
        # out = out1 + out
        return out


class MultiRangeGridNetBlock(nn.Module):
    def __getitem__(self, key):
        return getattr(self, key)

    def __init__(
        self,
        emb_dim,
        emb_ks,
        emb_hs,
        n_freqs,
        hidden_channels,
        n_head=4,
        approx_qk_dim=512,
        activation="prelu",
        eps=1e-5,
    ):
        super().__init__()
        
        self.intra_branch1_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_branch1_rnn = nn.LSTM(
            emb_dim, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_branch1_linear = nn.Conv1d(
            hidden_channels * 2, emb_dim, kernel_size=1
        )
        
        self.intra_branch2_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_branch2_rnn = nn.LSTM(
            emb_dim * 4, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_branch2_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 4, stride=1
        )
        
        self.intra_branch3_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.intra_branch3_rnn = nn.LSTM(
            emb_dim * 8, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.intra_branch3_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 8, stride=1
        )
        
        self.intra_fusion_conv = nn.Conv2d(emb_dim * 3, emb_dim, kernel_size=1)
        self.intra_fusion_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_branch1_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_branch1_rnn = nn.LSTM(
            emb_dim, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_branch1_linear = nn.Conv1d(
            hidden_channels * 2, emb_dim, kernel_size=1
        )
        
        self.inter_branch2_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_branch2_rnn = nn.LSTM(
            emb_dim * 4, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_branch2_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 4, stride=1
        )
        
        self.inter_branch3_norm = LayerNormalization4D(emb_dim, eps=eps)
        self.inter_branch3_rnn = nn.LSTM(
            emb_dim * 8, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        self.inter_branch3_linear = nn.ConvTranspose1d(
            hidden_channels * 2, emb_dim, 8, stride=1
        )
        
        self.inter_fusion_conv = nn.Conv2d(emb_dim * 3, emb_dim, kernel_size=1)
        self.inter_fusion_norm = LayerNormalization4D(emb_dim, eps=eps)
        
        E = math.ceil(approx_qk_dim * 1.0 / n_freqs)
        assert emb_dim % n_head == 0
        for ii in range(n_head):
            self.add_module(
                "attn_conv_Q_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_K_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                "attn_conv_V_%d" % ii,
                nn.Sequential(
                    nn.Conv2d(emb_dim, emb_dim // n_head, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps),
                ),
            )
        self.add_module(
            "attn_concat_proj",
            nn.Sequential(
                nn.Conv2d(emb_dim, emb_dim, 1),
                get_layer(activation)(),
                LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
            ),
        )
        
        self.emb_dim = emb_dim
        self.n_head = n_head
        self.emb_ks = 8
        self.emb_hs = 1

    def forward(self, x):
        B, C, old_T, old_Q = x.shape
        
        T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        x = FF.pad(x, (0, Q - old_Q, 0, T - old_T))
        
        residual = x
        
        intra_input = x
        
        intra_b1 = self.intra_branch1_norm(intra_input)
        intra_b1 = intra_b1.permute(0, 2, 3, 1).contiguous().view(B * T, Q, C)
        intra_b1, _ = self.intra_branch1_rnn(intra_b1)  # [BT, Q, H*2]
        intra_b1 = intra_b1.transpose(1, 2)  # [BT, H*2, Q]
        intra_b1 = self.intra_branch1_linear(intra_b1)  # [BT, C, Q]
        intra_b1 = intra_b1.view(B, T, C, Q).permute(0, 2, 1, 3).contiguous()
        
        intra_b2 = self.intra_branch2_norm(intra_input)
        intra_b2 = intra_b2.transpose(1, 2).contiguous().view(B * T, C, Q)
        intra_b2 = FF.unfold(intra_b2[..., None], (4, 1), stride=(1, 1))
        intra_b2 = intra_b2.transpose(1, 2)  # [BT, -1, C*4]
        intra_b2, _ = self.intra_branch2_rnn(intra_b2)
        intra_b2 = intra_b2.transpose(1, 2)
        intra_b2 = self.intra_branch2_linear(intra_b2)  # [BT, C, Q]
        intra_b2 = intra_b2.view(B, T, C, Q).transpose(1, 2).contiguous()
        
        intra_b3 = self.intra_branch3_norm(intra_input)
        intra_b3 = intra_b3.transpose(1, 2).contiguous().view(B * T, C, Q)
        intra_b3 = FF.unfold(intra_b3[..., None], (8, 1), stride=(1, 1))
        intra_b3 = intra_b3.transpose(1, 2)  # [BT, -1, C*8]
        intra_b3, _ = self.intra_branch3_rnn(intra_b3)
        intra_b3 = intra_b3.transpose(1, 2)
        intra_b3 = self.intra_branch3_linear(intra_b3)  # [BT, C, Q]
        intra_b3 = intra_b3.view(B, T, C, Q).transpose(1, 2).contiguous()
        
        intra_concat = torch.cat([intra_b1, intra_b2, intra_b3], dim=1)
        intra_fused = self.intra_fusion_conv(intra_concat)
        intra_fused = self.intra_fusion_norm(intra_fused)
        
        intra_output = intra_fused + intra_input
        
        inter_input = intra_output
        
        inter_b1 = self.inter_branch1_norm(inter_input)
        inter_b1 = inter_b1.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        inter_b1 = inter_b1.transpose(1, 2)
        inter_b1, _ = self.inter_branch1_rnn(inter_b1)  # [BQ, T, H*2]
        inter_b1 = inter_b1.transpose(1, 2)  # [BQ, H*2, T]
        inter_b1 = self.inter_branch1_linear(inter_b1)  # [BQ, C, T]
        inter_b1 = inter_b1.view(B, Q, C, T).permute(0, 2, 3, 1).contiguous()
        
        inter_b2 = self.inter_branch2_norm(inter_input)
        inter_b2 = inter_b2.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        inter_b2 = FF.unfold(inter_b2[..., None], (4, 1), stride=(1, 1))
        inter_b2 = inter_b2.transpose(1, 2)  # [BQ, -1, C*4]
        inter_b2, _ = self.inter_branch2_rnn(inter_b2)
        inter_b2 = inter_b2.transpose(1, 2)
        inter_b2 = self.inter_branch2_linear(inter_b2)  # [BQ, C, T]
        inter_b2 = inter_b2.view(B, Q, C, T).permute(0, 2, 3, 1).contiguous()
        
        inter_b3 = self.inter_branch3_norm(inter_input)
        inter_b3 = inter_b3.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        inter_b3 = FF.unfold(inter_b3[..., None], (8, 1), stride=(1, 1))
        inter_b3 = inter_b3.transpose(1, 2)  # [BQ, -1, C*8]
        inter_b3, _ = self.inter_branch3_rnn(inter_b3)
        inter_b3 = inter_b3.transpose(1, 2)
        inter_b3 = self.inter_branch3_linear(inter_b3)  # [BQ, C, T]
        inter_b3 = inter_b3.view(B, Q, C, T).permute(0, 2, 3, 1).contiguous()
    
        inter_concat = torch.cat([inter_b1, inter_b2, inter_b3], dim=1)
        inter_fused = self.inter_fusion_conv(inter_concat)
        inter_fused = self.inter_fusion_norm(inter_fused)
        
        inter_output = inter_fused + inter_input
        
        inter_output = inter_output[..., :old_T, :old_Q]
        
        batch = inter_output
        
        all_Q, all_K, all_V = [], [], []
        for ii in range(self.n_head):
            all_Q.append(self["attn_conv_Q_%d" % ii](batch))
            all_K.append(self["attn_conv_K_%d" % ii](batch))
            all_V.append(self["attn_conv_V_%d" % ii](batch))

        Q = torch.cat(all_Q, dim=0)
        K = torch.cat(all_K, dim=0)
        V = torch.cat(all_V, dim=0)

        Q = Q.transpose(1, 2)
        Q = Q.flatten(start_dim=2)
        K = K.transpose(1, 2)
        K = K.flatten(start_dim=2)
        V = V.transpose(1, 2)
        old_shape = V.shape
        V = V.flatten(start_dim=2)
        emb_dim = Q.shape[-1]

        attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5)
        attn_mat = FF.softmax(attn_mat, dim=2)
        V = torch.matmul(attn_mat, V)

        V = V.reshape(old_shape)
        V = V.transpose(1, 2)
        emb_dim = V.shape[1]

        batch = V.view([self.n_head, B, emb_dim, old_T, -1])
        batch = batch.transpose(0, 1)
        batch = batch.contiguous().view([B, self.n_head * emb_dim, old_T, -1])
        batch = self["attn_concat_proj"](batch)

        out = batch + residual[..., :old_T, :old_Q]
        return out

class LayerNormalization4D(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        param_size = [1, input_dimension, 1, 1]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        if x.ndim == 4:
            _, C, _, _ = x.shape
            stat_dim = (1,)
        else:
            raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,F]
        std_ = torch.sqrt(
            x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
        )  # [B,1,T,F]
        x_hat = ((x - mu_) / std_) * self.gamma + self.beta
        return x_hat


class LayerNormalization4DCF(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        assert len(input_dimension) == 2
        param_size = [1, input_dimension[0], 1, input_dimension[1]]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        if x.ndim == 4:
            stat_dim = (1, 3)
        else:
            raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,1]
        std_ = torch.sqrt(
            x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
        )  # [B,1,T,F]
        x_hat = ((x - mu_) / std_) * self.gamma + self.beta
        return x_hat


def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class BaseEncoder(nn.Module):
    def unsqueeze_to_3D(self, x: torch.Tensor):
        if x.ndim == 1:
            return x.reshape(1, 1, -1)
        elif x.ndim == 2:
            return x.unsqueeze(1)
        else:
            return x

    def unsqueeze_to_2D(self, x: torch.Tensor):
        if x.ndim == 1:
            return x.reshape(1, -1)
        elif len(s := x.shape) == 3:
            assert s[1] == 1
            return x.reshape(s[0], -1)
        else:
            return x

    def pad(self, x: torch.Tensor, lcm: int):
        values_to_pad = int(x.shape[-1]) % lcm
        if values_to_pad:
            appropriate_shape = x.shape
            padding = torch.zeros(
                list(appropriate_shape[:-1]) + [lcm - values_to_pad],
                dtype=x.dtype,
                device=x.device,
            )
            padded_x = torch.cat([x, padding], dim=-1)
            return padded_x
        else:
            return x

    def get_out_chan(self):
        return self.out_chan

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def get_config(self):
        encoder_args = {}

        for k, v in (self.__dict__).items():
            if not k.startswith("_") and k != "training":
                if not inspect.ismethod(v):
                    encoder_args[k] = v

        return encoder_args

class STFTEncoder(BaseEncoder):
    def __init__(
        self,
        win_length: int = 512,         
        hop_length: int = 128,         
        n_fft: int = 256,              
        bias: bool = False,
        *args,
        **kwargs,
    ):
        super(STFTEncoder, self).__init__()
        self.win_length = win_length
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.bias = bias

        hann = torch.hann_window(self.win_length)
        sqrt_hann = torch.sqrt(hann)
        self.register_buffer("window", sqrt_hann, persistent=False)

    def forward(self, x: torch.Tensor):
        x = self.unsqueeze_to_2D(x)

        spec = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window.to(x.device),
            return_complex=True,
        )

        spec = torch.stack([spec.real, spec.imag], 1).transpose(2, 3).contiguous()
        return spec


class BaseDecoder(nn.Module):
    def pad_to_input_length(self, separated_audio, input_frames):
        output_frames = separated_audio.shape[-1]
        return nn.functional.pad(separated_audio, [0, input_frames - output_frames])

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def get_config(self):
        encoder_args = {}

        for k, v in (self.__dict__).items():
            if not k.startswith("_") and k != "training":
                if not inspect.ismethod(v):
                    encoder_args[k] = v

        return encoder_args

class STFTDecoder_1(BaseDecoder):
    def __init__(
        self,
        win_length: int = 512,
        hop_length: int = 128,
        n_fft: int = 256,
        in_chan: int = 2,
        n_src: int = 2,
        kernel_size: int = -1,
        stride: int = 1,
        bias: bool = False,
        *args,
        **kwargs,
    ):
        super(STFTDecoder_1, self).__init__()
        self.win_length = win_length
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.in_chan = in_chan
        self.n_src = n_src
        self.kernel_size = kernel_size
        self.padding = (self.kernel_size - 1) // 2 if self.kernel_size > 0 else 0
        self.stride = stride
        self.bias = bias

        hann = torch.hann_window(self.win_length)
        sqrt_hann = torch.sqrt(hann)
        self.register_buffer("window", sqrt_hann, persistent=False)

    def forward(self, x, input_shape):
        B, n_src, N, T, F = x.shape
        batch_size, length = input_shape.shape[0], input_shape.shape[-1]
        
        x = x.view(B * n_src, N, T, F)  # (B*n_src, 2, T, F)
        spec = torch.complex(x[:, 0], x[:, 1])  # complex spectrum
        spec = spec.transpose(1, 2).contiguous()  # (B*n_src, F, T)

        output = torch.istft(
            spec,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window.to(x.device),
            length=length,
        )  # (B*n_src, L)

        output = output.view(batch_size, self.n_src, length)  # (B, n_src, L)

        return output


def _get_activation_fn(activation):
    if activation == "relu":
        return FF.relu
    elif activation == "gelu":
        return FF.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))


class multi_OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=2, embed_dim=64, bias=False):
        super(multi_OverlapPatchEmbed, self).__init__()

        # self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
        self.conv_1x1 = nn.Conv2d(in_c, embed_dim, kernel_size=1, dilation=1, padding=0)
        self.conv_3x3_d1 = nn.Conv2d(in_c, embed_dim, kernel_size=3, dilation=1, padding=1)
        self.conv_3x3_d2 = nn.Conv2d(in_c, embed_dim, kernel_size=3, dilation=2, padding=2)
        self.conv_3x3_d3 = nn.Conv2d(in_c, embed_dim, kernel_size=3, dilation=3, padding=3)

    def forward(self, x):
        out1 = self.conv_1x1(x)
        out2 = self.conv_3x3_d1(x)
        out3 = self.conv_3x3_d2(x)
        out4 = self.conv_3x3_d3(x)
        out = torch.cat([out1, out2, out3, out4], dim=1)  # 256
        return out



class multi_OverlapPatchEmbed_v(nn.Module):
    def __init__(self, in_c=1024, embed_dim=64, bias=False):
        super(multi_OverlapPatchEmbed_v, self).__init__()

        # self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
        self.conv_1x1 = nn.Conv1d(in_c, embed_dim, kernel_size=1, dilation=1, padding=0)
        self.conv_3x3_d1 = nn.Conv1d(in_c, embed_dim, kernel_size=3, dilation=1, padding=1)
        self.conv_3x3_d2 = nn.Conv1d(in_c, embed_dim, kernel_size=3, dilation=2, padding=2)
        self.conv_3x3_d3 = nn.Conv1d(in_c, embed_dim, kernel_size=3, dilation=3, padding=3)

    def forward(self, x):
        out1 = self.conv_1x1(x)
        out2 = self.conv_3x3_d1(x)
        out3 = self.conv_3x3_d2(x)
        out4 = self.conv_3x3_d3(x)
        out = torch.cat([out1, out2, out3, out4], dim=1)  # 256
        return out

class MultiRangeConv2d(nn.Module):
    def __init__(self, in_channels, out_channels_each):
        super(MultiRangeConv2d, self).__init__()
        
        self.conv_1x1 = nn.Conv2d(in_channels, out_channels_each, kernel_size=1, dilation=1, padding=0)
        self.conv_3x3_d1 = nn.Conv2d(in_channels, out_channels_each, kernel_size=3, dilation=1, padding=1)
        self.conv_3x3_d2 = nn.Conv2d(in_channels, out_channels_each, kernel_size=3, dilation=2, padding=2)
        self.conv_3x3_d3 = nn.Conv2d(in_channels, out_channels_each, kernel_size=3, dilation=3, padding=3)

    def forward(self, x):
        out1 = self.conv_1x1(x)
        out2 = self.conv_3x3_d1(x)
        out3 = self.conv_3x3_d2(x)
        out4 = self.conv_3x3_d3(x)

        out = torch.cat([out1, out2, out3, out4], dim=1)  # concatenate on channel dim
        return out

class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps

        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_tensors
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)

        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
            dim=0), None

class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class SEWeightModule(nn.Module):  # SE block
    def __init__(self, channels, reduction=16):
        super(SEWeightModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)
        return weight




from typing import *
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Parameter, init
from einops import rearrange
from einops.layers.torch import Rearrange
import math 
import difflib
from torch.nn import *
import math
import numpy as np

class LayerNorm(nn.LayerNorm):

    def __init__(self, seq_last: bool, **kwargs) -> None:
        # """
        # Arg s:
        #     seq_last (bool): whether the sequence dim is the last dim
        # """
        super().__init__(**kwargs)
        self.seq_last = seq_last

    def forward(self, input: Tensor) -> Tensor:
        if self.seq_last:
            input = input.transpose(-1, 1)  # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H]
        o = super().forward(input)
        if self.seq_last:
            o = o.transpose(-1, 1)
        return o


class GlobalLayerNorm(nn.Module):

    def __init__(self, dim_hidden: int, seq_last: bool, eps: float = 1e-5) -> None:
        super().__init__()
        self.dim_hidden = dim_hidden
        self.seq_last = seq_last
        self.eps = eps

        if seq_last:
            self.weight = Parameter(torch.empty([dim_hidden, 1]))
            self.bias = Parameter(torch.empty([dim_hidden, 1]))
        else:
            self.weight = Parameter(torch.empty([dim_hidden]))
            self.bias = Parameter(torch.empty([dim_hidden]))
        init.ones_(self.weight)
        init.zeros_(self.bias)

    def forward(self, input: Tensor) -> Tensor:
        # """
        # Args:
        #     input (Tensor): shape [B, Seq, H] or [B, H, Seq]
        # """
        var, mean = torch.var_mean(input, dim=(1, 2), unbiased=False, keepdim=True)

        output = (input - mean) / torch.sqrt(var + self.eps)
        output = output * self.weight + self.bias
        return output

    def extra_repr(self) -> str:
        return '{dim_hidden}, seq_last={seq_last}, eps={eps}'.format(**self.__dict__)


class LayerNormalization4D(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        param_size = [1, input_dimension, 1, 1]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        with torch.autocast(device_type = "cuda", enabled = False):
            if x.ndim == 4:
                _, C, _, _ = x.shape
                stat_dim = (1,)
            else:
                raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
            mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,F]
            std_ = torch.sqrt(torch.clamp(x.var(dim=stat_dim, unbiased=False, keepdim=True), self.eps))  # [B,1,T,F]
            x_hat = (x - mu_) / (std_ )
                
            x_hat = x_hat * self.gamma + self.beta

            return x_hat


class LayerNormalization4DCF(nn.Module):
    def __init__(self, input_dimension, eps=1e-5):
        super().__init__()
        assert len(input_dimension) == 2
        param_size = [1, input_dimension[0], 1, input_dimension[1]]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
        init.ones_(self.gamma)
        init.zeros_(self.beta)
        self.eps = eps

    def forward(self, x):
        with torch.autocast(device_type = "cuda", enabled = False):
            
            if x.ndim == 4:
                stat_dim = (1, 3)
            else:
                raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
            mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B,1,T,1]
            std_ = torch.sqrt(torch.clamp(x.var(dim=stat_dim, unbiased=False, keepdim=True), self.eps))  # [B,1,T,F]
            x_hat = (x - mu_) / (std_)
            
            x_hat = x_hat * self.gamma + self.beta
            
            
            return x_hat

import torch as th
import torch.nn as nn


class ResBlock(nn.Module):
    """
    Resnet block for speaker encoder to obtain speaker embedding
    ref to 
        https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py
        and
        https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py
    """
    def __init__(self, in_dims, out_dims):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)
        self.batch_norm1 = nn.BatchNorm1d(out_dims)
        self.batch_norm2 = nn.BatchNorm1d(out_dims)
        self.prelu1 = nn.PReLU()
        self.prelu2 = nn.PReLU()
        self.maxpool = nn.MaxPool1d(3)
        if in_dims != out_dims:
            self.downsample = True
            self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
        else:
            self.downsample = False

    def forward(self, x):
        y = self.conv1(x)
        y = self.batch_norm1(y)
        y = self.prelu1(y)
        y = self.conv2(y)
        y = self.batch_norm2(y)
        if self.downsample:
            y += self.conv_downsample(x)
        else:
            y += x
        y = self.prelu2(y)
        return self.maxpool(y)

class ChannelwiseLayerNorm(nn.LayerNorm):
    """
    Channel-wise layer normalization based on nn.LayerNorm
    Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
    Output: 3D tensor with same shape
    """

    def __init__(self, *args, **kwargs):
        super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs)

    def forward(self, x):
        if x.dim() != 3:
            raise RuntimeError("{} requires a 3D tensor input".format(
                self.__name__))
        x = th.transpose(x, 1, 2)
        x = super().forward(x)
        x = th.transpose(x, 1, 2)
        return x


class tfgridnet_v2(nn.Module):
    def __init__(self, 
        num_layers = 6,  # 12
        win = 256,
        hop_length = 128,
        n_fft = 256,
        inp_channels=2, 
        out_channels=2, 
        dim = 48,
        bias = False,
        vpre_channels=768,
        num_source = 2,
        lstm_hidden_units=128,
        attn_n_head=4,
        attn_approx_qk_dim=512,
        emb_dim=48,
        emb_ks=4,
        emb_hs=1,
        activation="prelu",
        eps=1.0e-5,
        
    ):

        super(tfgridnet_v2, self).__init__()
        self.num_source = num_source
        self.win = win
        self.hop_length = hop_length
        self.num_layers = num_layers
        assert n_fft % 2 == 0
        n_freqs = n_fft // 2 + 1

        self.pre_v = multi_OverlapPatchEmbed_v(vpre_channels, dim)  # 192
        self.stft_encoder = STFTEncoder(win, hop_length, n_fft, bias)
        # self.patch_embed = multi_OverlapPatchEmbed(inp_channels, dim)  # 192  #
        t_ksize = 3
        ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
        self.conv = nn.Sequential(
            # nn.Conv2d(inp_channels, emb_dim, ks, padding=padding),  # 48
            multi_OverlapPatchEmbed(inp_channels, emb_dim),
            nn.GroupNorm(1, emb_dim*4, eps=eps),
            nn.PReLU()  # 48
        ) 
        self.conv1 = nn.Sequential(
            nn.Conv2d(emb_dim*4+dim*8, emb_dim, ks, padding=padding),  # 48
            nn.GroupNorm(1, emb_dim, eps=eps),
            nn.PReLU()  # 48
        )
        # self.conv_v0 = nn.Conv2d(dim*8, dim*4, kernel_size=1)

        self.lipencoder = LipEncoderClassifier(num_speakers=2936, tcn_channels=256, lip_emb_dim=192)

        self.fusion_conv0 = nn.Sequential(
            nn.Conv2d(dim*12, dim*4, kernel_size=1),
            nn.GroupNorm(1, dim*4, eps=eps),
            nn.PReLU()
        )

        self.fusion_conv1 = nn.Sequential(
            nn.Conv2d(dim*12, dim*4, kernel_size=1),
            nn.GroupNorm(1, dim*4, eps=eps),
            nn.PReLU()
        )
        self.blocks = nn.ModuleList([])
        for _ in range(num_layers):
            self.blocks.append(
                MultiRangeGridNetBlock(
                    emb_dim,
                    emb_ks,
                    emb_hs,
                    n_freqs,
                    lstm_hidden_units,
                    n_head=attn_n_head,
                    approx_qk_dim=attn_approx_qk_dim,
                    activation=activation,
                    eps=eps,
                )
            )
        self.deconv = nn.ConvTranspose2d(emb_dim, num_source * out_channels, ks, padding=padding)
        self.stft_decoder = STFTDecoder_1(win, hop_length, n_fft, in_chan=dim, n_src=num_source, kernel_size=3, stride=1, bias=bias)

    def forward(self, x, v, face=None): 
        # print(x.shape)
        # print(v[:, 0].shape)
        x= x.unsqueeze(1)
        mix_std_ = torch.std(x, dim=(1, 2), keepdim=True)  # [B, 1, 1]
        x = x / mix_std_

        logits0, lip_emb0 = self.lipencoder(face[:, 0].float()) # torch.Size([4, 50, 192]) 
        logits1, lip_emb1 = self.lipencoder(face[:, 1].float()) # torch.Size([4, 50, 192])
        logits = [logits0, logits1]

        feature_map = self.stft_encoder(x)  # torch.Size([4, 2, 251, 129])
        assert not torch.isnan(feature_map).any(), "NaN in stft_encoder output"
        B, C, T, F = feature_map.size()
        inp_enc_level1 = self.conv(feature_map) # torch.Size([4, 256, 251, 129])
        assert not torch.isnan(inp_enc_level1).any(), "NaN in conv"

        
        v00 = self.pre_v(v[:, 0])  # torch.Size([4, 192, 251, 129])   torch.Size([4, 192, 50])
        v01 = self.pre_v(v[:, 1])

        lip_emb0 = lip_emb0.transpose(1, 2)
        lip_emb0 = FF.interpolate(lip_emb0, size=T, mode='linear', align_corners=True)
        lip_emb0 = lip_emb0.unsqueeze(-1)
        lip_emb0 = lip_emb0.repeat(1, 1, 1, F)

        lip_emb1 = lip_emb1.transpose(1, 2)
        lip_emb1 = FF.interpolate(lip_emb1, size=T, mode='linear', align_corners=True) 
        lip_emb1 = lip_emb1.unsqueeze(-1)
        lip_emb1 = lip_emb1.repeat(1, 1, 1, F)  # torch.Size([4, 192, 50])

        v00 = FF.interpolate(v00, size=T, mode='linear', align_corners=True)  # [B, C, 251]
        v00 = v00.unsqueeze(-1)  # [B, C, 251, 1]
        v00 = v00.repeat(1, 1, 1, F)  # [B, 192, 251, 129] 

        v01 = FF.interpolate(v01, size=T, mode='linear', align_corners=True)  # [B, C, 251]
        v01 = v01.unsqueeze(-1)  # [B, C, 251, 1]
        v01 = v01.repeat(1, 1, 1, F)  # [B, 192, 251, 129]



        fusion0 = torch.cat([inp_enc_level1, v00, lip_emb0], dim=1)  # [B, 512, T, F] 192+192+192  # 576
        fusion1 = torch.cat([inp_enc_level1, v01, lip_emb1], dim=1)  # [B, 512, T, F] 192+192+192  # 576

        fusion0 = self.fusion_conv0(fusion0)  # nn.Conv2d(768, 192, 1)
        fusion1 = self.fusion_conv1(fusion1)  # [B, 192, T, F]
       
        batch = self.conv1(torch.cat([inp_enc_level1, fusion0, fusion1], dim=1))
        for ii in range(self.num_layers):
            batch = self.blocks[ii](batch)  # [B, -1, T, F]
        y = self.deconv(batch) # [B, n_srcs*2, T, F]
        assert not torch.isnan(y).any(), "NaN in est_sources"
        y = y.view([B, self.num_source, 2, T, F])
        
        source = self.stft_decoder(y, x)
        # print(source.shape)
        source = mix_std_ * source

        stft_out0 = self.stft_encoder(source[:,0])
        stft_out0 = torch.complex(stft_out0[:, 0], stft_out0[:, 1])
        # print(stft_out0.shape)
        stft_out1 = self.stft_encoder(source[:,1])
        stft_out1 = torch.complex(stft_out1[:, 0], stft_out1[:, 1])
        stft_out_spec = torch.stack([stft_out0, stft_out1], dim=1)


        return stft_out_spec, source, logits

import sys
import os
sys.path.append("/auto_vsr_pretrain_model/auto_avsr_av")
import argparse
parser = argparse.ArgumentParser()
args, _ = parser.parse_known_args(args=[])

from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.nets_utils import (
    make_non_pad_mask
)
from espnet.nets.pytorch_backend.transformer.mask import target_mask
from espnet.nets.pytorch_backend.e2e_asr_conformer_av import E2E as E2E_av
from pytorch_lightning import LightningModule
from datamodule.transforms import TextTransform
from espnet.nets.scorers.ctc import CTCPrefixScorer
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.scorers.length_bonus import LengthBonus
from argparse import Namespace

def th_accuracy(pad_outputs, pad_targets, ignore_label):
    """
    Args:
        pad_outputs: Tensor of shape (B * T, D)
        pad_targets: LongTensor of shape (B * T)
    Returns:
        float: accuracy ignoring ignore_label
    """
    pad_pred = pad_outputs.argmax(dim=-1)  # shape: (B * T,)

    mask = pad_targets != ignore_label
    correct = (pad_pred[mask] == pad_targets[mask]).sum()
    total = mask.sum()

    if total == 0:
        return 0.0  
    else:
        return float(correct) / float(total)

config_dict = {
    "adim": 768,
    "aheads": 12,
    "eunits": 3072,
    "elayers": 12,
    "transformer_input_layer": "conv3d",
    "dropout_rate": 0.1,
    "transformer_attn_dropout_rate": 0.1,
    "transformer_encoder_attn_layer_type": "rel_mha",
    "macaron_style": True,
    "use_cnn_module": True,
    "cnn_module_kernel": 31,
    "zero_triu": False,
    "a_upsample_ratio": 1,
    "relu_type": "swish",
    "ddim": 768,
    "dheads": 12,
    "dunits": 3072,
    "dlayers": 6,
    "lsm_weight": 0.1,
    "transformer_length_normalized_loss": False,
    "mtlalpha": 0.1,
    "ctc_type": "builtin",
    "rel_pos_type": "latest",

    "aux_adim": 768,
    "aux_aheads": 12,
    "aux_eunits": 3072,
    "aux_elayers": 12,
    "aux_transformer_input_layer": "conv1d",
    "aux_dropout_rate": 0.1,
    "aux_transformer_attn_dropout_rate": 0.1,
    "aux_transformer_encoder_attn_layer_type": "rel_mha",
    "aux_macaron_style": True,
    "aux_use_cnn_module": True,
    "aux_cnn_module_kernel": 31,
    "aux_zero_triu": False,
    "aux_a_upsample_ratio": 1,
    "aux_relu_type": "swish",
    "aux_dunits": 3072,
    "aux_dlayers": 6,
    "aux_lsm_weight": 0.1,
    "aux_transformer_length_normalized_loss": False,
    "aux_mtlalpha": 0.1,
    "aux_ctc_type": "builtin",
    "aux_rel_pos_type": "latest",

    "fusion_hdim": 8192,
    "fusion_norm": "batchnorm",
}

args = Namespace(**config_dict)

class MLPHead(torch.nn.Module):
    def __init__(self, idim, hdim, odim, norm="batchnorm"):
        super(MLPHead, self).__init__()
        self.norm = norm

        self.fc1 = torch.nn.Linear(idim, hdim)
        if norm == "batchnorm":
            self.bn1 = torch.nn.BatchNorm1d(hdim)
        elif norm == "layernorm":
            self.norm1 = torch.nn.LayerNorm(hdim)
        self.nonlin1 = torch.nn.ReLU(inplace=True)
        self.fc2 = torch.nn.Linear(hdim, odim)

    def forward(self, x):
        x = self.fc1(x)
        if self.norm == "batchnorm":
            x = self.bn1(x.transpose(1, 2)).transpose(1, 2)
        elif self.norm == "layernorm":
            x = self.norm1(x)
        x = self.nonlin1(x)
        x = self.fc2(x)
        return x

def get_beam_search_decoder_av(model, token_list, ctc_weight=0.1, beam_size=40):
    scorers = {
        "decoder": model.decoder,
        "ctc": CTCPrefixScorer(model.ctc, model.eos),
        "length_bonus": LengthBonus(len(token_list)),
        "lm": None
    }

    weights = {
        "decoder": 1.0 - ctc_weight,
        "ctc": ctc_weight,
        "lm": 0.0,
        "length_bonus": 0.0,
    }

    return BatchBeamSearch(
        beam_size=beam_size,
        vocab_size=len(token_list),
        weights=weights,
        scorers=scorers,
        sos=model.sos,
        eos=model.eos,
        token_list=token_list,
        pre_beam_score_key=None if ctc_weight == 1.0 else "decoder",
    )

class ModelModule(LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.save_hyperparameters(args)

        # self.modality = args.modality
        self.adim = args.adim
        self.fusion = args.fusion_norm
        self.sos = self.adim - 1
        self.eos = self.adim - 1
        self.text_transform = TextTransform()
        self.token_list = self.text_transform.token_list
        self.ignore_id = -1

        self.model = E2E_av(len(self.token_list), args, ignore_id=-1)
        self.fusion = MLPHead(
            idim=self.adim + self.adim,
            hdim=8192,
            odim=self.adim,
            norm=self.fusion,
        )

    def forward(self, x, a, label=None):
        a, _ = self.model.aux_encoder(a, None)  # torch.Size([4, 50, 768])
        f = self.fusion(torch.cat((x, a), dim=-1)) # torch.Size([4, 50, 768]
        return x, a, f

    def forward_predicted_v_embed(self, x, a, label=None):
        self.beam_search = get_beam_search_decoder_av(self.model, self.token_list)
        a, _ = self.model.aux_encoder(a, None)  # torch.Size([4, 50, 768])
        f = self.fusion(torch.cat((x, a), dim=-1)) # torch.Size([4, 50, 768]
        audiovisual_feat = f.squeeze(0)
        nbest_hyps = self.beam_search(audiovisual_feat)
        nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]]
        predicted_token_id = torch.tensor(list(map(int, nbest_hyps[0]["yseq"][1:])))
        predicted = self.text_transform.post_process(predicted_token_id).replace("<eos>", "")
        return predicted

    def forward_predicted(self, video, audio): 
        # self.beam_search = get_beam_search_decoder_av(self.model, self.token_list)
        video_feat, _ = self.model.encoder(video.unsqueeze(0), None)
        # video_feat = video_feat[:, :50]
        audio_feat, _ = self.model.aux_encoder(audio, None)
        # audio_feat = audio_feat[:, :50]
        audiovisual_feat = self.model.fusion(torch.cat((video_feat, audio_feat), dim=-1))

        return audiovisual_feat
    
   

import torch
import torch.nn.functional as F

def si_snr(est_source, true_source, eps=1e-8):
    """
    Scale-Invariant Signal-to-Noise Ratio (SI-SNR)

    Args:
        est_source (Tensor): 估计源 [B, T]
        true_source (Tensor): 真实源 [B, T]
        eps (float): 避免除以零的微小值

    Returns:
        Tensor: 每个样本的 SI-SNR [B]
    """
    B, T = true_source.size()

    est_source = est_source - torch.mean(est_source, dim=1, keepdim=True)
    true_source = true_source - torch.mean(true_source, dim=1, keepdim=True)

    dot = torch.sum(est_source * true_source, dim=1, keepdim=True)  # [B, 1]
    s_target = dot * true_source / (torch.sum(true_source ** 2, dim=1, keepdim=True) + eps)  # [B, T]
    e_noise = est_source - s_target

    target_power = torch.sum(s_target ** 2, dim=1)  # [B]
    noise_power = torch.sum(e_noise ** 2, dim=1) + eps  # [B]

    si_snr_val = 10 * torch.log10(target_power / noise_power + eps)  # [B]
    return si_snr_val


class tfgridnet_v2_step2(nn.Module):
    def __init__(self, 
        num_layers = 6,  # 12
        win = 256,
        hop_length = 128,
        n_fft = 256,
        inp_channels=2, 
        out_channels=2, 
        dim = 48,
        bias = False,
        vpre_channels=768,
        num_source = 2,
        lstm_hidden_units=128,
        attn_n_head=4,
        attn_approx_qk_dim=512,
        emb_dim=48,
        emb_ks=4,
        emb_hs=1,
        activation="prelu",
        eps=1.0e-5,
        
    ):
        super(tfgridnet_v2_step2, self).__init__()
        self.num_source = num_source
        self.win = win
        self.hop_length = hop_length
        self.num_layers = num_layers
        assert n_fft % 2 == 0
        n_freqs = n_fft // 2 + 1

        self.tfgridnet1 = tfgridnet_v2(num_layers, win,hop_length,n_fft,inp_channels, out_channels, dim,
        bias,
        vpre_channels,
        num_source,
        lstm_hidden_units,
        attn_n_head,
        attn_approx_qk_dim,
        emb_dim,
        emb_ks,
        emb_hs,
        activation="prelu",
        eps=1.0e-5,)

        self.tfgridnet2 = tfgridnet_v2(num_layers, win,hop_length,n_fft,inp_channels, out_channels, dim,
        bias,
        vpre_channels,
        num_source,
        lstm_hidden_units,
        attn_n_head,
        attn_approx_qk_dim,
        emb_dim,
        emb_ks,
        emb_hs,
        activation="prelu",
        eps=1.0e-5,)

        model_path = "/home/xueke/DPT_1d_main/checkpoint_improve_tfgridnet_LRS2_SS/LRS2-restormer/epoch=113-16.5.ckpt"
        ckpt1 = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
        state_dict = ckpt1["state_dict"]
        # print(state_dict.keys())
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("audio_model."):
                new_k = k[len("audio_model."):] 
                new_state_dict[new_k] = v
        self.tfgridnet1.load_state_dict(new_state_dict, strict=False) 
        # self.tfgridnet1.eval()
        for p in self.tfgridnet1.parameters():
            p.requires_grad = False

        self.tfgridnet2.load_state_dict(new_state_dict, strict=False)

        config_dict = {
            "adim": 768,
            "aheads": 12,
            "eunits": 3072,
            "elayers": 12,
            "transformer_input_layer": "conv3d",
            "dropout_rate": 0.1,
            "transformer_attn_dropout_rate": 0.1,
            "transformer_encoder_attn_layer_type": "rel_mha",
            "macaron_style": True,
            "use_cnn_module": True,
            "cnn_module_kernel": 31,
            "zero_triu": False,
            "a_upsample_ratio": 1,
            "relu_type": "swish",
            "ddim": 768,
            "dheads": 12,
            "dunits": 3072,
            "dlayers": 6,
            "lsm_weight": 0.1,
            "transformer_length_normalized_loss": False,
            "mtlalpha": 0.1,
            "ctc_type": "builtin",
            "rel_pos_type": "latest",

            "aux_adim": 768,
            "aux_aheads": 12,
            "aux_eunits": 3072,
            "aux_elayers": 12,
            "aux_transformer_input_layer": "conv1d",
            "aux_dropout_rate": 0.1,
            "aux_transformer_attn_dropout_rate": 0.1,
            "aux_transformer_encoder_attn_layer_type": "rel_mha",
            "aux_macaron_style": True,
            "aux_use_cnn_module": True,
            "aux_cnn_module_kernel": 31,
            "aux_zero_triu": False,
            "aux_a_upsample_ratio": 1,
            "aux_relu_type": "swish",
            "aux_dunits": 3072,
            "aux_dlayers": 6,
            "aux_lsm_weight": 0.1,
            "aux_transformer_length_normalized_loss": False,
            "aux_mtlalpha": 0.1,
            "aux_ctc_type": "builtin",
            "aux_rel_pos_type": "latest",

            "fusion_hdim": 8192,
            "fusion_norm": "batchnorm",
        }
        args = Namespace(**config_dict)
        self.model_av = ModelModule(args)
        model_path_av = "/home/xueke/DPT_1d_main/auto_vsr_pretrain_model/avsr_trlrwlrs2lrs3vox2avsp_base.pth"
        ckpt = torch.load(model_path_av, map_location=lambda storage, loc: storage, weights_only=True)
        self.model_av.model.load_state_dict(ckpt)
        # self.model_av.freeze()
        for p in self.model_av.parameters():
            p.requires_grad = False

        self.a = nn.Parameter(torch.tensor(1.0))  
        self.b = nn.Parameter(torch.tensor(0.0))  

    def forward(self, x, v, face=None, targets=None, videos=None):  
        stft_out_spec, source, logits = self.tfgridnet1.forward(x, v, face)  # torch.Size([2, 2, 32000])

        si_snr1_1 = si_snr(source[:,0], targets[:,0])
        si_snr2_1 = si_snr(source[:,1], targets[:,1])
        si_snr1_avg = (si_snr1_1 + si_snr2_1) / 2
        # print(si_snr1_avg)
        si_snr1_2 = si_snr(source[:,1], targets[:,0])
        si_snr2_2 = si_snr(source[:,0], targets[:,1])
        si_snr1_avg2 = (si_snr1_2 + si_snr2_2) / 2
        # print(si_snr1_avg2)
        if si_snr1_avg > si_snr1_avg2:
            source1 = source[:,0].unsqueeze(-1)  # torch.Size([2, 32000, 1])
            source2 = source[:,1].unsqueeze(-1)  # torch.Size([2, 32000, 1])
        else:
            source1 = source[:,1].unsqueeze(-1)  # torch.Size([2, 32000, 1])
            source2 = source[:,0].unsqueeze(-1)  # torch.Size([2, 32000, 1])

        v1 = v[:, 0].transpose(1, 2)  # (2, 50, 768)
        v2 = v[:, 1].transpose(1, 2)  # (2, 50, 768)

        video1 = videos[:,0]  #  # torch.Size([1, 2, 65, 1, 88, 88])
        video2 = videos[:,1]  #  # torch.Size([1, 2, 65, 1, 88, 88])
        # 第一种
        with torch.no_grad():
            # f1, predicted_v_a_1 = self.model_av.forward_predicted(video1.squeeze(0), source1) 
            f1 = self.model_av.forward_predicted(video1.squeeze(0), source1)   

            f2 = self.model_av.forward_predicted(video2.squeeze(0), source2)
        fusied_y = torch.stack((f1, f2), dim=1)
        # print(fusied_y.shape)
        fusied_y = fusied_y.transpose(2, 3)

        out = fusied_y
        stft_out_spec_1, source_1, logits_1 = self.tfgridnet2.forward(x, out, face)



        return stft_out_spec_1, source_1, logits_1, self.a, self.b


if __name__ == '__main__':
    import os
    import glob
    from PIL import Image
    from torchvision import transforms
    transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.485], std=[0.229])  
        ])


    os.environ["CUDA_VISIBLE_DEVICES"] = "0"   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = tfgridnet_v2_step2().to(device)
    # print(model)
    input_signal = torch.randn(1,32000).to(device) 
    v = torch.randn(1,2,768,50).to(device)
    mouth = torch.randn(1,2,50,1,112,112).to(device)
    target = torch.randn(1, 2, 32000).to(device)
    # frames1 = frames1.unsqueeze(0).to(device)
    label1 = torch.tensor([4498, 10, 33, 1]).to(device)
    label2 = torch.tensor([4498, 10, 33, 1]).to(device)
    videos = torch.randn(1,2,50,1,88,88).to(device) # torch.Size([1, 2, 65, 1, 88, 88])
    # summary(model, input_signal, device='cuda')
    import time
    start_time = time.time()
    # out1,out2, a,b,c, d,e,_,_,_ = model(input_signal, v, mouth)
    out1, out2, logits, a, b,_,_ = model(input_signal, v, mouth, target, videos)  # [1, T, 1, 112, 112]
    print(logits[0].shape)
    print(logits[1].shape)
    # criterion = nn.CrossEntropyLoss()
    # labels = torch.tensor([1]).to(device)
    # loss = criterion(logits, labels)
    # print(loss)
    # print(b.shape)  # torch.Size([2, 256, 251, 129])
    # print(c.shape) # torch.Size([2, 1, 251, 1]) mask
    end_time = time.time()
    print(out1.shape)
    print(out2.shape)
