import math
import re
import time
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from models.distributional_predictor import AttentionDistributionPredictor, RNNDistributionPredictor
from modules import MultiheadAttention, TransformerEncoderLayer, _get_activation_fn, make_conv_pos
from torch import BoolTensor, FloatTensor, Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torchaudio.models.decoder import ctc_decoder


def init_bert_params(module):
    """
    Initialize the weights specific to the BERT Model.
    This overrides the default initializations depending on the specified arguments.
        1. If normal_init_linear_weights is set then weights of linear
           layer will be initialized using the normal distribution and
           bais will be set to the specified value.
        2. If normal_init_embed_weights is set then weights of embedding
           layer will be initialized using the normal distribution.
        3. If normal_init_proj_weights is set then weights of
           in_project_weight for MultiHeadAttention initialized using
           the normal distribution (to be validated).
    """

    def normal_(data):
        # with FSDP, module params will be on CUDA, so we cast them back to CPU
        # so that the RNG is consistent with and without FSDP
        data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))

    if isinstance(module, nn.Linear):
        normal_(module.weight.data)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, nn.Embedding):
        normal_(module.weight.data)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    if isinstance(module, MultiheadAttention):
        normal_(module.q_proj.weight.data)
        normal_(module.k_proj.weight.data)
        normal_(module.v_proj.weight.data)


def init_with_wavlm(
    model: nn.Module,
    num_layers: int = 24,
    ckpt: str = "PATH/TO/WavLM_CHECKPOINT",
    need_mask_emb: bool = True,
    info: str = "",
):
    assert ckpt is not None
    print(f"Initializing WavLM with checkpoint... {ckpt}")
    data = torch.load(ckpt)
    state_dict = data["model"]

    pop_dict = {}
    for key in state_dict.keys():
        if key.startswith("encoder.layers.") and not "relative_attention_bias" in key:
            pop_dict[key] = state_dict[key]

    for key in pop_dict.keys():
        state_dict.pop(key)
    encoder_layers_modules = set([re.search(r"(?<=\d\.).*", key).group(0) for key in pop_dict.keys()])

    for module in encoder_layers_modules:
        for i in range(num_layers):
            state_dict[f"encoder.layers.{i}.{module}"] = pop_dict[f"encoder.layers.{i}.{module}"]

    if not need_mask_emb:
        state_dict.pop("mask_emb")
        model.mask_emb = None

    # we remove the layer_normalization in the output of encoder
    state_dict.pop("encoder.layer_norm.weight")
    state_dict.pop("encoder.layer_norm.bias")

    print(f"WavLM/{info}: Initialized with WavLM pretrained weights.")
    model.load_state_dict(state_dict)

    del state_dict
    del pop_dict


def init_with_ckpt(
    model: nn.Module,
    ckpt: str = "PATH/TO/CHECKPOINT",
    name: str = "WavLM",
    need_mask_emb: bool = True,
    info: str = "",
    device: str = "cuda",
):
    assert ckpt is not None

    if ckpt == "":
        print(f"{name}/{info}: No checkpoint found.")
        return

    if not need_mask_emb and hasattr(model, "mask_emb"):
        model.mask_emb = None
    state_dict = torch.load(ckpt, map_location=device)["model"]

    dit = {}
    for k, v in state_dict.items():
        if k.startswith(name):
            dit[k[len(name) + 1 :]] = v

    if not need_mask_emb and "mask_emb" in dit.keys():
        dit.pop("mask_emb")

    # we remove the layer_normalization in the output of encoder
    dit.pop("encoder.layer_norm.weight", None)
    dit.pop("encoder.layer_norm.bias", None)

    if dit is None:
        print(f"{name}/{info}: No matching keys found in checkpoint: {ckpt}")
    else:
        print(f"{name}/{info}: Initialize with checkpoint: {ckpt}")
        model.load_state_dict(dit)

    del state_dict
    del dit


def apply_mask(x: Tensor, mask: BoolTensor, fill_value: Tensor, clone: bool = False):
    _x = x.clone() if clone else x
    _x[mask] = fill_value
    return _x


@torch.no_grad()
def space_indices(indices: Tensor, space: int = 1, maximum: int = 1, already_sorted: bool = True):
    if not already_sorted:
        indices, _ = torch.sort(indices, descending=False)
    for i in range(0, len(indices) - 1):
        if indices[i + 1] - indices[i] < space:
            indices[i + 1] = indices[i] + space
        if indices[i + 1] > maximum:
            indices = indices[: i + 1]
            break
    return indices


@torch.no_grad()
def expand_mask(
    mask: Tensor,
    expanded_span: int = 40,
    span_start: Tensor = None,
    max_num_expanded_span: int = 2,
    span_space: int = 1,
    real_length: Tensor = None,
    max_mask_percentage: float = 0.5,
):
    mask = torch.full_like(mask, False)

    if real_length is not None:
        num_span_per_sample = (real_length * max_mask_percentage / expanded_span).tolist()
        num_span_per_sample = [
            math.floor(s) if s < max_num_expanded_span else max_num_expanded_span for s in num_span_per_sample
        ]
        valid_length = (real_length - expanded_span).tolist()
    else:
        valid_length = [mask.shape[-1] - expanded_span] * mask.shape[0]
        num_span_per_sample = [max_num_expanded_span] * mask.shape[0]

    expanded_span_start = []
    for i, (indices, valid) in enumerate(zip(span_start, valid_length)):
        indices = indices[indices < valid]
        num_expanded_span = num_span_per_sample[i]

        indices = space_indices(
            indices,
            space=expanded_span + span_space,
            maximum=valid,
            already_sorted=False,
        )

        if len(indices) < num_expanded_span:
            indices = torch.cat((indices, torch.randperm(valid, device=indices.device)))[:num_expanded_span]
        else:
            indices = indices[torch.randperm(len(indices))][:num_expanded_span]

        if (not num_expanded_span) or (not len(indices)):
            indices = span_start[i][0].unsqueeze(dim=0)
            expanded_span_start.append(indices)
            mask[i][indices : real_length[i]] = True
        else:
            expanded_span_start.append(indices)

            indices = torch.as_tensor(
                [indices[j] + offset for j in range(num_expanded_span) for offset in range(expanded_span)]
            )

            mask[i][indices] = True

    return mask, expanded_span_start


