import copy
import math
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import difflib
from typing import Optional, Tuple, Union

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, BertTokenizer, BertModel, Wav2Vec2Model, Wav2Vec2Config
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureEncoder
from .motion_encoder import VQEncoderV6


def audio_to_time_aligned_text_features(inputs, processor, model, tokenizer, bert_model):  
    with torch.no_grad():
        logits = model(inputs.input_values).logits  # shape: (1, time_steps, vocab_size)

    predicted_ids_per_timestep = torch.argmax(logits, dim=-1)  # shape: (1, time_steps)
    predicted_ids_per_timestep = predicted_ids_per_timestep[0].cpu().numpy()
    vocab = processor.tokenizer.get_vocab()
    id_to_token = {v: k for k, v in vocab.items()}
    tokens_per_timestep = [id_to_token[id] for id in predicted_ids_per_timestep]

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    inputs_bert = tokenizer(transcription, return_tensors='pt')
    input_ids = inputs_bert['input_ids'][0]  
    tokens_bert = tokenizer.convert_ids_to_tokens(input_ids)

    with torch.no_grad():
        outputs_bert = bert_model(**inputs_bert.to(inputs.input_values.device))
    all_token_embeddings = outputs_bert.last_hidden_state[0]  
    per_timestep_chars = []
    per_timestep_char_indices = []
    for idx, t in enumerate(tokens_per_timestep):
        if t not in ('<pad>', '|'):
            per_timestep_chars.append(t.lower())
            per_timestep_char_indices.append(idx)
    bert_chars = []
    bert_char_indices = []
    for idx, token in enumerate(tokens_bert):
        if token in ('[CLS]', '[SEP]'):
            continue
        token_str = token.replace('##', '')
        for c in token_str:
            bert_chars.append(c)
            bert_char_indices.append(idx)

    s = difflib.SequenceMatcher(None, per_timestep_chars, bert_chars)
    opcodes = s.get_opcodes()
    per_timestep_to_bert_token_idx = {}
    for tag, i1, i2, j1, j2 in opcodes:
        if tag == 'equal':
            for k in range(i2 - i1):
                per_timestep_idx = per_timestep_char_indices[i1 + k]
                bert_token_idx = bert_char_indices[j1 + k]
                per_timestep_to_bert_token_idx[per_timestep_idx] = bert_token_idx
    features_per_timestep = []
    check = []
    for i, per_token in enumerate(tokens_per_timestep):
        if i == 0:
            embedding = all_token_embeddings[0]
            check.append("cls")
        elif per_token in ('<pad>', '|'):
            embedding = torch.zeros(all_token_embeddings.shape[-1]).to(inputs.input_values.device)
            check.append(0)
        else:
            if i in per_timestep_to_bert_token_idx:
                bert_idx = per_timestep_to_bert_token_idx[i]
                embedding = all_token_embeddings[bert_idx]
                check.append(tokens_bert[bert_idx])
            else:
                embedding = torch.zeros(all_token_embeddings.shape[-1]).to(inputs.input_values.device)
                check.append(0)
        features_per_timestep.append(embedding)
    features_per_timestep = torch.stack(features_per_timestep)  

    updated_check = check.copy()
    for i in range(len(check)):
        if check[i] == 0:
            left = i - 1
            right = i + 1
            left_found = False
            right_found = False

            while left >= 0:
                if check[left] != 0:
                    left_found = True
                    break
                left -= 1

            while right < len(check):
                if check[right] != 0:
                    right_found = True
                    break
                right += 1

            if left_found and right_found:
                if (i - left) <= (right - i):
                    nearest = left
                else:
                    nearest = right
            elif left_found:
                nearest = left
            elif right_found:
                nearest = right
            else:
                continue
            updated_check[i] = updated_check[nearest]
            features_per_timestep[i] = features_per_timestep[nearest]
    features_per_timestep = features_per_timestep.unsqueeze(0)
    return transcription, features_per_timestep, all_token_embeddings 


