import torch
import torch.nn as nn
import math


class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        return x


class MultiHeadAttention(nn.Module):
    """Multi-Head Self Attention"""
    def __init__(self, embed_dim, num_heads=8, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim should be fully divided by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape  # batch, seq_len, embed_dim

        # Q, K, V
        Q = self.q_proj(x)  # [B, N, C]
        K = self.k_proj(x)
        V = self.v_proj(x)

        # multi-head
        Q = Q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)  # [B, heads, N, head_dim]
        K = K.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        # get attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # [B, heads, N, N]
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        attn_output = torch.matmul(attn_probs, V)  # [B, heads, N, head_dim]
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, N, C)

        return self.out_proj(attn_output)


class MultiBranchMHSA(nn.Module):
    """Multi-branch MHSA layer"""
    def __init__(self, embed_dim, num_heads=12, num_branches=3, dropout=0.1):
        super().__init__()
        self.branches = nn.ModuleList([
            MultiHeadAttention(embed_dim, num_heads, dropout)
            for _ in range(num_branches)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x_ = self.norm(x)
        outs = [attn(x_) for attn in self.branches]
        return sum(outs) / len(outs)


class MultiBranchMLP(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0, num_branches=3, dropout=0.1):
        super().__init__()
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
                nn.Dropout(dropout)
            )
            for _ in range(num_branches)
        ])

    def forward(self, x):
        outs = [mlp(x) for mlp in self.branches]
        return sum(outs) / len(outs)


class ParallelTransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0,
                 dropout=0.1, attn_branches=3, mlp_branches=3):
        super().__init__()
        self.attn = MultiBranchMHSA(embed_dim, num_heads, num_branches=attn_branches, dropout=dropout)
        self.mlp = MultiBranchMLP(embed_dim, mlp_ratio, num_branches=mlp_branches, dropout=dropout)

    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x


class multibranchViTTiny(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=200,
                 embed_dim=192, depth=6, num_heads=12, mlp_ratio=4.0, dropout=0.1,
                 attn_branches=2, mlp_branches=2):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        self.blocks = nn.Sequential(*[
            ParallelTransformerBlock(embed_dim, num_heads, mlp_ratio, dropout,
                                     attn_branches=attn_branches, mlp_branches=mlp_branches)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 1:].mean(dim=1))


def get_multibranch_vit(num_classes=200, attn_branches=3, mlp_branches=3, dropout=0.1):
    return multibranchViTTiny(num_classes=num_classes,
                           attn_branches=attn_branches,
                           mlp_branches=mlp_branches,
                           dropout=dropout)