def normalize(x: Tensor, p: int = 2, dim: int = -1):
    return F.normalize(x, p, dim)


def masked_select(x: Tensor, mask: BoolTensor):
    """
    Inputs:
        x: (B, T, C), ``Tensor``
        mask: (B, T), ```BoolTensor`
    Output:
        x: (-1, C),  `` Tensor``
    """
    return x.masked_select(mask.unsqueeze(dim=-1)).view(-1, x.size(-1))


class ConvFeatureExtractionModel(nn.Module):
    def __init__(
        self,
        conv_layers: list = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2,
        dropout: float = 0.0,
        conv_bias: bool = False,
        mode: str = "default",
    ):
        super().__init__()

        def block(
            n_in,
            n_out,
            k,
            stride,
            conv_bias=False,
            is_layer_norm=False,
            is_group_norm=False,
        ):
            def make_conv():
                conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
                nn.init.kaiming_normal_(conv.weight)
                return conv

            if is_layer_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.Sequential(
                        Rearrange("b c t -> b t c"),
                        nn.LayerNorm(dim, elementwise_affine=True),
                        Rearrange("b c t -> b t c"),
                    ),
                    nn.GELU(),
                )
            elif is_group_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.GroupNorm(dim, dim, affine=True),
                    nn.GELU(),
                )
            else:
                return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())

        in_d = 1
        self.conv_layers = nn.ModuleList()
        for i, cl in enumerate(conv_layers):
            assert len(cl) == 3, "invalid conv definition: " + str(cl)
            (dim, k, stride) = cl

            self.conv_layers.append(
                block(
                    in_d,
                    dim,
                    k,
                    stride,
                    conv_bias=conv_bias,
                    is_layer_norm=mode == "layer_norm",
                    is_group_norm=mode == "default" and i == 0,
                )
            )
            in_d = dim

    def forward(self, x):
        # BxT -> BxCxT
        x = x.unsqueeze(1)
        for conv in self.conv_layers:
            x = conv(x)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.dropout = args.dropout
        self.pos_conv = make_conv_pos(args.encoder_embed_dim, args.conv_pos, args.conv_pos_groups)

        self.relative_position_embedding = args.relative_position_embedding
        self.num_buckets = args.num_buckets
        self.max_distance = args.max_distance

        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    embed_dim=args.encoder_embed_dim,
                    ffn_embed_dim=args.ffn_embed_dim,
                    num_heads=args.num_heads,
                    activation=args.activation,
                    dropout=args.dropout,
                    bias=args.bias,
                    normalize_before=args.normalize,
                    has_relative_attention_bias=(self.relative_position_embedding and i == 0),
                    num_buckets=self.num_buckets,
                    max_distance=self.max_distance,
                    gru_rel_pos=args.gru_rel_pos,
                    qk_norm=args.qk_norm,
                )
                for i in range(args.encoder_layers)
            ]
        )

        self.apply(init_bert_params)

    def forward(self, x: Tensor, padding_mask=None, layer=None):
        x, layer_results = self.extract_features(x, padding_mask, layer)
        return x, layer_results

    def extract_features(self, x, padding_mask=None, tgt_layer=None):
        if padding_mask is not None:
            x[padding_mask] = 0

        x_conv = self.pos_conv(x.transpose(1, 2))
        x_conv = x_conv.transpose(1, 2)
        x = x + x_conv

        x = F.dropout(x, p=self.dropout, training=self.training)

        layer_results = []
        attn_weights = None
        layer_results.append((x, attn_weights))
        pos_bias = None

        for i, layer in enumerate(self.layers):
            x, attn_weights, pos_bias = layer(
                x,
                key_padding_mask=padding_mask,
                need_weights=True,
                pos_bias=pos_bias,
            )
            layer_results.append((x, attn_weights))
            if i == tgt_layer:
                break
        return x, layer_results


class PredictionHead(nn.Module):
    """A simple feed-forward network.

    Inputs:
        x: (B, T, input_dim), ``Tensor``
    Outputs:
        x: (B, T, output_dim), ``Tensor``
    """

    def __init__(self, input_dim: int, output_dim: int, activation: str, norm_input: bool = True):
        super().__init__()
        self.norm_input = norm_input
        self.simple_ffn = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            _get_activation_fn(activation, module=True),
            nn.Linear(input_dim // 2, output_dim),
        )

    def forward(self, x: Tensor):
        if self.norm_input:
            x = F.layer_norm(x, [x.shape[-1]])
        return self.simple_ffn(x)


