from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose
import cv2
from .layers import *

from torchvision.utils import save_image


from model.depth_anything.dpt import DepthAnything
from model.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet


class BasicConv1(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, gelu=False, bn=False, bias=True):
        super(BasicConv1, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.gelu = nn.GELU() if gelu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.gelu is not None:
            x = self.gelu(x)
        return x

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
    def __init__(self, channel):
        super(SpatialGate, self).__init__()
        kernel_size = 3
        self.compress = ChannelPool()
        self.spatial = BasicConv1(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, gelu=False)
        self.dw1 = nn.Sequential(
            BasicConv1(channel, channel, 5, stride=1, dilation=2, padding=4, groups=channel),
            BasicConv1(channel, channel, 7, stride=1, dilation=3, padding=9, groups=channel)
        )
        self.dw2 = BasicConv1(channel, channel, kernel_size, stride=1, padding=1, groups=channel)

    def forward(self, x):
        out = self.compress(x)
        out = self.spatial(out)
        out = self.dw1(x) * out + self.dw2(x)
        return out


class LocalAttention(nn.Module):
    def __init__(self, channel, p) -> None:
        super().__init__()
        self.channel = channel

        self.num_patch = 2 ** p
        self.sig = nn.Sigmoid()

        self.a = nn.Parameter(torch.zeros(channel,1,1))
        self.b = nn.Parameter(torch.ones(channel,1,1))

    def forward(self, x):
        out = x - torch.mean(x, dim=(2,3), keepdim=True)
        return self.a*out*x + self.b*x

class ParamidAttention(nn.Module):
    def __init__(self, channel) -> None:
        super().__init__()
        pyramid = 1
        self.spatial_gate = SpatialGate(channel)
        layers = [LocalAttention(channel, p=i) for i in range(pyramid-1,-1,-1)]
        self.local_attention = nn.Sequential(*layers)
        self.a = nn.Parameter(torch.zeros(channel,1,1))
        self.b = nn.Parameter(torch.ones(channel,1,1))
    def forward(self, x):
        out = self.spatial_gate(x)
        out = self.local_attention(out)
        return self.a*out + self.b*x


class EBlock(nn.Module):
    def __init__(self, out_channel, num_res=8):
        super(EBlock, self).__init__()

        layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class DBlock(nn.Module):
    def __init__(self, channel, num_res=8):
        super(DBlock, self).__init__()

        layers = [ResBlock(channel, channel) for _ in range(num_res)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)
class EBlock1(nn.Module):
    def __init__(self, out_channel, num_res=8):
        super(EBlock1, self).__init__()

        self.layers = UNet(out_channel, out_channel, num_res)
    def forward(self, x):
        return self.layers(x)


class DBlock1(nn.Module):
    def __init__(self, channel, num_res=8):
        super(DBlock1, self).__init__()

        self.layers = UNet(channel, channel, num_res)
    def forward(self, x):
        return self.layers(x)

class SCM(nn.Module):
    def __init__(self, out_plane):
        super(SCM, self).__init__()
        self.main = nn.Sequential(
            BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
            nn.InstanceNorm2d(out_plane, affine=True)
        )

    def forward(self, x):
        x = self.main(x)
        return x

class FAM(nn.Module):
    def __init__(self, channel):
        super(FAM, self).__init__()
        self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)

    def forward(self, x1, x2):
        return self.merge(torch.cat([x1, x2], dim=1))