class MLP(nn.Module):
    def __init__(self, in_dim, hidden_size, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.LeakyReLU(0.2, True),
            nn.Linear(hidden_size, out_dim)
        )
    def forward(self, inputs):
        out = self.mlp(inputs)
        return out


class PeriodicPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, period=20, max_seq_len=64): 
        super(PeriodicPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(period, d_model)
        position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # (1, period, d_model)
        repeat_num = (max_seq_len//period) + 1
        pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model)
        self.register_buffer('pe', pe)
    def forward(self, x):
        # print(self.pe.shape, x.shape)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
    

class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CustomMultiheadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size, seq_len, embed_dim = query.size()

        # Linear projections
        Q = self.query_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)  # Shape: (batch_size, num_heads, seq_len, seq_len)
        attn_output = torch.matmul(attn_weights, V)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        # Apply final linear projection
        output = self.out_proj(attn_output)
        return output, attn_weights  # Return the per-head attention weights


# class JointEmbedding(nn.Module):
#     def __init__(self, args):
#         super(JointEmbedding, self).__init__()
#         self.args = args.model   
#         self.audio_processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')
#         self.audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
#         self.config_wav2vec = Wav2Vec2Config.from_pretrained('facebook/wav2vec2-base-960h')
#         self.audio_encoder_fintune = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h').feature_extractor
#         # print(self.audio_encoder_fintune)
#         # print(bug)
        
#         self.asr = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')
#         self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#         self.bert_model = BertModel.from_pretrained('bert-base-uncased')

#         self.audio_low_mapping = MLP(512+512, self.args.hidden_size, self.args.audio_f)
#         self.audio_high_mapping = MLP(768+768+512+512, self.args.hidden_size, self.args.audio_f)
#         # self.audio_sa = nn.MultiheadAttention(embed_dim=self.args.audio_f, num_heads=8, batch_first=True)
#         self.audio_sa = CustomMultiheadAttention(embed_dim=self.args.audio_f, num_heads=8,)

#         args_top = copy.deepcopy(self.args)
#         args_top.vae_layer = 3
#         args_top.vae_length = self.args.motion_f
#         args_top.vae_test_dim = self.args.motion_dim
#         self.motion_encoder_low = VQEncoderV6(args_top) 

#         args_top_high = copy.deepcopy(self.args)
#         args_top_high.vae_layer = 6
#         args_top_high.vae_length = self.args.motion_f
#         args_top_high.vae_test_dim = self.args.motion_dim
#         self.motion_encoder_high = VQEncoderV6(args_top_high) 
#         self.motion_low_mapping = MLP(self.args.motion_f, self.args.hidden_size, self.args.motion_f)
#         self.motion_high_mapping = MLP(self.args.motion_f, self.args.hidden_size, self.args.motion_f)
#         # self.motion_sa = nn.MultiheadAttention(embed_dim=self.args.audio_f, num_heads=8, batch_first=True)
#         self.motion_sa = CustomMultiheadAttention(embed_dim=self.args.audio_f, num_heads=8,)
        
#         self.down_sample = 2 # for downsample 30 fps motion to 15
#         self.smplx_model = None
#         self.get_motion_reps = None
#         self.audio_to_time_aligned_text_features = audio_to_time_aligned_text_features
#         self.low_temp = nn.Parameter(torch.tensor(0.07))
#         self.low_level_loss_fn = None
#         self.high_temp = nn.Parameter(torch.tensor(0.07))
#         self.high_level_loss_fn = None

#     def _reset_parameters(self):
#         nn.init.normal_(self.mask_embeddings, 0, self.args.hidden_size ** -0.5)
    
#     def forward(self, in_audio=None, in_motion=None, cached_audio_low=None, cached_audio_high=None, cached_rep15d=None):
#         # motion feature
#         if cached_rep15d is not None:
#             in_motion = cached_rep15d[:,::self.down_sample]
#         else:
#             in_motion = self.get_motion_reps(in_motion, self.smplx_model)["rep15d"][:,::self.down_sample]
        
#         raw_motion_low = self.motion_encoder_low(in_motion)
#         raw_motion_high = self.motion_encoder_high(in_motion)

