"""
Simplified Vision Transformer (ViT) for ImageNet classification.
Reference:
[1] Dosovitskiy, A., et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
    In ICLR, 2021.
[2] https://github.com/google-research/vision_transformer
[3] https://pytorch.org/vision/stable/models/vision_transformer.html
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding. Splits an image into patches and
    embeds each patch with a linear projection.
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size // patch_size, img_size // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """
        x shape: (B, 3, H, W)
        returns shape: (B, num_patches, embed_dim)
        """
        x = self.proj(x)                        # shape: (B, embed_dim, H/ps, W/ps)
        x = x.flatten(2)                        # shape: (B, embed_dim, num_patches)
        x = x.transpose(1, 2)                   # shape: (B, num_patches, embed_dim)
        return x


class Attention(nn.Module):
    """
    Multi-Head Self-Attention module.
    """
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)   # (3, B, num_heads, N, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]   # each shape: (B, num_heads, N, head_dim)
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MLP(nn.Module):
    """
    Feed-forward network within each Transformer block.
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    """
    Single transformer encoder block:
    1) LayerNorm
    2) Multi-Head Self-Attention
    3) Residual
    4) LayerNorm
    5) MLP
    6) Residual
    """
    def __init__(self, dim, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        self.norm2 = nn.LayerNorm(dim)

        hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=hidden_dim, drop=proj_drop)

        # Stochastic depth (a.k.a. drop path), used in some ViT variants
        self.drop_path_rate = drop_path

    def forward(self, x):
        # Self-Attention + Residual
        x_res = x
        x = self.norm1(x)
        x = self.attn(x)
        if self.drop_path_rate > 0.0 and self.training:
            x = self.drop_path(x, self.drop_path_rate)
        x = x_res + x

        # MLP + Residual
        x_res = x
        x = self.norm2(x)
        x = self.mlp(x)
        if self.drop_path_rate > 0.0 and self.training:
            x = self.drop_path(x, self.drop_path_rate)
        x = x_res + x

        return x

    def drop_path(self, x, drop_prob):
        """Implementation of DropPath (stochastic depth)."""
        if drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - drop_prob
        shape = (x.shape[0],) + (1,)*(x.ndim-1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        output = x / keep_prob * random_tensor
        return output


class VisionTransformer(nn.Module):
    """
    Simplified Vision Transformer for ImageNet classification.
    - Patch embedding
    - Learnable class token
    - Position embeddings
    - Transformer encoder blocks
    - MLP Head
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        attn_drop=0.0,
        proj_drop=0.0,
        drop_path=0.0,
        temp=1.0
    ):
        super().__init__()
        self.num_classes = num_classes
        self.temp = temp
        self.embed_dim = embed_dim
        self.num_tokens = 1  # class token
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=proj_drop)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                attn_drop=attn_drop,
                proj_drop=proj_drop,
                drop_path=drop_path
            )
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)

        # Init
        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
        for name, p in self.head.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(p)
            elif 'bias' in name:
                nn.init.zeros_(p)

    def forward(self, x):
        """
        x: shape (B, 3, H, W)
        returns: shape (B, num_classes)
        """
        B = x.shape[0]
        # Embed patches
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)

        # Concat class token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)          # (B, 1+num_patches, embed_dim)

        # Add positional embedding
        x = x + self.pos_embed[:, : x.size(1), :]
        x = self.pos_drop(x)

        # Pass through Transformer blocks
        for blk in self.blocks:
            x = blk(x)

        # Final norm
        x = self.norm(x)

        # cls_token output
        cls_out = x[:, 0]
        logits = self.head(cls_out) / self.temp
        return logits


def vit_tiny_patch16_224(temp=1.0, **kwargs):
    """ ViT-Tiny: embed_dim=192, depth=12, num_heads=3, etc. """
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4.0,
        temp=temp,
        **kwargs
    )
    return model

def vit_small_patch16_224(temp=1.0, **kwargs):
    """ ViT-Small: embed_dim=384, depth=12, num_heads=6. """
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4.0,
        temp=temp,
        **kwargs
    )
    return model

def vit_base_patch16_224(temp=1.0, **kwargs):
    """ ViT-Base: embed_dim=768, depth=12, num_heads=12. """
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        temp=temp,
        **kwargs
    )
    return model

def vit_large_patch16_224(temp=1.0, **kwargs):
    print("Call vit_large_patch16_224")
    """ ViT-Large: embed_dim=1024, depth=24, num_heads=16. """
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4.0,
        temp=temp,
        **kwargs
    )
    return model