class WavLM(nn.Module):
    def __init__(self, args):
        super().__init__()
        feature_enc_layers = eval(args.conv_feature_layers)
        conv_embed = feature_enc_layers[-1][0]

        self.feature_extractor = ConvFeatureExtractionModel(feature_enc_layers, mode=args.extractor_mode)
        self.layer_norm = nn.LayerNorm(conv_embed)
        self.post_extract_proj = nn.Linear(conv_embed, args.encoder_embed_dim)
        self.dropout_input = nn.Dropout(args.dropout_input)

        self.encoder = TransformerEncoder(args)

        self.mask_emb = nn.Parameter(FloatTensor(args.encoder_embed_dim).uniform_(), requires_grad=True)
        self.padding_mask = None
        self.elbo = True if args.output_rep == "elbo" else False
        self.normalize = args.normalize
        self.freeze_cnn = args.freeze_cnn

    def forward_padding_mask(self, features: Tensor, padding_mask: Tensor) -> Tensor:
        extra = padding_mask.size(1) % features.size(1)
        if extra > 0:
            padding_mask = padding_mask[:, :-extra]
        padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
        self.padding_mask = padding_mask.all(-1)

    def get_padding_mask(self):
        return self.padding_mask

    def forward(
        self,
        waveform: Tensor,
        padding_mask: Optional[Tensor] = None,
        output_layer: Optional[int] = None,
        ret_layer_results: bool = False,
    ):
        """
        Inputs:
            waveform: (B, T_audio), ``Tensor``
            padding_mask: (B, T_audio), ``BoolTensor``, key padding mask.
            output_layer: ``int``, varies between [1, 24].
            ret_layer_results: ``bool``, default False.
        Outputs:
            features: (B, T, C), ``Tensor``
            layers_rep: [feature_encoder_output, layer_1_output, layer_2_output, ..., layer_n_output], ``list``
        """
        if self.normalize:
            waveform = F.layer_norm(waveform, [waveform.shape[-1]])

        if self.freeze_cnn:
            with torch.no_grad():
                features = self.feature_extractor(waveform)
        else:
            features = self.feature_extractor(waveform)

        features = features.transpose(1, 2)
        features = self.layer_norm(features)

        features = self.post_extract_proj(features)
        features = self.dropout_input(features)

        if padding_mask is not None:
            self.forward_padding_mask(features, padding_mask)
        else:
            self.padding_mask = None

        if self.elbo:
            feature_extractor_output = features
            feature_extractor_output[self.padding_mask] = 0

        features, layer_results = self.encoder(
            features,
            padding_mask=self.padding_mask,
            layer=None if output_layer is None else output_layer - 1,
        )

        if ret_layer_results:
            if self.elbo:
                return features, layer_results, feature_extractor_output
            else:
                return features, layer_results, None
        else:
            if self.elbo:
                return features, None, feature_extractor_output
            else:
                return features, None, None


class SelfAttentionPooling(nn.Module):
    """
    Implementation of SelfAttentionPooling
    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
    https://arxiv.org/pdf/2008.01077v1.pdf
    """

    def __init__(self, input_dim):
        super(SelfAttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)
        self.softmax = nn.functional.softmax

    def forward(self, batch_rep, att_mask=None):
        """
        N: batch size, T: sequence length, H: Hidden dimension
        input:
            batch_rep : size (N, T, H)
        attention_weight:
            att_w : size (N, T, 1)
        return:
            utter_rep: size (N, H)
        """
        att_logits = self.W(batch_rep).squeeze(-1)
        if att_mask is not None:
            att_logits = att_mask + att_logits
        att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)

        return utter_rep


def token_to_word(text):
    # Hard coding but it is only used here for now.
    # Assumption that units are characters. Doesn't handle BPE.
    # Inter-character separator is " " and inter-word separator is "|".
    # list of letters with "|" as separators of words
    text = "".join(text)
    return text.replace(" ", "").replace("|", " ").strip()


def downsample(x, x_len, sample_rate, sample_style):
    batch_size, timestep, feature_dim = x.shape
    x_len = x_len // sample_rate

    if sample_style == "drop":
        # Drop the unselected timesteps
        x = x[:, ::sample_rate, :].contiguous()
    elif sample_style == "concat":
        # Drop the redundant frames and concat the rest according to sample rate
        if timestep % sample_rate != 0:
            x = x[:, : -(timestep % sample_rate), :]
        x = x.contiguous().view(batch_size, int(timestep / sample_rate), feature_dim * sample_rate)
    else:
        raise NotImplementedError

    return x, x_len


class RNNLayer(nn.Module):
    """RNN wrapper, includes time-downsampling"""

    def __init__(
        self,
        input_dim,
        module,
        bidirection,
        dim,
        dropout,
        layer_norm,
        sample_rate,
        proj,
    ):
        super(RNNLayer, self).__init__()
        # Setup
        rnn_out_dim = 2 * dim if bidirection else dim
        self.out_dim = rnn_out_dim
        self.dropout = dropout
        self.layer_norm = layer_norm
        self.sample_rate = sample_rate
        self.proj = proj

        # Recurrent layer
        self.layer = getattr(nn, module.upper())(
            input_dim, dim, bidirectional=bidirection, num_layers=1, batch_first=True
        )

        # Regularizations
        if self.layer_norm:
            self.ln = nn.LayerNorm(rnn_out_dim)
        if self.dropout > 0:
            self.dp = nn.Dropout(p=dropout)

        # Additional projection layer
        if self.proj:
            self.pj = nn.Linear(rnn_out_dim, rnn_out_dim)

    def forward(self, input_x, x_len):
        # Forward RNN
        if not self.training:
            self.layer.flatten_parameters()
        # print(f"X len: {x_len}")
        input_x = pack_padded_sequence(input_x, x_len, batch_first=True, enforce_sorted=False)
        output, _ = self.layer(input_x)
        output, x_len = pad_packed_sequence(output, batch_first=True)

        # Normalizations
        if self.layer_norm:
            output = self.ln(output)
        if self.dropout > 0:
            output = self.dp(output)

        # Perform Downsampling
        if self.sample_rate > 1:
            output, x_len = downsample(output, x_len, self.sample_rate, "drop")

        if self.proj:
            output = torch.tanh(self.pj(output))

        return output, x_len


