import math
import re
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from models.distributional_predictor import AttentionDistributionPredictor, RNNDistributionPredictor
from modules import MultiheadAttention, TransformerEncoderLayer, _get_activation_fn, make_conv_pos
from torch import BoolTensor, FloatTensor, Tensor, pca_lowrank
from transformers import Data2VecAudioModel


def init_bert_params(module):
    """
    Initialize the weights specific to the BERT Model.
    This overrides the default initializations depending on the specified arguments.
        1. If normal_init_linear_weights is set then weights of linear
           layer will be initialized using the normal distribution and
           bais will be set to the specified value.
        2. If normal_init_embed_weights is set then weights of embedding
           layer will be initialized using the normal distribution.
        3. If normal_init_proj_weights is set then weights of
           in_project_weight for MultiHeadAttention initialized using
           the normal distribution (to be validated).
    """

    def normal_(data):
        # with FSDP, module params will be on CUDA, so we cast them back to CPU
        # so that the RNG is consistent with and without FSDP
        data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))

    if isinstance(module, nn.Linear):
        normal_(module.weight.data)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        normal_(module.weight.data)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    if isinstance(module, MultiheadAttention):
        normal_(module.q_proj.weight.data)
        normal_(module.k_proj.weight.data)
        normal_(module.v_proj.weight.data)


def init_with_wavlm(
    model: nn.Module,
    num_layers: int = 24,
    ckpt: str = "PATH/TO/WavLM_CHECKPOINT",
    need_mask_emb: bool = True,
    info: str = "",
):
    assert ckpt is not None
    print(f"Initializing WavLM with checkpoint... {ckpt}")
    data = torch.load(ckpt)
    state_dict = data["model"]

    pop_dict = {}
    for key in state_dict.keys():
        if key.startswith("encoder.layers.") and not "relative_attention_bias" in key:
            pop_dict[key] = state_dict[key]

    for key in pop_dict.keys():
        state_dict.pop(key)
    encoder_layers_modules = set([re.search(r"(?<=\d\.).*", key).group(0) for key in pop_dict.keys()])

    for module in encoder_layers_modules:
        for i in range(num_layers):
            state_dict[f"encoder.layers.{i}.{module}"] = pop_dict[f"encoder.layers.{i}.{module}"]

    if not need_mask_emb:
        state_dict.pop("mask_emb")
        model.mask_emb = None

    # we remove the layer_normalization in the output of encoder
    state_dict.pop("encoder.layer_norm.weight")
    state_dict.pop("encoder.layer_norm.bias")

    print(f"wavlm/{info}: Initialize with WavLM.")
    model.load_state_dict(state_dict)

    del state_dict
    del pop_dict


def init_with_ckpt(
    model: nn.Module,
    ckpt: str = "PATH/TO/CHECKPOINT",
    name: str = "wavlm",
    need_mask_emb: bool = True,
    info: str = "",
    device: str = "cuda",
):
    assert ckpt is not None

    if ckpt == "":
        print(f"{name}/{info}: No checkpoint found.")
        return

    if not need_mask_emb and hasattr(model, "mask_emb"):
        model.mask_emb = None
    state_dict = torch.load(ckpt, map_location=device)["model"]

    dit = {}
    for k, v in state_dict.items():
        if k.startswith(name):
            dit[k[len(name) + 1 :]] = v

    if not need_mask_emb and "mask_emb" in dit.keys():
        dit.pop("mask_emb")

    # we remove the layer_normalization in the output of encoder
    dit.pop("encoder.layer_norm.weight", None)
    dit.pop("encoder.layer_norm.bias", None)

    if dit is None:
        print(f"{name}/{info}: No matching keys found in checkpoint: {ckpt}")
    else:
        print(f"{name}/{info}: Initialize with checkpoint: {ckpt}")
        model.load_state_dict(dit)

    del state_dict
    del dit


def apply_mask(x: Tensor, mask: BoolTensor, fill_value: Tensor, clone: bool = False):
    _x = x.clone() if clone else x
    _x[mask] = fill_value
    return _x


@torch.no_grad()
def get_rms(x: Tensor, frame_length: int = 2048, hop_length: int = 512):
    """
    Inputs:
        x: (B, T), ``Tensor``, T dedotes the length of the time series.
    Outputs:
        rms: (B, Tf), ``Tensor``, Tf denotes the number of frames.
    """
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    if x.dim() == 1:
        x = x.unsqueeze(dim=0)

    n_frames = 1 + (x.shape[-1] - frame_length) // hop_length
    strides = torch.tensor(x.stride())

    shape = list(x.shape)[:-1] + [frame_length, n_frames]
    strides = list(strides) + [hop_length]  # * new_stride

    frame = torch.as_strided(x, size=shape, stride=strides)
    rms = torch.sqrt(torch.mean(torch.abs(frame) ** 2, dim=1, keepdim=False))

    return rms


@torch.no_grad()
def space_indices(indices: Tensor, space: int = 1, maximum: int = 1, already_sorted: bool = True):
    if not already_sorted:
        indices, _ = torch.sort(indices, descending=False)
    for i in range(0, len(indices) - 1):
        if indices[i + 1] - indices[i] < space:
            indices[i + 1] = indices[i] + space
        if indices[i + 1] > maximum:
            indices = indices[: i + 1]
            break
    return indices


@torch.no_grad()
def get_random_mask(
    fea: Tensor,
    span: int = 8,
    max_num_span: int = 10,
    span_space: int = 1,
    real_length: Tensor = None,
    max_mask_percentage: float = 0.5,
):
    mask = torch.full(fea.shape[:2], False, dtype=torch.bool, device=fea.device)

    if real_length is not None:
        num_span_per_sample = (real_length * max_mask_percentage / span).tolist()
        num_span_per_sample = [math.floor(s) if s < max_num_span else max_num_span for s in num_span_per_sample]
        valid_length = (real_length - span).tolist()
    else:
        valid_length = [fea.shape[1] - span] * fea.shape[0]
        num_span_per_sample = [max_num_span] * fea.shape[0]

    span_start = []
    for i, (valid) in enumerate(valid_length):
        num_span = num_span_per_sample[i]
        indices = torch.randperm(valid)[:num_span]

        indices = space_indices(indices, space=span + span_space, maximum=valid, already_sorted=False)

        if len(indices) < num_span:
            indices = torch.cat((indices, torch.randperm(valid, device=indices.device)))[:num_span]

        if (not num_span) or (not len(indices)):
            indices = torch.randperm(valid)[0].unsqueeze(dim=0)
            span_start.append(indices)
            mask[i][indices : real_length[i]] = True
        else:
            span_start.append(indices)

            indices = torch.as_tensor([indices[j] + offset for j in range(num_span) for offset in range(span)])

            mask[i][indices] = True

    return mask, span_start


