import torch
import torch.nn as nn
import torch.nn.functional as F
from pdb import set_trace as stx
from .arch_util import LayerNorm2d, DownSample, UpSample, FilterLow
from einops import rearrange


# Multi-scale Consistency Calibration Module
class SpatialAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv1 = nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1, bias=True)
        self.conv2_1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=True)
        self.conv2_2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=True)
        self.relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x1, x2):
        f_cat = torch.cat([x1, x2], dim=1)
        feats = self.relu(self.conv1(f_cat))
        attn1 = self.sigmoid(self.conv2_1(feats))
        attn2 = self.sigmoid(self.conv2_2(feats))
        return x1*attn1, x2*attn2

class MICM(nn.Module):
    def __init__(self, dim, bias):
        super(MICM, self).__init__()
        self.in_x = nn.Conv2d(dim, dim, 3, 1, 1, bias=bias)
        self.in_y = nn.Conv2d(dim, dim, 3, 1, 1, bias=bias)
        self.pools_sizes = [8,4,2]
        pools_x, pools_y, convs_x, convs_y, attns = [],[],[],[],[]
        for i in self.pools_sizes:
            pools_x.append(nn.AvgPool2d(kernel_size=i, stride=i))
            pools_y.append(nn.AvgPool2d(kernel_size=i, stride=i))
            convs_x.append(nn.Conv2d(dim, dim, 3, 1, 1, bias=bias))
            convs_y.append(nn.Conv2d(dim, dim, 3, 1, 1, bias=bias))
            attns.append(SpatialAttention(dim))
        self.pools_x = nn.ModuleList(pools_x)
        self.pools_y = nn.ModuleList(pools_y)
        self.convs_x = nn.ModuleList(convs_x)
        self.convs_y = nn.ModuleList(convs_y)
        self.attns = nn.ModuleList(attns)
        self.relu = nn.GELU()
        self.sum_x = nn.Conv2d(dim, dim, 3, 1, 1, bias=bias)
        self.sum_y = nn.Conv2d(dim, dim, 3, 1, 1, bias=bias)

    def forward(self, x, y):
        x_size = x.size()
        res_x = self.in_x(x)
        res_y = self.in_y(y)
        for i in range(len(self.pools_sizes)):
            if i == 0:
                x_, y_ = self.attns[i](self.convs_x[i](self.pools_x[i](x)), self.convs_y[i](self.pools_y[i](y)))
            else:
                x_, y_ = self.attns[i](self.convs_x[i](self.pools_x[i](x)+x_up), self.convs_y[i](self.pools_y[i](y)+y_up))
            res_x = torch.add(res_x, F.interpolate(x_, x_size[2:], mode='bilinear', align_corners=True))
            res_y = torch.add(res_y, F.interpolate(y_, x_size[2:], mode='bilinear', align_corners=True))
            if i != len(self.pools_sizes)-1:
                x_up = F.interpolate(x_, scale_factor=2, mode='bilinear', align_corners=True)
                y_up = F.interpolate(y_, scale_factor=2, mode='bilinear', align_corners=True)
        res_x = x + self.sum_x(self.relu(res_x))
        res_y = y + self.sum_y(self.relu(res_y))

        return res_x, res_y
    
def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x,h,w):
    return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)