class RNNs(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        upstream_rate,
        module,
        bidirection,
        dim,
        dropout,
        layer_norm,
        proj,
        sample_rate,
        sample_style,
        total_rate,
    ):
        super(RNNs, self).__init__()
        latest_size = input_size

        self.sample_rate = 1 if total_rate == -1 else round(total_rate / upstream_rate)  # 1
        self.sample_style = sample_style
        if sample_style == "concat":
            latest_size *= self.sample_rate  # 1024

        self.rnns = nn.ModuleList()
        for i in range(len(dim)):
            rnn_layer = RNNLayer(
                latest_size,
                module,
                bidirection,
                dim[i],
                dropout[i],
                layer_norm[i],
                sample_rate[i],
                proj[i],
            )
            self.rnns.append(rnn_layer)
            latest_size = rnn_layer.out_dim

        self.linear = nn.Linear(latest_size, output_size)

    def forward(self, x, x_len):
        r"""
        Args:
            x (torch.Tensor): Tensor of dimension (batch_size, input_length, num_features).
            x_len (torch.IntTensor): Tensor of dimension (batch_size).
        Returns:
            Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).
        """
        # Perform Downsampling
        # start_time = time.time()
        # print(f"1 x len: {x_len}")
        if self.sample_rate > 1:
            x, x_len = downsample(x, x_len, self.sample_rate, self.sample_style)
        # print(f"Down sampled x len: {x_len}")
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"Downsample time: {shuffle_time}")

        # start_time = time.time()
        # print(f"X: {x}")
        # print(f"2 x len: {x_len}")
        for rnn in self.rnns:
            x, x_len = rnn(x, x_len)
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"RNN forward: {shuffle_time}")

        logits = self.linear(x).log_softmax(dim=-1)
        return logits, x_len


def read_tokens(file_path):
    tokens = []
    with open(file_path, "r", encoding="utf-8") as file:
        for line in file:
            token = line.strip()
            if token == "":
                tokens.append(" ")
            else:
                tokens.append(token)
    return tokens


class MHFA(nn.Module):
    def __init__(self, head_nb=8, inputs_dim=1024, compression_dim=128, outputs_dim=1024, n_layers=24):
        super(MHFA, self).__init__()

        # Define learnable weights for key and value computations across layers
        self.weights_k = nn.Parameter(data=torch.ones(n_layers), requires_grad=True)
        self.weights_v = nn.Parameter(data=torch.ones(n_layers), requires_grad=True)

        # Initialize given parameters
        self.head_nb = head_nb
        self.ins_dim = inputs_dim
        self.cmp_dim = compression_dim
        self.ous_dim = outputs_dim

        # Define compression linear layers for keys and values
        self.cmp_linear_k = nn.Linear(self.ins_dim, self.cmp_dim)
        self.cmp_linear_v = nn.Linear(self.ins_dim, self.cmp_dim)

        # Define linear layer to compute multi-head attention weights
        self.att_head = nn.Linear(self.cmp_dim, self.head_nb)

        # Define a fully connected layer for final output
        self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, self.ous_dim)

    def forward(self, x, padding_mask):
        # Input x has shape: [Batch, Dim, Frame_len, Nb_Layer]

        # Compute the key by taking a weighted sum of input across layers
        k = torch.sum(x.mul(nn.functional.softmax(self.weights_k, dim=-1)), dim=-1).transpose(1, 2)

        # Compute the value in a similar fashion
        v = torch.sum(x.mul(nn.functional.softmax(self.weights_v, dim=-1)), dim=-1).transpose(1, 2)

        # Pass the keys and values through compression linear layers
        k = self.cmp_linear_k(k)  # bs, seq_len, hid_dim
        v = self.cmp_linear_v(v)  # bs, seq_len, hid_dim

        # Compute attention weights using compressed keys
        att_k = self.att_head(k)  # bs, seq_len, n_heads

        # Adjust dimensions for computing attention output
        v = v.unsqueeze(-2)
        padding_mask_expanded = padding_mask.unsqueeze(-1).expand_as(att_k)
        att_k_masked = att_k.masked_fill(padding_mask_expanded, -np.inf)

        # Compute attention output by taking weighted sum of values using softmaxed attention weights
        pooling_outs = v.mul(nn.functional.softmax(att_k_masked, dim=1).unsqueeze(-1))  # , dim=1)

        # Reshape the tensor before passing through the fully connected layer
        b, seq_len, h, f = pooling_outs.shape
        pooling_outs = pooling_outs.reshape(b, seq_len, -1)

        # Pass through fully connected layer to get the final output
        outs = self.pooling_fc(pooling_outs)
        # padding_mask_expanded_final = padding_mask.unsqueeze(-1).expand_as(outs)
        # masked_outs = outs.masked_fill(padding_mask_expanded_final, 0.)
        return outs


