# -*- coding: utf-8 -*-

import torch
import torch.nn as nn


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, emb_size, H/P, W/P]
        x = x.flatten(2)  # [B, emb_size, N]
        x = x.transpose(1, 2)  # [B, N, emb_size]
        return x


class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10, emb_size=768, depth=12, n_heads=12, mlp_dim=3072, dropout=0.1):
        super(ViT, self).__init__()
        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.pos_embed = nn.Parameter(torch.randn(1, 1 + self.patch_embed.n_patches, emb_size))
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=n_heads, dim_feedforward=mlp_dim, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, 512),
            nn.ReLU(inplace=True)
        )

        self.linear = nn.Linear(512, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # [B, N, emb_size]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, emb_size]
        x = torch.cat((cls_tokens, x), dim=1)  # [B, 1+N, emb_size]
        x = x + self.pos_embed[:, :x.size(1)]
        x = self.dropout(x)

        x = self.encoder(x)  # [B, 1+N, emb_size]
        cls_output = x[:, 0]  # [B, emb_size]

        x = self.mlp_head(cls_output)  # [B, 512]
        embedding = x.detach()
        x = self.linear(x)  # [B, num_classes]

        return x, embedding


def vit(num_classes):
    return ViT(num_classes=num_classes)


if __name__ == '__main__':
    net = ViT(num_classes=11)
    dummy_input = torch.randn(1, 3, 224, 224)
    output, emb = net(dummy_input)
    print(net)
    print("Output shape:", output.shape)
    print("Embedding shape:", emb.shape)