class FocalNet(nn.Module):
    def __init__(self, num_res=4):
        super(FocalNet, self).__init__()

        base_channel = 32

        self.Encoder = nn.ModuleList([
            EBlock1(base_channel, num_res),
            EBlock(base_channel*2, num_res),
            EBlock(base_channel*4, num_res),
        ])

        self.feat_extract = nn.ModuleList([
            BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
            BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
        ])

        self.Decoder = nn.ModuleList([
            DBlock(base_channel * 4, num_res),
            DBlock(base_channel * 2, num_res),
            DBlock1(base_channel, num_res)
        ])

        self.Convs = nn.ModuleList([
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
            BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
        ])

        self.ConvsOut = nn.ModuleList(
            [
                BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
                BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
            ]
        )

        self.FAM1 = FAM(base_channel * 4)
        self.SCM1 = SCM(base_channel * 4)
        self.FAM2 = FAM(base_channel * 2)
        self.SCM2 = SCM(base_channel * 2)

        pyramid_attention = []
        for _ in range(1):
            pyramid_attention.append(ParamidAttention(base_channel * 4))
        self.pyramid_attentions = nn.Sequential(*pyramid_attention)

    def forward(self, x):
        x_2 = F.interpolate(x, scale_factor=0.5)
        x_4 = F.interpolate(x_2, scale_factor=0.5)
        z2 = self.SCM2(x_2)
        z4 = self.SCM1(x_4)

        outputs = list()
        # 256
        x_ = self.feat_extract[0](x)
        res1 = self.Encoder[0](x_)
        # 128
        z = self.feat_extract[1](res1)
        z = self.FAM2(z, z2)
        res2 = self.Encoder[1](z)
        # 64
        z = self.feat_extract[2](res2)
        z = self.FAM1(z, z4)
        z = self.Encoder[2](z)

        z = self.pyramid_attentions(z)
        
        z = self.Decoder[0](z)
        z_ = self.ConvsOut[0](z)
        # 128
        z = self.feat_extract[3](z)
        outputs.append(z_+x_4)

        z = torch.cat([z, res2], dim=1)
        z = self.Convs[0](z)
        z = self.Decoder[1](z)
        z_ = self.ConvsOut[1](z)
        # 256
        z = self.feat_extract[4](z)
        outputs.append(z_+x_2)

        z = torch.cat([z, res1], dim=1)
        z = self.Convs[1](z)
        z = self.Decoder[2](z)
        z = self.feat_extract[5](z)
        outputs.append(z+x)

        return outputs[2]


# class SFTLayer(nn.Module):
#     def __init__(self, dim, dim_hidden):
#         super(SFTLayer, self).__init__()
#         # self.SFT_conv0 = nn.Conv2d(dim, dim_hidden, 3, padding=1)
#         self.SFT_scale_conv0 = nn.Conv2d(dim, dim_hidden, 3, padding=1)
#         self.SFT_scale_conv1 = nn.Conv2d(dim_hidden, dim, 3, padding=1)
#         self.SFT_shift_conv0 = nn.Conv2d(dim, dim_hidden, 3, padding=1)
#         self.SFT_shift_conv1 = nn.Conv2d(dim_hidden, dim, 3, padding=1)
#         self.gelu = nn.GELU()

#     def forward(self, x):
#         # x[0]: fea; x[1]: cond
#         scale = self.SFT_scale_conv1(self.gelu(self.SFT_scale_conv0(x[1])))
#         shift = self.SFT_shift_conv1(self.gelu(self.SFT_shift_conv0(x[1])))
#         return x[0] * (scale + 1) + shift


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


class SFTLayer(nn.Module):
    def __init__(self, dim, dim_hidden):
        super(SFTLayer, self).__init__()
        # self.SFT_conv0 = nn.Conv2d(dim, dim_hidden, 3, padding=1)
        self.SFT_scale_conv0 = nn.Conv2d(dim, dim_hidden, 3, padding=1)
        self.SFT_scale_conv1 = nn.Conv2d(dim_hidden, dim, 3, padding=1)
        self.SFT_shift_conv0 = nn.Conv2d(dim, dim_hidden, 3, padding=1)
        self.SFT_shift_conv1 = nn.Conv2d(dim_hidden, dim, 3, padding=1)
        self.gelu = nn.GELU()

    def forward(self, x):
        # x[0]: fea; x[1]: cond
        scale = self.SFT_scale_conv1(self.gelu(self.SFT_scale_conv0(x[1])))
        shift = self.SFT_shift_conv1(self.gelu(self.SFT_shift_conv0(x[1])))
        return x[0] * (scale + 1) + shift