@torch.no_grad()
def get_rms_mask(
    rms: Tensor,
    h_up: float = 1.0,
    h_down: float = 0.5,
    l_up: float = 0.49,
    l_down: float = 0.2,
    span: int = 8,
    max_num_span: int = 10,
    span_space: int = 1,
    real_length: Tensor = None,
    max_mask_percentage: float = 0.5,
):
    mask = torch.full(rms.shape, False, dtype=torch.bool, device=rms.device)

    if real_length is not None:
        num_span_per_sample = (real_length * max_mask_percentage / span).tolist()
        num_span_per_sample = [math.floor(s) if s < max_num_span else max_num_span for s in num_span_per_sample]
        valid_length = (real_length - span).tolist()
    else:
        valid_length = [rms.shape[-1] - span] * rms.shape[0]
        num_span_per_sample = [max_num_span] * rms.shape[0]

    span_start = []
    for i, (row, valid) in enumerate(zip(rms, valid_length)):
        row = row[:valid]
        max_val = torch.max(row)
        h_down = h_down * max_val
        h_up = h_up * max_val
        l_down = l_down * max_val
        l_up = l_up * max_val
        h_mask = torch.logical_and(row >= h_down, row <= h_up)  # samples with high amplitudes
        l_mask = torch.logical_and(row >= l_down, row <= l_up)  # samples with low amplitudes
        h_indices = torch.nonzero(h_mask, as_tuple=False).squeeze(dim=1)
        l_indices = torch.nonzero(l_mask, as_tuple=False).squeeze(dim=1)

        num_span = num_span_per_sample[i]
        h_indices = h_indices[torch.randperm(len(h_indices))][: num_span // 2]  # half of spans are for high amplitudes
        l_indices = l_indices[torch.randperm(len(l_indices))][: num_span - len(h_indices)]  # half for low amplitudes

        h_indices = space_indices(h_indices, space=span + span_space, maximum=valid)
        l_indices = space_indices(l_indices, space=span + span_space, maximum=valid)

        if len(h_indices) + len(l_indices) < num_span:
            indices = torch.cat((h_indices, l_indices, torch.randperm(valid, device=h_indices.device)))[:num_span]
        else:
            indices = torch.cat((h_indices, l_indices))

        if (not num_span) or (not len(indices)):
            indices = torch.randperm(valid)[0].unsqueeze(dim=0)
            span_start.append(indices)
            mask[i][indices : real_length[i]] = True
        else:
            span_start.append(indices)

            indices = torch.as_tensor([indices[j] + offset for j in range(num_span) for offset in range(span)])

            mask[i][indices] = True

    return mask, span_start


@torch.no_grad()
def expand_mask(
    mask: Tensor,
    expanded_span: int = 40,
    span_start: Tensor = None,
    max_num_expanded_span: int = 2,
    span_space: int = 1,
    real_length: Tensor = None,
    max_mask_percentage: float = 0.5,
):
    mask = torch.full_like(mask, False)

    if real_length is not None:
        num_span_per_sample = (real_length * max_mask_percentage / expanded_span).tolist()
        num_span_per_sample = [
            math.floor(s) if s < max_num_expanded_span else max_num_expanded_span for s in num_span_per_sample
        ]
        valid_length = (real_length - expanded_span).tolist()
    else:
        valid_length = [mask.shape[-1] - expanded_span] * mask.shape[0]
        num_span_per_sample = [max_num_expanded_span] * mask.shape[0]

    expanded_span_start = []
    for i, (indices, valid) in enumerate(zip(span_start, valid_length)):
        indices = indices[indices < valid]
        num_expanded_span = num_span_per_sample[i]

        indices = space_indices(
            indices,
            space=expanded_span + span_space,
            maximum=valid,
            already_sorted=False,
        )

        if len(indices) < num_expanded_span:
            indices = torch.cat((indices, torch.randperm(valid, device=indices.device)))[:num_expanded_span]
        else:
            indices = indices[torch.randperm(len(indices))][:num_expanded_span]

        if (not num_expanded_span) or (not len(indices)):
            indices = span_start[i][0].unsqueeze(dim=0)
            expanded_span_start.append(indices)
            mask[i][indices : real_length[i]] = True
        else:
            expanded_span_start.append(indices)

            indices = torch.as_tensor(
                [indices[j] + offset for j in range(num_expanded_span) for offset in range(expanded_span)]
            )

            mask[i][indices] = True

    return mask, expanded_span_start


def normalize(x: Tensor, p: int = 2, dim: int = -1):
    return F.normalize(x, p, dim)


def masked_select(x: Tensor, mask: BoolTensor):
    """
    Inputs:
        x: (B, T, C), ``Tensor``
        mask: (B, T), ```BoolTensor`
    Output:
        x: (-1, C),  `` Tensor``
    """
    return x.masked_select(mask.unsqueeze(dim=-1)).view(-1, x.size(-1))


class ConvFeatureExtractionModel(nn.Module):
    def __init__(
        self,
        conv_layers: list = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2,
        dropout: float = 0.0,
        conv_bias: bool = False,
        mode: str = "default",
    ):
        super().__init__()

        def block(
            n_in,
            n_out,
            k,
            stride,
            conv_bias=False,
            is_layer_norm=False,
            is_group_norm=False,
        ):
            def make_conv():
                conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
                nn.init.kaiming_normal_(conv.weight)
                return conv

            if is_layer_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.Sequential(
                        Rearrange("b c t -> b t c"),
                        nn.LayerNorm(dim, elementwise_affine=True),
                        Rearrange("b c t -> b t c"),
                    ),
                    nn.GELU(),
                )
            elif is_group_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.GroupNorm(dim, dim, affine=True),
                    nn.GELU(),
                )
            else:
                return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())

        in_d = 1
        self.conv_layers = nn.ModuleList()
        for i, cl in enumerate(conv_layers):
            assert len(cl) == 3, "invalid conv definition: " + str(cl)
            (dim, k, stride) = cl

            self.conv_layers.append(
                block(
                    in_d,
                    dim,
                    k,
                    stride,
                    conv_bias=conv_bias,
                    is_layer_norm=mode == "layer_norm",
                    is_group_norm=mode == "default" and i == 0,
                )
            )
            in_d = dim

    def forward(self, x):
        # BxT -> BxCxT
        x = x.unsqueeze(1)
        for conv in self.conv_layers:
            x = conv(x)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.frame_length = args.frame_length
        self.hop_length = args.hop_length
        self.h_up = args.h_up
        self.h_down = args.h_down
        self.l_up = args.l_up
        self.l_down = args.l_down
        self.small_span = args.small_span
        self.num_small_span = args.num_small_span
        self.large_span = args.large_span
        self.num_large_span = args.num_large_span
        self.span_space = args.span_space
        self.max_mask_percentage = args.max_mask_percentage
        self.encoder_layers = args.encoder_layers
        self.dropout = args.dropout
        self.pos_conv = make_conv_pos(args.encoder_embed_dim, args.conv_pos, args.conv_pos_groups)
        self.mask_depend_on_rms = args.mask_depend_on_rms

        self.relative_position_embedding = args.relative_position_embedding
        self.num_buckets = args.num_buckets
        self.max_distance = args.max_distance

        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    embed_dim=args.encoder_embed_dim,
                    ffn_embed_dim=args.ffn_embed_dim,
                    num_heads=args.num_heads,
                    activation=args.activation,
                    dropout=args.dropout,
                    bias=args.bias,
                    normalize_before=True,
                    has_relative_attention_bias=(self.relative_position_embedding and i == 0),
                    num_buckets=self.num_buckets,
                    max_distance=self.max_distance,
                    gru_rel_pos=args.gru_rel_pos,
                    qk_norm=args.qk_norm,
                )
                for i in range(args.encoder_layers)
            ]
        )

        # self.layer_norm = nn.LayerNorm(self.args.encoder_embed_dim)

        self.apply(init_bert_params)

    def forward(
        self,
        x: Tensor,
        padding_mask=None,
        layer=None,
        student_pretraining=False,
        waveform=None,
        mask_emb=None,
    ):
        if student_pretraining:
            if padding_mask is not None:
                real_length = torch.sum(~padding_mask, dim=-1, dtype=torch.int)
            else:
                real_length = torch.full((x.size(0),), fill_value=x.size(1), device=x.device, dtype=torch.int)

            if self.mask_depend_on_rms:
                rms = get_rms(waveform, frame_length=self.frame_length, hop_length=self.hop_length)
                small_span_mask, span_start = get_rms_mask(
                    rms,
                    self.h_up,
                    self.h_down,
                    self.l_up,
                    self.l_down,
                    self.small_span,
                    self.num_small_span,
                    self.span_space,
                    real_length,
                    self.max_mask_percentage,
                )
            else:
                small_span_mask, span_start = get_random_mask(
                    x,
                    self.small_span,
                    self.num_small_span,
                    self.span_space,
                    real_length,
                    self.max_mask_percentage,
                )
            large_span_mask, expanded_span_start = expand_mask(
                small_span_mask,
                self.large_span,
                span_start,
                self.num_large_span,
                self.span_space,
                real_length,
                self.max_mask_percentage,
            )
            interlayer = self.encoder_layers // 2
            x, layer_results = self.extract_features(
                x,
                padding_mask,
                None,
                student_pretraining,
                interlayer,
                small_span_mask,
                large_span_mask,
                mask_emb,
            )
        else:
            x, layer_results = self.extract_features(x, padding_mask, layer)

        # if layer is None:
        #     x = self.layer_norm(x)

        if student_pretraining:
            return (
                x,
                layer_results,
                real_length,
                interlayer,
                small_span_mask,
                large_span_mask,
            )
        else:
            return x, layer_results

    def extract_features(
        self,
        x,
        padding_mask=None,
        tgt_layer=None,
        student_pretraining=False,
        interlayer=0,
        small_span_mask=None,
        large_span_mask=None,
        mask_emb=None,
    ):
        if padding_mask is not None:
            x[padding_mask] = 0

        x_conv = self.pos_conv(x.transpose(1, 2))
        x_conv = x_conv.transpose(1, 2)
        x = x + x_conv
        # print(f"561: {torch.isnan(x)}")

        x = F.dropout(x, p=self.dropout, training=self.training)

        layer_results = []
        attn_weights = None
        layer_results.append((x, attn_weights))
        pos_bias = None

        if student_pretraining:
            x = apply_mask(x, small_span_mask, mask_emb, clone=True)
            for i, layer in enumerate(self.layers):
                if i == interlayer:
                    x = apply_mask(x, large_span_mask, mask_emb, clone=True)
                x, attn_weights, pos_bias = layer(
                    x,
                    key_padding_mask=padding_mask,
                    need_weights=True,
                    pos_bias=pos_bias,
                )
                layer_results.append((x, attn_weights))
        else:
            # print(len(self.layers), "LEN OF LAYERS")
            # print(len(layer_results), "LEN OF RES ON THE START")
            for i, layer in enumerate(self.layers):
                x, attn_weights, pos_bias = layer(
                    x,
                    key_padding_mask=padding_mask,
                    need_weights=True,
                    pos_bias=pos_bias,
                )
                # print(f"592: {torch.isnan(x)}")
                layer_results.append((x, attn_weights))
                if i == tgt_layer:
                    break
            # print(len(layer_results), "LEN OF RESULTS")
        return x, layer_results


