#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jul  9 14:41:15 2020

@author: zw
"""

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from torch.nn import Parameter
from torch.nn.modules.batchnorm import _BatchNorm
from torch import nn, einsum
from torch.utils import model_zoo

import numpy as np
import math

from .resnet import *
from .resnet import BasicBlock
import torchvision
import os

############################################################################### 

class Encoding_rawres50(nn.Module):
    def __init__(self, trinum=1, pretrain=True):
        super(Encoding_rawres50, self).__init__()
        self.trinum = trinum
        self.resnet = models.resnet50(pretrained=pretrain)
    
    def forward(self, x):
        
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x1 = self.resnet.layer1(x)
        x2 = self.resnet.layer2(x1)
        x3 = self.resnet.layer3(x2)
        x4 = self.resnet.layer4(x3)
        
        return [x1, x2, x3, x4]

class Encoding(nn.Module):
    def __init__(self):
        super(Encoding, self).__init__()
        self.resnet = resnet50()
    
    def forward(self, x):
        x = self.resnet.relu1(self.resnet.bn1(self.resnet.conv1(x)))
        if self.resnet.deep_base:
            x = self.resnet.relu2(self.resnet.bn2(self.resnet.conv2(x)))
            x = self.resnet.relu3(self.resnet.bn3(self.resnet.conv3(x)))
        x = self.resnet.maxpool(x)

        x1 = self.resnet.layer1(x)              # out = [88] 256
        x2 = self.resnet.layer2(x1)             # out = [44] 512
        x3 = self.resnet.layer3(x2)             # out = [22] 1024
        x4 = self.resnet.layer4(x3)             # out = [11] 2048
        
        return [x1, x2, x3, x4]

class UpConvBlock(nn.Module):
    def __init__(self, inp1, inp2, out):
        super(UpConvBlock, self).__init__()
        
        self.inp = nn.Conv2d(inp1, out // 4, 3, padding=1) 
        self.inp_bn = nn.BatchNorm2d(out // 4)
        self.inp_relu = nn.ReLU(inplace=True)

        self.skip = nn.Conv2d(inp2, out // 4, 3, padding=1) 
        self.skip_bn = nn.BatchNorm2d(out // 4)
        self.skip_relu = nn.ReLU(inplace=True)

        self.Up = nn.Conv2d(out // 2, out, 1) 
        self.Up_bn = nn.BatchNorm2d(out)
        self.Up_relu = nn.ReLU(inplace=True)

        self.upscore = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x, x_skip):

        x_up = self.upscore(self.inp_relu(self.inp_bn(self.inp(x))))
        x_skip = self.skip_relu(self.skip_bn(self.skip(x_skip)))
        new_x = torch.cat((x_up, x_skip), 1)
        new_x = self.Up_relu(self.Up_bn(self.Up(new_x)))
        
        return new_x

class Decoding(nn.Module):
    def __init__(self, ups=3):
        super(Decoding, self).__init__()
        self.ups = ups
        
        channels = [2048, 1024, 512, 256]
        layer = []
        for i in range(ups):
            layer.append(UpConvBlock(channels[i], channels[i+1], channels[i+1]))
            
        self.layer = nn.Sequential(*layer)
        
    def forward(self, Enc_list):
        out = []
        base_in = Enc_list[-1]
        
        out.append(base_in)
        Enc_list = list(reversed(Enc_list))
        Enc_list = Enc_list[1:]
        for i in range(self.ups):
            base_in = self.layer[i](base_in, Enc_list[i])
            out.append(base_in)
        return out

class OutConvBlock(nn.Module):
    def __init__(self, inp, upscale, outch):
        super(OutConvBlock, self).__init__()
        self.outch = outch
        self.outconv = nn.Conv2d(inp, outch, 1)   
        self.upscore = nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        
        x = self.outconv(x)            
        x = self.upscore(x)

        return x

class Out(nn.Module):
    def __init__(self, outnum, outch=1):
        super(Out, self).__init__()
        self.outnum = outnum
        scale = [32, 16, 8, 4]
        inp_channel = [2048, 1024, 512, 256]
        layer = []
        
        for i in range(outnum):
            layer.append(OutConvBlock(inp_channel[i], scale[i], outch))
        
        self.layer = nn.Sequential(*layer)

    def forward(self, out_list):
        out = []
        
        for i in range(self.outnum):
            out.append(self.layer[i](out_list[i]))
        
        out = list(reversed(out))
        return out

# Global-Interaction
class GI(nn.Module):
    def __init__(self, d, dk=16, du=1, Nh=4, m=None, r=23, stride=1):
        super(GI, self).__init__()
        self.d = d
        self.dk = dk
        self.du = du
        self.Nh = Nh
        assert d % Nh == 0, 'd should be divided by Nh'
        dv = d // Nh
        self.dv = dv
        assert stride in [1, 2]
        self.stride = stride

        self.conv_qkv = nn.Conv2d(d, Nh * dk + dk * du + dv * du, 1, bias=False)
        self.norm_q = nn.BatchNorm2d(Nh * dk)
        self.norm_v = nn.BatchNorm2d(dv * du)
        self.softmax = nn.Softmax(dim=-1)
        self.lambda_conv = nn.Conv3d(du, dk, (1, r, r), padding = (0, (r - 1) // 2, (r - 1) // 2))

        if self.stride > 1:
            self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        N, C, H, W = x.shape

        qkv = self.conv_qkv(x)
        q, k, v = torch.split(qkv, [self.Nh * self.dk, self.dk * self.du, self.dv * self.du], dim=1)
        q = self.norm_q(q).view(N, self.Nh, self.dk, H*W)
        v = self.norm_v(v).view(N, self.du, self.dv, H*W)
        k = self.softmax(k.view(N, self.du, self.dk, H*W))

        lambda_c = torch.einsum('bukm,buvm->bkv', k, v)
        yc = torch.einsum('bhkm,bkv->bhvm', q, lambda_c)
        lambda_p = self.lambda_conv(v.view(N, self.du, self.dv, H, W)).view(N, self.dk, self.dv, H*W)
        yp = torch.einsum('bhkm,bkvm->bhvm', q, lambda_p)
        out = (yc + yp).reshape(N, C, H, W)

        if self.stride > 1:
            out = self.avgpool(out)

        return out

# Co-Local-Interaction
class CLI(nn.Module):
    def __init__(self, d, dk=16, du=1, Nh=4, m=None, r=5, stride=1):
        super(CLI, self).__init__()
        self.d = d
        self.dk = dk
        self.du = du
        self.Nh = Nh
        assert d % Nh == 0, 'd should be divided by Nh'
        dv = d // Nh
        self.dv = dv
        assert stride in [1, 2]
        self.stride = stride

        self.conv_qkv_Covx = nn.Conv2d(d, Nh * dk + dk * du + dv * du, 1, bias=False)
        self.conv_qkv_LR = nn.Conv2d(d, Nh * dk + dk * du + dv * du, 1, bias=False)
        self.norm_q = nn.BatchNorm2d(Nh * dk)
        self.norm_v = nn.BatchNorm2d(dv * du)
        self.softmax = nn.Softmax(dim=-1)
        self.lambda_conv_Covx = nn.Conv3d(du, dk, (1, r, r), padding = (0, (r - 1) // 2, (r - 1) // 2))
        self.lambda_conv_LR = nn.Conv3d(du, dk, (1, r, r), padding = (0, (r - 1) // 2, (r - 1) // 2))

        if self.stride > 1:
            self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, Covx, LR):
        N, C, H, W = Covx.shape

        qkv_Covx = self.conv_qkv_Covx(Covx)
        q_Covx, k_Covx, v_Covx = torch.split(qkv_Covx, [self.Nh * self.dk, self.dk * self.du, self.dv * self.du], dim=1)
        q_Covx = self.norm_q(q_Covx).view(N, self.Nh, self.dk, H*W)
        v_Covx = self.norm_v(v_Covx).view(N, self.du, self.dv, H*W)
        k_Covx = self.softmax(k_Covx.view(N, self.du, self.dk, H*W))

        qkv_LR = self.conv_qkv_LR(LR)
        q_LR, k_LR, v_LR = torch.split(qkv_LR, [self.Nh * self.dk, self.dk * self.du, self.dv * self.du], dim=1)
        q_LR = self.norm_q(q_LR).view(N, self.Nh, self.dk, H*W)
        v_LR = self.norm_v(v_LR).view(N, self.du, self.dv, H*W)
        k_LR = self.softmax(k_LR.view(N, self.du, self.dk, H*W))

        lambda_c_Covx = torch.einsum('bukm,buvm->bkv', k_Covx, v_Covx)
        yc_Covx = torch.einsum('bhkm,bkv->bhvm', q_LR, lambda_c_Covx)
        lambda_p_Covx = self.lambda_conv_Covx(v_Covx.view(N, self.du, self.dv, H, W)).view(N, self.dk, self.dv, H*W)
        yp_Covx = torch.einsum('bhkm,bkvm->bhvm', q_LR, lambda_p_Covx)
        out_Covx = (yc_Covx + yp_Covx).reshape(N, C, H, W)

        lambda_c_LR = torch.einsum('bukm,buvm->bkv', k_LR, v_LR)
        yc_LR = torch.einsum('bhkm,bkv->bhvm', q_Covx, lambda_c_LR)
        lambda_p_LR = self.lambda_conv_LR(v_LR.view(N, self.du, self.dv, H, W)).view(N, self.dk, self.dv, H*W)
        yp_LR = torch.einsum('bhkm,bkvm->bhvm', q_Covx, lambda_p_LR)
        out_LR = (yc_LR + yp_LR).reshape(N, C, H, W)

        if self.stride > 1:
            out_Covx = self.avgpool(out_Covx)
            out_LR = self.avgpool(out_LR)

        out = torch.cat((out_Covx, out_LR), 1)

        return out

# IEM
class IEM(nn.Module):
    def __init__(self, inp, rate=32):
        super(IEM, self).__init__()
        self.x_tri = nn.Conv2d(inp, inp//rate, 3, padding=1)   
        self.x_tri_bn = nn.BatchNorm2d(inp//rate)
        self.x_tri_relu = nn.ReLU(inplace=True)

        self.conx_tri = nn.Conv2d(inp, inp//rate, 3, padding=1)  
        self.conx_tri_bn = nn.BatchNorm2d(inp//rate)
        self.conx_tri_relu = nn.ReLU(inplace=True)
        
        self.order_tri = nn.Conv2d(inp, inp//rate, 3, padding=1)  
        self.order_tri_bn = nn.BatchNorm2d(inp//rate)
        self.order_tri_relu = nn.ReLU(inplace=True)

        self.outconv = nn.Conv2d((inp//rate)*3, inp, 1)   
        self.outconv_bn = nn.BatchNorm2d(inp)
        self.outconv_relu = nn.ReLU(inplace=True)

        self.CLI = CLI(d=inp//rate, r=5)
        self.GI = GI(d=(inp//rate)*3)
        
    def forward(self, x, conx, order):
        
        out = self.x_tri_relu(self.x_tri_bn(self.x_tri(x)))
        conx = self.conx_tri_relu(self.conx_tri_bn(self.conx_tri(conx)))
        order = self.conx_tri_relu(self.conx_tri_bn(self.order_tri(order)))

        x_out = out

        out_conx_order = self.CLI(conx, order)
        fuse = torch.cat((x_out, out_conx_order), 1)
        fuse = self.GI(fuse)
        fuse = self.outconv_relu(self.outconv_bn(self.outconv(fuse)))
        
        return fuse

class FG_Bridge(nn.Module):
    def __init__(self, outnum, fusenum=4):
        super(FG_Bridge, self).__init__()
        
        self.outnum = outnum
        self.fusenum = fusenum
        inp_channel = [2048, 1024, 512, 256]
        layer = []
        
        for i in range(outnum):
            layer.append(IEM(inp_channel[i]))
        
        self.layer = nn.Sequential(*layer)

    def forward(self, out_list, conx_list, order_list):
        out = []
        
        for i in range(self.outnum):
            if(i <= self.fusenum):    
                out.append(out_list[i]*(1+self.layer[i](out_list[i], conx_list[i], order_list[i])))
            else:
                out.append(out_list[i])
        
        return out

class BaseFGM(nn.Module):
    def __init__(self, ups=3, pretrain=True):
        super(BaseFGM, self).__init__()
        #self.encoding = Encoding_rawres50(pretrain=True)   
        self.encoding = Encoding()
        
        self.decoding_main = Decoding(ups=ups)     
        self.out_main = Out(outnum=ups+1)
        self.final_fuse = nn.Conv2d(4, 1, 1, padding=0)
        
        self.decoding_order = Decoding(ups=ups)
        #self.out_order = Out(outnum=ups+1, outch=2)
        
        self.decoding_con = Decoding(ups=ups)
        #self.out_con = Out(outnum=ups+1, outch=2)
        
        self.fg_bridge = FG_Bridge(outnum=ups+1)
        
        # order out
        self.outorder = nn.Conv2d(256, 32, 3, padding=1)
        self.outorder_bn = nn.BatchNorm2d(32)
        self.outorder_relu = nn.ReLU(inplace=True)
        self.outorderlast = nn.Conv2d(32, 2, 1, padding=0)
        
        # convex out
        self.outconx = nn.Conv2d(256, 32, 3, padding=1)
        self.outconx_bn = nn.BatchNorm2d(32)
        self.outconx_relu = nn.ReLU(inplace=True)
        self.outconxlast = nn.Conv2d(32, 2, 1, padding=0)
        
        self.outup = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)

        
    def forward(self, x):
        '''
        '1': [256 , 88, 88]
        '2': [512 , 44, 44]
        '3': [1024, 22, 22]
        '4': [2048, 11, 11]
        '''
        # Encoding
        Ex = self.encoding(x)               #out = [x1, x2, x3, x4]
        
        # Order Branch
        DOx = self.decoding_order(Ex)       #out = [d4, d3, d2, d1]
        
        # Convex Branch
        DCx = self.decoding_con(Ex)         #out = [d4, d3, d2, d1]
        
        # Figure-Ground Cus Fuse
        Ex = list(reversed(Ex))
        Ex = self.fg_bridge(Ex, DCx, DOx)
        Ex = list(reversed(Ex))
        
        # Main Branch
        DMx = self.decoding_main(Ex)        #out = [d4, d3, d2, d1]

        # ----Output----
        # Final Out
        
        OMx = self.out_main(DMx)            #out = [o1, o2, o3, o4]
        OMx = [torch.sigmoid(temp) for temp in OMx]
        
        # Order Out
        DOx = DOx[3]
        DOx = self.outorder_relu(self.outorder_bn(self.outorder(DOx)))
        OOx = self.outup(self.outorderlast(DOx))
        OOx = [torch.sigmoid(OOx)]
        
        # Convex Out
        DCx = DCx[3]
        DCx = self.outconx_relu(self.outconx_bn(self.outconx(DCx)))
        OCx = self.outup(self.outconxlast(DCx))
        OCx = [torch.sigmoid(OCx)]
        
        return [OMx, OOx, OCx]
