import torch
import torch.nn as nn
import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.modules.encoders.adapter import ResnetBlock

class ca_net(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.layer1 = nn.TransformerDecoderLayer(d_model,8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True)
        self.layer2 = nn.TransformerDecoderLayer(d_model,8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True)
    def forward(self,x,y):
        out = self.layer1(x,y)
        out = self.layer2(out,y)
        return out

# class fusionnet(nn.Module):
#  v1
#     def __init__(self, ch=[320, 640, 1280, 1280]):
#         super().__init__()
#         self.layers_sa = nn.ModuleList([nn.Sequential(nn.TransformerEncoderLayer(ch[i],8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True),
#                                     nn.TransformerEncoderLayer(ch[i],8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True),
#                                     nn.TransformerEncoderLayer(ch[i],8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True)) 
#                                     for i in range(len(ch))])
#         self.mask_para = nn.Parameter(torch.randn(1,64*64,8))
#         self.mask_sa = nn.Sequential(nn.TransformerEncoderLayer(ch[0]+8,8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True),
#                                     nn.TransformerEncoderLayer(ch[0]+8,8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True),
#                                     nn.TransformerEncoderLayer(ch[0]+8,8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True)) 
#         self.linear = nn.Conv2d(8,1,3,1,1)

#     def forward(self, adapter_feature, ori_img_feature):
#         '''
#          adapter_feature: Sketch features : [bs,320,64,64], [bs,640,32,32], [bs,1280,16,16], [bs,1280,8,8]
#          ori_img_feature: ori_image + noise feature: [bs,320,64,64], [bs,640,32,32], [bs,1280,16,16], [bs,1280,8,8]
#          return:
#          [h_edge, h_fea, fus_layer]
#         '''
#         hs = []
#         for i, (layer_sa) in enumerate(self.layers_sa):
#             b1,c1,h1,w1 = adapter_feature[i].shape
#             b2,c2,h2,w2 = ori_img_feature[i].shape
#             assert c1==c2, "The dim c of Adapter_fearture mismatches ori_image_feature"
#             ad_f = adapter_feature[i].view(b1,c1,h1*w1).permute(0,2,1)
#             ori_f = ori_img_feature[i].view(b2,c2,h2*w2).permute(0,2,1)
#             h_edge = layer_sa(ori_f+ad_f)
#             h_edge = h_edge.permute(0,2,1).view(b2,c2,h2,w2)

#             if i==0:
#                 tmp = torch.cat([ori_f+ad_f, self.mask_para.repeat(b1,1,1)], dim=-1)
#                 h_mask = self.mask_sa(tmp)[:,:,-8:]
#                 h_mask = h_mask.permute(0,2,1).view(b1,8,h1,w1)
#                 h_mask = torch.sigmoid(self.linear(h_mask))

#             hs.append([h_edge,h_mask])
#             #hs.append([adapter_feature[i],ori_img_feature[i]])
#         return hs
    
class fusionnet(nn.Module):
    def __init__(self, ch=[320+16, 640, 1280, 1280]):
        super().__init__()
        self.layers_sa = nn.ModuleList([nn.Sequential(nn.TransformerEncoderLayer(ch[i],8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True),
                                    nn.TransformerEncoderLayer(ch[i],8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True),
                                    nn.TransformerEncoderLayer(ch[i],8,1024,activation=nn.SiLU(),batch_first=True,norm_first=True)) 
                                    for i in range(len(ch))])
        self.mask_para = nn.Parameter(torch.randn(1,64*64,16))
        self.linear = nn.Conv2d(16,1,3,1,1)

    def forward(self, adapter_feature, ori_img_feature):
        '''
         adapter_feature: Sketch features : [bs,320,64,64], [bs,640,32,32], [bs,1280,16,16], [bs,1280,8,8]
         ori_img_feature: ori_image + noise feature: [bs,320,64,64], [bs,640,32,32], [bs,1280,16,16], [bs,1280,8,8]
         return:
         [h_edge, mask]
        '''
        hs = []
        for i, (layer_sa) in enumerate(self.layers_sa):
            b1,c1,h1,w1 = adapter_feature[i].shape
            b2,c2,h2,w2 = ori_img_feature[i].shape
            assert c1==c2, "The dim c of Adapter_fearture mismatches ori_image_feature"
            ad_f = adapter_feature[i].view(b1,c1,h1*w1).permute(0,2,1)
            ori_f = ori_img_feature[i].view(b2,c2,h2*w2).permute(0,2,1)
            if i==0:
                tmp = torch.cat([ori_f+ad_f, self.mask_para.repeat(b1,1,1)], dim=-1)
                c1,c2 = c1+16,c2+16
            else:
                tmp = ori_f+ad_f
            h_edge = layer_sa(tmp)
            h_edge = h_edge.permute(0,2,1).view(b2,c2,h2,w2)

            if i==0:
                h_mask = torch.sigmoid(self.linear(h_edge[:,320:,:,:]))
                h_edge = h_edge[:,:320,:,:]
                

            hs.append([h_edge,h_mask])
            #hs.append([h_edge,torch.ones_like(h_mask).cuda()])
            #hs.append([adapter_feature[i],h_mask])
            #hs.append([adapter_feature[i],torch.ones_like(h_mask).cuda()])
        return hs