import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
#from utils import save_net,load_net
from models.cbam_model import ChannelAttention,SpatialAttention
from models.deform_conv import DeformConv2D
import time
import numpy as np
from models.FourierEncoding import (BasicEncoding, PositionalEncoding, GaussianEncoding, PosEncoding)


from torch.autograd import Variable



import random


from torch.nn import Dropout


    
#from thirdparty_complex.complexLayers import ComplexConv2d, ComplexMaxPool2d
from math import ceil


from models.ses_cov import SESConv_H_H, SESConv_Z2_H, SESConv_H_H_1x1, SESMaxProjection


import torch.utils.model_zoo as model_zoo
from models.pys_model import PyConvHead
from models.GCT import GCT




def get_coordinates(batch_size, res_original_w, res_original_h, k_HR, device, lower = -0.99, upper = 0.99, endpoint_s = True):
    '''
    Generate coordinate matrices
    Args:
        k_HR (int): high resolution data reduction ratio
        k_LR (int): low resolution data reduction ratio
    Returns:
        d_HR (torch.tensor): high resolution data sample coordinate matrix
    '''
    #lower, upper = -0.99, 0.99 
    
    #lower = -1 + 2/res_original/2
    #upper = 1 - 2/res_original/2
    x = np.linspace(lower, upper, res_original_h, endpoint = endpoint_s)#, endpoint=True
    y = np.linspace(lower, upper, res_original_w, endpoint = endpoint_s)
    xx, yy = np.meshgrid(x, y)
    #print(xx.shape, yy.shape)
    dx_HR, dy_HR = [], []
    for i in range(0, res_original_w, k_HR):
        tmp_x, tmp_y = [], []
        for j in range(0, res_original_h, k_HR):
            tmp_x.append(xx[i][j])
            tmp_y.append(yy[i][j])
        dx_HR.append(tmp_x)
        dy_HR.append(tmp_y)
    
    d_HR = np.array([dx_HR, dy_HR])
    del dx_HR, dy_HR, xx, yy
    
    #print(d_HR.shape)
    
    cor_map= torch.tensor(np.transpose(np.reshape(d_HR, (d_HR.shape[0], d_HR.shape[1]*d_HR.shape[2])), [1,0]), dtype=torch.float, device = device)
    
    
    cor_maps = cor_map.unsqueeze(0).repeat(batch_size, 1, 1)

    return cor_maps

                
class SE_Encoder(nn.Module):
    def __init__(self,BatchNorm=nn.BatchNorm2d,k_size=3):
        super(SE_Encoder, self).__init__()
        #self.features = features
        
        #self.features = make_layers_psg(cfg['E'])#VGG_Backbone()
        #self.backbone = effnetv2_s()
        self.features = VGG_Backbone()
        self.pyconvhead = PyConvHead(512, 512, BatchNorm)
        self.GCT = GCT(512, k_size)
        
        

    def forward(self, x):
        x = self.features(x)
        x = F.interpolate(x, scale_factor=2)
        #print(x.shape)
        #x = F.interpolate(x, 64)
        #feature_high = x
        
        #x = x_seen
        #x = torch.cat((x,x_seen), dim=1)
        
        #feature_high = x
        
        x_pyconv = self.pyconvhead(x)
        
        x_GCT = self.GCT(x)
        x = x_pyconv * x_GCT  ##512
        
        
        return x                

#############################################################################################