class PredictionHead(nn.Module):
    """A simple feed-forward network.

    Inputs:
        x: (B, T, input_dim), ``Tensor``
    Outputs:
        x: (B, T, output_dim), ``Tensor``
    """

    def __init__(self, input_dim: int, output_dim: int, activation: str, norm_input: bool = True):
        super().__init__()
        self.norm_input = norm_input
        self.simple_ffn = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            _get_activation_fn(activation, module=True),
            nn.Linear(input_dim // 2, output_dim),
        )

    def forward(self, x: Tensor):
        if self.norm_input:
            x = F.layer_norm(x, [x.shape[-1]])
        return self.simple_ffn(x)


class WavLM(nn.Module):
    def __init__(self, args):
        super().__init__()
        feature_enc_layers = eval(args.conv_feature_layers)
        conv_embed = feature_enc_layers[-1][0]

        self.feature_extractor = ConvFeatureExtractionModel(feature_enc_layers, mode=args.extractor_mode)
        self.layer_norm = nn.LayerNorm(conv_embed)
        self.post_extract_proj = nn.Linear(conv_embed, args.encoder_embed_dim)
        self.dropout_input = nn.Dropout(args.dropout_input)

        self.encoder = TransformerEncoder(args)

        self.mask_emb = nn.Parameter(FloatTensor(args.encoder_embed_dim).uniform_(), requires_grad=True)
        self.padding_mask = None
        self.elbo = True if args.output_rep == "elbo" else False
        self.normalize = args.normalize
        self.freeze_cnn = args.freeze_cnn

    def forward_padding_mask(self, features: Tensor, padding_mask: Tensor) -> Tensor:
        extra = padding_mask.size(1) % features.size(1)
        # print(f"padding_mask.shape: {padding_mask.shape}")
        # print(f"features.shape: {features.shape}")
        if extra > 0:
            # print(f"padding_mask[:, :-extra]: {padding_mask[:, :-extra]}")
            padding_mask = padding_mask[:, :-extra]
        padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
        self.padding_mask = padding_mask.all(-1)
        # print(f"self.padding_mask.shape: {self.padding_mask.shape}")
        # print(f"self.padding_mask: {self.padding_mask}")

    def get_padding_mask(self):
        return self.padding_mask

    def forward(
        self,
        waveform: Tensor,
        padding_mask: Optional[Tensor] = None,
        output_layer: Optional[int] = None,
        ret_layer_results: bool = False,
        student_pretraining=False,
    ):
        """
        Inputs:
            waveform: (B, T_audio), ``Tensor``
            padding_mask: (B, T_audio), ``BoolTensor``, key padding mask.
            output_layer: ``int``, varies between [1, 24].
            ret_layer_results: ``bool``, default False.
        Outputs:
            features: (B, T, C), ``Tensor``
            layers_rep: [feature_encoder_output, layer_1_output, layer_2_output, ..., layer_n_output], ``list``
        """
        # print(f"waweform is nan: {torch.isnan(waveform)}")
        if self.normalize:
            waveform = F.layer_norm(waveform, [waveform.shape[-1]])

        if self.freeze_cnn:
            with torch.no_grad():
                features = self.feature_extractor(waveform)
        else:
            features = self.feature_extractor(waveform)
        # print(f"cnn out is nan: {torch.isnan(features)}")

        features = features.transpose(1, 2)
        features = self.layer_norm(features)

        features = self.post_extract_proj(features)
        features = self.dropout_input(features)

        if padding_mask is not None:
            self.forward_padding_mask(features, padding_mask)
        else:
            self.padding_mask = None

        if self.elbo:
            feature_extractor_output = features
            feature_extractor_output[self.padding_mask] = 0

        if student_pretraining:
            (
                features,
                layer_results,
                real_length,
                interlayer,
                small_span_mask,
                large_span_mask,
            ) = self.encoder(
                features,
                padding_mask=self.padding_mask,
                layer=None,
                student_pretraining=True,
                waveform=waveform,
                mask_emb=self.mask_emb,
            )
            return (
                features,
                layer_results,
                real_length,
                interlayer,
                small_span_mask,
                large_span_mask,
            )
        else:
            features, layer_results = self.encoder(
                features,
                padding_mask=self.padding_mask,
                layer=None if output_layer is None else output_layer - 1,
            )

            if ret_layer_results:
                if self.elbo:
                    return features, layer_results, feature_extractor_output
                else:
                    return features, layer_results, None
            else:
                if self.elbo:
                    return features, None, feature_extractor_output
                else:
                    return features, None, None


