import torch
import torch.nn as nn
import math
from timm.models.vision_transformer import DropPath, Mlp, Attention as BaseAttn
import torch.nn.functional as F
from ptflops import get_model_complexity_info


class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        return x


class Attention(BaseAttn):
    def __init__(self, *args, **kwargs):
        super(Attention, self).__init__(*args, **kwargs)
        self.identity = nn.Identity()

    def forward(self, x, return_latent=False):
        B, N, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        # attn = F.layer_norm(attn, attn.shape[-1:])
        if return_latent:
            return
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.identity(x)  # used for get attention map for visualization and distillation via forward hook
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.1):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = Attention(embed_dim, qkv_bias=True, num_heads=num_heads, attn_drop=dropout, proj_drop=dropout)
        self.drop_path = nn.Identity()  # optional: can use DropPath if needed
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=dropout)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=200,
                 embed_dim=192, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.n_patches

        # learnable class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # learnable positional encoding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # Transformer blocks
        self.blocks = nn.Sequential(*[
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, N, D)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, N+1, D)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)
        # cls_output = x[:, 0]  # only the class token
        cls_output = x[:, 1:].mean(dim=1)
        x = self.head(cls_output)
        return x


def get_vit_tiny(num_classes=200, pretrained=False, dropout=0.1, depth=12):
    # pretrained = False，
    return ViT(num_classes=num_classes, dropout=dropout, depth=depth)


def get_model_flops(model, img_size=224, device="cuda"):
    # model = model.to(device)
    # with torch.cuda.device(device):
    macs, params = get_model_complexity_info(
        model,
        (3, img_size, img_size),
        as_strings=True,
        print_per_layer_stat=False,
        verbose=False
    )
    print(f"[Model Stats] FLOPs (MACs): {macs}, Params: {params}")
    return macs, params


if __name__ == "__main__":
#     # 在这里测试模型 FLOPs
#     model = get_vit_tiny(num_classes=1000, dropout=0.1)
#     get_model_flops(model, img_size=224, device="cpu")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = get_vit_tiny(depth=3, num_classes=1000, dropout=0.1).to(device)
    model.eval()  # 如果只测前向，不做训练

    # 模拟输入
    B = 256  # batch size，可改
    img_size = 224
    x = torch.randn(B, 3, img_size, img_size, device=device)

    torch.cuda.reset_peak_memory_stats(device)  # 重置统计
    with torch.no_grad():
        _ = model(x)

    peak_memory = torch.cuda.max_memory_allocated(device) / (1024 ** 3)  # 转为 GB
    print(f"Peak GPU memory (forward only, B={B}): {peak_memory:.2f} GB")

    # 如果想测训练（forward + backward）
    model.train()
    y = torch.randint(0, 1000, (B,), device=device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    torch.cuda.reset_peak_memory_stats(device)
    optimizer.zero_grad()
    out = model(x)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    peak_memory_train = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
    print(f"Peak GPU memory (training, B={B}): {peak_memory_train:.2f} GB")