class New_bay_Net(nn.Module):
    def __init__(self, downsample_ratio,input_size, load_weights=False):
        super(New_bay_Net, self).__init__()
        self.seen = 0
        self.downsample_ratio = downsample_ratio
        self.input_size = input_size
        
        self.frontend_feat1 = [64, 64]
        self.frontend_feat2 = ['M', 128, 128]
        self.frontend_feat3 = ['M', 256, 256, 256]
        self.frontend_feat4 = ['M', 512, 512, 512]
        
        
        
        
        self.modelA = SE_Encoder()
        
        #self.modelA.load_state_dict(model_zoo.load_url(model_urls_psg['vgg19']), strict=False)
        
        
        self.m = 64 ## pos = [b 1024 4*m]  
        #self.z_feature_size = 64   ##z = [b, 4*m, zf,zf]  
        self.matrix_size = 256
        self.latent_c_per = 128
        
        
        self.pos_encode_layer = PositionalEncoding(0.5, self.m) ##0.5
        #self.pos_encode_layer1 = PositionalEncoding(0.5, 100) ##0.5
        #self.pos_encode_layer2 = PositionalEncoding(0.5, 120) ##0.5

        pos_out_dim = 2*2*self.m
        
        
        #weight_dim = pos_out_dim + 3*self.matrix_size + 4
        weight_dim = 4*self.latent_c_per + 4
        
        
        self.Encoder2z = Encoder2z(self.input_size, weight_dim)
        

        
        self.cc_decoder = build_cc_decoder(self.matrix_size, self.m, pos_out_dim, self.latent_c_per)   #### output_layer's size
        

        self.kl_div = 0
        
        #self.grad_layer = GradLayer()
        
        self.final_layer = nn.Conv2d(128, 1, 1)
        

    def forward(self, x, grid_c, mode, epoch):
    
        #x_seen = self.SESN(x)
        #print(x_seen.shape)
        
        #discrete_density, x, feature_high = self.modelA(x, x_seen)  ###512, 512 all 64 64
        x = self.modelA(x)
        
        #discrete_density = x#####discrete density maps
        
        ############################
        
        #x = F.interpolate(x, 128)###b 1 64 64
        
        #x = self.final_layer(x)
        #return x, x, x
        
        
        x = self.Encoder2z(x)
        
        #print(x.shape)
        b = x.shape[0]
        density_w = x.shape[2]
        density_h = x.shape[3]
        
        grid_new = get_coordinates(b, density_w, density_h, 1, x.device)
        #print(grid_new.shape)
        
        #print(grid_c.shape, grid_sr.shape)[16, 1024, 2])
        

        #print(grid_c1.shape, grid_c2.shape, grid_c3.shape)###26 30 34
        
        
        grid_c1 = self.pos_encode_layer(grid_new)
        


        if mode == 'train': #and epoch>500:
            mod = True
        else:
            mod = False
                

            
            
        x, out_sigma, kl_div = self.cc_decoder(x, grid_c1, mod) #[16, 1024]
        
        #print(x.shape)
        self.kl_div = kl_div
        discrete_density = x
        return x, discrete_density, out_sigma
        

        
        
        


    