class Data2Vec(nn.Module):
    def __init__(self, args):
        super(Data2Vec, self).__init__()
        self.model = Data2VecAudioModel.from_pretrained(args.path_to_data2vec, cache_dir=args.path_to_huggingface_cache)
        self.model.config.mask_feature_prob = 0.0
        self.model.config.layerdrop = 0.0
        print(self.model.config)
        # assert False

    def forward(self, waveform: Tensor, padding_mask: Optional[Tensor] = None, ret_layer_results: bool = False):
        output = self.model(
            waveform, attention_mask=padding_mask, output_hidden_states=ret_layer_results, mask_time_indices=None
        )
        return output


class SelfAttentionPooling(nn.Module):
    """
    Implementation of SelfAttentionPooling
    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
    https://arxiv.org/pdf/2008.01077v1.pdf
    """

    def __init__(self, input_dim):
        super(SelfAttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)
        self.softmax = nn.functional.softmax

    def forward(self, batch_rep, att_mask=None):
        """
        N: batch size, T: sequence length, H: Hidden dimension
        input:
            batch_rep : size (N, T, H)
        attention_weight:
            att_w : size (N, T, 1)
        return:
            utter_rep: size (N, H)
        """
        att_logits = self.W(batch_rep).squeeze(-1)
        if att_mask is not None:
            att_logits = att_mask + att_logits
        att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)

        return utter_rep


class SelfAttentionPooling(nn.Module):
    """
    Implementation of SelfAttentionPooling
    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
    https://arxiv.org/pdf/2008.01077v1.pdf
    """

    def __init__(self, input_dim):
        super(SelfAttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)
        self.softmax = nn.functional.softmax

    def forward(self, batch_rep, att_mask=None):
        """
        N: batch size, T: sequence length, H: Hidden dimension
        input:
            batch_rep : size (N, T, H)
        attention_weight:
            att_w : size (N, T, 1)
        return:
            utter_rep: size (N, H)
        """
        att_logits = self.W(batch_rep).squeeze(-1)
        # print(f"Attention logits shape: {att_logits.shape}")
        # print(f"Attention mask shape: {att_mask.shape}")
        if att_mask is not None:
            att_logits = att_mask + att_logits
        att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)

        return utter_rep