#         motion_low = self.motion_low_mapping(raw_motion_low)
#         motion_high = self.motion_high_mapping(raw_motion_high)
#         motion_high_att, motion_high_weight = self.motion_sa(motion_high, motion_high, motion_high)
#         bs, n, c = motion_high.shape
#         # print("a:", motion_high_weight[:, :, 0, :].unsqueeze(2).shape, "b:", motion_high.transpose(1, 2).view(bs, 8, c//8, n).shape)
#         motion_high_att_before_sum = motion_high_weight[:, :, 0, :].unsqueeze(2) * motion_high.transpose(1, 2).view(bs, 8, c//8, n)
#         motion_high_att_before_sum = motion_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
#         motion_low = F.interpolate(motion_low.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         motion_high_att = F.interpolate(motion_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         motion_high_att_before_sum = F.interpolate(motion_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         motion_cls = motion_high_att[:, 0]

#         # audio feature
#         if cached_audio_low is not None:
#             raw_audio_low = cached_audio_low
#             raw_audio_high = cached_audio_high
            
#             audio_list = [i.cpu().numpy() for i in in_audio]
#             inputs = self.audio_processor(audio_list, sampling_rate=16000, return_tensors="pt", padding=True).to(in_audio.device)
#             finetune_audio_low = self.audio_encoder_fintune(inputs.input_values).transpose(1, 2)
#             diff = raw_audio_low.shape[1] - finetune_audio_low.shape[1]
#             if diff > 0:
#                 finetune_audio_low = torch.cat([finetune_audio_low, finetune_audio_low[:, -diff:]], dim=1)
#             raw_audio_low = torch.cat([raw_audio_low, finetune_audio_low], dim=-1) # bs, t, 1024
#         else:
#             print("error! must have cached audio in training")
        
#         # print(raw_audio_low.shape, raw_audio_high.shape, "before")
#         raw_audio_low = F.interpolate(raw_audio_low.transpose(1, 2), scale_factor=30/50, mode='linear', align_corners=True).transpose(1, 2) 
#         raw_audio_high = F.interpolate(raw_audio_high.transpose(1, 2), scale_factor=15/50, mode='linear', align_corners=True).transpose(1, 2) 
#         # print(raw_audio_low.shape, raw_audio_high.shape, "after")
#         audio_low = self.audio_low_mapping(raw_audio_low)
#         raw_audio_high = torch.cat([raw_audio_low[:, ::2], raw_audio_high], dim=-1)
#         audio_high = self.audio_high_mapping(raw_audio_high)
#         audio_high_att, audio_high_weight = self.audio_sa(audio_high, audio_high, audio_high)
#         bs, n, c = audio_high.shape
#         audio_high_att_before_sum = audio_high_weight[:, :, 0, :].unsqueeze(2) * audio_high.transpose(1, 2).view(bs, 8, c//8, n)
#         audio_high_att_before_sum = audio_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
#         audio_high_att = F.interpolate(audio_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         audio_high_att_before_sum = F.interpolate(audio_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         audio_cls = audio_high_att[:, 0]
#         # low_infonce, low_acc = self.low_level_loss_fn(audio_low, motion_low, learned_temp=self.low_temp)
        
#         # fix temp to 0.1 is better than learned temp
#         low_infonce, low_acc = self.low_level_loss_fn(audio_low, motion_low)
#         high_infonce = self.high_level_loss_fn(audio_cls, motion_cls)
#         return {
#             "audio_low":audio_low,
#             "audio_high":audio_high_att,
#             "audio_cls":audio_cls,
#             "audio_high_weight":audio_high_att_before_sum,
#             "motion_low":motion_low,
#             "motion_high":motion_high_att,
#             "motion_cls":motion_cls,
#             "motion_high_weight":motion_high_att_before_sum,
#             "low_level_loss": [low_infonce, low_acc],
#             "high_level_loss": high_infonce
#             }

