from typing import Any, Optional, Tuple

import torch
import torch.nn as nn
import copy
import math
import numpy as np

from .tcn import SingleStageTCN
from .SP import MultiScale_GraphConv_SE
import torch.nn.functional as F

def exponential_descrease(idx_decoder, p=3):
    return math.exp(-p*idx_decoder)

class Linear_Attention(nn.Module):
    def __init__(self,
                 in_channel,
                 n_features,
                 out_channel,
                 n_heads=4,
                 drop_out=0.05
                 ):
        super().__init__()
        self.n_heads = n_heads

        self.query_projection = nn.Linear(in_channel, n_features)
        self.key_projection   = nn.Linear(in_channel, n_features)
        self.value_projection = nn.Linear(in_channel, n_features)
        self.out_projection   = nn.Linear(n_features, out_channel)
        self.dropout          = nn.Dropout(drop_out)

    def elu(self, x):
        return torch.sigmoid(x)

    def forward(self, queries, keys, values, mask, weight=None):
       
        assert queries.dim() == 3, f"queries must be 3D, got {queries.shape}"
        B, L, Cq = queries.shape

       
        assert keys.dim() == 3 and values.dim() == 3, \
            f"keys and values must be 3D, got {keys.shape}, {values.shape}"
        Bk, S, Ck = keys.shape
        Bv, Sv, Cv = values.shape
        assert Bk == B and Bv == B, "batch size mismatch between queries/keys/values"
        assert Cq == Ck == Cv, "feature dim mismatch between queries/keys/values"
        assert Sv == S, "length mismatch between keys and values"

       
        assert mask.dim() == 3 and mask.shape[0] == B and mask.shape[2] == L, \
            f"mask must be [B,1,L], got {mask.shape}"

       
        if weight is not None:
            if weight.dim() != 2:
                weight = weight.view(B, L)  # 尝试将 weight 转换为 [B, L]
            assert weight.dim() == 2 and weight.shape == (B, L), \
                f"when provided, weight must be [B,L], got {weight.shape}"

       
        queries = self.query_projection(queries).view(B, L, self.n_heads, -1)
        keys    = self.key_projection(keys).view(B, S, self.n_heads, -1)
        values  = self.value_projection(values).view(B, S, self.n_heads, -1)

        queries = queries.transpose(1, 2)  # [B, heads, L, C']
        keys    = keys.transpose(1, 2)     # [B, heads, S, C']
        values  = values.transpose(1, 2)   # [B, heads, S, C']

       
        queries = self.elu(queries)
        keys    = self.elu(keys)

       
        KV = torch.einsum('bhsc,bhsv->bhcv', keys, values)  # [B, heads, C', C']

       
        Z = 1.0 / (torch.einsum('bhlc,bhc->bhl', queries, keys.sum(dim=2)) + 1e-6)  # [B, heads, L]

       
        if weight is not None:
            Z = Z * weight.unsqueeze(1)  # broadcast to [B, heads, L]

        
        x = torch.einsum('bhcv,bhlc,bhl->bhlv', KV, queries, Z)  # [B, heads, L, C']

       
        x = x.transpose(1, 2).reshape(B, L, -1)  # [B, L, heads*C']
        x = self.out_projection(x)
        x = self.dropout(x)

       
        return x * mask[:, 0, :, None]




class AttModule(nn.Module):
    def __init__(self, dilation, in_channel, out_channel, stage, alpha):
        super(AttModule, self).__init__()
        self.stage = stage
        self.alpha = alpha

        self.feed_forward = nn.Sequential(
            nn.Conv1d(in_channel, out_channel, 3, padding=dilation, dilation=dilation),
            nn.ReLU()
            ) #膨胀卷积
        
        
        self.instance_norm = nn.InstanceNorm1d(out_channel, track_running_stats=False)
       
        self.att_layer = Linear_Attention(out_channel, out_channel, out_channel)
        
        self.conv_out = nn.Conv1d(out_channel, out_channel, 1)
        self.dropout = nn.Dropout()
        
    def forward(self, x, f, mask):

        out = self.feed_forward(x)
       
        
        if self.stage == 'encoder':
            q = self.instance_norm(out).permute(0, 2, 1)
            if f is not None:
                f = f.permute(0, 2, 1)
                
                out = self.alpha * self.att_layer(f, q, q, mask).permute(0, 2, 1) + out
            else:
                out = self.alpha * self.att_layer(q, q, q, mask).permute(0, 2, 1) + out
            
        else:
            assert f is not None
            q = self.instance_norm(out).permute(0, 2, 1)
            f = f.permute(0, 2, 1)
            out = self.alpha * self.att_layer(q, q, f, mask).permute(0, 2, 1) + out
       
        out = self.conv_out(out)
        
        out = self.dropout(out)
        
        return (x + out) * mask