class FocalNetDepth(nn.Module):
    def __init__(self, num_res=4):
        super(FocalNetDepth, self).__init__()

        self.depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_vitl14').cuda().eval()
        for param in self.depth_anything.parameters():
            param.requires_grad = False
        self.focalnet = FocalNet()

        self.depth_transform = Compose([
            Resize(
                width=518,
                height=518,
                resize_target=False,
                keep_aspect_ratio=True,
                ensure_multiple_of=14,
                resize_method='lower_bound',
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            PrepareForNet(),
        ])

        self.CondNet = nn.ModuleList([
            BasicConv(1, 32, kernel_size=3, relu=True, stride=1),
            BasicConv(32, 64, kernel_size=3, relu=True, stride=2),
            BasicConv(64, 128, kernel_size=3, relu=True, stride=2)
        ])

        self.sft1 = SFTLayer(32, 32)
        self.sft2 = SFTLayer(64, 64)
        self.sft3 = SFTLayer(128, 128)


    def get_depth(self, x):
        _, _, x_h, x_w = x.shape
        xs = torch.split(x, 1, dim=0)
        tmp = []
        for x_ in xs:
            x_ = x_.squeeze().permute(1, 2, 0).cpu().numpy()
            x_ = self.depth_transform({'image': x_})['image']
            tmp.append(torch.from_numpy(x_).cuda().unsqueeze(dim=0))
        x = torch.concat(tmp, dim=0).cuda()
        with torch.no_grad():
            depth = self.depth_anything(x)
        depth = F.interpolate(depth.unsqueeze(dim=1), (x_h, x_w), mode='bilinear', align_corners=False)
        depth = (depth - depth.min()) / (depth.max() - depth.min())
        return depth

    def forward(self, x):

        x_depth = self.get_depth(x)

        x_depth1 = self.CondNet[0](x_depth)
        x_depth2 = self.CondNet[1](x_depth1)
        x_depth3 = self.CondNet[2](x_depth2)

        x_2 = F.interpolate(x, scale_factor=0.5)
        x_4 = F.interpolate(x_2, scale_factor=0.5)
        z2 = self.focalnet.SCM2(x_2)
        z4 = self.focalnet.SCM1(x_4)

        outputs = list()
        # 256
        x_ = self.focalnet.feat_extract[0](x)
        res1 = self.focalnet.Encoder[0](x_)
        res1_sft = self.sft1([res1, x_depth1])
        # 128
        z = self.focalnet.feat_extract[1](res1_sft)
        z = self.focalnet.FAM2(z, z2)
        res2 = self.focalnet.Encoder[1](z)
        res2_sft = self.sft2([res2, x_depth2])
        # 64
        z = self.focalnet.feat_extract[2](res2_sft)
        z = self.focalnet.FAM1(z, z4)
        z = self.focalnet.Encoder[2](z)
        z_sft = self.sft3([z, x_depth3])

        z = self.focalnet.pyramid_attentions(z_sft)
        
        z = self.focalnet.Decoder[0](z)
        z_ = self.focalnet.ConvsOut[0](z)
        # 128
        z = self.focalnet.feat_extract[3](z)
        outputs.append(z_+x_4)

        z = torch.cat([z, res2_sft], dim=1)
        z = self.focalnet.Convs[0](z)
        z = self.focalnet.Decoder[1](z)
        z_ = self.focalnet.ConvsOut[1](z)
        # 256
        z = self.focalnet.feat_extract[4](z)
        outputs.append(z_+x_2)

        z = torch.cat([z, res1_sft], dim=1)
        z = self.focalnet.Convs[1](z)
        z = self.focalnet.Decoder[2](z)
        z = self.focalnet.feat_extract[5](z)
        outputs.append(z+x)

        return outputs[2]