import torch
from torch import nn
from models.deeplab.deeplab_ori.backbone.resnet import ResNet101_multiscale, ResNet50_multiscale
from models.shared import conv_block, up_conv
from models.layers import ASPP
from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
import torch.nn.functional as F
from .shared import conv_block
import math


class DeepLabv3p_boost(nn.Module):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__()
        self.resnet = ResNet101_multiscale(in_ch, output_stride=16, pretrained=True)
        out_ch = num_classes

        self.ts = nn.Parameter(torch.zeros(3), requires_grad=False)
        self.ts[:2] = 0.5
        filters = self.resnet.out_s  # [3,64,256,512,2048]
        self.aspp = ASPP(filters[-1], 256)

        self.Up5 = up_conv(filters[4]+256, filters[3])
        self.Up_conv5 = conv_block(filters[3]+filters[3], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[2]+filters[2], filters[2])

        # self.Up3 = up_conv(filters[2], filters[1])
        # self.Up_conv3 = conv_block(filters[1]+filters[1], filters[1])

        # self.Up2 = up_conv(filters[1], 64)
        # self.conv_e1 = conv_block(filters[0],64)
        # self.Up_conv2 = conv_block(64+64, filters[0])

        self.Pred5 = nn.Sequential(
            conv_block(filters[4]+256, filters[0], kernel_size=1),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
        )

        self.Pred4 = nn.Sequential(
            conv_block(filters[3], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
        )
        self.Pred3 = nn.Sequential(
            conv_block(filters[2], filters[0]),
            nn.Conv2d(filters[0], out_ch,
                      kernel_size=1, stride=1, padding=0),
        )

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x)  # 3,64,256,512,2048

        sig = 5
        aspp_e5 = self.aspp(e5)
        e5 = torch.cat([e5, aspp_e5], dim=1)
        pred5 = self.Pred5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        pred4 = self.Pred4(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        pred3 = self.Pred3(d4)

        # d3 = self.Up3(d4)
        # d3 = torch.cat((e2, d3), dim=1)
        # d3 = self.Up_conv3(d3)
        # pred2 = self.Pred2(d3)

        # d2 = self.Up2(d3)
        # e1 = self.conv_e1(e1)
        # d2 = torch.cat((e1, d2), dim=1)
        # d2 = self.Up_conv2(d2)
        # pred1 = self.Pred1(d2)
        # d1 = self.active(out)
        return pred5, pred4, pred3  # , pred2, pred1


class Decoder(nn.Module):
    def __init__(self, num_classes, high_level_inplanes, low_level_inplanes, low_dim=48, out_dim=256, rm_BN=False, BatchNorm=SynchronizedBatchNorm2d):
        super(Decoder, self).__init__()

        self.conv1 = nn.Conv2d(low_level_inplanes, low_dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = BatchNorm(low_dim)
        self.relu = nn.ReLU()

        self.last_conv = nn.Sequential(nn.Conv2d(high_level_inplanes+low_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(out_dim),
                                       nn.ReLU(),
                                       )
        # if rm_BN:
        #     self.last_conv2 = nn.Sequential(nn.Dropout(0.5),
        #                             nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
        #                             nn.ReLU(),
        #                             nn.Dropout(0.1),
        #                             nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        # else:
        self.last_conv2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNorm(out_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(out_dim, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x_high_level, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x_high_level = F.interpolate(x_high_level, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_high_level, low_level_feat), dim=1)
        mid = self.last_conv(x)
        x = self.last_conv2(mid)
        return mid, x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class DeepLabv3p_boost_similar(nn.Module):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__()
        self.resnet = ResNet101_multiscale(in_ch, output_stride=16, pretrained=False)
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3), requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s  # [3,64,256,512,2048]
        self.aspp = ASPP(filters[-1], 256)

        self.Pred5 = Decoder(out_ch, high_level_inplanes=256, low_dim=256, out_dim=256, low_level_inplanes=filters[4])

        self.Pred4 = Decoder(out_ch, high_level_inplanes=256, low_dim=512, out_dim=128, low_level_inplanes=filters[3])

        self.Pred3 = Decoder(out_ch, high_level_inplanes=128, low_dim=256, out_dim=128, low_level_inplanes=filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x)  # 3,64,256,512,2048

        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5, pred5 = self.Pred5(aspp_e5, e5)

        mid_e4, pred4 = self.Pred4(mid_e5, e4)

        mid_e3, pred3 = self.Pred3(mid_e4, e3)

        return pred5, pred4, pred3  # , pred2, pred1

class DeepLabv3p_boost_res50_similar(DeepLabv3p_boost_similar):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__(in_ch, num_classes, **kwargs)
        self.resnet = ResNet50_multiscale(in_ch, output_stride=16, pretrained=False)


class DeepLabv3p_boost_similar_base(nn.Module):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__()
        self.resnet = ResNet101_multiscale(in_ch, output_stride=16, pretrained=False)
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(3), requires_grad=False)
        self.ts[:2] = 0.5

        filters = self.resnet.out_s  # [3,64,256,512,2048]
        self.aspp = ASPP(filters[-1], 256)

        self.Pred5 = Decoder(out_ch, high_level_inplanes=256, low_dim=256, out_dim=256, low_level_inplanes=filters[4])

        self.Pred4 = Decoder(out_ch, high_level_inplanes=256, low_dim=256, out_dim=256, low_level_inplanes=filters[3])

        self.Pred3 = Decoder(out_ch, high_level_inplanes=256, low_dim=256, out_dim=256, low_level_inplanes=filters[2])

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x)  # 3,64,256,512,2048

        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5, pred5 = self.Pred5(aspp_e5, e5)

        mid_e4, pred4 = self.Pred4(mid_e5, e4)

        mid_e3, pred3 = self.Pred3(mid_e4, e3)

        return pred5, pred4, pred3  # , pred2, pred1


class DeepLabv3p_boost_5similar(nn.Module):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__()
        self.resnet = ResNet101_multiscale(in_ch, output_stride=16, pretrained=True)
        out_ch = num_classes
        BatchNorm = nn.BatchNorm2d
        self.ts = nn.Parameter(torch.zeros(5), requires_grad=False)
        self.ts[:] = 0.5
        filters = self.resnet.out_s  # [3,64,256,512,2048]
        self.aspp = ASPP(filters[-1], 256)

        self.Pred5 = Decoder(out_ch, high_level_inplanes=256, low_level_inplanes=filters[4])

        self.Pred4 = Decoder(out_ch, high_level_inplanes=256, out_dim=256, low_level_inplanes=filters[3])

        self.Pred3 = Decoder(out_ch, high_level_inplanes=256, out_dim=128, low_level_inplanes=filters[2])

        self.Pred2 = Decoder(out_ch, high_level_inplanes=128, out_dim=32, low_level_inplanes=filters[1])

        self.conv_e1 = nn.Sequential(
            nn.Conv2d(filters[0], 32, kernel_size=5, stride=1, padding=2, bias=False),
            BatchNorm(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2, bias=False),
            BatchNorm(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2, bias=False),
            BatchNorm(32),
            nn.ReLU(),
        )

        self.Pred1 = Decoder(out_ch, high_level_inplanes=32, low_dim=32, out_dim=32, low_level_inplanes=32)

        # self.active = torch.nn.Softmax(dim=1)
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.resnet(x)  # 3,64,256,512,2048

        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5, pred5 = self.Pred5(aspp_e5, e5)

        mid_e4, pred4 = self.Pred4(mid_e5, e4)

        mid_e3, pred3 = self.Pred3(mid_e4, e3)

        mid_e2, pred2 = self.Pred2(mid_e3, e2)

        e1 = self.conv_e1(e1)
        mid_e1, pred1 = self.Pred1(mid_e2, e1)

        return pred5, pred4, pred3, pred2, pred1


class UNet_no_dec(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=3, kernel_size=3):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.filters = filters

        self.ks = kernel_size
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = conv_block(filters[2], filters[3], kernel_size=self.ks)
        self.Conv5 = conv_block(filters[3], filters[4], kernel_size=self.ks)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        #d1 = self.active(out)
        return e1, e2, e3, e4, e5


class UNet_boost_similar(nn.Module):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__()
        self.unet = UNet_no_dec(in_ch)
        out_ch = num_classes
        self.ts = nn.Parameter(torch.zeros(5), requires_grad=False)
        self.ts[:] = 0.5

        filters = self.unet.filters  # [64, 128, 256, 512, 1024]
        self.aspp = ASPP(filters[-1], 256)

        self.Pred5 = Decoder(out_ch, high_level_inplanes=256, low_dim=256, out_dim=256, low_level_inplanes=filters[4])
        self.Pred4 = Decoder(out_ch, high_level_inplanes=256, low_dim=512, out_dim=128, low_level_inplanes=filters[3])
        self.Pred3 = Decoder(out_ch, high_level_inplanes=128, low_dim=256, out_dim=128, low_level_inplanes=filters[2])
        self.Pred2 = Decoder(out_ch, high_level_inplanes=128, low_dim=128, out_dim=64, low_level_inplanes=filters[1])
        self.Pred1 = Decoder(out_ch, high_level_inplanes=64, low_dim=64, out_dim=128, low_level_inplanes=filters[0])

        # self.active = torch.nn.Softmax(dim=1)

    def forward(self, x):
        e1, e2, e3, e4, e5 = self.unet(x)  # 3,64,256,512,2048

        sig = 5
        aspp_e5 = self.aspp(e5)

        mid_e5, pred5 = self.Pred5(aspp_e5, e5)
        mid_e4, pred4 = self.Pred4(mid_e5, e4)
        mid_e3, pred3 = self.Pred3(mid_e4, e3)
        mid_e2, pred2 = self.Pred2(mid_e3, e2)
        mid_e1, pred1 = self.Pred1(mid_e2, e1)

        return pred5, pred4, pred3, pred2, pred1  # , pred2, pred1


class Edge(nn.Module):
    def __init__(self, x_dim, feat_dim, num_classes):
        super().__init__()

        self.feat_dim = feat_dim

        self.edge1 = conv_block(feat_dim // 2, feat_dim // 2, kernel_size=3)
        self.edge2 = nn.Conv2d(feat_dim // 2, 1, kernel_size=3, padding=1)

        self.area1= conv_block(feat_dim // 2 , feat_dim // 2, kernel_size=3)
        self.area2 = nn.Conv2d(feat_dim // 2, num_classes, kernel_size=3, padding=1)

        self.output= nn.Sequential(
            nn.Dropout(0.5),
            nn.Conv2d(feat_dim, feat_dim, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(feat_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(feat_dim, num_classes, kernel_size=1, stride=1))

        # self.output1 = conv_block(feat_dim, feat_dim, kernel_size=3)
        # self.output2 = nn.Conv2d(feat_dim, num_classes, kernel_size=3, padding=1)

        self.img_q = conv_block(x_dim, feat_dim // 2, kernel_size=3)
    
    def forward(self, x, feat):
        dim = self.feat_dim // 2

        edge1 = self.edge1(feat[:, :dim])
        edge2 = self.edge2(edge1)

        area1 = self.area1(feat[:, dim:])
        area2 = self.area2(area1)

        x = F.interpolate(x, size=feat.size()[2:], mode='bilinear', align_corners=True)
        img_q = self.img_q(x)
        att_e = (img_q * edge1).sum(dim=1, keepdim=True) / math.sqrt(self.feat_dim // 2)
        att_a = (img_q * area1).sum(dim=1, keepdim=True) / math.sqrt(self.feat_dim // 2)
        att = torch.cat((att_e, att_a), dim=1)
        att = F.softmax(att, dim=1)

        out = torch.cat((edge1, area1), dim=1)
        out[:, :dim] *= att[:, 0:1]
        out[:, dim:] *= att[:, 1:2]

        # out = self.output2(self.output1(out))
        out = self.output(out)

        return edge2, area2, out

class UNet_boost_similar_3edge(UNet_boost_similar):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__(in_ch, num_classes, **kwargs)
        
        self.edge = Edge(x_dim=in_ch, feat_dim=128, num_classes=num_classes)
    
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.unet(x)  # 3,64,256,512,2048

        aspp_e5 = self.aspp(e5)

        mid_e5, pred5 = self.Pred5(aspp_e5, e5)
        mid_e4, pred4 = self.Pred4(mid_e5, e4)
        mid_e3, _ = self.Pred3(mid_e4, e3)
        edge, area, pred3 = self.edge(x, mid_e3)

        # mid_e2, pred2 = self.Pred2(mid_e3, e2)
        # mid_e1, pred1 = self.Pred1(mid_e2, e1)

        return pred5, pred4, pred3, edge, area

class UNet_boost_similar_5edge(UNet_boost_similar):
    def __init__(self, in_ch, num_classes, **kwargs):
        super().__init__(in_ch, num_classes, **kwargs)
        
        self.edge = Edge(x_dim=in_ch, feat_dim=128, num_classes=num_classes)
    
    def forward(self, x):
        e1, e2, e3, e4, e5 = self.unet(x)  # 3,64,256,512,2048

        aspp_e5 = self.aspp(e5)

        mid_e5, pred5 = self.Pred5(aspp_e5, e5)
        mid_e4, pred4 = self.Pred4(mid_e5, e4)
        mid_e3, pred3 = self.Pred3(mid_e4, e3)
        mid_e2, pred2 = self.Pred2(mid_e3, e2)
        mid_e1, _ = self.Pred1(mid_e2, e1)
        edge, area, pred1 = self.edge(x, mid_e1)

        # mid_e2, pred2 = self.Pred2(mid_e3, e2)
        # mid_e1, pred1 = self.Pred1(mid_e2, e1)

        return pred5, pred4, pred3, pred2, pred1, edge, area