class SFI(nn.Module):
    def __init__(self, in_channel, n_features):
        super().__init__()
        self.conv_s = nn.Conv1d(in_channel, n_features, 1) #19->64
        self.softmax = nn.Softmax(dim=-1)
        self.ff = nn.Sequential(nn.Linear(n_features, n_features),
                                nn.GELU(),
                                nn.Dropout(0.3),
                                nn.Linear(n_features, n_features)) #64—>64
        
    def forward(self, feature_s, feature_t, mask):
        # Feature_s comes from space (n, t, v) feature_t comes from the previous layer of time (n, t, c)
        feature_s = feature_s.permute(0, 2, 1) #(n,v,t)
        n, c, t = feature_s.shape
        feature_s = self.conv_s(feature_s) #(n,v,t)->(n,c,t)
        map = self.softmax(torch.einsum("nct,ndt->ncd", feature_s, feature_t)/t) #（n,c,c）
        feature_cross = torch.einsum("ncd,ndt->nct", map, feature_t) #（n,c,c),（n,c,t）->（n,c,t）
        feature_cross = feature_cross + feature_t
        feature_cross = feature_cross.permute(0, 2, 1) #(n,t,c）
        feature_cross = self.ff(feature_cross).permute(0, 2, 1) + feature_t

        return feature_cross * mask



import torch
import torch.nn as nn
import torch.nn.functional as F

