import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4):
        super(MHSA, self).__init__()
        self.heads = heads

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

        if self.training:
            n_half = n_batch // 2
            q[:n_half, :, :, :], q[n_half:, :, :, :] = q[n_half:, :, :, :], q[:n_half, :, :, :]
            # k[:n_half, :, :, :], k[n_half:, :, :, :] = k[n_half:, :, :, :], k[:n_half, :, :, :]

        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)

        content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.matmul(v, attention.permute(0, 1, 3, 2))
        out = out.view(n_batch, C, width, height)

        return out

class TransformerEncoder(nn.Sequential):
    def __init__(self, in_channels=256, emb_size=128, patch_size=10, num_classes=10):
        super().__init__()
        self.in_channels = in_channels
        self.emb_size = emb_size
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.mhsa = MHSA(n_dims=self.in_channels, width=self.patch_size, height=self.patch_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.bn = nn.BatchNorm2d(self.in_channels)
        self.fc1 = nn.Linear(self.in_channels*(self.patch_size//2)*(self.patch_size//2), 512)
        self.fc2 = nn.Linear(512, self.num_classes)

    def forward(self, x):

        x = self.mhsa(x)
        x = self.pool(x)
        x = self.bn(x)
        x = F.relu(x)
        x = x.view(-1, self.in_channels*(self.patch_size//2)*(self.patch_size//2))
        x = self.fc1(x)
        x = self.fc2(x)
        return x
