import torch
import torch.nn as nn
from sonics.layers import (
    SinusoidPositionalEncoding,
    LearnedPositionalEncoding,
    Transformer,
)
from timm.layers import PatchEmbed


class ViT(nn.Module):
    def __init__(
        self,
        image_size,
        patch_size,
        embed_dim,
        num_heads,
        num_layers,
        pe_learnable=False,
        patch_norm=False,
        pos_drop_rate=0.0,
        attn_drop_rate=0.0,
        proj_drop_rate=0.0,
        mlp_ratio=4.0,
    ):
        super().__init__()
        assert (
            image_size[0] % patch_size == 0 and image_size[1] % patch_size == 0
        ), "Image dimensions must be divisible by patch size."

        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.pe_learnable = pe_learnable
        self.patch_norm = patch_norm
        self.pos_drop_rate = pos_drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.proj_drop_rate = proj_drop_rate
        self.mlp_ratio = mlp_ratio

        self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)

        # self.patch_conv = nn.Conv2d(
        #     1, embed_dim, kernel_size=patch_size, stride=patch_size
        # )  # Original ViT has 3 input channels
        self.patch_encoder = PatchEmbed(
            img_size=image_size,
            patch_size=patch_size,
            in_chans=1,
            embed_dim=embed_dim,
            norm_layer=nn.LayerNorm if patch_norm else None,
        )
        self.pos_encoder = (
            SinusoidPositionalEncoding(embed_dim)
            if not pe_learnable
            else LearnedPositionalEncoding(embed_dim, self.num_patches)
        )
        self.pos_drop = nn.Dropout(p=pos_drop_rate)

        self.transformer = Transformer(
            embed_dim,
            num_heads,
            num_layers,
            attn_drop=self.attn_drop_rate,
            proj_drop=self.proj_drop_rate,
            mlp_ratio=self.mlp_ratio,
        )

    def forward(self, x):
        B = x.shape[0]
        # x = x.unsqueeze(1)  # B x 1 x n_mels x n_frames # taken care of in the AudioClassifier
        if x.dim() == 3:
            x = x.unsqueeze(1)  # timm PatchEmbed expects 4D tensor

        # Convolutional patch embedding
        # patches = self.patch_conv(x)  # B x embed_dim x num_patches_h x num_patches_w
        patches = self.patch_encoder(x)

        # # Reshape patches
        # patches = patches.permute(
        #     0, 2, 3, 1
        # ).contiguous()  # B x num_patches_h x num_patches_w x embed_dim
        # patches = patches.view(B, -1, patches.size(-1))  # B x num_patches x embed_dim

        # Add positional embeddings
        embeddings = self.pos_encoder(patches)

        # Positional dropout
        embeddings = self.pos_drop(embeddings)

        # Transformer encoding
        output = self.transformer(embeddings)  # B x num_patches x embed_dim

        return output


batch_size = 1
input_height = 128
input_width = 384 * 6 * 4
patch_size = 16
