import math
import torch
import torch.nn as nn
from sonics.layers.embedding import (
    SinusoidPositionalEncoding,
    LearnedPositionalEncoding,
)


class STTokenizer(nn.Module):
    def __init__(
        self,
        input_spec_dim,
        input_temp_dim,
        t_clip,
        f_clip,
        embed_dim,
        pre_norm=False,
        pe_learnable=False,
    ):
        super(STTokenizer, self).__init__()
        self.input_spec_dim = input_spec_dim
        self.input_temp_dim = input_temp_dim
        self.t_clip = t_clip
        self.f_clip = f_clip
        self.embed_dim = embed_dim
        self.pre_norm = pre_norm
        self.pe_learnable = pe_learnable

        self.num_temporal_tokens = math.floor(
            (input_temp_dim - t_clip) / t_clip + 1
        )  # floor((1280 - 5) / 5 + 1)= 256
        self.num_spectral_tokens = math.floor(
            (input_spec_dim - f_clip) / f_clip + 1
        )  # floor((128 - 3) / 3 + 1) = 42
        # L_out = floor((L_in + 2*p - d*(k - 1) - 1) / s + 1) (ref: PyTorch docs)
        self.num_tokens = (
            self.num_temporal_tokens + self.num_spectral_tokens
        )  # 255 + 42 = 299
        # For ViT, num_tokens = (1280 * 128)//(5 * 3) = 10922 :)

        self.temporal_tokenizer = Tokenizer1D(
            input_spec_dim,
            embed_dim,
            clip_size=t_clip,
            num_clips=self.num_temporal_tokens,
            pre_norm=pre_norm,
            pe_learnable=pe_learnable,
        )
        self.spectral_tokenizer = Tokenizer1D(
            input_temp_dim,
            embed_dim,
            clip_size=f_clip,
            num_clips=self.num_spectral_tokens,
            pre_norm=pre_norm,
            pe_learnable=pe_learnable,
        )

    def forward(self, x):
        # Temporal tokenization
        temporal_input = x  # shape: (B, F, T)
        temporal_tokens = self.temporal_tokenizer(
            temporal_input
        )  # shape: (B, T/t, dim)

        # Spectral tokenization
        spectral_input = x.permute(0, 2, 1)  # shape: (batch_size, T, F)
        spectral_tokens = self.spectral_tokenizer(
            spectral_input
        )  # shape: (B, F/f, dim)

        spectro_temporal_tokens = torch.cat(
            (temporal_tokens, spectral_tokens), dim=1
        )  # shape: (B, T/t + F/f, dim)
        return spectro_temporal_tokens


class Tokenizer1D(nn.Module):
    """Teimporal/Spectral Tokenizer

    Whisper uses temporal tokenizer but time_clip_size is too small, stride=1,  thus
    complexity is very high. We use stride=clip_size - 1 to reduce complexity.
    """

    def __init__(
        self,
        input_dim,
        token_dim,
        clip_size,
        num_clips,
        pre_norm=False,
        pe_learnable=False,
    ):
        super(Tokenizer1D, self).__init__()
        self.conv1d = nn.Conv1d(
            input_dim,
            token_dim,
            clip_size,
            stride=clip_size,
            bias=not pre_norm,  #  # disable bias if pre-norm is used (e.g. CLIP)
        )
        self.act = nn.GELU()
        self.pos_encoder = (
            SinusoidPositionalEncoding(token_dim)
            if not pe_learnable
            else LearnedPositionalEncoding(token_dim, num_clips)
        )
        self.norm_pre = nn.LayerNorm(token_dim, eps=1e-6) if pre_norm else nn.Identity()

    def forward(self, x):
        x = x  # (F, T)
        x = self.conv1d(x)  # (F, T) -> (dim, T/t)
        x = self.act(x)
        x = x.transpose(1, 2)  # (dim, T/t) -> (T/t, dim)
        x = self.pos_encoder(x)  # add position embeds
        x = self.norm_pre(x)
        return x
