import shutil
from collections import OrderedDict

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt

from models.unet_parts import DoubleConv, Down, Up, OutConv
from models.mmffc import MMFFCLayer


class DepthBoundaryAwareModule(nn.Module):
    def __init__(self):
        super(DepthBoundaryAwareModule, self).__init__()

        # Sobel filter kernels
        self.sobel_x = torch.HalfTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3).cuda()
        self.sobel_y = torch.HalfTensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).view(1, 1, 3, 3).cuda()
        # require grad false
        self.sobel_x.requires_grad = False
        self.sobel_y.requires_grad = False

        # self.sobel_x = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3).cuda()
        # self.sobel_y = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).view(1, 1, 3, 3).cuda()

    def forward(self, depth_image):
        # Apply Sobel filter for x and y gradients
        depth_gradient_x = F.conv2d(depth_image, self.sobel_x, padding=1)
        depth_gradient_y = F.conv2d(depth_image, self.sobel_y, padding=1)

        # Calculate gradient magnitude
        depth_gradient_magnitude = torch.sqrt(depth_gradient_x ** 2 + depth_gradient_y ** 2)

        # Normalize gradient magnitude to [0, 1]
        depth_gradient_normalized = depth_gradient_magnitude / depth_gradient_magnitude.max()

        # Use depth gradients to generate a boundary mask
        boundary_mask = (depth_gradient_normalized > 0.03).float()

        return boundary_mask


class SeparableConv2d(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, relu_first=True,
                 bias=False, norm_layer=nn.BatchNorm2d):
        super().__init__()
        depthwise = nn.Conv2d(inplanes, inplanes, kernel_size,
                              stride=stride, padding=dilation,
                              dilation=dilation, groups=inplanes, bias=bias)
        bn_depth = norm_layer(inplanes)
        pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
        bn_point = norm_layer(planes)

        if relu_first:
            self.block = nn.Sequential(OrderedDict([('relu', nn.ReLU()),
                                                    ('depthwise', depthwise),
                                                    ('bn_depth', bn_depth),
                                                    ('pointwise', pointwise),
                                                    ('bn_point', bn_point)
                                                    ]))
        else:
            self.block = nn.Sequential(OrderedDict([('depthwise', depthwise),
                                                    ('bn_depth', bn_depth),
                                                    ('relu1', nn.ReLU(inplace=True)),
                                                    ('pointwise', pointwise),
                                                    ('bn_point', bn_point),
                                                    ('relu2', nn.ReLU(inplace=True))
                                                    ]))

    def forward(self, x):
        return self.block(x)