class Encoder2z(nn.Module):
    def __init__(self, input_size, m_size, load_weights=False):
        super(Encoder2z, self).__init__()
        
        self.seen = 0
        self.m_size = m_size
        
        
        self.ratio = 2
        self.frontend_feat1 = []
        self.frontend_feat2 = []
        self.frontend_feat3 = []
        self.frontend_feat4 = []
        self.frontend_feat5 = []
        
        
        
        
        #self.frontend_feat3 += [128, 128, 'M', 256, 256, 'M']
        self.frontend_feat3 += [128, 256]
        
            
        self.frontend_feat4 += [512, self.m_size]
        
        
        #self.frontend_feat5 += [256, 128]
        self.frontend_feat5 += [256, 256]
        
        
        self.frontend5 = make_layers_4(self.frontend_feat5, in_channels = 512, batch_norm = True)
        #self.conv = nn.Conv2d(16, 1, 1, bias=False)
        
        #self.output_layer = nn.Conv2d(self.m_size, self.m_size, kernel_size=1)

        scales = [0.8 * 1.11**i for i in range(3)]
        self.sesn1 = nn.Sequential(
            #nn.Conv2d(512, 256, kernel_size=3,padding=1),
            SESConv_Z2_H(256, 256, 3, 3, scales=scales,
                         padding=1, bias=True,
                         basis_type='A'),
            #nn.BatchNorm3d(256),
            #nn.LeakyReLU(True),
            #SESConv_H_H(256, 256, 1, kernel_size=3, effective_size=3, stride=1,
            #                             padding=1, bias=False, scales=scales, basis_type='A'),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=1),
            )
        
        self._initialize_weights() 
        
        
        #self.effient_net = EfficientNet_unic(self.m_size)
        #self.output_layer = nn.Conv2d(self.m_size, self.m_size, kernel_size=1)
        
        #self.resnet = se_resnet50()
        
        #self.frontend_feat4 += [self.m_size]
        #self.frontend4 = make_layers_4(self.frontend_feat4, in_channels = 512, batch_norm = True)
        #self.output_layer = nn.Conv2d(self.m_size, self.m_size, kernel_size=1)
        
        
        
        
    def forward(self, feature_high):
        #print(feature_high.shape)
        
        if True:
            
            feature_high = self.frontend5(feature_high)
            feature_high = self.sesn1(feature_high) + feature_high
            #print(feature_high)
            #print(discrete_density.shape, x.shape)
        
            return feature_high

        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0.001)
            elif isinstance(m, (SESConv_H_H, SESConv_Z2_H, SESConv_H_H_1x1)):
                nelement = m.weight.nelement()
                n = nelement / m.in_channels
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

      
class CC_Decoder(nn.Module):
    def __init__(self, matrix_size, m_size, pos_out_dim, latent_c_per):
        super(CC_Decoder, self).__init__()
        
        
        #####many things can be done here, like position+features then faltten, like w1*F*W2 to generate the output, like how to initialze the data
        
        
        self.m_size = m_size
        #self.n_features = int(feature_size*feature_size)
        self.matrix_size = matrix_size
        self.pos_dim = pos_out_dim
        
        self.latent_c_per = latent_c_per
        
        #self.weight_dim = self.pos_dim + 3*matrix_size + 4
        self.weight_dim = 4*self.latent_c_per + 4
        
        #n_features = feature_size*feature_size#+2*2*m_size
        #inplanes = inplanes
        #self.conv1 = nn.Conv2d(inplanes, 1, 1, bias=False)
        
        self.last1 = nn.Linear(self.matrix_size, 1)
        #self.last2 = nn.Linear(self.n_features, 1)
        self.last2 = torch.nn.Sequential(
            nn.Linear(self.matrix_size, self.matrix_size),
            nn.PReLU(),
            nn.Linear(self.matrix_size, 1),
            )
        
        self.act = nn.PReLU()#nn.SiLU()  PReLU

        
        self.act1 = nn.PReLU()#nn.Tanh()#nn.LeakyReLU()
        self.act2 = nn.PReLU()#nn.Tanh()#nn.LeakyReLU()
        self.act3 = nn.PReLU()#nn.Tanh()#nn.LeakyReLU()
        self.act4 = nn.PReLU()#nn.Tanh()#nn.LeakyReLU()
        self.act5 = nn.LeakyReLU()
        
        self.act6 = nn.LeakyReLU()
        
        self.N = torch.distributions.Normal(0,1)
        self.N.loc = self.N.loc.cuda()
        self.N.scale = self.N.scale.cuda()
        

        #self.learn_prior1_ = Variable(torch.rand(1,1,256).type(torch.FloatTensor), requires_grad=True).cuda()
        #self.learn_prior2_ = Variable(torch.rand(1,1,256).type(torch.FloatTensor), requires_grad=True).cuda()
        #self.learn_prior3_ = Variable(torch.rand(1,1,256).type(torch.FloatTensor), requires_grad=True).cuda()
        #self.learn_prior4_ = Variable(torch.rand(1,1,256).type(torch.FloatTensor), requires_grad=True).cuda()

        
        #self.W_fine = torch.nn.Sequential(#nn.LayerNorm(self.n_features),
        #                                  nn.Linear(self.n_features, self.n_features), 
        #                                  )
                                          
        #self.W_fine1 = nn.Linear(self.matrix_size, self.matrix_size)
        #self.W_fine2 = nn.Linear(self.matrix_size, self.matrix_size)
        #self.W_fine3 = nn.Linear(self.matrix_size, self.matrix_size)
        #self.W_fine4 = nn.Linear(self.matrix_size, self.matrix_size) 
        
        
        self.inr1 = nn.Linear(self.matrix_size, self.matrix_size)
        self.inr2 = nn.Linear(self.matrix_size, self.matrix_size)
        self.inr3 = nn.Linear(self.matrix_size, self.matrix_size)
        self.inr4 = nn.Linear(self.matrix_size, self.matrix_size)  

        
        #self.finetune_channels = [32,16,16]
        #self.finetune = make_layers_4(self.finetune_channels, in_channels = 32, batch_norm = True)
        #self.finetune_out = nn.Conv2d(16, 1, kernel_size=1)
        
        self._initialize_weights()
        self.omega_0 = 30.0
        self.kl = 0
        
        self.trans_layer =  torch.nn.Sequential(nn.Linear(256, 256),nn.PReLU())
        
        

        self.W_normalize = False
        
            


    def forward(self, feature_high, x2, mode):### x1:z x2:coordinates

        #x1 = self.conv1(x1) 

        
        b, n_query_pts = x2.shape[0], x2.shape[1] #x1 [b 1 8 8 ]x2 [b 1024 2*2*m]
        
        out_dim = int(np.sqrt(n_query_pts))
        
        
  

        
        #print(feature_high.shape)
        #n_query_size = int(np.sqrt(n_query_pts))
        feature_high = F.interpolate(feature_high, size=(64, 64), mode='bicubic', align_corners=True)
        density_w = feature_high.shape[2]
        density_h = feature_high.shape[3]
        b1_f = torch.reshape(feature_high, (b, self.matrix_size, density_w*density_h))
        b1_f = b1_f.transpose(1,2)

        #x2 = torch.cat((x2, b1_f/10), dim=2)
        #x2 = self.trans_layer(x2)
        
        
        #out1 = self.inr1(x2)+ b1_f/10
        out1 = self.inr1(b1_f)
        out1 = self.act1(out1)+b1_f
        
        out2 = self.inr2(out1)
        out2 = self.act2(out2)+out1
        
        out3 = self.inr3(out2)
        out3 = self.act3(out3)+out2
        
        out4 = self.inr4(out3)
        out4 = self.act4(out4)+out3
        
        

        
        
        #out_mu = torch.pow(torch.squeeze(self.last1(out4)), 2)#torch.squeeze(self.last1(out4))
        out_mu = torch.squeeze(self.act(self.last1(out4)))#torch.exp(torch.squeeze(self.last1(out4)))
        #out_mu = torch.squeeze(self.last1(out4))
        out_sigma = torch.exp(torch.squeeze(self.last2(out4)))
        
        #print(out_mu.shape)
        #if mode:
        #    out = out_mu + out_sigma*torch.mean(self.N.sample([out_mu.shape[0], out_mu.shape[1], 500]), axis = 2)
        #else:
        #    out = out_mu
        #out = out_mu + 0.5
        
        
        #out_mu = out_mu.reshape((b, 32, out_dim, out_dim))
        
        #out_mu = self.finetune(out_mu)
        #out_mu = self.finetune_out(out_mu)
        
        #out_mu = out_mu.reshape((b, 1, -1))
        out_mu = out_mu.reshape((b, 1, density_w, density_h))
        #self.kl = torch.mean((0.5*out_sigma**2 + 0.5*(out_mu)**2 - torch.log(out_sigma) - 1/2))
        
        
        return torch.abs(out_mu), out_sigma, self.kl
        
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.001)
                #nn.init.kaiming_uniform_(m.weight)
                #nn.init.constant_(m.weight, 0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0) 
                    #nn.init.kaiming_uniform_(m.bias)
   
