# _*_ coding: utf-8
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from scipy.optimize._basinhopping import AdaptiveStepsize
from einops.layers.torch import Rearrange, Reduce

def general_conv(in_channels, out_channels, kernel_size, bias=False):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features//4
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
        
class EffAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.reduce = nn.Linear(dim, dim//2, bias=qkv_bias)
        self.qkv = nn.Linear(dim//2, dim//2*3, bias=qkv_bias)
        self.proj = nn.Linear(dim//2, dim)
        self.attn_drop = nn.Dropout(attn_drop)

    def forward(self, x):
        x = self.reduce(x)
        B, N, C, _ = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, 2*C*self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        q_all = torch.split(q, math.ceil(N//16), dim=-2)
        k_all = torch.split(k, math.ceil(N//16), dim=-2)
        v_all = torch.split(v, math.ceil(N//16), dim=-2)        

        output = []
        for q,k,v in zip(q_all, k_all, v_all):
            attn = (q @ k.transpose(-2, -1)) * self.scale   #16*8*37*37
            attn = attn.softmax(dim=-1) 
            attn = self.attn_drop(attn)
            trans_x = (attn @ v).transpose(1, 2) #.reshape(B, N, C)
            output.append(trans_x)
        x = torch.cat(output,dim=1)
        x = x.reshape(B,N,C,-1)
        x = self.proj(x)
        return x        

class TransBlock(nn.Module):
    def __init__(
        self, n_feat = 64, dim=64, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
        super(TransBlock, self).__init__()
        self.dim = dim
        self.atten = EffAttention(self.dim, num_heads=8, qkv_bias=False, qk_scale=None, \
                             attn_drop=0., proj_drop=0.)
        self.norm1 = nn.LayerNorm(self.dim)
        self.mlp = Mlp(in_features=dim, hidden_features=dim//4, act_layer=act_layer, drop=drop)
        self.norm2 = nn.LayerNorm(self.dim)

    def forward(self, x):
        B = x.shape[0]

        x = x + self.atten(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ConvTransBlock(nn.Module):
    def __init__(self, conv_dim, trans_dim, window_size=8, type='W', input_resolution=256):
        """ SwinTransformer and Conv Block
        """
        super(ConvTransBlock, self).__init__()
        self.conv_dim = conv_dim
        self.trans_dim = trans_dim
        self.window_size = window_size
        self.type = type
        self.input_resolution = input_resolution

        assert self.type in ['W', 'SW']
        if self.input_resolution <= self.window_size:
            self.type = 'W'

        self.trans_block = TransBlock(self.trans_dim, self.trans_dim)
        self.conv1_1 = nn.Conv2d(self.conv_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
        self.conv1_2 = nn.Conv2d(self.conv_dim+self.trans_dim, self.trans_dim, 1, 1, 0, bias=True)

        self.conv_block = nn.Sequential(
                nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
                )

    def forward(self, x):
        conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
        conv_x = self.conv_block(conv_x) + conv_x
        trans_x = Rearrange('b c h w -> b h w c')(trans_x)
        trans_x = self.trans_block(trans_x)
        trans_x = Rearrange('b h w c -> b c h w')(trans_x)
        res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
        res =x + res

        return res






# 定义一个包含空洞卷积、批量归一化和ReLU激活函数的子模块
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            # 空洞卷积，通过调整dilation参数来捕获不同尺度的信息
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),  # 批量归一化
            nn.ReLU()  # ReLU激活函数
        ]
        super(ASPPConv, self).__init__(*modules)


# 定义一个全局平均池化后接卷积、批量归一化和ReLU的子模块
class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),  # 全局平均池化
            nn.Conv2d(in_channels, out_channels, 1, bias=False),  # 1x1卷积
            nn.BatchNorm2d(out_channels),  # 批量归一化
            nn.ReLU())  # ReLU激活函数

    def forward(self, x):
        size = x.shape[-2:]  # 保存输入特征图的空间维度
        x = super(ASPPPooling, self).forward(x)
        # 通过双线性插值将特征图大小调整回原始输入大小
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


# ASPP模块主体，结合不同膨胀率的空洞卷积和全局平均池化
class ASP_conv(nn.Module):
    def __init__(self, in_channels=256, out_channels=256, scale=0.8):
        super(ASP_conv, self).__init__()
        self.scale = scale
        self.layer_out = nn.Sequential(
            nn.Conv2d(out_channels, 3 , kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=True)
            )
        out_channels = 256  # 输出通道数
        modules = []
        # modules.append(nn.Sequential(
        #     nn.Conv2d(in_channels, out_channels, 1, bias=False),  # 1x1卷积用于降维
        #     nn.BatchNorm2d(out_channels),
        #     nn.ReLU()))
        atrous_rates = [1,2,4]
        # 根据不同的膨胀率添加空洞卷积模块
        for rate in atrous_rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        # 添加全局平均池化模块
        # modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        # 将所有模块的输出融合后的投影层
        self.project = nn.Sequential(
            nn.Conv2d(3 * out_channels, out_channels, 1, bias=False),  # 融合特征后降维
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))  # 防止过拟合的Dropout层

    def forward(self, x):
        feature_size = x.shape[-2:]
        res = []
        # 对每个模块的输出进行收集
        for conv in self.convs:
            res.append(conv(x))
        # 将收集到的特征在通道维度上拼接
        res = torch.cat(res, dim=1)
        # 对拼接后的特征进行处理
        out = x + self.project(res)
        out = F.interpolate(out, size=(math.ceil(feature_size[0]*self.scale), math.ceil(feature_size[1]*self.scale)), mode = 'bilinear' )
        return out, self.layer_out(out)
    
    
# #定义空洞金字塔
# class ASP_conv(nn.Module):
#     def __init__(self,in_channels=256, out_channels=256, d_rate=2, scale=0.8):
#         super(ASP_conv, self).__init__()
#         self.scale = scale
#         self.layer1_1 = nn.Sequential(
#             nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=d_rate, dilation=d_rate),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels , out_channels , kernel_size=3, stride=1, padding=d_rate*2, dilation=d_rate*2),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             )
#         self.layer_out = nn.Sequential(
#             nn.Conv2d(out_channels, 3 , kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(3),
#             nn.ReLU(inplace=True)
#             )

#     def forward(self, x):
#         feature_size = x.shape[-2:]
#         #head
#         out =x + self.layer1_1(x)
#         #scale
#         out = F.interpolate(out, size=(math.ceil(feature_size[0]*self.scale), math.ceil(feature_size[1]*self.scale)), mode = 'bilinear' )
#         return out, self.layer_out(out)


class UpSample(nn.Module):
    def __init__(self, in_channels=128, size=256, out_channels = 3):
        super(UpSample,self).__init__()
        self.up_layer = nn.Sequential(
            nn.Upsample(size = size, mode='bilinear'),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
    def forward(self, x):
        return self.up_layer(x)

class RPANet(nn.Module):
    def __init__(self,args):
        super(RPANet,self).__init__()
        ## parameters inintal
        self.in_put = 3
        self.out_put = 3
        self.scale = args.scale
        self.layers = 3
        self.size = args.feature_size
        self.input_size = 512
        self.mid_num = [math.ceil(self.size*(self.scale**(i+1))) for i in range(self.layers)]

        #down sampling module
        self.head = nn.Conv2d(self.in_put, self.size, kernel_size=3, stride=2, padding=1) # reduce dim
        self.down0 = ASP_conv(self.size, self.size, self.scale)  # out:256*0.8,
        self.down1 = ASP_conv(self.size, self.size,  self.scale) # out:256*0.8*0.8
        self.down2 = ASP_conv(self.size, self.size,  self.scale) # out:256*0.8*0.8*0.8

        self.up0 = UpSample(self.size, self.input_size, self.out_put)
        self.up1 = UpSample(self.size, self.mid_num[0], self.out_put) #205
        self.up2 = UpSample(self.size, self.mid_num[1], self.out_put) #164
        
        #vit 
        self.vit0 = ConvTransBlock(self.size, self.size)
        self.vit1 = ConvTransBlock(self.size, self.size)
        self.vit2 = ConvTransBlock(self.size, self.size)
       
        self.c0 = general_conv(3, self.size, kernel_size=3)
        self.c1 = general_conv(3, self.size, kernel_size=3)
        self.c2 = general_conv(3, self.input_size, kernel_size=3)

        self.layer_out = self.layer_out = nn.Sequential(
            nn.Conv2d(self.input_size, 3 , kernel_size=1, stride=1, padding=0),
            )

    def forward(self,x):
        # x is the input image size 2, 3, 256, 256
        x_size = x.shape[-2:]


        #encoder
        out = self.head(x) # 3, 128,128,128
        out, out0 = self.down0(out)  # 2,128,103,103, out0 is the skip part
        # L_0 = x - F.upsample(out0, size = x_size, mode='bilinear') #512
        L_0 = F.upsample(out0, size = x_size, mode='bilinear')
        
        out, out1= self.down1(out) #2,128,87
        # L_1 = F.normalize(out0 - F.upsample(out1, size = (out0.shape[2], out0.shape[3]), mode='bilinear')) #205
        L_1 = F.upsample(out1, size = (out0.shape[2], out0.shape[3]), mode='bilinear')
        
        out, out0  = self.down2(out) # 2,128,67
        # L_2 = F.normalize(out1 - F.upsample(out0, size = (out1.shape[2],out1.shape[3]), mode='bilinear')) #164
        L_2 = F.upsample(out0, size = (out1.shape[2],out1.shape[3]), mode='bilinear')
        
    
        #decoder: reconstructed smoothed images
        # y=self.up2(self.vit0(out))
        out = L_2 + self.up2(self.vit0(out))#164
        out = self.c0(out)
        out = L_1 + self.up1(self.vit1(out))#205
        out = self.c1(out)
        out = L_0 + self.up0(self.vit2(out))#256
        out = self.c2(out)
     
        return self.layer_out(out)