class DropPath(nn.Dropout):
    def forward(self, inputs):
        shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1)
        mask = torch.ones(shape).cuda()
        if self.training:
            mask = F.dropout(mask, self.p, training=True)
            return inputs * mask
        else:
            return inputs

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class DFFN(nn.Module):
    def __init__(self, dim, hidden_features=None, bias=False):
        super(DFFN, self).__init__()

        self.patch_size = 8

        self.dim = dim
        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
                                groups=hidden_features * 2, bias=bias)

        self.fft = nn.Parameter(torch.ones((hidden_features * 2, 1, 1, self.patch_size, self.patch_size // 2 + 1)))
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x_patch = rearrange(x, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
                            patch2=self.patch_size)
        x_patch_fft = torch.fft.rfft2(x_patch.float())
        x_patch_fft = x_patch_fft * self.fft
        x_patch = torch.fft.irfft2(x_patch_fft, s=(self.patch_size, self.patch_size))
        x = rearrange(x_patch, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
                      patch2=self.patch_size)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)

        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

class DSAttention(nn.Module):
    def __init__(self, dim, num_head, bias):
        super(DSAttention, self).__init__()
        self.num_head = num_head
        self.low_filter = FilterLow(recursions=1, kernel_size=5, stride=1, include_pad=False)
        self.temperature = nn.Parameter(torch.ones(num_head, 1, 1), requires_grad=True)
        self.q = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1, bias=bias),nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias, groups=dim))
        self.k = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1, bias=bias),nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias, groups=dim))
        self.v_h = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1, bias=bias),nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias, groups=dim))
        self.v_l = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1, bias=bias),nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias, groups=dim))

        self.fc1 = nn.Conv2d(num_head, num_head, kernel_size=1, bias=bias)
        self.fc2 = nn.Conv2d(num_head, num_head, kernel_size=1, bias=bias)
        self.softmax = nn.Softmax(dim=-1)

        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x, y):
        assert x.shape == y.shape, 'The shape of feature maps from target and guidance branch are not equal!'

        b, c, h, w = x.shape
        high = self.low_filter(y)
        low = y - high
        q = self.q(x) 
        k = self.k(y)
        v_h = self.v_h(high)
        v_l = self.v_l(low)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_head)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_head)
        v_h = rearrange(v_h, 'b (head c) h w -> b head c (h w)', head=self.num_head)
        v_l = rearrange(v_l, 'b (head c) h w -> b head c (h w)', head=self.num_head)
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        attn = q @ k.transpose(-2, -1) * self.temperature
        high_att = self.fc1(attn)
        low_att = self.fc2(attn)
        attention_vectors = torch.cat([high_att, low_att], dim=-1)

        attention_vectors = self.softmax(attention_vectors)
        high_att, low_att = torch.chunk(attention_vectors, 2, dim=-1)
        out = high_att @ v_h + low_att @ v_l
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_head, h=h, w=w)

        out = self.project_out(out)
        return out

# Spatial-Frequancy Cross Attention Transformer Block
class DAFM(nn.Module):
    def __init__(self, dim, num_heads, ffn_expand=4, drop_path=0.1, bias=False):
        super(DAFM, self).__init__()

        self.norm1_x = LayerNorm2d(dim)
        self.norm1_y = LayerNorm2d(dim)
        self.attn = DSAttention(dim, num_heads, bias=bias)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = LayerNorm2d(dim)
        ffn_expand_dim = int(dim * ffn_expand)
        self.ffn = DFFN(dim, hidden_features=ffn_expand_dim)

    def forward(self, x, y):
        # input (b, c, h, w) return (b, c, h, w)
        # x: vis
        # y: nir
        assert x.shape == y.shape, 'the shape of image doesnt equal to event'
        b, c, h, w = x.shape
        inp = x
        x = self.norm1_x(x)
        y = self.norm1_y(y)
        x = self.attn(x, y)
        mid = inp + self.drop_path(x)
        # FFN
        #mid = to_3d(mid)
        x = mid + self.drop_path(self.ffn(self.norm2(mid)))
        #x = to_4d(x, h, w)

        return x

##########################################################################
def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias, stride = stride)


##########################################################################
## Channel Attention Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16, bias=False):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y


##########################################################################
## Channel Attention Block (CAB)
class CAB(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, bias, act, ca=True):
        super(CAB, self).__init__()
        modules_body = []
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
        modules_body.append(act)
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))

        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

##########################################################################
## Supervised Attention Module
class SAM(nn.Module):
    def __init__(self, n_feat, kernel_size, bias):
        super(SAM, self).__init__()
        self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
        self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
        self.conv3 = conv(3, n_feat, kernel_size, bias=bias)

    def forward(self, x, x_img):
        x1 = self.conv1(x)
        img = self.conv2(x) + x_img
        x2 = torch.sigmoid(self.conv3(img))
        x1 = x1*x2
        x1 = x1+x
        return x1, img

##########################################################################
## U-Net

class Encoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff, ca, cross=False):
        super(Encoder, self).__init__()

        num=1

        self.encoder_level1 = [CAB(n_feat,                     kernel_size, reduction, bias=bias, act=act, ca=ca) for _ in range(num)]
        self.encoder_level2 = [CAB(n_feat+scale_unetfeats,     kernel_size, reduction, bias=bias, act=act, ca=ca) for _ in range(num)]
        self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act, ca=ca) for _ in range(num)]

        self.encoder_level1 = nn.Sequential(*self.encoder_level1)
        self.encoder_level2 = nn.Sequential(*self.encoder_level2)
        self.encoder_level3 = nn.Sequential(*self.encoder_level3)

        self.down12  = DownSample(n_feat, scale_unetfeats)
        self.down23  = DownSample(n_feat+scale_unetfeats, scale_unetfeats)

        if cross:
            num_heads=[1,2,4]
            self.image_event_transformer1 = DAFM(n_feat, num_heads=num_heads[0], ffn_expand=4, bias=bias)
            self.image_event_transformer2 = DAFM(n_feat+scale_unetfeats, num_heads=num_heads[1], ffn_expand=4, bias=bias)
            self.image_event_transformer3 = DAFM(n_feat+(scale_unetfeats*2), num_heads=num_heads[2], ffn_expand=4, bias=bias)

        # Cross Stage Feature Fusion (CSFF)
        if csff:
            self.csff_enc1 = nn.Conv2d(n_feat,                     n_feat,                     kernel_size=1, bias=bias)
            self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats,     n_feat+scale_unetfeats,     kernel_size=1, bias=bias)
            self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)

            self.csff_dec1 = nn.Conv2d(n_feat,                     n_feat,                     kernel_size=1, bias=bias)
            self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats,     n_feat+scale_unetfeats,     kernel_size=1, bias=bias)
            self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias)

    def forward(self, x, encoder_outs=None, decoder_outs=None, guids=None):
        enc1 = self.encoder_level1(x)
        if (encoder_outs is not None) and (decoder_outs is not None):
            enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
        if guids is not None:
            guid1, guid2, guid3 = guids
            enc1 = self.image_event_transformer1(enc1, guid1)

        x = self.down12(enc1)

        enc2 = self.encoder_level2(x)
        if (encoder_outs is not None) and (decoder_outs is not None):
            enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
        if guids is not None:
            enc2 = self.image_event_transformer2(enc2, guid2)

        x = self.down23(enc2)

        enc3 = self.encoder_level3(x)
        if (encoder_outs is not None) and (decoder_outs is not None):
            enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
        if guids is not None:
            enc3 = self.image_event_transformer3(enc3, guid3)
        
        return [enc1, enc2, enc3]

class Decoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, ca):
        super(Decoder, self).__init__()

        num=1

        self.decoder_level1 = [CAB(n_feat,                     kernel_size, reduction, bias=bias, act=act, ca=ca) for _ in range(num)]
        self.decoder_level2 = [CAB(n_feat+scale_unetfeats,     kernel_size, reduction, bias=bias, act=act, ca=ca) for _ in range(num)]
        self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act, ca=ca) for _ in range(num)]

        self.decoder_level1 = nn.Sequential(*self.decoder_level1)
        self.decoder_level2 = nn.Sequential(*self.decoder_level2)
        self.decoder_level3 = nn.Sequential(*self.decoder_level3)

        self.skip_attn1 = conv(n_feat, n_feat, kernel_size, bias=bias)
        self.skip_attn2 = conv(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size, bias=bias)
        
        self.up21  = SkipUpSample(n_feat, scale_unetfeats)
        self.up32  = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats)

    def forward(self, outs):
        enc1, enc2, enc3 = outs
        dec3 = self.decoder_level3(enc3)

        x = self.up32(dec3, self.skip_attn2(enc2))
        dec2 = self.decoder_level2(x)

        x = self.up21(dec2, self.skip_attn1(enc1))
        dec1 = self.decoder_level1(x)

        return [dec1,dec2,dec3]