class Data2VecFinetuneWrapper(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.data2vec = Data2Vec(args)

        if self.args.output_rep in [
            "last_layer",
            "weighted_sum",
            "weighted_hiddens",
        ] or self.args.output_rep.startswith("layer"):
            self.projector = nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
            self.classifier = nn.Linear(self.args.projector_dim, self.args.num_classes)
        elif self.args.output_rep == "elbo":
            self.projector = nn.ModuleList(
                [
                    nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
                    for i in range(self.args.encoder_layers)
                ]
            )
            self.classifier = nn.ModuleList(
                [nn.Linear(self.args.projector_dim, self.args.num_classes) for i in range(self.args.encoder_layers)]
            )
        else:
            raise NotImplementedError(f"output_rep {self.args.output_rep} is not implemented.")

        if self.args.output_rep == "weighted_sum":
            self.elbo = False
            self.weights = nn.Parameter(torch.rand(self.args.encoder_layers))
            self.layer_index = None
            print(f"Using weighted sum of {list(self.weights.shape)} representations as output representation.")
        elif self.args.output_rep == "last_layer":
            self.weights = None
            self.elbo = False
            self.layer_index = None
            print("Using last layer representation as output representation.")
        elif self.args.output_rep in ["elbo", "weighted_hiddens"]:
            self.elbo = self.args.output_rep == "elbo"
            self.layer_index = None
            self.distribution_prediction = self.args.distribution_prediction
            if self.args.distribution_prediction == "single":
                self.layer_distribution = nn.Sequential(
                    nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim // 2),
                    nn.ReLU(),
                    nn.Linear(self.args.encoder_embed_dim // 2, 1),
                )
            elif self.args.distribution_prediction == "multiple":
                self.layer_distribution = nn.ModuleList(
                    [
                        nn.Sequential(
                            nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim // 2),
                            nn.ReLU(),
                            nn.Linear(self.args.encoder_embed_dim // 2, 1),
                        )
                        for i in range(self.args.encoder_layers)
                    ]
                )
            elif self.args.distribution_prediction in ["from_12_transformer", "from_last"]:
                self.layer_distribution = nn.Linear(self.args.encoder_embed_dim, self.args.encoder_layers)
            elif self.args.distribution_prediction == "from_24_layers_mhsa":
                self.layer_distribution = AttentionDistributionPredictor(
                    hid_dim=self.args.encoder_embed_dim,
                    num_heads=self.args.dist_att_n_num_heads,
                    dist_mlp=self.args.dist_mlp,
                )
            print(
                f"Using elbo representation as output representation trained with {self.args.distribution_prediction}"
                f"distribution."
            )
        elif self.args.output_rep.startswith("layer"):
            self.weights = None
            self.elbo = False
            self.layer_index = int(self.args.output_rep.split("_")[-1]) - 1
            print(f"Using {self.layer_index}th layer representation as output representation.")
        else:
            raise Exception(
                f"Expected args.output_rep to be: elbo, "
                f"layer_n, weighted or last_layer, got: {self.args.output_rep}"
            )

    @staticmethod
    def forward_padding_mask(features, padding_mask):
        extra = padding_mask.size(1) % features.size(1)
        if extra > 0:
            padding_mask = padding_mask[:, :-extra]
        padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
        padding_mask = padding_mask.all(-1)
        return padding_mask

    def _weighted_sum(self, layer_results: list):
        stacked_feature = torch.stack(layer_results, dim=0)
        _, *origin_shape = stacked_feature.shape
        stacked_feature = stacked_feature.view(len(layer_results), -1)
        norm_weights = F.softmax(self.weights, dim=-1)
        weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
        weighted_feature = weighted_feature.view(*origin_shape)

        return weighted_feature

    @torch.no_grad()
    def layer_inference(self, waveform: Tensor, padding_mask: Optional[Tensor], i_th_layer: int):
        fea, layer_results, feature_extractor_output = self.wavlm(
            waveform=waveform,
            padding_mask=padding_mask,
            ret_layer_results=True,
            student_pretraining=False,
        )
        layer_results = layer_results[1:]  # cut input to wavlm
        fea = layer_results[i_th_layer][0]
        prediction_logits = self._process_single_hidden(fea, self.projector, self.attention, self.classifier)
        prediction_logits = [prediction_logits]
        layer_distribution_logits = None
        return prediction_logits, layer_distribution_logits

    @staticmethod
    def _get_layer_results_stack_and_mask(layer_results, padding_mask):
        n_layers = len(layer_results)
        batch_size, seq_len, feature_dim = layer_results[0].shape
        stacked_layer_results = torch.empty(
            batch_size,
            seq_len,
            feature_dim,
            n_layers,
            device=layer_results[0].device,
            dtype=layer_results[0].dtype,
        )
        for i in range(n_layers):
            stacked_layer_results[..., i] = layer_results[i]

        zero_mask = padding_mask.unsqueeze(-1).unsqueeze(-1).expand_as(stacked_layer_results)
        stacked_masked_layer_results = stacked_layer_results.masked_fill(zero_mask, 0.0)
        return stacked_masked_layer_results

    def apply_mask_to_merged_output(self, layer_results, padding_mask, real_length_feature_extractor):
        stacked_masked_layer_results = self._get_layer_results_stack_and_mask(layer_results, padding_mask)
        stacked_layer_results = stacked_masked_layer_results.sum(dim=1) / real_length_feature_extractor.unsqueeze(
            -1
        )  # [bs, 1024, 24]
        return stacked_layer_results, len(layer_results)

    def get_layer_distribution(self, layer_results, padding_mask):
        real_length_feature_extractor = torch.sum(~padding_mask, dim=-1, keepdim=True)

        if self.distribution_prediction in ["from_24_layers_rnn", "from_24_layers_mhsa"]:
            layer_results, n_layers = self.apply_mask_to_merged_output(
                layer_results, padding_mask, real_length_feature_extractor
            )  # [bs, 1024, 24]
            layer_results = layer_results.permute(0, 2, 1)  # [bs, 24, 1024]
            layer_distribution_logits = self.layer_distribution(layer_results)  # [bs, 24]
        elif self.distribution_prediction in ["from_12_transformer", "from_last"]:
            zero_mask = torch.zeros_like(layer_results[0][0])
            zero_mask[padding_mask] = 1.0
            layer_result = layer_results[11 if self.distribution_prediction == "from_12_transformer" else -1]
            layer_result = torch.where(zero_mask.bool(), torch.zeros_like(layer_result), layer_result)
            layer_result = layer_result.sum(dim=1) / real_length_feature_extractor  # [bs, 1024]
            layer_distribution_logits = self.layer_distribution(layer_result)
        else:
            print(f"Unknown distribution prediction {self.distribution_prediction}")
            raise AttributeError
        return layer_distribution_logits

    def _elbo(self, layer_results: list, wavlm_format_padding_mask: torch.Tensor):
        layer_distribution = self.get_layer_distribution(layer_results, wavlm_format_padding_mask)
        prediction_logits = []
        for i, hidden in enumerate(layer_results):  # hidden: c (encoder_hid_dim)
            pred = self._process_single_hidden(
                fea=hidden,
                padding_mask=wavlm_format_padding_mask,
                projector=self.projector[i],
                classifier=self.classifier[i],
            )
            prediction_logits.append(pred)
        return prediction_logits, layer_distribution

    def _process_single_hidden(self, fea, padding_mask=None, projector=None, classifier=None):
        if padding_mask is not None:
            real_length = torch.sum(~padding_mask, dim=-1, keepdim=True)
            zero_mask = torch.zeros_like(fea)
            zero_mask[padding_mask] = 1.0
            fea = torch.where(zero_mask.bool(), torch.zeros_like(fea), fea)
        else:
            real_length = torch.full(
                (fea.size(0), 1),
                fill_value=fea.size(1),
                dtype=fea.dtype,
                device=fea.device,
            )
        fea = F.relu(projector(fea))  # torch.Size([bs, 249, h])
        pooled_output = fea.sum(dim=1) / real_length  # torch.Size([bs, h])
        pred = classifier(pooled_output)  # torch.Size([bs, classes])
        return pred

    def _weighted_hidden_representation(self, layer_results, padding_mask):
        layer_distribution_logits = self.get_layer_distribution(layer_results, padding_mask)
        layer_distribution_probas = F.softmax(layer_distribution_logits, dim=1)
        stacked_layer_results = self._get_layer_results_stack_and_mask(layer_results, padding_mask)
        weighted_layer_results = layer_distribution_probas.unsqueeze(1).unsqueeze(1) * stacked_layer_results
        weighted_averaged_layer_results = torch.sum(weighted_layer_results, dim=-1)
        prediction_logits = self._process_single_hidden(
            weighted_averaged_layer_results, padding_mask, self.projector, self.classifier
        )
        return prediction_logits, layer_distribution_logits

    def forward(self, waveform: Tensor, padding_mask: Optional[Tensor] = None):
        padding_mask_data2vec_format = torch.abs(padding_mask - 1)
        # print("data2vec padding mask", padding_mask_data2vec_format)
        if self.args.freeze_backbone:
            with torch.no_grad():
                data2vec_model_output = self.data2vec(
                    waveform, padding_mask=padding_mask_data2vec_format, ret_layer_results=True
                )
        else:
            data2vec_model_output = self.data2vec(
                waveform, padding_mask=padding_mask_data2vec_format, ret_layer_results=True
            )
        fea = data2vec_model_output.last_hidden_state
        # print(f"fea.shape", fea.shape)
        wavlm_format_padding_mask = self.forward_padding_mask(fea, torch.abs(padding_mask_data2vec_format - 1))
        # print("wavlm padding mask", wavlm_format_padding_mask)
        layer_results = data2vec_model_output.hidden_states
        layer_results = layer_results[1:]  # cut input to wavlm
        if self.args.output_rep == "elbo":
            prediction_logits, layer_distribution_logits = self._elbo(layer_results, wavlm_format_padding_mask)
        elif self.args.output_rep == "weighted_sum":
            fea = self._weighted_sum(layer_results)
            prediction_logits = self._process_single_hidden(
                fea, wavlm_format_padding_mask, self.projector, self.classifier
            )
            prediction_logits = [prediction_logits]
            layer_distribution_logits = None
        elif self.args.output_rep.startswith("layer"):
            fea = layer_results[self.layer_index]
            prediction_logits = self._process_single_hidden(
                fea, wavlm_format_padding_mask, self.projector, self.classifier
            )
            prediction_logits = [prediction_logits]
            layer_distribution_logits = None
        elif self.args.output_rep == "weighted_hiddens":
            prediction_logits, layer_distribution_logits = self._weighted_hidden_representation(
                layer_results, wavlm_format_padding_mask
            )
            prediction_logits = [prediction_logits]
        elif self.args.output_rep == "last_layer":
            prediction_logits = self._process_single_hidden(
                fea, wavlm_format_padding_mask, self.projector, self.classifier
            )
            prediction_logits = [prediction_logits]
            layer_distribution_logits = None
        return prediction_logits, layer_distribution_logits


class WavLMFinetuneWrapper(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.wavlm = WavLM(self.args)
        if self.args.init_with_ckpt:
            init_with_ckpt(self.wavlm, self.args.init_with_ckpt, "wavlm", need_mask_emb=False)
        elif self.args.init_with_wavlm:
            init_with_wavlm(
                self.wavlm,
                self.args.encoder_layers,
                self.args.path_to_wavlm,
                need_mask_emb=False,
            )
        else:
            print("No initialization method specified. Initializing with random weights.")
        self.freeze_upstream = self.args.freeze_upstream
        self.args = self.args

        if self.args.deep_model == "CNNSelfAttention":
            if self.args.output_rep in [
                "last_layer",
                "weighted_sum",
                "weighted_hiddens",
            ] or self.args.output_rep.startswith("layer"):
                self.elbo = False
                self.projector = nn.Sequential(
                    nn.AvgPool1d(
                        self.args.deep_model_kernel_size, self.args.deep_model_pooling, self.args.deep_model_padding
                    ),
                    nn.Dropout(p=self.args.deep_model_dropout),
                    nn.Conv1d(
                        self.args.encoder_embed_dim,
                        self.args.projector_dim,
                        self.args.deep_model_kernel_size,
                        padding=self.args.deep_model_padding,
                    ),
                    nn.ReLU(),
                    nn.Dropout(p=self.args.deep_model_dropout),
                    nn.Conv1d(
                        self.args.projector_dim,
                        self.args.projector_dim,
                        self.args.deep_model_kernel_size,
                        padding=self.args.deep_model_padding,
                    ),
                    nn.ReLU(),
                    nn.Dropout(p=self.args.deep_model_dropout),
                    nn.Conv1d(
                        self.args.projector_dim,
                        self.args.projector_dim,
                        self.args.deep_model_kernel_size,
                        padding=self.args.deep_model_padding,
                    ),
                )
                self.attention = SelfAttentionPooling(self.args.projector_dim)
                self.classifier = nn.Sequential(
                    nn.Linear(self.args.projector_dim, self.args.projector_dim),
                    nn.ReLU(),
                    nn.Linear(self.args.projector_dim, self.args.num_classes),
                )
            elif self.args.output_rep == "elbo":
                self.projector = nn.ModuleList(
                    [
                        nn.Sequential(
                            nn.AvgPool1d(
                                self.args.deep_model_kernel_size,
                                self.args.deep_model_pooling,
                                self.args.deep_model_padding,
                            ),
                            nn.Dropout(p=self.args.deep_model_dropout),
                            nn.Conv1d(
                                self.args.encoder_embed_dim,
                                self.args.projector_dim,
                                self.args.deep_model_kernel_size,
                                padding=self.args.deep_model_padding,
                            ),
                            nn.ReLU(),
                            nn.Dropout(p=self.args.deep_model_dropout),
                            nn.Conv1d(
                                self.args.projector_dim,
                                self.args.projector_dim,
                                self.args.deep_model_kernel_size,
                                padding=self.args.deep_model_padding,
                            ),
                            nn.ReLU(),
                            nn.Dropout(p=self.args.deep_model_dropout),
                            nn.Conv1d(
                                self.args.projector_dim,
                                self.args.projector_dim,
                                self.args.deep_model_kernel_size,
                                padding=self.args.deep_model_padding,
                            ),
                        )
                        for _ in range(self.args.encoder_layers)
                    ]
                )
                self.attention = nn.ModuleList(
                    [SelfAttentionPooling(self.args.projector_dim) for _ in range(self.args.encoder_layers)]
                )
                self.classifier = nn.ModuleList(
                    [
                        nn.Sequential(
                            nn.Linear(self.args.projector_dim, self.args.projector_dim),
                            nn.ReLU(),
                            nn.Linear(self.args.projector_dim, self.args.num_classes),
                        )
                        for _ in range(self.args.encoder_layers)
                    ]
                )

        else:
            if self.args.output_rep in [
                "last_layer",
                "weighted_sum",
                "weighted_hiddens",
            ] or self.args.output_rep.startswith("layer"):
                self.projector = nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
                self.classifier = nn.Linear(self.args.projector_dim, self.args.num_classes)
                self.attention = None
            elif self.args.output_rep == "elbo":
                if self.args.elbo_share_downstream_weights:
                    self.projector = nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
                    self.classifier = nn.Linear(self.args.projector_dim, self.args.num_classes)
                    self.attention = None
                else:
                    self.projector = nn.ModuleList(
                        [
                            nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
                            for i in range(self.args.encoder_layers)
                        ]
                    )
                    self.classifier = nn.ModuleList(
                        [
                            nn.Linear(self.args.projector_dim, self.args.num_classes)
                            for i in range(self.args.encoder_layers)
                        ]
                    )
                    self.attention = [None] * self.args.encoder_layers
            else:
                raise NotImplementedError(f"output_rep {self.args.output_rep} is not implemented.")

        if self.args.output_rep == "weighted_sum":
            self.elbo = False
            self.weights = nn.Parameter(torch.zeros(self.args.encoder_layers))
            self.layer_index = None
            print(f"Using weighted sum of {list(self.weights.shape)} representations as output representation.")
        elif self.args.output_rep == "last_layer":
            self.weights = None
            self.elbo = False
            self.layer_index = None
            print("Using last layer representation as output representation.")
        elif self.args.output_rep in ["elbo", "weighted_hiddens"]:
            self.elbo = self.args.output_rep == "elbo"
            self.layer_index = None
            self.distribution_prediction = self.args.distribution_prediction
            if self.args.distribution_prediction == "single":
                self.layer_distribution = nn.Sequential(
                    nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim // 2),
                    nn.ReLU(),
                    nn.Linear(self.args.encoder_embed_dim // 2, 1),
                )
            elif self.args.distribution_prediction == "multiple":
                self.layer_distribution = nn.ModuleList(
                    [
                        nn.Sequential(
                            nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim // 2),
                            nn.ReLU(),
                            nn.Linear(self.args.encoder_embed_dim // 2, 1),
                        )
                        for i in range(self.args.encoder_layers)
                    ]
                )
            elif self.args.distribution_prediction in ["from_12_transformer", "from_last"]:
                # self.layer_distribution = nn.Sequential(
                #     nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim // 2),
                #     nn.ReLU(),
                #     nn.Linear(self.args.encoder_embed_dim // 2, self.args.encoder_layers),
                # )
                self.layer_distribution = nn.Linear(self.args.encoder_embed_dim, self.args.encoder_layers)
            elif self.args.distribution_prediction == "from_24_layers_mhsa":
                self.layer_distribution = AttentionDistributionPredictor(
                    hid_dim=self.args.encoder_embed_dim,
                    num_heads=self.args.dist_att_n_num_heads,
                    dist_mlp=self.args.dist_mlp,
                )
            else:
                self.distribution_prediction = "from_cnn"
                self.layer_distribution = nn.Sequential(
                    nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim // 2),
                    nn.ReLU(),
                    nn.Linear(self.args.encoder_embed_dim // 2, self.args.encoder_layers),
                )

            print(
                f"Using elbo representation as output representation trained with {self.args.distribution_prediction}"
                f"distribution."
            )
        elif self.args.output_rep.startswith("layer"):
            self.weights = None
            self.elbo = False
            self.layer_index = int(self.args.output_rep.split("_")[-1]) - 1
            print(f"Using {self.layer_index}th layer representation as output representation.")
        else:
            raise Exception(
                f"Expected args.output_rep to be: elbo, "
                f"layer_n, weighted or last_layer, got: {self.args.output_rep}"
            )

    def _weighted_sum(self, layer_results: list):
        stacked_feature = torch.stack(layer_results, dim=0)
        _, *origin_shape = stacked_feature.shape
        stacked_feature = stacked_feature.view(len(layer_results), -1)
        norm_weights = F.softmax(self.weights, dim=-1)
        weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
        weighted_feature = weighted_feature.view(*origin_shape)

        return weighted_feature

    @torch.no_grad()
    def layer_inference(self, waveform: Tensor, padding_mask: Optional[Tensor], i_th_layer: int):
        fea, layer_results, feature_extractor_output = self.wavlm(
            waveform=waveform,
            padding_mask=padding_mask,
            ret_layer_results=True,
            student_pretraining=False,
        )
        layer_results = layer_results[1:]  # cut input to wavlm
        fea = layer_results[i_th_layer][0]
        prediction_logits = self._process_single_hidden(fea, self.projector, self.attention, self.classifier)
        prediction_logits = [prediction_logits]
        layer_distribution_logits = None
        return prediction_logits, layer_distribution_logits

    @staticmethod
    def _get_layer_results_stack_and_mask(layer_results, padding_mask):
        n_layers = len(layer_results)
        batch_size, seq_len, feature_dim = layer_results[0][0].shape
        stacked_layer_results = torch.empty(
            batch_size,
            seq_len,
            feature_dim,
            n_layers,
            device=layer_results[0][0].device,
            dtype=layer_results[0][0].dtype,
        )
        for i in range(n_layers):
            stacked_layer_results[..., i] = layer_results[i][0]

        zero_mask = padding_mask.unsqueeze(-1).unsqueeze(-1).expand_as(stacked_layer_results)
        stacked_masked_layer_results = stacked_layer_results.masked_fill(zero_mask, 0.0)
        return stacked_masked_layer_results

    def apply_mask_to_merged_output(self, layer_results, padding_mask, real_length_feature_extractor):
        stacked_masked_layer_results = self._get_layer_results_stack_and_mask(layer_results, padding_mask)
        stacked_layer_results = stacked_masked_layer_results.sum(dim=1) / real_length_feature_extractor.unsqueeze(
            -1
        )  # [bs, 1024, 24]
        return stacked_layer_results, len(layer_results)

    def get_layer_distribution(self, layer_results, feature_extractor_output):
        padding_mask = self.wavlm.get_padding_mask()
        real_length_feature_extractor = torch.sum(~padding_mask, dim=-1, keepdim=True)

        if self.distribution_prediction in ["from_24_layers_rnn", "from_24_layers_mhsa"]:
            layer_results, n_layers = self.apply_mask_to_merged_output(
                layer_results, padding_mask, real_length_feature_extractor
            )  # [bs, 1024, 24]
            layer_results = layer_results.permute(0, 2, 1)  # [bs, 24, 1024]
            layer_distribution_logits = self.layer_distribution(layer_results)  # [bs, 24]
        elif self.distribution_prediction in ["from_12_transformer", "from_last"]:
            zero_mask = torch.zeros_like(layer_results[0][0])
            zero_mask[padding_mask] = 1.0
            layer_results = [layer_results[i][0] for i in range(len(layer_results))]
            layer_result = layer_results[11 if self.distribution_prediction == "from_12_transformer" else -1]
            layer_result = torch.where(zero_mask.bool(), torch.zeros_like(layer_result), layer_result)
            layer_result = layer_result.sum(dim=1) / real_length_feature_extractor  # [bs, 1024]
            layer_distribution_logits = self.layer_distribution(layer_result)
        else:
            print(f"Unknown distribution prediction {self.distribution_prediction}")
            raise AttributeError
        return layer_distribution_logits

    def _elbo(self, layer_results: list, feature_extractor_output):
        layer_distribution = self.get_layer_distribution(layer_results, feature_extractor_output)
        layer_results = [layer_results[i][0] for i in range(len(layer_results))]
        prediction_logits = []
        for i, hidden in enumerate(layer_results):  # hidden: c (encoder_hid_dim)
            projector = self.projector if self.args.elbo_share_downstream_weights else self.projector[i]
            attention = self.attention if self.args.elbo_share_downstream_weights else self.attention[i]
            classifier = self.classifier if self.args.elbo_share_downstream_weights else self.classifier[i]
            pred = self._process_single_hidden(
                fea=hidden, projector=projector, attention=attention, classifier=classifier
            )
            prediction_logits.append(pred)
        return prediction_logits, layer_distribution

    def _process_single_hidden(self, fea, projector=None, attention=None, classifier=None):
        padding_mask = self.wavlm.get_padding_mask()
        if padding_mask is not None:
            real_length = torch.sum(~padding_mask, dim=-1, keepdim=True)
            zero_mask = torch.zeros_like(fea)
            zero_mask[padding_mask] = 1.0

            fea = torch.where(zero_mask.bool(), torch.zeros_like(fea), fea)
        else:
            real_length = torch.full(
                (fea.size(0), 1),
                fill_value=fea.size(1),
                dtype=fea.dtype,
                device=fea.device,
            )

        if self.args.deep_model == "CNNSelfAttention":
            fea = projector(fea.permute(0, 2, 1)).permute(0, 2, 1)
        else:
            fea = F.relu(projector(fea))  # torch.Size([bs, 249, h])
        # print(f"fea shape: {fea.shape}") # t: 249 –> 50
        # print(f"Real lengths: {real_length}")
        # ATTENTION POOLING OR MEAN
        if attention is None:
            pooled_output = fea.sum(dim=1) / real_length  # torch.Size([bs, h])
        else:
            attention_mask = [
                torch.ones(math.ceil((l / self.args.deep_model_pooling)), device=fea.device) for l in real_length
            ]
            attention_mask[0] = nn.ConstantPad1d((0, fea.shape[1] - attention_mask[0].shape[0]), 0)(attention_mask[0])
            attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True)
            attention_mask = (1.0 - attention_mask) * -100000.0  # 32, 50
            pooled_output = attention(fea, att_mask=attention_mask)

        pred = classifier(pooled_output)  # torch.Size([bs, classes])
        return pred

    def _weighted_hidden_representation(self, layer_results, feature_extractor_output):
        layer_distribution_logits = self.get_layer_distribution(layer_results, feature_extractor_output)
        layer_distribution_probas = F.softmax(layer_distribution_logits, dim=1)
        padding_mask = self.wavlm.get_padding_mask()
        stacked_layer_results = self._get_layer_results_stack_and_mask(layer_results, padding_mask)
        weighted_layer_results = layer_distribution_probas.unsqueeze(1).unsqueeze(1) * stacked_layer_results
        weighted_averaged_layer_results = torch.sum(weighted_layer_results, dim=-1)
        prediction_logits = self._process_single_hidden(
            weighted_averaged_layer_results, self.projector, self.attention, self.classifier
        )
        return prediction_logits, layer_distribution_logits

    def forward(self, waveform: Tensor, padding_mask: Optional[Tensor] = None):
        # print("waveform", waveform)
        # print("padding_mask", padding_mask)
        if self.freeze_upstream:
            with torch.no_grad():
                fea, layer_results, feature_extractor_output = self.wavlm(
                    waveform=waveform,
                    padding_mask=padding_mask,
                    ret_layer_results=True,
                    student_pretraining=False,
                )
        else:
            fea, layer_results, feature_extractor_output = self.wavlm(
                waveform=waveform,
                padding_mask=padding_mask,
                ret_layer_results=True,
                student_pretraining=False,
            )
        # assert False
        layer_results = layer_results[1:]  # cut input to wavlm
        # print(layer_results)
        if self.args.output_rep == "elbo":
            prediction_logits, layer_distribution_logits = self._elbo(layer_results, feature_extractor_output)
        elif self.args.output_rep == "weighted_sum":
            layer_results = [layer_results[i][0] for i in range(len(layer_results))]
            fea = self._weighted_sum(layer_results)
            prediction_logits = self._process_single_hidden(fea, self.projector, self.attention, self.classifier)
            prediction_logits = [prediction_logits]
            layer_distribution_logits = None
        elif self.args.output_rep.startswith("layer"):
            fea = layer_results[self.layer_index][0]
            prediction_logits = self._process_single_hidden(fea, self.projector, self.attention, self.classifier)
            prediction_logits = [prediction_logits]
            layer_distribution_logits = None
        elif self.args.output_rep == "weighted_hiddens":
            prediction_logits, layer_distribution_logits = self._weighted_hidden_representation(
                layer_results, feature_extractor_output
            )
            prediction_logits = [prediction_logits]
        else:
            prediction_logits = self._process_single_hidden(fea, self.projector, self.attention, self.classifier)
            prediction_logits = [prediction_logits]
            layer_distribution_logits = None
        return prediction_logits, layer_distribution_logits
