import math
from typing import Any
import torch
import torch.nn as nn

from utils.torch_utilities import concat_non_padding, restore_from_concat, create_mask_from_length
from models.content_encoder.content_encoder import ContentEncoder


######################
# fastspeech modules
######################
class LayerNorm(nn.LayerNorm):
    """Layer normalization module.
    :param int nout: output dim size
    :param int dim: dimension to be normalized
    """
    def __init__(self, nout, dim=-1):
        """Construct an LayerNorm object."""
        super(LayerNorm, self).__init__(nout, eps=1e-12)
        self.dim = dim

    def forward(self, x):
        """Apply layer normalization.
        :param torch.Tensor x: input tensor
        :return: layer normalized tensor
        :rtype torch.Tensor
        """
        if self.dim == -1:
            return super(LayerNorm, self).forward(x)
        return super(LayerNorm,
                     self).forward(x.transpose(1, -1)).transpose(1, -1)


class DurationPredictor(nn.Module):
    def __init__(
        self,
        in_channels: int,
        filter_channels: int,
        n_layers: int = 2,
        kernel_size: int = 3,
        p_dropout: float = 0.1,
        padding: str = "SAME"
    ):
        super(DurationPredictor, self).__init__()
        self.conv = nn.ModuleList()
        self.kernel_size = kernel_size
        self.padding = padding
        for idx in range(n_layers):
            in_chans = in_channels if idx == 0 else filter_channels
            self.conv += [
                nn.Sequential(
                    nn.ConstantPad1d(((kernel_size - 1) // 2,
                                      (kernel_size - 1) //
                                      2) if padding == 'SAME' else
                                     (kernel_size - 1, 0), 0),
                    nn.Conv1d(
                        in_chans,
                        filter_channels,
                        kernel_size,
                        stride=1,
                        padding=0
                    ), nn.ReLU(), LayerNorm(filter_channels, dim=1),
                    nn.Dropout(p_dropout)
                )
            ]
        self.linear = nn.Linear(filter_channels, 1)

    def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
        # x: [B, T, E]
        x = x.transpose(1, -1)
        x_mask = x_mask.unsqueeze(1).to(x.device)
        for f in self.conv:
            x = f(x)
            x = x * x_mask.float()

        x = self.linear(x.transpose(1, -1)
                       ) * x_mask.transpose(1, -1).float()  # [B, T, 1]
        return x


######################
# adapter modules
######################


class ContentAdapterBase(nn.Module):
    def __init__(self, d_out):
        super().__init__()
        self.d_out = d_out


class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, d_model, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() *
            (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1), :]
        return self.dropout(x)


class ContentAdapter(ContentAdapterBase):
    def __init__(
        self,
        d_model: int,
        d_out: int,
        num_layers: int,
        num_heads: int,
        duration_predictor: DurationPredictor,
        dropout: float = 0.1,
        norm_first: bool = False,
        activation: str = "gelu",
        duration_grad_scale: float = 0.0,
    ):
        super().__init__(d_out)
        self.duration_grad_scale = duration_grad_scale
        self.cls_embed = nn.Parameter(torch.randn(d_model))
        if hasattr(torch, "npu") and torch.npu.is_available():
            enable_nested_tensor = False
        else:
            enable_nested_tensor = True
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation=activation,
            norm_first=norm_first,
            batch_first=True
        )
        self.encoder_layers = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_layers,
            enable_nested_tensor=enable_nested_tensor
        )
        self.duration_predictor = duration_predictor
        self.content_proj = nn.Conv1d(d_model, d_out, 1)

    def forward(self, x, x_mask):
        batch_size = x.size(0)
        cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
        cls_embed = cls_embed.to(x.device).unsqueeze(1)
        x = torch.cat([cls_embed, x], dim=1)

        cls_mask = torch.ones(batch_size, 1).to(x_mask.device)
        x_mask = torch.cat([cls_mask, x_mask], dim=1)
        x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool())
        x_grad_rescaled = x * self.duration_grad_scale + x.detach(
        ) * (1 - self.duration_grad_scale)
        duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
        content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
        return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]