##########################################################################
##---------- Resizing Modules ----------    
class DownSample(nn.Module):
    def __init__(self, in_channels,s_factor):
        super(DownSample, self).__init__()
        self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
                                  nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False))

    def forward(self, x):
        x = self.down(x)
        return x

class UpSample(nn.Module):
    def __init__(self, in_channels,s_factor):
        super(UpSample, self).__init__()
        self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                                nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))

    def forward(self, x):
        x = self.up(x)
        return x

class SkipUpSample(nn.Module):
    def __init__(self, in_channels,s_factor):
        super(SkipUpSample, self).__init__()
        self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                                nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False))

    def forward(self, x, y):
        x = self.up(x)
        x = x + y
        return x


##########################################################################
class RFFNet_paper(nn.Module):
    def __init__(self, in_c=3, out_c=3, aux_chn=1, n_feat=64, scale_unetfeats=32, kernel_size=3, reduction=4, bias=False):
        super(RFFNet_paper, self).__init__()

        act=nn.PReLU()
        self.shallow_feat1 = conv(in_c, n_feat, kernel_size, bias=bias)
        self.shallow_feat2 = conv(in_c, n_feat, kernel_size, bias=bias)
        self.shallow_feat_guid = conv(aux_chn, n_feat, kernel_size, bias=bias)

        self.calibrate = MICM(n_feat, bias=bias)

        # Cross Stage Feature Fusion (CSFF)
        self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, ca=True, csff=False)
        self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, ca=True)

        self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True, ca=False, cross=True)
        self.guid_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False, ca=False)
        self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, ca=False)

        self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)
        
        self.concat12  = conv(n_feat*2, n_feat, kernel_size, bias=bias)
        self.tail     = conv(n_feat, out_c, kernel_size, bias=bias)

    def forward(self, x3_img, guid):
        # Original-resolution Image for Stage 3
        H = x3_img.size(2)
        W = x3_img.size(3)

        ##-------------------------------------------
        ##-------------- Stage 1---------------------
        ##-------------------------------------------
        ## Compute Shallow Features
        x1 = self.shallow_feat1(x3_img)
        
        ## Process features of all 4 patches with Encoder of Stage 1
        feat1 = self.stage1_encoder(x1)
        
        ## Pass features through Decoder of Stage 1
        res1 = self.stage1_decoder(feat1)

        ## Apply Supervised Attention Module (SAM)
        x2_samfeats, stage1_img = self.sam12(res1[0], x3_img)
        
        ##-------------------------------------------
        ##-------------- Stage 2---------------------
        ##-------------------------------------------
        ## Compute Shallow Features
        x2  = self.shallow_feat2(x3_img)

        guid  = self.shallow_feat_guid(guid)

        ## Concatenate SAM features of Stage 1 with shallow features of Stage 2
        x2_cat = self.concat12(torch.cat([x2, x2_samfeats], 1))

        ## Process features of both patches with Encoder of Stage 2
        x2_cat, guid = self.calibrate(x2_cat, guid)

        guids = self.guid_encoder(guid)

        feat2 = self.stage2_encoder(x2_cat, feat1, res1, guids)

        ## Pass features through Decoder of Stage 2
        res2 = self.stage2_decoder(feat2)

        stage2_img = self.tail(res2[0])

        return stage2_img+x3_img, stage1_img
    
if __name__ == "__main__":
    import torch
    from thop import profile

    # Model
    print('==> Building model..')
    model = RFFNet_paper().cuda(0)

    noisy = torch.randn(1, 3, 128, 128).cuda(0)
    guid = torch.randn(1, 1, 128, 128).cuda(0)

    num = 200

    import time
    for i in range(500):
        model(noisy, guid)
    start = time.time()
    for i in range(num):
        torch.cuda.synchronize()
        model(noisy, guid)
        torch.cuda.synchronize()
    time_used = time.time() - start

    # guid = torch.randn(1, 1, 128, 128)
    flops, params = profile(model, (noisy, guid))
    # print('flops: ', flops, 'params: ', params)
    print('flops: %.2f G, params: %.2f M, time: %.2f ms' % (flops / 1e9, params / 1e6, time_used / num * 1e3))