import torch
from torch import nn





class SETR(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(SETR, self).__init__()
        self.img_height = 256
        self.img_size = 16 #
        
        self.layers = 12
        self.nhead = 4
        self.embedding_dim = 512
        
        self.linear = nn.Sequential(nn.Linear(in_ch*(self.img_height//self.img_size)**2,self.embedding_dim))
        
        self.pe = nn.Parameter(torch.randn(1,1,self.embedding_dim))
        self.pe_dropout = nn.Dropout(0)
        
        self.transformers = nn.ModuleList()
        for i in range(self.layers):
            self.transformers.append(nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=self.nhead))
        
        
        self.mid_output = [3-1,6-1,9-1,12-1]
        self.conv_reshapes = nn.ModuleList()
        for i in range(len( self.mid_output)):
            self.conv_reshapes.append(
                                    nn.Sequential(nn.Conv2d(self.embedding_dim,self.embedding_dim//2,kernel_size=3,stride=1,padding=1),
                                         nn.Conv2d(self.embedding_dim//2,self.embedding_dim//2,kernel_size=3,stride=1,padding=1),
                                          nn.Conv2d(self.embedding_dim//2,self.embedding_dim//4,kernel_size=3,stride=1,padding=1),
                                          nn.Upsample(scale_factor=4,mode='bilinear'),)
                                    )
        
        self.final = nn.Sequential(nn.Upsample(scale_factor=4,mode='bilinear'),
                                  nn.Conv2d(self.embedding_dim,self.embedding_dim,kernel_size=3,stride=1,padding=1),
                                  nn.BatchNorm2d(self.embedding_dim),
                                  nn.ReLU(),
                                  nn.Conv2d(self.embedding_dim,out_ch,kernel_size=1,stride=1))
        
    def forward(self,x):
        B,C,H,W = x.shape
        
        x = x.reshape(B,C,H//self.img_size,self.img_size,W//self.img_size,self.img_size)
        x = x.permute(2,4,0,1,3,5).reshape(-1,B,C*self.img_size**2)
        
        x = self.linear(x)
        x = x + self.pe
        x = self.pe_dropout(x)
        
        output = []
        for i in range(self.layers):
            x = self.transformers[i](x) ##-1*B*k
            if i in self.mid_output:
                output.append(x)
        
        out = []
        for i in range(len(self.mid_output)):
            cur = output[i].reshape(H//self.img_size,W//self.img_size,B,-1).permute(2,3,0,1)
            out.append(self.conv_reshapes[i](cur))
        
        out = torch.cat(out,dim = 1)
        out = self.final(out)
        
        return out
                
                