#     def get_audio_features(self, in_audio):
#         audio_list = [i.cpu().numpy() for i in in_audio]
#         inputs = self.audio_processor(audio_list, sampling_rate=16000, return_tensors="pt", padding=True).to(in_audio.device)
#         raw_audio_low = self.audio_encoder.feature_extractor(inputs.input_values).transpose(1, 2)
#         finetune_audio_low = self.audio_encoder_fintune(inputs.input_values).transpose(1, 2)
#         diff = raw_audio_low.shape[1] - finetune_audio_low.shape[1]
#         if diff > 0:
#             finetune_audio_low = torch.cat([finetune_audio_low, finetune_audio_low[:, -diff:]], dim=1)
#         raw_audio_low = torch.cat([raw_audio_low, finetune_audio_low], dim=-1)

#         raw_audio_high = self.audio_encoder(inputs.input_values).last_hidden_state
#         _, bert_time_aligned_text, _ = audio_to_time_aligned_text_features(inputs, self.audio_processor, self.asr, self.bert_tokenizer, self.bert_model)
#         raw_audio_high = torch.cat([raw_audio_high, bert_time_aligned_text], dim=2)
#         raw_audio_low = F.interpolate(raw_audio_low.transpose(1, 2), scale_factor=30/50, mode='linear', align_corners=True).transpose(1, 2) 
#         raw_audio_high = F.interpolate(raw_audio_high.transpose(1, 2), scale_factor=15/50, mode='linear', align_corners=True).transpose(1, 2)
        
#         if raw_audio_low.shape[1] % 2 == 1:
#             raw_audio_low = torch.cat([raw_audio_low, raw_audio_low[:, -1:]], dim=1)
#         diff = raw_audio_low[:, ::2].shape[1] - raw_audio_high.shape[1]
#         if diff > 0:
#             raw_audio_high = torch.cat([raw_audio_high, raw_audio_high[:, -diff:]], dim=1)

#         audio_low = self.audio_low_mapping(raw_audio_low)
#         # print(audio_low.shape[1]//2, raw_audio_high.shape[1])
#         raw_audio_high = torch.cat([raw_audio_low[:, ::2], raw_audio_high], dim=-1)
#         audio_high = self.audio_high_mapping(raw_audio_high)
#         audio_high_att, audio_high_weight = self.audio_sa(audio_high, audio_high, audio_high)
#         bs, n, c = audio_high.shape
#         audio_high_att_before_sum = audio_high_weight[:, :, 0, :].unsqueeze(2) * audio_high.transpose(1, 2).view(bs, 8, c//8, n)
#         audio_high_att_before_sum = audio_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
#         audio_high_att = F.interpolate(audio_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         audio_high_att_before_sum = F.interpolate(audio_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         audio_cls = audio_high_att[:, 0]
#         return {
#             "audio_low":audio_low,
#             "audio_high":audio_high_att,
#             "audio_cls":audio_cls,
#             "audio_high_weight":audio_high_att_before_sum,
#             }

#     def get_motion_features(self, in_motion):
#         in_motion = self.get_motion_reps(in_motion, self.smplx_model)["rep15d"][:,::self.down_sample]
#         raw_motion_low = self.motion_encoder_low(in_motion)
#         raw_motion_high = self.motion_encoder_high(in_motion)
#         motion_low = self.motion_low_mapping(raw_motion_low)
#         motion_high = self.motion_high_mapping(raw_motion_high)
#         motion_high_att, motion_high_weight = self.motion_sa(motion_high, motion_high, motion_high)
#         bs, n, c = motion_high.shape
#         motion_high_att_before_sum = motion_high_weight[:, :, 0, :].unsqueeze(2) * motion_high.transpose(1, 2).view(bs, 8, c//8, n)
#         motion_high_att_before_sum = motion_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
#         motion_low = F.interpolate(motion_low.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         motion_high_att = F.interpolate(motion_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         motion_high_att_before_sum = F.interpolate(motion_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
#         motion_cls = motion_high_att[:, 0]
#         return {
#             "motion_low":motion_low,
#             "motion_high":motion_high_att,
#             "motion_cls":motion_cls,
#             "motion_high_weight":motion_high_att_before_sum,
#             }