def build_cc_decoder(matrix_size, m_size, pos_out_dim, latent_c_per):
    return CC_Decoder(matrix_size, m_size, pos_out_dim, latent_c_per)
    

def make_layers_2(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1

    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            #conv2d = nn.Conv2d(in_channels, v, kernel_size=3,padding=d_rate,dilation = d_rate)
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3,padding=d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.LeakyReLU(inplace=True)]
            else:
                layers += [conv2d, nn.LeakyReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


    
def make_layers_4(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1

    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.PReLU()]
            else:
                layers += [conv2d, nn.PReLU()]
            in_channels = v
    return nn.Sequential(*layers)
    
    
class SEEN(nn.Module):
    def __init__(self, pool_size=2, kernel_size=11, scales=[1.0], basis_type='A', 
                 n_channels=3, input_chanels = 16, **kwargs):
        super(SEEN, self).__init__()
        
        self.seen = 0
        
        kernel_size = 3
        
        
        #C1, C2, C3, C4, C5 = input_chanels, input_chanels, input_chanels, input_chanels, input_chanels
        C1, C2, C3, C4, C5 = 64
        #kernel_size = 11
        
        self.inc = nn.Sequential(
            SESConv_Z2_H(3, C1, kernel_size, 7, scales=scales,
                         padding=1, bias=True,
                         basis_type=basis_type, **kwargs),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(C1),)
        
        
        self.down1 = nn.Sequential(
            SESConv_Z2_H(C1, C2, kernel_size, 7, scales=scales,
                         padding=kernel_size // 2, bias=True,
                         basis_type=basis_type, **kwargs),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(C2),)
        
        self.down2 = nn.Sequential(
            SESConv_Z2_H(C2, C3, kernel_size, 7, scales=scales,
                         padding=kernel_size // 2, bias=True,
                         basis_type=basis_type, **kwargs),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(C3),)
        

            
        self.down3 = nn.Sequential(    
            SESConv_Z2_H(C3, C4, kernel_size, 7, scales=scales,
                         padding=kernel_size // 2, bias=True,
                         basis_type=basis_type, **kwargs),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(C4),)
        
        self.down4 = nn.Sequential(

            SESConv_Z2_H(C4, C5, kernel_size, 7, scales=scales,
                         padding=kernel_size // 2, bias=True,  ##
                         basis_type=basis_type, **kwargs),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            #nn.MaxPool2d(pool_size),
            nn.BatchNorm2d(C5),
        )
        
        
        
        
        
        
        
    def forward(self, x):
        #print(x.shape)
        x1 = self.inc(x)
        #print(x1.shape)
        x2 = self.down1(x1)+x1
        #print(x2.shape)
        x3 = self.down2(x2)+x2
        #print(x3.shape)
        x4 = self.down3(x3)+x3
        #print(x4.shape)
        #features = self.down4(x4)#+x4
        
        return x4#features
class VGG_Backbone(nn.Module):
    def __init__(self, load_weights=False):
        super(VGG_Backbone, self).__init__()
        
        self.seen = 0
        ## frontend feature extraction
        #self.frontend_feat1 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
        self.frontend_feat1 = [64, 64]
        self.frontend_feat2 = ['M', 128, 128]
        self.frontend_feat3 = ['M', 256, 256, 256, 256]
        self.frontend_feat4 = ['M', 512, 512, 512, 512]
        self.frontend_feat5 = ['M', 512, 512, 512, 512]
        ##

        #self.frontend1 = make_layers_vgg(self.frontend_feat1, in_channels = 3) #make_layers_2
        #self.frontend2 = make_layers_vgg(self.frontend_feat2, in_channels = 64)
        #self.frontend3 = make_layers_vgg(self.frontend_feat3, in_channels = 128)
        #self.frontend4 = make_layers_vgg(self.frontend_feat4, in_channels = 256)
        #self.frontend5 = make_layers_vgg(self.frontend_feat5, in_channels = 512)
        
        #self.layers = [self.frontend1, self.frontend2, self.frontend3, self.frontend4, self.frontend5]
        self.layers = nn.ModuleList([make_layers_vgg(self.frontend_feat1, in_channels = 3), make_layers_vgg(self.frontend_feat2, in_channels = 64), 
            make_layers_vgg(self.frontend_feat3, in_channels = 128), make_layers_vgg(self.frontend_feat4, in_channels = 256), 
            make_layers_vgg(self.frontend_feat5, in_channels = 512)])
        
        scales = [0.8 * 1.11**i for i in range(4)]  #others
        #scales = [0.9 * 1.05**i for i in range(3)] ###for ships
        #scales = [1.0]
        self.sesn1 = nn.Sequential(
            SESConv_Z2_H(64, 64, 3, 3, scales=scales,
                         padding=1, bias=True,
                         basis_type='A'),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(64),
            #SESConv_Z2_H(64, 64, 3, 3, scales=scales,
            #             padding=1, bias=True,
            #             basis_type='A'),
            #SESMaxProjection(),
            #nn.LeakyReLU(True),
            #nn.BatchNorm2d(64),
            )
            
        self.sesn2 = nn.Sequential(
            SESConv_Z2_H(128, 128, 3, 3, scales=scales,
                         padding=1, bias=True,
                         basis_type='A'),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(128),)
        
        self.sesn3 = nn.Sequential(
            SESConv_Z2_H(256, 256, 3, 3, scales=scales,
                         padding=1, bias=True,
                         basis_type='A'),
            #nn.BatchNorm3d(256),
            #nn.LeakyReLU(True),
            #SESConv_H_H(256, 256, 1, kernel_size=3, effective_size=3, stride=1,
            #                             padding=1, bias=False, scales=scales, basis_type='A'),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(256),)
          
        self.inc1 = nn.Sequential(
            SESConv_Z2_H(3, 3, 3, 3, scales=scales,
                         padding=1, bias=True,
                         basis_type='A'),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(3),)
        
        self.inc2 = nn.Sequential(
            SESConv_Z2_H(3, 3, 3, 3, scales=scales,
                         padding=1, bias=True,
                         basis_type='A'),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(3),)
        '''
        self.inc3 = nn.Sequential(
            SESConv_Z2_H(8, 8, 3, 3, scales=scales,
                         padding=1, bias=True,
                         basis_type='A'),
            SESMaxProjection(),
            nn.LeakyReLU(True),
            nn.BatchNorm2d(8),)
        '''
        n_list = [0, 4, 8, 16, 24]
        mod = models.vgg19(pretrained = True)
        for i in range(5):
            for j in range(len(self.layers[i].state_dict().items())):
                list(self.layers[i].state_dict().items())[j][1].data[:] = list(mod.state_dict().items())[j+n_list[i]][1].data[:]
        
        
        #for i in range(len(self.frontend1.state_dict().items())):
                #print(list(mod.state_dict().items())[i][0])
                #print(i)
                #if i == 32:
                #    break
        #        list(self.backbone.state_dict().items())[i][1].data[:] = list(mod.state_dict().items())[i][1].data[:] 
        
        
    def forward(self, x):
        x = self.inc1(x)
        #x = F.interpolate(x, size=(256, 512), mode='bicubic', align_corners=False)
        x = self.inc2(x)+x
        #x = self.inc3(x)+x
    
        x = self.layers[0](x)
        x = self.sesn1(x)+x
        
        x = self.layers[1](x)
        x = self.sesn2(x)+x
        
        x = self.layers[2](x)
        x = self.sesn3(x)+x
        
        if True: #our rescaling operator is here
            #x = F.interpolate(x, size=(128, 128), mode='bicubic', align_corners=False)    ###FOR OTHERS
            #x = F.interpolate(x, size=(128, 128), mode='bicubic', align_corners=False)    ###FOR SV
            l = x.shape[2]
            x = F.interpolate(x, size=(l, l), mode='bicubic', align_corners=False)    ###FOR SV
        
        x = self.layers[3](x)
        x = self.layers[4](x)

        #x = self.frontend1(x)
        #x = self.sesn1(x)+x
        
        #x = self.frontend2(x)
        
        #x = self.sesn2(x)+x

        #x = self.frontend3(x)
        
        #x = self.sesn3(x)+x

        #x = self.frontend4(x)
        #res = x
        #x = self.frontend5(x)
        #x = F.interpolate(x, scale_factor=2)
        #x = x + res
        return x
        
def make_layers_vgg(cfg, in_channels = 3, batch_norm=False):
    layers = []
    in_channels = in_channels
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)