class ImageDepthFusionModule(nn.Module):
    def __init__(self, norm_layer=nn.BatchNorm2d, inplane=256):
        super(ImageDepthFusionModule, self).__init__()
        self.conv1 = SeparableConv2d(inplane * 2, inplane, 3, norm_layer=norm_layer, relu_first=False)
        self.fc1 = nn.Conv2d(inplane, inplane // 16, kernel_size=1)
        self.fc2 = nn.Conv2d(inplane // 16, inplane, kernel_size=1)

    def forward(self, c, att_map):
        # if c.size() != att_map.size():
        #     att_map = F.interpolate(att_map, c.size()[2:], mode='bilinear', align_corners=True)

        atted_c = c * att_map
        x = torch.cat([c, atted_c], 1)  # 512
        x = self.conv1(x)  # 256

        weight = F.avg_pool2d(x, x.size(2))
        weight = F.relu(self.fc1(weight))
        weight = torch.sigmoid(self.fc2(weight))
        x = x * weight
        return x


class MoE(nn.Module):
    def __init__(self, num_experts, input_dims, output_dim):
        super(MoE, self).__init__()
        self.experts = nn.ModuleList(
            [OutConv(in_channels=input_dims[i], out_channels=output_dim) for i in range(num_experts)])
        self.gating_network = torch.tensor([1.0 / num_experts for _ in range(num_experts)], requires_grad=True)

    def forward(self, inputs):
        # 获取每一层的输出
        expert_outputs = [expert(inputs[i]) for i, expert in enumerate(self.experts)]
        for i, expert in enumerate(self.experts):
            expert_outputs[i] = F.interpolate(expert_outputs[i],
                                              scale_factor=inputs[-1].shape[2] / expert_outputs[i].shape[2],
                                              mode='bilinear')
        # 将每一层的输出与权重相乘
        weighted_outputs = [expert_outputs[i] * self.gating_network[i] for i in range(len(expert_outputs))]
        # 将每一层的输出相加
        moe_output = sum(weighted_outputs)
        return moe_output, expert_outputs


class MoEConv(nn.Module):
    def __init__(self, input_dims, output_dim=3):
        super(MoEConv, self).__init__()
        num_experts = len(input_dims)
        self.experts = nn.ModuleList(
            [OutConv(in_channels=input_dims[i], out_channels=output_dim) for i in range(num_experts)])
        self.gating_network = nn.Conv2d(in_channels=output_dim * num_experts, out_channels=num_experts, kernel_size=1,
                                        padding='same')

    def forward(self, inputs):
        # 获取每一层的输出
        expert_outputs = [expert(inputs[i]) for i, expert in enumerate(self.experts)]
        for i, expert in enumerate(self.experts):
            expert_outputs[i] = F.interpolate(expert_outputs[i],
                                              scale_factor=inputs[-1].shape[2] / expert_outputs[i].shape[2],
                                              mode='bilinear')
        expert_weights = self.gating_network(torch.cat(expert_outputs, dim=1))
        expert_weights = F.softmax(expert_weights, dim=1)
        expert_weights = torch.split(expert_weights, 1, dim=1)
        # 将每一层的输出与权重相乘
        weighted_outputs = [expert_outputs[i] * expert_weights[i] for i in range(len(expert_outputs))]
        # 将每一层的输出相加
        moe_output = sum(weighted_outputs)
        return moe_output, expert_outputs


class LEDN(nn.Module):
    def __init__(self, img_channels=4, depth_channel=1, n_classes=3, bilinear=False,
                 use_dba=True, use_idf=True, use_moe=True):
        super(LEDN, self).__init__()
        self.use_dba = use_dba
        self.use_idf = use_idf
        self.use_moe = use_moe

        factor = 2 if bilinear else 1
        self.dba = DepthBoundaryAwareModule()
        img_channels = 4  # if use_dba else 3

        self.en_inc = (DoubleConv(img_channels, 64))
        self.dp_inc = (DoubleConv(depth_channel, 64))
        self.idf0 = ImageDepthFusionModule(inplane=64)
        self.ffc0 = MMFFCLayer(64, 64)

        self.en_down1 = (Down(64, 128))
        self.dp_down1 = (Down(64, 128))
        self.idf1 = ImageDepthFusionModule(inplane=128)
        self.ffc1 = MMFFCLayer(128, 128)

        self.en_down2 = (Down(128, 256))
        self.dp_down2 = (Down(128, 256))
        self.idf2 = ImageDepthFusionModule(inplane=256)
        self.ffc2 = MMFFCLayer(256, 256)

        self.en_down3 = (Down(256, 256))
        self.dp_down3 = (Down(256, 256))
        self.idf3 = ImageDepthFusionModule(inplane=256)
        self.ffc3 = MMFFCLayer(256, 256)

        self.de_up3 = (Up(512, 256 // factor, bilinear))
        self.de_up2 = (Up(256, 128 // factor, bilinear))
        self.de_up1 = (Up(128, 64, bilinear))

        self.outc = (OutConv(64, n_classes))
        self.moe = MoEConv([256 // factor, 128 // factor, 64], n_classes)

    def forward(self, imgs, depths):
        if self.use_dba:
            boundary_mask = self.dba(depths)
            imgs = torch.cat((imgs, boundary_mask), dim=1)
        else:
            boundary_mask = torch.ones(imgs.size(0), 1, imgs.size(2), imgs.size(3)).cuda()
            imgs = torch.cat((imgs, depths), dim=1)

        x1 = self.en_inc(imgs)
        if self.use_idf:
            x1_dp = self.dp_inc(depths)
            x1 = self.ffc0((x1, x1_dp))

        x2 = self.en_down1(x1)
        if self.use_idf:
            x2_dp = self.dp_down1(x1_dp)
            x2 = self.ffc1((x2, x2_dp))

        x3 = self.en_down2(x2)
        if self.use_idf:
            x3_dp = self.dp_down2(x2_dp)
            x3 = self.ffc2((x3, x3_dp))

        x4 = self.en_down3(x3)
        if self.use_idf:
            x4_dp = self.dp_down3(x3_dp)
            x4 = self.ffc3((x4, x4_dp))

        x_de3 = self.de_up3(x4, x3)
        x_de2 = self.de_up2(x_de3, x2)
        x_de1 = self.de_up1(x_de2, x1)

        if self.use_moe:
            moe_output, expert_outputs = self.moe([x_de3, x_de2, x_de1])
            return moe_output, expert_outputs
        else:
            logits = self.outc(x_de1)
            return logits, None