class PrefixAdapter(ContentAdapterBase):
    def __init__(
        self,
        content_dim: int,
        d_model: int,
        d_out: int,
        prefix_dim: int,
        num_layers: int,
        num_heads: int,
        duration_predictor: DurationPredictor,
        dropout: float = 0.1,
        norm_first: bool = False,
        use_last_norm: bool = True,
        activation: str = "gelu",
        duration_grad_scale: float = 0.1,
    ):
        super().__init__(d_out)
        self.duration_grad_scale = duration_grad_scale
        self.prefix_mlp = nn.Sequential(
            nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_model, d_model)
        )
        self.content_mlp = nn.Sequential(
            nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_model, d_model)
        )
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            norm_first=norm_first
        )
        if hasattr(torch, "npu") and torch.npu.is_available():
            enable_nested_tensor = False
        else:
            enable_nested_tensor = True
        self.cls_embed = nn.Parameter(torch.randn(d_model))
        # self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout)
        self.layers = nn.TransformerEncoder(
            encoder_layer=layer,
            num_layers=num_layers,
            enable_nested_tensor=enable_nested_tensor
        )
        self.use_last_norm = use_last_norm
        if self.use_last_norm:
            self.last_norm = nn.LayerNorm(d_model)
        self.duration_predictor = duration_predictor
        self.content_proj = nn.Conv1d(d_model, d_out, 1)
        nn.init.normal_(self.cls_embed, 0., 0.02)
        nn.init.xavier_uniform_(self.content_proj.weight)
        nn.init.constant_(self.content_proj.bias, 0.)

    def forward(self, content, content_mask, instruction, instruction_mask):
        batch_size = content.size(0)
        cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1)
        cls_embed = cls_embed.to(content.device).unsqueeze(1)
        content = self.content_mlp(content)
        x = torch.cat([cls_embed, content], dim=1)
        cls_mask = torch.ones(batch_size, 1,
                              dtype=bool).to(content_mask.device)
        x_mask = torch.cat([cls_mask, content_mask], dim=1)

        prefix = self.prefix_mlp(instruction)
        seq, seq_mask, perm = concat_non_padding(
            prefix, instruction_mask, x, x_mask
        )
        # seq = self.pos_embed(seq)
        x = self.layers(seq, src_key_padding_mask=~seq_mask.bool())
        if self.use_last_norm:
            x = self.last_norm(x)
        _, x = restore_from_concat(x, instruction_mask, x_mask, perm)

        x_grad_rescaled = x * self.duration_grad_scale + x.detach(
        ) * (1 - self.duration_grad_scale)
        duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1)
        content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
        return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:]


class CrossAttentionAdapter(ContentAdapterBase):
    def __init__(
        self,
        d_out: int,
        content_dim: int,
        prefix_dim: int,
        num_heads: int,
        duration_predictor: DurationPredictor,
        dropout: float = 0.1,
        duration_grad_scale: float = 0.1,
    ):
        super().__init__(d_out)
        self.attn = nn.MultiheadAttention(
            embed_dim=content_dim,
            num_heads=num_heads,
            dropout=dropout,
            kdim=prefix_dim,
            vdim=prefix_dim,
            batch_first=True,
        )
        self.duration_grad_scale = duration_grad_scale
        self.duration_predictor = duration_predictor
        self.global_duration_mlp = nn.Sequential(
            nn.Linear(content_dim, content_dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(content_dim, 1)
        )
        self.norm = nn.LayerNorm(content_dim)
        self.content_proj = nn.Conv1d(content_dim, d_out, 1)

    def forward(self, content, content_mask, prefix, prefix_mask):
        attn_output, attn_output_weights = self.attn(
            query=content,
            key=prefix,
            value=prefix,
            key_padding_mask=~prefix_mask.bool()
        )
        attn_output = attn_output * content_mask.unsqueeze(-1).float()
        x = self.norm(attn_output + content)
        x_grad_rescaled = x * self.duration_grad_scale + x.detach(
        ) * (1 - self.duration_grad_scale)
        x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
                       ).sum(dim=1) / content_mask.sum(dim=1,
                                                       keepdim=True).float()
        global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
        local_duration = self.duration_predictor(
            x_grad_rescaled, content_mask
        ).squeeze(-1)
        content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
        return content, content_mask, global_duration, local_duration