class STI(nn.Module):
    def __init__(
        self, node, in_channel, n_features, out_channel,
        num_layers, SFI_layer,
        M=2,                    
        channel_masking_rate=0.3,
        alpha=1, div_margin=0.1, div_weight=0.1,
        pool_len=64,           
        use_gaussian_weighting=True, use_diversity_loss=True
    ):
        super().__init__()
        self.SFI_layer = SFI_layer
        self.L = len(SFI_layer) + 1
        self.channel_masking_rate = channel_masking_rate
        self.dropout = nn.Dropout2d(p=channel_masking_rate)

        self.use_gaussian_weighting = use_gaussian_weighting
        self.use_diversity_loss = use_diversity_loss

        
        self.conv_in = nn.Conv2d(in_channel, self.L, kernel_size=1)
        self.conv_t  = nn.Conv1d(node, n_features, 1)

        
        self.SFI_layers_c = nn.ModuleList([SFI(node, n_features) for _ in range(self.L-1)])
        self.SFI_layers_l = nn.ModuleList([SFI(node, n_features) for _ in range(self.L-1)])

        
        self.layers_c = nn.ModuleList([AttModule(2**i, n_features, n_features, 'encoder', alpha)
                                       for i in range(num_layers)])
        self.layers_l = nn.ModuleList([AttModule(2**i, n_features, n_features, 'encoder', alpha)
                                       for i in range(num_layers)])

        
        self.M = M
        self.pool_len = pool_len      
        self.G = pool_len * M         
        self.eps = 1e-6
        self.div_margin = div_margin
        self.div_weight = div_weight

        
        eye = torch.eye(self.G, dtype=torch.bool)
        self.register_buffer('div_mask', ~eye)

        
        self.gauss_cls_pred = None
        self.gauss_loc_pred = None

        
        self.gauss_info = {
            'mu_c': [None] * num_layers,
            'sigma_c': [None] * num_layers,
            'g_c': [None] * num_layers,
            'scale_c': [None] * num_layers,
            'w_c': [None] * num_layers,
            'mu_l': [None] * num_layers,
            'sigma_l': [None] * num_layers,
            'g_l': [None] * num_layers,
            'scale_l': [None] * num_layers,
            'w_l': [None] * num_layers
        }

        
        hidden = n_features * self.G // 2
        self.scale_mlp_c = nn.ModuleList([
            nn.Sequential(
                nn.Linear(n_features * self.G, hidden),
                nn.ReLU(),
                nn.Linear(hidden, self.G),
                nn.Softplus()
            ) for _ in range(num_layers)
        ])
        self.scale_mlp_l = nn.ModuleList([
            nn.Sequential(
                nn.Linear(n_features * self.G, hidden),
                nn.ReLU(),
                nn.Linear(hidden, self.G),
                nn.Softplus()
            ) for _ in range(num_layers)
        ])

        
        self.layer_weight_net_c = nn.Sequential(
            nn.Linear(n_features * num_layers, n_features),
            nn.ReLU(),
            nn.Linear(n_features, num_layers),
            nn.Softmax(dim=1)
        )
        self.layer_weight_net_l = nn.Sequential(
            nn.Linear(n_features * num_layers, n_features),
            nn.ReLU(),
            nn.Linear(n_features, num_layers),
            nn.Softmax(dim=1)
        )

        # 最后输出分支
        self.conv_out_c = nn.Conv1d(n_features, out_channel, 1)
        self.conv_out_l = nn.Conv1d(n_features, out_channel, 1)

    def _init_gauss_predictors(self, feat_dim):
        
        self.gauss_cls_pred = nn.ModuleList([
            nn.Linear(feat_dim, 2 * self.M)  
            for _ in range(self.L)
        ])
        self.gauss_loc_pred = nn.ModuleList([
            nn.Linear(feat_dim, 2 * self.M)
            for _ in range(self.L)
        ])
        self.add_module('gauss_cls_pred', self.gauss_cls_pred)
        self.add_module('gauss_loc_pred', self.gauss_loc_pred)

    def forward(self, x, mask, joint_text_embedding=None):
        """
        x: [B, in_channel, node, t_len]
        mask: [B, 1, t_len]
        """
        
        if self.channel_masking_rate > 0:
            x = self.dropout(x)

        
        x = self.conv_in(x)                  # [B, L, node, t_len]
        feat_s, feat_t = torch.split(x, (self.L-1, 1), dim=1)
        feat_t = feat_t.squeeze(1).permute(0, 2, 1)  # [B, node, t_len]
        feature_st = self.conv_t(feat_t)            # [B, C, t_len]
        B, C, t_len = feature_st.shape

        #
        if (self.use_gaussian_weighting or self.use_diversity_loss) and (self.gauss_cls_pred is None):
            
            feat_dim = C
            self._init_gauss_predictors(feat_dim)
            for m in self.gauss_cls_pred: m.to(x.device)
            for m in self.gauss_loc_pred: m.to(x.device)

        div_loss = 0.0
        weights_cls, weights_loc = [], []
        all_feats_c, all_feats_l = [], []

        feature_st_c, feature_st_l = feature_st, feature_st

        for idx, (layer_c, layer_l) in enumerate(zip(self.layers_c, self.layers_l)):
            
            if idx in self.SFI_layer:
                sfi_idx = self.SFI_layer.index(idx)
                feature_st_c = self.SFI_layers_c[sfi_idx](feat_s[:, sfi_idx], feature_st_c, mask)
                feature_st_l = self.SFI_layers_l[sfi_idx](feat_s[:, sfi_idx], feature_st_l, mask)

            
            if self.use_gaussian_weighting or self.use_diversity_loss:
                
                
                pooled_c = F.adaptive_avg_pool1d(feature_st_c.contiguous(), self.pool_len)  # [B, 64, C]
                
               
                params_c_raw = self.gauss_cls_pred[idx](pooled_c.permute(0,2,1)).permute(0,2,1)  # [B, 2*M, 64]
                
                mu_c_raw_unscaled, logvar_c_raw_unscaled = params_c_raw.chunk(2, dim=1)  # 各 [B, M, 64]

                sigma_c_raw_unscaled = F.softplus(logvar_c_raw_unscaled) + self.eps  # [B, M, 64]
                #
                mu_c_unit = torch.sigmoid(mu_c_raw_unscaled)  # [B, M, 64]

                
                if self.use_diversity_loss:
                    mu_flat = mu_c_raw_unscaled.view(B, self.G)  # [B, 64*M]
                    diff = torch.relu(self.div_margin - (mu_flat.unsqueeze(2) - mu_flat.unsqueeze(1)).abs())
                    div_loss += diff[:, self.div_mask].mean()

                
                device = x.device
                pool_idxs = torch.arange(self.pool_len, device=device).view(1, self.pool_len, 1)  # [1, 64, 1]
                seg_len = t_len / float(self.pool_len)  # 每段的实际帧数

                
                mu_c_unit = mu_c_unit.permute(0, 2, 1)       # [B, 64, M]
                sigma_c_unit = sigma_c_raw_unscaled.permute(0, 2, 1)  # [B, 64, M]

                
                mu_c_abs = (pool_idxs + mu_c_unit) * seg_len   # [B, 64, M]
                sigma_c_abs = sigma_c_unit * seg_len           # [B, 64, M]

                
                mu_c_abs = mu_c_abs.reshape(B, self.G, 1)      # [B, G, 1]
                sigma_c_abs = sigma_c_abs.reshape(B, self.G, 1)# [B, G, 1]

               
                time_idx = torch.arange(t_len, device=device).view(1, 1, t_len).float()

               
                g_c = torch.exp(- (time_idx - mu_c_abs)**2 / (2 * sigma_c_abs**2 + self.eps))  # [B, G, t_len]

                
                # feature_st_c: [B, C, t_len], g_c: [B, G, t_len]
                fg_c_raw = torch.einsum('bct,bgt->btcg', feature_st_c, g_c).reshape(B * t_len, C * self.G)
                scale_c = self.scale_mlp_c[idx](fg_c_raw).view(B, self.G, t_len)  # [B, G, t_len]
                

                
                w_c = torch.einsum('bgt,bgt->bt', g_c, scale_c)  # [B, t_len]
                w_c = torch.clamp(w_c, min=0.1, max=1.0)         # 保证在 [0.1, 1.0] 之间
                weights_cls.append(w_c)

                
                pooled_l = F.adaptive_avg_pool1d(feature_st_l.contiguous(), self.pool_len)  # [B, C, 64]
                params_l_raw = self.gauss_loc_pred[idx](pooled_l.permute(0,2,1)).permute(0,2,1)  # [B, 2*M, 64]
                mu_l_raw_unscaled, logvar_l_raw_unscaled = params_l_raw.chunk(2, dim=1)  #  [B, M, 64]
                sigma_l_raw_unscaled = F.softplus(logvar_l_raw_unscaled) + self.eps       # [B, M, 64]
                mu_l_unit = torch.sigmoid(mu_l_raw_unscaled)                             # [B, M, 64]

                if self.use_diversity_loss:
                    mu_l_flat = mu_l_raw_unscaled.view(B, self.G)  # [B, G]
                    diff2 = torch.relu(self.div_margin - (mu_l_flat.unsqueeze(2) - mu_l_flat.unsqueeze(1)).abs())
                    div_loss += diff2[:, self.div_mask].mean()

                mu_l_unit = mu_l_unit.permute(0, 2, 1)        # [B, 64, M]
                sigma_l_unit = sigma_l_raw_unscaled.permute(0, 2, 1)  # [B,64,M]

                mu_l_abs = (pool_idxs + mu_l_unit) * seg_len  # [B, 64, M]
                sigma_l_abs = sigma_l_unit * seg_len          # [B, 64, M]

                mu_l_abs = mu_l_abs.reshape(B, self.G, 1)     # [B, G, 1]
                sigma_l_abs = sigma_l_abs.reshape(B, self.G, 1)  # [B, G, 1]

                g_l = torch.exp(- (time_idx - mu_l_abs)**2 / (2 * sigma_l_abs**2 + self.eps))  # [B, G, t_len]

                fg_l_raw = torch.einsum('bct,bgt->btcg', feature_st_l, g_l).reshape(B * t_len, C * self.G)
                scale_l = self.scale_mlp_l[idx](fg_l_raw).view(B, self.G, t_len)  # [B, G, t_len]
                
                w_l = torch.einsum('bgt,bgt->bt', g_l, scale_l)  # [B, t_len]
                w_l = torch.clamp(w_l, min=0.1, max=1.0)        # [B, t_len]
                weights_loc.append(w_l)

                
                self.gauss_info['mu_c'][idx] = mu_c_abs.detach().cpu().clone()
                self.gauss_info['sigma_c'][idx] = sigma_c_abs.detach().cpu().clone()
                self.gauss_info['g_c'][idx] = g_c.detach().cpu().clone()
                self.gauss_info['scale_c'][idx] = scale_c.detach().cpu().clone()
                self.gauss_info['w_c'][idx] = w_c.detach().cpu().clone()

                self.gauss_info['mu_l'][idx] = mu_l_abs.detach().cpu().clone()
                self.gauss_info['sigma_l'][idx] = sigma_l_abs.detach().cpu().clone()
                self.gauss_info['g_l'][idx] = g_l.detach().cpu().clone()
                self.gauss_info['scale_l'][idx] = scale_l.detach().cpu().clone()
                self.gauss_info['w_l'][idx] = w_l.detach().cpu().clone()
            else:
                
                w_c = torch.ones(B, t_len, device=x.device, dtype=x.dtype)
                w_l = torch.ones(B, t_len, device=x.device, dtype=x.dtype)
                weights_cls.append(w_c)
                weights_loc.append(w_l)

           
            Q_c = feature_st_c * w_c.unsqueeze(1)  
            feature_st_c = layer_c(feature_st_c, Q_c, mask)  # [B, C, t_len]

            Q_l = feature_st_l * w_l.unsqueeze(1)
            feature_st_l = layer_l(feature_st_l, Q_l, mask)  # [B, C, t_len]

          
            all_feats_c.append(feature_st_c)
            all_feats_l.append(feature_st_l)

      
        all_feats_c = torch.stack(all_feats_c, dim=1)  # [B, num_layers, C, t_len]
        all_feats_l = torch.stack(all_feats_l, dim=1)  # [B, num_layers, C, t_len]
       
        final_feat_c = all_feats_c.sum(1)
        final_feat_l = all_feats_l.sum(1)

       
        out_c = self.conv_out_c(final_feat_c) * mask  # [B, out_channel, t_len]
        out_l = self.conv_out_l(final_feat_l) * mask  # [B, out_channel, t_len]

        return out_c, out_l, div_loss * self.div_weight



       