class WavLMFinetuneWrapper(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.args = cfg.model
        data_args = cfg.dataset
        self.freeze_upstream = self.args.freeze_upstream
        self.output_representation = self.args.output_rep
        self.wavlm = WavLM(self.args)

        ######################################### Initializing CTC Decoder ############################################
        self.decoder = ctc_decoder(
            tokens=read_tokens(data_args.token_dictionary),
            lexicon=data_args.lexicon_path,
            nbest=self.args.nbest,
            beam_size=self.args.beam,
            beam_threshold=self.args.beam_threshold,
            lm=None,
            lm_weight=0,
            blank_token="<blank>",
            sil_token="|",
            unk_word="<unk>",
        )
        dictionary = dict(
            zip(
                self.decoder.idxs_to_tokens(torch.tensor(list(range(data_args.dictionary_len))).long()),
                range(data_args.dictionary_len),
            )
        )
        self.pad_token_idx = dictionary["<pad>"]

        ######################################## Initializing WavLM backbone ##########################################
        if self.args.init_with_wavlm:
            init_with_wavlm(
                self.wavlm,
                self.args.encoder_layers,
                self.args.path_to_wavlm,
                need_mask_emb=False,
            )
        else:
            print("No initialization method specified. Initializing with random weights.")

        ########################################## Initializing Projector #############################################
        if (
            self.args.output_rep in ["last_layer", "weighted_sum", "weighted_hiddens"]
            or self.args.output_rep.startswith("layer")
            or (self.args.output_rep == "elbo" and not self.args.n_asr_models)
        ):
            self.projector = nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
        elif (
            self.args.output_rep == "elbo" and self.args.n_asr_models
        ) or self.args.output_rep == "weighted_sum_ensemble":
            self.projector = nn.ModuleList(
                [
                    nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
                    for i in range(self.args.encoder_layers)
                ]
            )
        elif self.args.output_rep == "mhfa":
            self.projector = nn.Linear(self.args.encoder_embed_dim, self.args.projector_dim)
        else:
            raise NotImplementedError(f"output_rep {self.args.output_rep} is not implemented.")

        ######################## Initializing layer_distribution if needed or weights ################################
        if self.args.output_rep == "weighted_sum":
            self.elbo = False
            self.weights = nn.Parameter(torch.zeros(self.args.encoder_layers))
            self.layer_index = None
            print(f"Using weighted sum of {list(self.weights.shape)} representations as output representation.")

        elif self.args.output_rep == "mhfa":
            self.elbo = False
            self.mhfa = MHFA(
                head_nb=self.args.mhfa_head_nb,
                inputs_dim=self.args.encoder_embed_dim,
                compression_dim=self.args.mhfa_compression_dim,
                outputs_dim=self.args.encoder_embed_dim,
                n_layers=self.args.encoder_layers,
            )
            self.layer_index = None
            print(f"Using MHFA as output representation.")

        elif self.args.output_rep == "weighted_sum_ensemble":
            self.elbo = False
            self.weights = nn.Parameter(torch.zeros((self.args.encoder_layers, self.args.encoder_layers)))
            self.layer_index = None
            print(
                f"Using ensemble of weighted sum of {list(self.weights.shape)} representations as output representation."
            )
        elif self.args.output_rep == "last_layer":
            self.weights = None
            self.elbo = False
            self.layer_index = None
            print(f"Using {self.args.output_rep} representation as output representation.")
        elif self.args.output_rep in ["elbo", "weighted_hiddens"]:
            if self.args.output_rep == "elbo":
                self.elbo = True
            else:
                self.elbo = False
            self.layer_index = None
            self.weights = None
            self.distribution_prediction = self.args.distribution_prediction
            if self.args.distribution_prediction in ["from_12_transformer", "from_last", "from_24_layers_summation"]:
                if self.args.distribution_prediction_architecture == "linear":
                    self.layer_distribution = nn.Linear(self.args.encoder_embed_dim, self.args.encoder_layers)
                else:
                    self.layer_distribution = nn.Sequential(
                        nn.Linear(self.args.encoder_embed_dim, self.args.encoder_embed_dim // 2),
                        nn.ReLU(),
                        nn.Linear(self.args.encoder_embed_dim // 2, self.args.encoder_layers),
                    )
            elif self.args.distribution_prediction == "from_cnn":
                self.distribution_prediction = "from_cnn"
                self.layer_distribution = nn.Linear(self.args.encoder_embed_dim, self.args.encoder_layers)
            elif self.args.distribution_prediction == "from_24_layers_rnn":
                self.layer_distribution = RNNDistributionPredictor(
                    input_size=self.args.encoder_embed_dim,
                    hidden_size=self.args.rnn_hid_dim,
                    d=self.args.d,
                    rnn_layers=self.args.rnn_n_layers,
                    dropout=self.args.rnn_prediction_dropout,
                )
            elif self.args.distribution_prediction == "from_24_layers_mhsa":
                self.layer_distribution = AttentionDistributionPredictor(
                    hid_dim=self.args.encoder_embed_dim, num_heads=self.args.dist_att_n_num_heads
                )
            else:
                raise AttributeError(f"Unknown distribution prediction type: {self.args.distribution_prediction}")

            print(
                f"Using {self.args.output_rep} representation as output representation trained with "
                f" {self.args.distribution_prediction} distribution."
            )
        elif self.args.output_rep.startswith("layer"):
            self.weights = None
            self.elbo = False
            self.layer_index = int(self.args.output_rep.split("_")[-1]) - 1
            print(f"Using {self.layer_index}th layer representation as output representation.")
        else:
            raise Exception(
                f"Expected self.args.output_rep to be: elbo, "
                f"layer_n, weighted or last_layer, got: {self.args.output_rep}"
            )

        ############################################ Initializing asr_model ##########################################
        if (self.elbo and self.args.n_asr_models) or self.args.output_rep == "weighted_sum_ensemble":
            if self.args.asr_model == "linear":
                self.asr_model = nn.ModuleList(
                    [
                        nn.Linear(self.args.projector_dim, data_args.dictionary_len)
                        for i in range(self.args.encoder_layers)
                    ]
                ).to("cuda")
            else:
                self.asr_model = nn.ModuleList(
                    [
                        RNNs(
                            input_size=self.args.projector_dim,
                            output_size=data_args.dictionary_len,
                            upstream_rate=self.args.upsample_rate,  # 320
                            module=self.args.module,
                            bidirection=self.args.bidirection,
                            dim=self.args.rnn_dim,
                            dropout=self.args.rnn_dropout,
                            layer_norm=self.args.rnn_layer_norm,
                            proj=self.args.rnn_proj,
                            sample_rate=self.args.rnn_sample_rate,  # [1, 1]
                            sample_style=self.args.sample_style,
                            total_rate=self.args.rnn_total_rate,  # -1
                        )
                        for i in range(self.args.encoder_layers)
                    ]
                ).to("cuda")
        elif self.args.output_rep == "mhfa":
            self.asr_model = nn.Linear(self.args.projector_dim, data_args.dictionary_len)
        else:
            if self.args.asr_model == "linear":
                self.asr_model = nn.Linear(self.args.projector_dim, data_args.dictionary_len)
            else:
                self.asr_model = RNNs(
                    input_size=self.args.projector_dim,
                    output_size=data_args.dictionary_len,
                    upstream_rate=self.args.upsample_rate,  # 320
                    module=self.args.module,
                    bidirection=self.args.bidirection,
                    dim=self.args.rnn_dim,
                    dropout=self.args.rnn_dropout,
                    layer_norm=self.args.rnn_layer_norm,
                    proj=self.args.rnn_proj,
                    sample_rate=self.args.rnn_sample_rate,  # [1, 1]
                    sample_style=self.args.sample_style,
                    total_rate=self.args.rnn_total_rate,  # -1
                ).to("cuda")
        if self.elbo and self.args.asr_model == "linear" and self.args.layer_position_encoding:
            self.layer_position_encoding = nn.Embedding(
                num_embeddings=self.args.encoder_layers, embedding_dim=self.args.encoder_embed_dim
            ).to("cuda")
        else:
            self.layer_position_encoding = None

    def _weighted_sum(self, layer_results: list, n_layer=None):
        stacked_feature = torch.stack(layer_results, dim=0)
        _, *origin_shape = stacked_feature.shape
        stacked_feature = stacked_feature.view(len(layer_results), -1)
        weights = self.weights[n_layer, :] if n_layer is not None else self.weights
        norm_weights = F.softmax(weights, dim=-1)
        weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
        weighted_feature = weighted_feature.view(*origin_shape)
        return weighted_feature

    @staticmethod
    def _get_layer_results_stack_and_mask(layer_results, padding_mask):
        n_layers = len(layer_results)
        batch_size, seq_len, feature_dim = layer_results[0][0].shape
        stacked_layer_results = torch.empty(
            batch_size,
            seq_len,
            feature_dim,
            n_layers,
            device=layer_results[0][0].device,
            dtype=layer_results[0][0].dtype,
        )
        for i in range(n_layers):
            stacked_layer_results[..., i] = layer_results[i][0]

        zero_mask = padding_mask.unsqueeze(-1).unsqueeze(-1).expand_as(stacked_layer_results)
        stacked_masked_layer_results = stacked_layer_results.masked_fill(zero_mask, 0.0)
        return stacked_masked_layer_results

    def apply_mask_to_merged_output(self, layer_results, padding_mask, real_length_feature_extractor):
        stacked_masked_layer_results = self._get_layer_results_stack_and_mask(layer_results, padding_mask)
        stacked_layer_results = stacked_masked_layer_results.sum(dim=1) / real_length_feature_extractor.unsqueeze(
            -1
        )  # [bs, 1024, 24]
        return stacked_layer_results, len(layer_results)

    def get_layer_distribution(self, layer_results, feature_extractor_output=None):
        padding_mask = self.wavlm.get_padding_mask()
        real_length_feature_extractor = torch.sum(~padding_mask, dim=-1, keepdim=True)

        if self.distribution_prediction in ["from_12_transformer", "from_last"]:
            zero_mask = torch.zeros_like(layer_results[0][0])
            zero_mask[padding_mask] = 1.0
            layer_results = [layer_results[i][0] for i in range(len(layer_results))]
            layer_result = layer_results[11 if self.distribution_prediction == "from_12_transformer" else -1]
            layer_result = torch.where(zero_mask.bool(), torch.zeros_like(layer_result), layer_result)
            layer_result = layer_result.sum(dim=1) / real_length_feature_extractor  # [bs, 1024]
            layer_distribution_logits = self.layer_distribution(layer_result)
        elif self.distribution_prediction == "from_all_24_transformer":
            layer_results = [layer_results[i][0] for i in range(len(layer_results))]
            layer_result = torch.stack(layer_results, dim=-1)
            assert (
                len(layer_result.shape) == 4 and layer_result.shape[-1] == self.args.encoder_layers
            ), f"Expected to have: bs, seq_len, hid_dim, n_layers shape layers, got: {layer_result.shape}"
            zero_mask = torch.zeros_like(layer_result)
            zero_mask[padding_mask] = 1.0
            layer_result = torch.where(zero_mask.bool(), torch.zeros_like(layer_result), layer_result)
            # print(f"Real length in layer24 distribution: {real_length_feature_extractor.unsqueeze(-1)}")
            layer_result = layer_result.sum(dim=1) / real_length_feature_extractor.unsqueeze(-1)  # [bs, 1024, 24]
            layer_distribution_logits = self.layer_distribution(layer_result.permute(0, 2, 1))
        elif self.distribution_prediction == "from_24_layers_summation":
            layer_results, n_layers = self.apply_mask_to_merged_output(
                layer_results, padding_mask, real_length_feature_extractor
            )  # [bs, 1024, 24]
            layer_results = layer_results.sum(dim=-1) / n_layers  # [bs, 1024]
            layer_distribution_logits = self.layer_distribution(layer_results)

        elif self.distribution_prediction == "from_24_layers_rnn":
            layer_results, n_layers = self.apply_mask_to_merged_output(
                layer_results, padding_mask, real_length_feature_extractor
            )  # [bs, 1024, 24]
            layer_results = layer_results.permute(0, 2, 1)  # [bs, 24, 1024]
            layer_distribution_logits = self.layer_distribution(layer_results)  # [bs, 24]

        elif self.distribution_prediction == "from_24_layers_mhsa":
            layer_results, n_layers = self.apply_mask_to_merged_output(
                layer_results, padding_mask, real_length_feature_extractor
            )  # [bs, 1024, 24]
            layer_results = layer_results.permute(0, 2, 1)  # [bs, 24, 1024]
            layer_distribution_logits = self.layer_distribution(layer_results)

        else:
            print(f"Distribution prediction type is not implemented {self.distribution_prediction}.")
            raise AttributeError
        return layer_distribution_logits

    def _elbo(self, layer_results: list, feature_extractor_output):
        layer_distribution_logits = self.get_layer_distribution(layer_results, feature_extractor_output)
        prediction_logits = []
        prediction_logits_lens = []

        for i in range(len(layer_results)):  # hidden: c (encoder_hid_dim)
            logits, logits_lens = self._process_single_hidden(
                fea=layer_results[i][0],
                projector=self.projector[i] if self.args.n_asr_models else self.projector,
                asr_model=self.asr_model[i] if self.args.n_asr_models else self.asr_model,
            )
            prediction_logits.append(logits)
            prediction_logits_lens.append(logits_lens)
        # print(f"Len of prediction_log_probas: {len(prediction_log_probas)}")
        return (
            prediction_logits,  # * 24,
            prediction_logits_lens,  # * 24,
            layer_distribution_logits,
        )

    def _weighted_sum_ensemble(self, layer_results):
        prediction_logits = []
        prediction_logits_len = []
        n_layers = len(layer_results)
        for layer in range(n_layers):
            torch.cuda.empty_cache()
            fea = self._weighted_sum(layer_results=layer_results, n_layer=layer)
            torch.cuda.empty_cache()
            log_probs, log_probs_len = self._process_single_hidden(fea, self.projector[layer], self.asr_model[layer])
            prediction_logits.append(log_probs)
            prediction_logits_len.append(log_probs_len)
        return prediction_logits, prediction_logits_len

    def _weighted_hidden_representation(self, layer_results):
        layer_distribution_logits = self.get_layer_distribution(layer_results)
        layer_distribution_probas = F.softmax(layer_distribution_logits, dim=1)
        padding_mask = self.wavlm.get_padding_mask()
        stacked_layer_results = self._get_layer_results_stack_and_mask(layer_results, padding_mask)
        weighted_layer_results = layer_distribution_probas.unsqueeze(1).unsqueeze(1) * stacked_layer_results
        weighted_averaged_layer_results = torch.sum(weighted_layer_results, dim=-1)
        log_probs, log_probs_len = self._process_single_hidden(
            weighted_averaged_layer_results, self.projector, self.asr_model
        )
        return log_probs, log_probs_len, layer_distribution_logits

    def _mhfa(self, layer_results):
        # bs, seq_len
        padding_mask = self.wavlm.get_padding_mask()
        stacked_layer_results = self._get_layer_results_stack_and_mask(layer_results, padding_mask)
        stacked_layer_results = stacked_layer_results.permute(0, 2, 1, 3)
        weighted_layer_result = self.mhfa(stacked_layer_results, padding_mask)
        log_probs, log_probs_len = self._process_single_hidden(weighted_layer_result, self.projector, self.asr_model)
        return log_probs, log_probs_len

    def process_target_labels(self, labels):
        target_tokens_batch = []
        target_words_batch = []
        for label in labels:
            label_idx = label != self.pad_token_idx
            target_token_ids = label[label_idx]
            target_tokens = self.decoder.idxs_to_tokens(target_token_ids)
            target_words = token_to_word(target_tokens).split()
            target_tokens_batch.append(target_tokens)
            target_words_batch.append(target_words)
        return target_tokens_batch, target_words_batch

    def _process_single_hidden(self, fea, projector, asr_model):
        padding_mask = self.wavlm.get_padding_mask()
        if padding_mask is not None:
            real_length = torch.sum(~padding_mask, dim=-1, keepdim=True)
            zero_mask = torch.zeros_like(fea)
            zero_mask[padding_mask] = 1.0
            fea = torch.where(zero_mask.bool(), torch.zeros_like(fea), fea)
        else:
            real_length = torch.full(
                (fea.size(0), 1),
                fill_value=fea.size(1),
                dtype=fea.dtype,
                device=fea.device,
            )
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"Get padding mask : {shuffle_time}")

        if projector is not None:
            fea = projector(fea)  # torch.Size([bs, l, h])
        # print(f"Before asr: {real_length.squeeze().cpu()}")

        if self.cfg.model.asr_model == "linear":
            log_probs_len = real_length
            log_probs = asr_model(fea).log_softmax(dim=-1)
        else:
            log_probs, log_probs_len = asr_model(fea, real_length.squeeze().cpu())  # .tolist()

        return log_probs, log_probs_len  # , pred_tokens_batch, pred_words_batch

    def _apply_layer_position_encoding(self, layer_results):
        encoded_layer_results = []
        n_layers = len(layer_results)
        bs, seq_len, hid_dim = layer_results[0][0].shape
        embeddings = self.layer_position_encoding(torch.arange(n_layers, device=layer_results[0][0].device))
        for i in range(n_layers):
            expanded_embedding = embeddings[i].unsqueeze(0).unsqueeze(0).expand(bs, seq_len, hid_dim)
            encoded_layer_result = layer_results[i][0] + expanded_embedding
            encoded_layer_results.append([encoded_layer_result])
        return encoded_layer_results

    def forward(
        self,
        waveform: Tensor,
        padding_mask: Optional[Tensor] = None,
        labels: torch.Tensor = None,
        inference: bool = True,
    ):
        if inference:
            target_tokens_batch, target_words_batch = self.process_target_labels(labels)
        else:
            target_tokens_batch, target_words_batch = None, None
        # start_time = time.time()
        if self.freeze_upstream:
            with torch.no_grad():
                fea, layer_results, feature_extractor_output = self.wavlm(
                    waveform=waveform, padding_mask=padding_mask, ret_layer_results=True
                )
        else:
            fea, layer_results, feature_extractor_output = self.wavlm(
                waveform=waveform, padding_mask=padding_mask, ret_layer_results=True
            )
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"WavLM forward: {shuffle_time}")
        layer_results = layer_results[1:]  # cut input to WavLM
        if self.layer_position_encoding:
            layer_results = self._apply_layer_position_encoding(layer_results)
        if self.elbo:
            # start_time = time.time()
            (
                log_probs,
                log_probs_len,
                layer_distribution_logits,
            ) = self._elbo(layer_results, feature_extractor_output)
            # torch.cuda.synchronize()
            # end_time = time.time()
            # shuffle_time = end_time - start_time
            # print(f"Total elbo forward time : {shuffle_time}")
        elif self.output_representation == "weighted_hiddens":
            (
                log_probs,
                log_probs_len,
                layer_distribution_logits,
            ) = self._weighted_hidden_representation(layer_results)
        elif self.output_representation == "mhfa":
            (
                log_probs,
                log_probs_len,
            ) = self._mhfa(layer_results)
            layer_distribution_logits = None
        elif self.weights is not None:
            layer_results = [layer_results[i][0] for i in range(len(layer_results))]
            layer_distribution_logits = None
            # ensemble case
            if len(self.weights.shape) > 1:
                log_probs, log_probs_len = self._weighted_sum_ensemble(layer_results)
            else:
                fea = self._weighted_sum(layer_results)
                log_probs, log_probs_len = self._process_single_hidden(fea, self.projector, self.asr_model)
        elif self.layer_index is not None:
            fea = layer_results[self.layer_index][0]
            log_probs, log_probs_len = self._process_single_hidden(fea, self.projector, self.asr_model)
            layer_distribution_logits = None
        else:
            log_probs, log_probs_len = self._process_single_hidden(fea, self.projector, self.asr_model)
            layer_distribution_logits = None
        # print(f"Vesper2: Pred words batch: {pred_words_batch}")
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"ASR forward time: {shuffle_time}")
        if inference and self.cfg.model.output_rep not in ["elbo", "weighted_sum_ensemble"]:
            # start_time = time.time()
            pred_tokens_batch, pred_words_batch = self.decode(log_probs, log_probs_len, elbo=False)
            # torch.cuda.synchronize()
            # end_time = time.time()
            # shuffle_time = end_time - start_time
            # print(f"Total decoding time : {shuffle_time}")
        else:
            pred_words_batch, pred_tokens_batch = None, None
        return (
            log_probs if self.output_representation in ["elbo", "weighted_hiddens"] else [log_probs],
            log_probs_len if self.output_representation in ["elbo", "weighted_hiddens"] else [log_probs_len],
            pred_tokens_batch,
            pred_words_batch,
            target_tokens_batch,
            target_words_batch,
            layer_distribution_logits,
        )

    @torch.no_grad()
    def ith_layer_inference(
        self, waveform: Tensor, i: int, padding_mask: Optional[Tensor] = None, labels: torch.Tensor = None
    ):
        """
        Method used to inference model trained with weighted sum for the i'th layer inference
        """
        target_tokens_batch, target_words_batch = self.process_target_labels(labels)
        fea, layer_results, feature_extractor_output = self.wavlm(
            waveform=waveform, padding_mask=padding_mask, ret_layer_results=True
        )
        layer_results = layer_results[1:]  # cut input to WavLM
        fea = layer_results[i][0]
        log_probs, log_probs_len = self._process_single_hidden(fea, self.projector, self.asr_model)
        layer_distribution_logits = None
        pred_tokens_batch, pred_words_batch = self.decode(log_probs, log_probs_len, elbo=False)
        return (
            [log_probs],
            [log_probs_len],
            pred_tokens_batch,
            pred_words_batch,
            target_tokens_batch,
            target_words_batch,
            layer_distribution_logits,
        )

    @torch.no_grad()
    def decode(self, log_probs, input_lens, elbo=False):
        """Decoder that take log probabilities as input and outputs decoded seq"""
        decoded_words, decoded_tokens = [], []
        # start_time = time.time()
        if elbo:
            n_layers = len(log_probs)
            bs = log_probs[0].shape[0]
            log_probs = torch.cat(log_probs, dim=0)
            input_lens = torch.cat(input_lens, dim=0)

        # Move log_probs and input_lens to CPU only once
        log_probs_cpu = log_probs.float().contiguous().cpu()
        input_lens_cpu = input_lens.squeeze().cpu()
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"Preparation in decode: {shuffle_time}")
        # start_time = time.time()
        decoded = self.decoder(log_probs_cpu, input_lens_cpu)
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"Decoding itself: {shuffle_time}")
        if elbo:
            # start_time = time.time()
            for i in range(n_layers):
                current_layer_words = []
                current_layer_tokens = []
                for j in range(bs):
                    idx = i * bs + j
                    current_layer_words.append(decoded[idx][0][1])
                    current_layer_tokens.append(self.decoder.idxs_to_tokens(decoded[idx][0][0]))
                decoded_words.append(current_layer_words)
                decoded_tokens.append(current_layer_tokens)
            # torch.cuda.synchronize()
            # end_time = time.time()
            # shuffle_time = end_time - start_time
            # print(f"For loop: {shuffle_time}")
        else:
            decoded_words = [[decoded[i][0][1] for i in range(len(decoded))]]
            decoded_tokens = [[self.decoder.idxs_to_tokens(decoded[i][0][0]) for i in range(len(decoded))]]
        # print(f"Decoded tokens: {decoded_tokens}")
        # print(f"Len of decoded tokens: {len(decoded_tokens)}")
        # print(f"Len of decoded tokens[0]: {len(decoded_tokens[0])}")
        return decoded_tokens, decoded_words


class IdentityPredictionHead(nn.Module):
    """A simple linear layer that is initialized in as an identity matrix.

    Inputs:
        x: (B, T, input_dim), ``Tensor``
    Outputs:
        x: (B, T, output_dim), ``Tensor``
    """

    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        assert input_dim == output_dim, (
            f"Was going to use eye matrix for initialization, but found:" f" {input_dim} and {output_dim} "
        )
        self.eye_ffn = nn.Linear(input_dim, output_dim)
        self.eye_ffn.weight.data.copy_(torch.eye(input_dim))

    def forward(self, x: Tensor):
        return self.eye_ffn(x)
