import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(2, embed_dim, kernel_size=3, bias=False)
        self.bn1 = nn.BatchNorm2d(embed_dim)  
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def initialize_conv1(self):
        self.conv1.weight.data[0,0,0,0] = 1
        self.conv1.weight.data[0,1,0,0] = -1
        self.conv1.weight.data[1,0,0,0] = -1
        self.conv1.weight.data[1,1,0,0] = 1

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x))) 
        x = self.pool(x)  
        return x

class Head(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim

        self.fc = nn.Linear(embed_dim, 1)

    def forward(self, x):# BxNxE
        x = self.fc(x) # BxNx1
        return x

class Detector(nn.Module):
    def __init__(self, embed_dim, *args, **kwargs):
        super().__init__()
        if kwargs["num_encoder_layers"] == 1:
            self.encoder = Encoder(embed_dim)
        else:
            raise NotImplementedError
            
        self.head = Head(embed_dim)

        if kwargs["num_layers"] > 0:
            self.layers = nn.Sequential(*[
                nn.TransformerEncoderLayer(
                    d_model=embed_dim, 
                    nhead=1, 
                    dim_feedforward=embed_dim*2, 
                    batch_first=True, norm_first=True) for _ in range(kwargs["num_layers"])
            ])
        else:
            self.layers = nn.Identity()

        self.embed_dim = embed_dim
        # self.num_head = kwargs["num_head"]
        self.num_layers = kwargs["num_layers"]
        self.num_encoder_layers = kwargs["num_encoder_layers"]

    def forward(self, x):
        # print(x.shape)
        xshape = x.shape # BxNxwxh
        x = x.unsqueeze(2) # BxNx1xwxh
        p = x.repeat(1,xshape[1], 1, 1, 1) # Bx(NxN)x1xwxh
        c = torch.repeat_interleave(x, xshape[1], dim = 1) # Bx(NxN)x1xwxh
        x = torch.cat((p,c), dim = 2) # Bx(NxN)x2xwxh
        x = x.view(xshape[0]*xshape[1]*xshape[1], 2, *xshape[2:]) # (BxNxN)x2xwxh

        x = self.encoder(x) # (BxNxN)xEx1x1
        x = x.squeeze(-1).squeeze(-1)# (BxNxN)xE

        x = x.view(xshape[0], xshape[1]*xshape[1], self.embed_dim)# Bx(NxN)xE
        x = self.layers(x) # Bx(NxN)xE

        x = self.head(x) # Bx(NxN)x1
        x = x.view(xshape[0], xshape[1], xshape[1])
        x = torch.softmax(x, dim = 2)

        return x

def get_model(args):
    if args.model == "default":
        return Detector(args.embed_dim, num_layers = args.num_layers, num_encoder_layers = args.num_encoder_layers)
    else:
        raise NotImplementedError