class Decoder(nn.Module):
    def __init__(self, in_channel, n_features, out_channel, num_layers, alpha=1):
        super().__init__()
        
        self.conv_in = nn.Conv1d(in_channel, n_features, 1)
        self.layers = nn.ModuleList(
            [AttModule(2 ** i, n_features, n_features, 'decoder', alpha) for i in 
             range(num_layers)])
        self.conv_out = nn.Conv1d(n_features, out_channel, 1)

    def forward(self, x, fencoder, mask):
        feature = self.conv_in(x)
        for layer in self.layers:
            feature = layer(feature, fencoder, mask)
        out = self.conv_out(feature)
        
        return out, feature

    
class Model(nn.Module):
    """
    this model predicts both frame-level classes and boundaries.
    Args:
        in_channel: 
        n_feature: 64
        n_classes: the number of action classes
        n_layers: 10
    """

    def __init__(
        self,
        in_channel: int,
        n_features: int,
        n_classes: int,
        n_stages: int,
        n_layers: int,
        n_refine_layers: int,
        n_stages_asb: Optional[int] = None,
        n_stages_brb: Optional[int] = None,
        SFI_layer: Optional[int] = None,
        dataset: str = None,
        **kwargs: Any
    ) -> None:

        if not isinstance(n_stages_asb, int):
            n_stages_asb = n_stages

        if not isinstance(n_stages_brb, int):
            n_stages_brb = n_stages

        super().__init__()

        self.logit_scale = nn.Parameter(torch.ones(2) * np.log(1 / 0.07))  # 2.6593

        self.in_channel = in_channel
        node = 19 if dataset == "LARA" else 25

        self.SP = MultiScale_GraphConv_SE(13, in_channel, n_features, dataset) #MS-G3D
        self.STI = STI(node, n_features, n_features, \
                       n_features, n_layers, SFI_layer,\
                        use_diversity_loss=False, 
                        use_gaussian_weighting=True, 
                        )
        self.joint_att = Spatial_AttLayer(n_features, n_features, 32, 64, 4)
 
        self.conv_cls = nn.Conv1d(n_features, n_classes, 1)
        self.conv_bound = nn.Conv1d(n_features, 1, 1)
        self.conv_feature = nn.Conv1d(n_features,512,1)
        self.conv_feature_split = nn.Conv1d(n_features, 512, 1)

        # action segmentation branch
        asb = [
            copy.deepcopy(Decoder(n_classes, n_features, n_classes, n_refine_layers, alpha=exponential_descrease(s))) for s in range(n_stages_asb - 1)
        ]
        conv_asb_feature = [nn.Conv1d(n_features,512,1) for s in range(n_stages_asb - 1)]
        # boundary regression branch
        brb = [
            SingleStageTCN(1, n_features, 1, n_refine_layers) for _ in range(n_stages_brb - 1)
        ]
        self.asb = nn.ModuleList(asb)
        self.brb = nn.ModuleList(brb)
        self.conv_asb_feature = nn.ModuleList(conv_asb_feature)

        self.activation_asb = nn.Softmax(dim=1)
        self.activation_brb = nn.Sigmoid()
        self.ff_embedding = nn.Linear(512, 512)
        self.div_loss = None

    def forward(self, x: torch.Tensor, mask: torch.Tensor, joint_text_embedding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        joint_text_embedding = self.ff_embedding(joint_text_embedding)
        joint_text_embedding = joint_text_embedding.permute(1, 0).unsqueeze(0).unsqueeze(2)
        

        x = self.SP(x) * mask.unsqueeze(3) #MS-G3D （n,c,t,v）
        
        x = self.joint_att(x, joint_text_embedding)

        feature_cls, feature_brd, div_loss = self.STI(x, mask, joint_text_embedding)
        self.div_loss = div_loss

        
        out_cls = self.conv_cls(feature_cls)
        out_bound = self.conv_bound(feature_brd)
        out_feature = self.conv_feature(feature_cls)
        out_feature_split = self.conv_feature_split(feature_cls+feature_brd)
        
        if self.training:
            outputs_cls = [out_cls]
            outputs_bound = [out_bound]
            outputs_feature = [out_feature]

            for as_stage, conv_stage in zip(self.asb, self.conv_asb_feature):
                out_cls, feature_cls = as_stage(self.activation_asb(out_cls)* mask, feature_cls* mask, mask)
                out_feature = conv_stage(feature_cls)
                outputs_cls.append(out_cls)
                outputs_feature.append(out_feature)

            for br_stage in self.brb:
                out_bound,_ = br_stage(self.activation_brb(out_bound), mask)
                outputs_bound.append(out_bound)

            return (outputs_cls, outputs_bound, outputs_feature, out_feature_split, self.logit_scale)
        else:
            out_feature = None
            for as_stage, conv_stage in zip(self.asb, self.conv_asb_feature):
                out_cls, feature_cls = as_stage(self.activation_asb(out_cls)* mask, feature_cls* mask, mask)
                out_feature = conv_stage(feature_cls)

            bound_feature = None
            for br_stage in self.brb:
                out_bound, bound_feature = br_stage(self.activation_brb(out_bound), mask)

            return (out_cls, out_bound, out_feature, bound_feature)



class Spatial_AttLayer(nn.Module):
    def __init__(self, in_channels, out_channels, qk_dim, v_dim, num_heads):
        super().__init__()
        self.qk_dim = qk_dim
        self.v_dim = v_dim
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads

        self.query_conv = nn.Conv2d(in_channels=512, out_channels=num_heads * qk_dim, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=512, out_channels=num_heads * qk_dim, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels= in_channels, out_channels= v_dim, kernel_size=1)

        z = self.conv_out = nn.Conv2d(in_channels=num_heads * v_dim, out_channels=out_channels, kernel_size=1)

        self.softmax = nn.Softmax(dim=-1)


    def forward(self, feature, text_feature):

        N, C, T, V = feature.size()

        k_feature = text_feature.expand(N, -1, 1, -1)
        q_feature = text_feature.expand(N, -1, 1, -1)
        v_feature = feature

        q = self.query_conv(q_feature).view(N, self.num_heads, self.qk_dim, V)
        k = self.key_conv(k_feature).view(N, self.num_heads, self.qk_dim, V)
        v = self.value_conv(v_feature)

        energy = torch.einsum('nhcu,nhcv->nhuv', q,k)
        attention = energy / (np.sqrt(self.qk_dim) * 1.0)
        attention = self.softmax(attention)

        z = torch.einsum('nctu,nhuv->nhctv', v, attention).contiguous().view(N, self.num_heads * self.v_dim, T, V)
        z = self.conv_out(z) + feature

        return z