def reinitialize_weights(module):
    for submodule in module.modules():
        weight = getattr(submodule, 'weight', None)
        if weight is not None and isinstance(weight, torch.Tensor) and weight.dim() >= 2:
            torch.nn.init.xavier_uniform_(weight)
            print("init")
        elif weight is not None and isinstance(weight, torch.Tensor):
            torch.nn.init.normal_(weight, mean=0.0, std=0.02)
            print("init")
        bias = getattr(submodule, 'bias', None)
        if bias is not None and isinstance(bias, torch.Tensor):
            torch.nn.init.zeros_(bias)
        




class WrapedMotionCNN(nn.Module):
    def __init__(self, args):
        super(WrapedMotionCNN, self).__init__()
        self.args = args
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.args.motion_f,  # This should match the hidden size of the Wav2Vec2 model
            nhead=8,      # Number of attention heads
            dim_feedforward=self.args.hidden_size,  # The feedforward network dimension
            dropout=0.1,   # Dropout rate
            batch_first=True
        )
        args_top = copy.deepcopy(self.args)
        args_top.vae_layer = 3
        args_top.vae_length = self.args.motion_f
        args_top.vae_test_dim = self.args.motion_dim
        self.feature_extractor = VQEncoderV6(args_top) 

     
        args_top = copy.deepcopy(self.args)
        args_top.vae_layer = 6
        args_top.vae_length = self.args.motion_f
        args_top.vae_test_dim = self.args.motion_dim + self.args.motion_f
      
        self.encoder_cnn = VQEncoderV6(args_top) 
        self.pos_encoding = PeriodicPositionalEncoding(d_model=self.args.motion_f, period=20, max_seq_len=64, dropout=0.0)
        self.encoder_trans = nn.TransformerEncoder(encoder_layer, num_layers=1) # Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h').encoder

    def forward(self, 
        inputs,
        attention_mask: Optional[torch.Tensor] = None,
        mask_time_indices: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None
        ):
        low_level = self.feature_extractor(inputs)
        # print(low_level.shape, inputs.shape)
        hidden_states = self.encoder_cnn(torch.cat([low_level.detach(), inputs], dim=-1))
        hidden_states = self.pos_encoding(hidden_states)
        hidden_states = self.encoder_trans(hidden_states)
        return {
            "low_level": low_level,
            "high_level": hidden_states
        }
        

class WrapedWav2Vec(nn.Module):
    def __init__(self):
        super(WrapedWav2Vec, self).__init__()
        self.feature_extractor = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h').feature_extractor
        self.feature_projection = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h').feature_projection
        self.encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h').encoder
        # print(self.encoder)
        self.encoder.layers = self.encoder.layers[:1]
        # print(self.encoder)
        self.proj_down = nn.Linear(768,512)
        # print(bug)
    
    def forward(self, 
        inputs,
        attention_mask: Optional[torch.Tensor] = None,
        mask_time_indices: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None
        ):
        finetune_audio_low = self.feature_extractor(inputs).transpose(1, 2)
        hidden_states, _ = self.feature_projection(finetune_audio_low.detach())
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = encoder_outputs[0]
        hidden_states = self.proj_down(hidden_states)
        # print(hidden_states.shape)
        return {
            "low_level": finetune_audio_low,
            "high_level": hidden_states
        }


