

import math
import logging
from typing import List, Tuple

import torch
import torch.nn as nn
from fairseq.modules import FairseqDropout
from fairseq import utils

from ..components import Embedding, trunc_normal_, LayerNorm, Linear

logger = logging.getLogger(__name__)


def make_token_bucket_position(bucket_size, max_position):
    context_pos = torch.arange(max_position, dtype=torch.long)[:, None]
    memory_pos = torch.arange(max_position, dtype=torch.long)[None, :]
    relative_pos = context_pos - memory_pos
    assert bucket_size < max_position
    sign = torch.sign(relative_pos)
    mid = bucket_size // 2
    abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid-1, torch.abs(relative_pos))
    log_pos = mid + torch.ceil(
        torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid-1)
    ).long()
    bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long()
    return bucket_pos + bucket_size - 1


class AudioAdapter(torch.nn.Module):
    def __init__(self, cfg, embed_dim, attention_heads, num_layers=None):
        super().__init__()
        self.dropout_module = FairseqDropout(
            cfg.dropout, module_name=self.__class__.__name__
        )
        self.alpha = cfg.shrink_alpha

        if cfg.feature_encoder_spec is not None:
            feature_encoder_spec = eval(cfg.feature_encoder_spec)
            feature_embed_dim = feature_encoder_spec[-1][0]
            self.embed_audios = torch.nn.Sequential(
                ConvFeatureExtractionModel(
                    conv_layers=feature_encoder_spec,
                    dropout=0.0,
                    conv_bias=cfg.conv_bias
                ),
                TransposeLast(),
                LayerNorm(feature_embed_dim),
                Linear(feature_embed_dim, embed_dim),
            )

        if cfg.abs_pos_type == 'conv':
            num_pos_layers = cfg.conv_pos_depth
            k = max(3, cfg.conv_pos_width // num_pos_layers)
            self.embed_positions = nn.Sequential(
                TransposeLast(),
                *[
                    nn.Sequential(
                        nn.Conv1d(
                            embed_dim,
                            embed_dim,
                            kernel_size=k,
                            padding=k // 2,
                            groups=cfg.conv_pos_groups,
                        ),
                        SamePad(k),
                        TransposeLast(),
                        nn.LayerNorm(embed_dim, elementwise_affine=False),
                        TransposeLast(),
                        nn.GELU(),
                    )
                    for _ in range(num_pos_layers)
                ],
                TransposeLast(),
            )
            if cfg.conv_pos_pre_ln:
                self.embed_positions = nn.Sequential(LayerNorm(embed_dim), self.embed_positions)
            self.cls_pos_embed = nn.Parameter(torch.zeros(1, 1, embed_dim))
            trunc_normal_(self.cls_pos_embed)
        elif cfg.abs_pos_type == 'fixed':
            self.embed_positions = Embedding(1024 + 2, embed_dim)
        else:
            raise NotImplementedError

        if cfg.layernorm_embedding:
            self.layernorm_embedding = LayerNorm(embed_dim)
        else:
            self.layernorm_embedding = None

        self.cls_embedding = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
        if cfg.add_type_embedding:
            self.type_embedding = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
        else:
            self.type_embedding = None

        if cfg.use_attn_bias:
            num_rel_dis = 2 * cfg.bucket_size - 1
            rp_bucket = make_token_bucket_position(cfg.bucket_size, max_position=1024)
            rp_bucket[0, :] = num_rel_dis
            rp_bucket[:, 0] = num_rel_dis + 1
            rp_bucket[0, 0] = num_rel_dis + 2
            self.register_buffer("rp_bucket", rp_bucket)
            self.rel_pos_table_list = torch.nn.ModuleList(
                [
                    Embedding(num_rel_dis + 3, attention_heads, zero_init=True)
                    for _ in range(num_layers if num_layers is not None else 1)
                ]
            )
        else:
            self.rel_pos_table_list = None

        self.mask_embedding = torch.nn.Parameter(torch.zeros(1, embed_dim))

        trunc_normal_(self.cls_embedding)
        trunc_normal_(self.mask_embedding)
        if cfg.abs_pos_type == 'fixed':
            trunc_normal_(self.embed_positions.weight)

    def get_rel_pos_bias(self, bsz, seq_len):
        rel_pos_bias_list = []
        for rel_pos_table in self.rel_pos_table_list:
            rp_bucket = self.rp_bucket[:seq_len, :seq_len]
            values = rel_pos_table(rp_bucket).unsqueeze(0).expand(bsz, -1, -1, -1)
            values = values.permute(0, 3, 1, 2)
            rel_pos_bias_list.append(values)
        return rel_pos_bias_list

    def gather_features(self, adapter_embedding, self_attn_bias_list, position_ids):
        seq_len, embed_dim = adapter_embedding.shape[-2:]
        adapter_embedding = adapter_embedding.gather(1, position_ids[:, 1:, None].expand(-1, -1, embed_dim) - 1)

        gather_seq_len = position_ids.size(1)
        if self_attn_bias_list is not None:
            new_self_attn_bias_list = []
            for self_attn_bias in self_attn_bias_list:
                self_attn_bias = self_attn_bias.gather(
                    2, position_ids[:, None, :, None].expand(-1, self_attn_bias.size(1), -1, seq_len+1)
                ).gather(3, position_ids[:, None, None, :].expand(-1, self_attn_bias.size(1), gather_seq_len, -1))
                new_self_attn_bias_list.append(self_attn_bias)
        else:
            new_self_attn_bias_list = None

        return adapter_embedding, new_self_attn_bias_list

    def forward(
        self,
        src_audios,
        padding_mask,
        preserve_ids=None,
        preserve_embed=None,
        mask_token=None
    ):
        """
        Args:
            src_audios List of audios
        Returns:
                - **x** (Tensor): the processed embedding of
                  shape `(batch, src_len, embed_dim)`
                - **padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **self_attn_bias_list** (Tensor): list of attention bias in self attention
                 of shape `(bsz, num_attention_heads, src_len, src_len)`.
        """

        bsz = padding_mask.size(0)
        position_ids = utils.new_arange(padding_mask)
        if self.rel_pos_table_list is not None:
            self_attn_bias_list = self.get_rel_pos_bias(bsz, position_ids.size(1))
        else:
            self_attn_bias_list = None

        if preserve_embed is not None:
            pos_embed = self.embed_positions(position_ids)
            seq_len, embed_dim = pos_embed.size(1), pos_embed.size(2)
            adapter_embedding = mask_token.repeat(bsz * seq_len, 1)
            right_preserve_indices = torch.nonzero(preserve_ids.ne(-1).flatten(), as_tuple=False).flatten()
            left_preserve_indices = preserve_ids + (torch.arange(bsz) * seq_len).unsqueeze(1).type_as(preserve_ids)
            left_preserve_indices = left_preserve_indices.view(-1)[right_preserve_indices]
            adapter_embedding[left_preserve_indices] = preserve_embed.reshape(-1, embed_dim)[right_preserve_indices]
            adapter_embedding = adapter_embedding.reshape(bsz, seq_len, embed_dim)
        else:
            adapter_embedding = self.embed_audios(src_audios)
            if preserve_ids is not None:
                padding_mask = preserve_ids.eq(-1)
                position_ids = preserve_ids.masked_fill(preserve_ids.eq(-1), preserve_ids.size(1) - 1)
                adapter_embedding, self_attn_bias_list = self.gather_features(
                    adapter_embedding, self_attn_bias_list, position_ids
                )
            pos_embed = self.embed_positions(adapter_embedding)
            pos_embed = torch.cat([self.cls_pos_embed.expand(bsz, -1, -1), pos_embed], dim=1)
            cls_embedding = self.cls_embedding.expand(bsz, -1, -1)
            adapter_embedding = torch.cat([cls_embedding, adapter_embedding], dim=1)

            if self.layernorm_embedding is not None:
                adapter_embedding = self.layernorm_embedding(adapter_embedding)
            if self.alpha != 1.0:
                adapter_embedding = adapter_embedding * self.alpha + adapter_embedding.detach() * (1 - self.alpha)

        x = adapter_embedding + pos_embed

        if self.type_embedding is not None:
            x += self.type_embedding.expand_as(x)
        x = self.dropout_module(x)

        return x, padding_mask, self_attn_bias_list

    def upgrade_state_dict_named(self, state_dict, name):
        prefix = name + "." if name != "" else ""

        if prefix + 'rel_pos_table.weight' in state_dict:
            rel_pos_table_weight = state_dict[prefix + 'rel_pos_table.weight']
            state_dict[prefix + 'rel_pos_table_list.0.weight'] = rel_pos_table_weight
            del state_dict[prefix + 'rel_pos_table.weight']
        if self.rel_pos_table_list is not None and len(self.rel_pos_table_list) > 1 \
                and prefix + 'rel_pos_table_list.1.weight' not in state_dict:
            logger.info('copy rel_pos_weight to each layer')
            rel_pos_table_weight = state_dict[prefix + 'rel_pos_table_list.0.weight']
            for i in range(len(self.rel_pos_table_list)):
                state_dict[prefix + 'rel_pos_table_list.{}.weight'.format(i)] = rel_pos_table_weight.clone()

        for param_name, param_tensor in self.state_dict().items():
            if (prefix + param_name) not in state_dict:
                logger.info('{} not exists, re-initialized'.format(prefix + param_name))
                state_dict[prefix + param_name] = self.state_dict()[param_name]

        return state_dict


class TransposeLast(nn.Module):
    def __init__(self, tranpose_dim=-2):
        super().__init__()
        self.tranpose_dim = tranpose_dim

    def forward(self, x):
        return x.transpose(self.tranpose_dim, -1)


class SamePad(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.remove = 1 if kernel_size % 2 == 0 else 0

    def forward(self, x):
        if self.remove > 0:
            x = x[:, :, : -self.remove]
        return x


class ConvFeatureExtractionModel(nn.Module):
    def __init__(
        self,
        conv_layers: List[Tuple[int, int, int]],
        dropout: float = 0.0,
        conv_bias: bool = False
    ):
        super().__init__()

        def block(
            n_in,
            n_out,
            k,
            stride,
            conv_bias=False,
        ):
            def make_conv():
                conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
                nn.init.kaiming_normal_(conv.weight)
                return conv

            return nn.Sequential(
                make_conv(),
                nn.Dropout(p=dropout),
                nn.Sequential(
                    TransposeLast(),
                    LayerNorm(dim),
                    TransposeLast(),
                ),
                nn.GELU(),
            )

        in_d = 1
        self.conv_layers = nn.ModuleList()
        for i, cl in enumerate(conv_layers):
            assert len(cl) == 3, "invalid conv definition: " + str(cl)
            (dim, k, stride) = cl

            self.conv_layers.append(
                block(
                    in_d,
                    dim,
                    k,
                    stride,
                    conv_bias=conv_bias,
                )
            )
            in_d = dim

    def forward(self, x):

        # BxT -> BxCxT
        x = x.unsqueeze(1)

        for conv in self.conv_layers:
            x = conv(x)

        return x