class ExperimentalCrossAttentionAdapter(ContentAdapterBase):
    def __init__(
        self,
        d_out: int,
        content_dim: int,
        prefix_dim: int,
        num_heads: int,
        duration_predictor: DurationPredictor,
        dropout: float = 0.1,
        duration_grad_scale: float = 0.1,
    ):
        super().__init__(d_out)
        self.content_mlp = nn.Sequential(
            nn.Linear(content_dim, content_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(content_dim, content_dim),
        )
        self.content_norm = nn.LayerNorm(content_dim)
        self.prefix_mlp = nn.Sequential(
            nn.Linear(prefix_dim, prefix_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(prefix_dim, prefix_dim),
        )
        self.prefix_norm = nn.LayerNorm(content_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=content_dim,
            num_heads=num_heads,
            dropout=dropout,
            kdim=prefix_dim,
            vdim=prefix_dim,
            batch_first=True,
        )
        self.duration_grad_scale = duration_grad_scale
        self.duration_predictor = duration_predictor
        self.global_duration_mlp = nn.Sequential(
            nn.Linear(content_dim, content_dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(content_dim, 1)
        )
        self.content_proj = nn.Sequential(
            nn.Linear(content_dim, d_out),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_out, d_out),
        )
        self.norm1 = nn.LayerNorm(content_dim)
        self.norm2 = nn.LayerNorm(d_out)
        self.init_weights()

    def init_weights(self):
        def _init_weights(module):
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.)

        self.apply(_init_weights)

    def forward(self, content, content_mask, prefix, prefix_mask):
        content = self.content_mlp(content)
        content = self.content_norm(content)
        prefix = self.prefix_mlp(prefix)
        prefix = self.prefix_norm(prefix)
        attn_output, attn_weights = self.attn(
            query=content,
            key=prefix,
            value=prefix,
            key_padding_mask=~prefix_mask.bool(),
        )
        attn_output = attn_output * content_mask.unsqueeze(-1).float()
        x = attn_output + content
        x = self.norm1(x)
        x_grad_rescaled = x * self.duration_grad_scale + x.detach(
        ) * (1 - self.duration_grad_scale)
        x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float()
                       ).sum(dim=1) / content_mask.sum(dim=1,
                                                       keepdim=True).float()
        global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
        local_duration = self.duration_predictor(
            x_grad_rescaled, content_mask
        ).squeeze(-1)
        content = self.content_proj(x)
        content = self.norm2(content)
        return content, content_mask, global_duration, local_duration


class ContentEncoderAdapterMixin:
    def __init__(
        self,
        content_encoder: ContentEncoder,
        content_adapter: ContentAdapterBase | None = None
    ):
        self.content_encoder = content_encoder
        self.content_adapter = content_adapter

    def encode_content(
        self,
        content: list[Any],
        task: list[str],
        device: str | torch.device,
        instruction: torch.Tensor | None = None,
        instruction_lengths: torch.Tensor | None = None
    ):
        content_output: dict[
            str, torch.Tensor] = self.content_encoder.encode_content(
                content, task, device=device
            )
        content, content_mask = content_output["content"], content_output[
            "content_mask"]

        if instruction is not None:
            instruction_mask = create_mask_from_length(instruction_lengths)
            (
                content,
                content_mask,
                global_duration_pred,
                local_duration_pred,
            ) = self.content_adapter(
                content, content_mask, instruction, instruction_mask
            )

        return_dict = {
            "content": content,
            "content_mask": content_mask,
            "length_aligned_content": content_output["length_aligned_content"],
        }
        if instruction is not None:
            return_dict["global_duration_pred"] = global_duration_pred
            return_dict["local_duration_pred"] = local_duration_pred

        return return_dict