class JointEmbedding(nn.Module):
    def __init__(self, args):
        super(JointEmbedding, self).__init__()
        self.args = args.model   
        self.audio_processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')
        self.audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
        self.config_wav2vec = Wav2Vec2Config.from_pretrained('facebook/wav2vec2-base-960h')
        # self.audio_encoder_fintune = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h').feature_extractor
        self.audio_encoder_fintune = WrapedWav2Vec()
        # print(self.audio_encoder_fintune)
        # print(bug)
        
        self.asr = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')
        self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

        self.audio_low_mapping = MLP(512+512, self.args.hidden_size, self.args.audio_f)
        self.audio_high_mapping = MLP(512+512+512, self.args.hidden_size, self.args.audio_f)
        # self.audio_down_proj_1 = nn.Linear(768, 512)
        self.audio_down_proj_2 = nn.Linear(768, 512)
        self.audio_down_proj_3 = nn.Linear(768, 512)
        # self.audio_sa = nn.MultiheadAttention(embed_dim=self.args.audio_f, num_heads=8, batch_first=True)
        self.audio_sa = CustomMultiheadAttention(embed_dim=self.args.audio_f, num_heads=8,)

        self.motion_encoder_fintune = WrapedMotionCNN(self.args)
        self.motion_low_mapping = MLP(self.args.motion_f, self.args.hidden_size, self.args.motion_f)
        self.motion_high_mapping = MLP(self.args.motion_f, self.args.hidden_size, self.args.motion_f)
        # self.motion_sa = nn.MultiheadAttention(embed_dim=self.args.audio_f, num_heads=8, batch_first=True)
        self.motion_sa = CustomMultiheadAttention(embed_dim=self.args.audio_f, num_heads=8,)
        
        self.down_sample = 2 # for downsample 30 fps motion to 15
        self.smplx_model = None
        self.get_motion_reps = None
        self.audio_to_time_aligned_text_features = audio_to_time_aligned_text_features
        self.low_temp = nn.Parameter(torch.tensor(0.07))
        self.low_level_loss_fn = None
        self.high_temp = nn.Parameter(torch.tensor(0.07))
        self.high_level_loss_fn = None

    def _reset_parameters(self):
        nn.init.normal_(self.mask_embeddings, 0, self.args.hidden_size ** -0.5)
    
    def forward(self, in_audio=None, in_motion=None, cached_audio_low=None, cached_audio_high=None, cached_rep15d=None):
        # motion feature
        if cached_rep15d is not None:
            in_motion = cached_rep15d[:,::self.down_sample]
        else:
            in_motion = self.get_motion_reps(in_motion, self.smplx_model)["rep15d"][:,::self.down_sample]
        
        motion_features = self.motion_encoder_fintune(in_motion)
        raw_motion_low = motion_features["low_level"] # self.motion_encoder_low(in_motion)
        raw_motion_high = motion_features["high_level"] # self.motion_encoder_high(torch.cat([raw_motion_low.detach(), in_motion], dim=-1))

        motion_low = self.motion_low_mapping(raw_motion_low)
        motion_high = self.motion_high_mapping(raw_motion_high)
        motion_high_att, motion_high_weight = self.motion_sa(motion_high, motion_high, motion_high)
        bs, n, c = motion_high.shape
        # print("a:", motion_high_weight[:, :, 0, :].unsqueeze(2).shape, "b:", motion_high.transpose(1, 2).view(bs, 8, c//8, n).shape)
        motion_high_att_before_sum = motion_high_weight[:, :, 0, :].unsqueeze(2) * motion_high.transpose(1, 2).view(bs, 8, c//8, n)
        motion_high_att_before_sum = motion_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
        motion_low = F.interpolate(motion_low.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        motion_high_att = F.interpolate(motion_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        motion_high_att_before_sum = F.interpolate(motion_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        motion_cls = motion_high_att[:, 0]

        # audio feature
        if cached_audio_low is not None:
            raw_audio_low = cached_audio_low
            raw_audio_high = torch.cat([self.audio_down_proj_2(cached_audio_high[:, :, :768]), self.audio_down_proj_3(cached_audio_high[:, :, 768:])], dim=-1)
            
            audio_list = [i.cpu().numpy() for i in in_audio]
            inputs = self.audio_processor(audio_list, sampling_rate=16000, return_tensors="pt", padding=True).to(in_audio.device)
            finetune_audio = self.audio_encoder_fintune(inputs.input_values)
            finetune_audio_low, finetune_audio_high = finetune_audio["low_level"], finetune_audio["high_level"]
            diff = raw_audio_low.shape[1] - finetune_audio_low.shape[1]
            if diff > 0:
                finetune_audio_low = torch.cat([finetune_audio_low, finetune_audio_low[:, -diff:]], dim=1)
            diff = raw_audio_high.shape[1] - finetune_audio_high.shape[1]
            if diff > 0:
                finetune_audio_high = torch.cat([finetune_audio_high, finetune_audio_high[:, -diff:]], dim=1)
            raw_audio_low = torch.cat([raw_audio_low, finetune_audio_low], dim=-1) # bs, t, 1024
        else:
            print("error! must have cached audio in training")
        
        # print(raw_audio_low.shape, raw_audio_high.shape, "before")

        raw_audio_low = F.interpolate(raw_audio_low.transpose(1, 2), scale_factor=30/50, mode='linear', align_corners=True).transpose(1, 2) 
        raw_audio_high = F.interpolate(raw_audio_high.transpose(1, 2), scale_factor=15/50, mode='linear', align_corners=True).transpose(1, 2)
        finetune_audio_high = F.interpolate(finetune_audio_high.transpose(1, 2), scale_factor=15/50, mode='linear', align_corners=True).transpose(1, 2)  
        # print(raw_audio_low.shape, raw_audio_high.shape, "after")
        audio_low = self.audio_low_mapping(raw_audio_low)
        raw_audio_high = torch.cat([finetune_audio_high, raw_audio_high], dim=-1)
        # print(finetune_audio_high.shape, raw_audio_high.shape)
        audio_high = self.audio_high_mapping(raw_audio_high)
        audio_high_att, audio_high_weight = self.audio_sa(audio_high, audio_high, audio_high)
        bs, n, c = audio_high.shape
        audio_high_att_before_sum = audio_high_weight[:, :, 0, :].unsqueeze(2) * audio_high.transpose(1, 2).view(bs, 8, c//8, n)
        audio_high_att_before_sum = audio_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
        audio_high_att = F.interpolate(audio_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        audio_high_att_before_sum = F.interpolate(audio_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        audio_cls = audio_high_att[:, 0]
        # low_infonce, low_acc = self.low_level_loss_fn(audio_low, motion_low, learned_temp=self.low_temp)
        
        # fix temp to 0.1 is better than learned temp
        low_infonce, low_acc = self.low_level_loss_fn(audio_low, motion_low)
        high_infonce = self.high_level_loss_fn(audio_cls, motion_cls)
        return {
            "audio_low":audio_low,
            "audio_high":audio_high_att,
            "audio_cls":audio_cls,
            "audio_high_weight":audio_high_att_before_sum,
            "motion_low":motion_low,
            "motion_high":motion_high_att,
            "motion_cls":motion_cls,
            "motion_high_weight":motion_high_att_before_sum,
            "low_level_loss": [low_infonce, low_acc],
            "high_level_loss": high_infonce
            }

    def get_audio_features(self, in_audio):
        audio_list = [i.cpu().numpy() for i in in_audio]
        inputs = self.audio_processor(audio_list, sampling_rate=16000, return_tensors="pt", padding=True).to(in_audio.device)
        raw_audio_low = self.audio_encoder.feature_extractor(inputs.input_values).transpose(1, 2)
        raw_audio_low = raw_audio_low
            
        finetune_audio = self.audio_encoder_fintune(inputs.input_values)
        finetune_audio_low, finetune_audio_high = finetune_audio["low_level"], finetune_audio["high_level"]
        diff = raw_audio_low.shape[1] - finetune_audio_low.shape[1]
        if diff > 0:
            finetune_audio_low = torch.cat([finetune_audio_low, finetune_audio_low[:, -diff:]], dim=1)
        raw_audio_low = torch.cat([raw_audio_low, finetune_audio_low], dim=-1)

        raw_audio_high = self.audio_encoder(inputs.input_values).last_hidden_state
        
        diff = raw_audio_high.shape[1] - finetune_audio_high.shape[1]
        if diff > 0:
            finetune_audio_high = torch.cat([finetune_audio_high, finetune_audio_high[:, -diff:]], dim=1)
        # print(raw_audio_high.shape, finetune_audio_high.shape)

        _, bert_time_aligned_text, _ = audio_to_time_aligned_text_features(inputs, self.audio_processor, self.asr, self.bert_tokenizer, self.bert_model)
        raw_audio_high = torch.cat([raw_audio_high, bert_time_aligned_text], dim=2)
        raw_audio_high = torch.cat([self.audio_down_proj_2(raw_audio_high[:, :, :768]), self.audio_down_proj_3(raw_audio_high[:, :, 768:])], dim=-1)

        raw_audio_low = F.interpolate(raw_audio_low.transpose(1, 2), scale_factor=30/50, mode='linear', align_corners=True).transpose(1, 2) 
        raw_audio_high = F.interpolate(raw_audio_high.transpose(1, 2), scale_factor=15/50, mode='linear', align_corners=True).transpose(1, 2)
        finetune_audio_high = F.interpolate(finetune_audio_high.transpose(1, 2), scale_factor=15/50, mode='linear', align_corners=True).transpose(1, 2) 
        
        if raw_audio_low.shape[1] % 2 == 1:
            raw_audio_low = torch.cat([raw_audio_low, raw_audio_low[:, -1:]], dim=1)
        diff = raw_audio_low[:, ::2].shape[1] - raw_audio_high.shape[1]
        if diff > 0:
            raw_audio_high = torch.cat([raw_audio_high, raw_audio_high[:, -diff:]], dim=1)
            finetune_audio_high = torch.cat([finetune_audio_high, finetune_audio_high[:, -diff:]], dim=1)

        audio_low = self.audio_low_mapping(raw_audio_low)
        # print(audio_low.shape[1]//2, raw_audio_high.shape[1])
        raw_audio_high = torch.cat([finetune_audio_high, raw_audio_high], dim=-1)
        audio_high = self.audio_high_mapping(raw_audio_high)
        audio_high_att, audio_high_weight = self.audio_sa(audio_high, audio_high, audio_high)
        bs, n, c = audio_high.shape
        audio_high_att_before_sum = audio_high_weight[:, :, 0, :].unsqueeze(2) * audio_high.transpose(1, 2).view(bs, 8, c//8, n)
        audio_high_att_before_sum = audio_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
        audio_high_att = F.interpolate(audio_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        audio_high_att_before_sum = F.interpolate(audio_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        audio_cls = audio_high_att[:, 0]
        return {
            "audio_low":audio_low,
            "audio_high":audio_high_att,
            "audio_cls":audio_cls,
            "audio_high_weight":audio_high_att_before_sum,
            }

    def get_motion_features(self, in_motion):
        in_motion = self.get_motion_reps(in_motion, self.smplx_model)["rep15d"][:,::self.down_sample]
        motion_features = self.motion_encoder_fintune(in_motion)
        raw_motion_low = motion_features["low_level"] # self.motion_encoder_low(in_motion)
        raw_motion_high = motion_features["high_level"] # self.motion_encoder_high(torch.cat([raw_motion_low.detach(), in_motion], dim=-1))
        motion_low = self.motion_low_mapping(raw_motion_low)
        motion_high = self.motion_high_mapping(raw_motion_high)
        
        motion_high_att, motion_high_weight = self.motion_sa(motion_high, motion_high, motion_high)
        bs, n, c = motion_high.shape
        motion_high_att_before_sum = motion_high_weight[:, :, 0, :].unsqueeze(2) * motion_high.transpose(1, 2).view(bs, 8, c//8, n)
        motion_high_att_before_sum = motion_high_att_before_sum.reshape(bs, c, n).transpose(1, 2)
        motion_low = F.interpolate(motion_low.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        motion_high_att = F.interpolate(motion_high_att.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        motion_high_att_before_sum = F.interpolate(motion_high_att_before_sum.transpose(1, 2), scale_factor=2, mode='linear', align_corners=True).transpose(1, 2)
        motion_cls = motion_high_att[:, 0]
        return {
            "motion_low":motion_low,
            "motion_high":motion_high_att,
            "motion_cls":motion_cls,
            "motion_high_weight":motion_high_att_before_sum,
            }
  