"""
A simplified Data-Efficient Image Transformer (DeiT) for ImageNet classification.

Reference:
[1] Touvron, Hugo, et al. "Training data-efficient image transformers & distillation through attention."
    ICML, 2021. 
Official code: https://github.com/facebookresearch/deit

NOTE: This minimal version does not include knowledge distillation or some advanced training tricks.
      It is structurally similar to a ViT, but with carefully chosen hyperparams.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# --------------------
# Patch Embedding (same as standard ViT)
# --------------------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size // patch_size, img_size // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)

    def forward(self, x):
        # (B, in_chans, H, W) -> (B, embed_dim, H/ps, W/ps)
        x = self.proj(x)
        # Flatten -> (B, embed_dim, num_patches)
        x = x.flatten(2)
        # Transpose -> (B, num_patches, embed_dim)
        x = x.transpose(1, 2)
        return x

# --------------------
# Multi-Head Self-Attention
# --------------------
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        # qkv -> shape: (B, N, 3 * C)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        # rearrange to (3, B, heads, N, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# --------------------
# MLP / Feed-Forward
# --------------------
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.0):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# --------------------
# Transformer Block (Similar to ViT)
# --------------------
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0,
                 attn_drop=0.0, proj_drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        self.norm2 = nn.LayerNorm(dim)

        hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=hidden_dim, drop=proj_drop)
        self.drop_path_rate = drop_path

    def forward(self, x):
        # (1) Self-Attention
        x_res = x
        x = self.norm1(x)
        x = self.attn(x)
        if self.drop_path_rate > 0.0 and self.training:
            x = self.drop_path(x, self.drop_path_rate)
        x = x_res + x

        # (2) MLP
        x_res = x
        x = self.norm2(x)
        x = self.mlp(x)
        if self.drop_path_rate > 0.0 and self.training:
            x = self.drop_path(x, self.drop_path_rate)
        x = x_res + x

        return x

    def drop_path(self, x, drop_prob):
        """Stochastic Depth."""
        if drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - drop_prob
        shape = (x.shape[0],) + (1,)*(x.ndim-1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x / keep_prob * random_tensor

# --------------------
# DeiT: Vision Transformer with slight modifications 
# (e.g., specialized training approach, augmentation).
# For simplicity, we keep it close to standard ViT, 
# but add a smaller # of parameters than a typical base ViT.
# --------------------
class DeiT(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4.0,
        attn_drop=0.0,
        proj_drop=0.0,
        drop_path=0.0,
        temp=1.0
    ):
        super().__init__()
        self.num_classes = num_classes
        self.temp = temp
        self.embed_dim = embed_dim
        self.num_tokens = 1  # class token

        # Patch embedding
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Class token and positional embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=proj_drop)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                attn_drop=attn_drop,
                proj_drop=proj_drop,
                drop_path=drop_path
            )
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # weight init
        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.xavier_uniform_(self.head.weight)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):
        """
        x shape: (B, 3, H, W)
        Return shape: (B, num_classes)
        """
        B = x.shape[0]
        # Patch Embedding
        x = self.patch_embed(x)

        # Concat class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embeddings
        x = x + self.pos_embed[:, : x.size(1), :]
        x = self.pos_drop(x)

        # Transformer Blocks
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        # Class token output
        cls_out = x[:, 0]
        logits = self.head(cls_out) / self.temp
        return logits

# --------------------
# Predefined Configs
# --------------------
def deit_tiny(temp=1.0, **kwargs):
    """
    DeiT-Tiny: embed_dim=192, depth=12, num_heads=3
    """
    return DeiT(
        img_size=224,
        patch_size=16,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4.0,
        temp=temp,
        **kwargs
    )

def deit_small(temp=1.0, **kwargs):
    """
    DeiT-Small: embed_dim=384, depth=12, num_heads=6 
    (The standard 'DeiT-S' config from the paper)
    """
    return DeiT(
        img_size=224,
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4.0,
        temp=temp,
        **kwargs
    )

def deit_base(temp=1.0, **kwargs):
    """
    DeiT-Base: embed_dim=768, depth=12, num_heads=12 
    (Similar to ViT-Base shape)
    """
    return DeiT(
        img_size=224,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        temp=temp,
        **kwargs
    )
