import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.layers import DropPath
import math
import torch.cuda
from networks import common
from torch.nn.functional import softmax
from layers import *
import cv2

class PositionalEncodingFourier(nn.Module):
    """
    Positional encoding relying on a fourier kernel matching the one used in the
    "Attention is all of Need" paper. The implementation builds on DeTR code
    https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
    """

    def __init__(self, hidden_dim=32, dim=768, temperature=10000):
        super().__init__()
        self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
        self.scale = 2 * math.pi
        self.temperature = temperature
        self.hidden_dim = hidden_dim
        self.dim = dim

    def forward(self, B, H, W):
        mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device)
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        eps = 1e-6
        y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
        x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(),
                             pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(),
                             pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        pos = self.token_projection(pos)
        return pos


class XCA(nn.Module):
    """ Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted
     sum. The weights are obtained from the (softmax normalized) Cross-covariance
    matrix (Q^T K \\in d_h \\times d_h)
    """

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'temperature'}


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)


    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class BNGELU(nn.Module):
    def __init__(self, nIn):
        super().__init__()
        self.bn = nn.BatchNorm2d(nIn, eps=1e-5)
        self.act = nn.GELU()

    def forward(self, x):
        output = self.bn(x)
        output = self.act(output)

        return output


class Conv(nn.Module):
    def __init__(self, nIn, nOut, kSize, stride, padding=0, dilation=(1, 1), groups=1, bn_act=False, bias=False):
        super().__init__()

        self.bn_act = bn_act

        self.conv = nn.Conv2d(nIn, nOut, kernel_size=kSize,
                              stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)

        if self.bn_act:
            self.bn_gelu = BNGELU(nOut)

    def forward(self, x):
        output = self.conv(x)

        if self.bn_act:
            output = self.bn_gelu(output)

        return output


class CDilated(nn.Module):
    """
    This class defines the dilated convolution.
    """

    def __init__(self, nIn, nOut, kSize, stride=1, d=0, groups=1, bias=False):
        """
        :param nIn: number of input channels
        :param nOut: number of output channels
        :param kSize: kernel size
        :param stride: optional stride rate for down-sampling
        :param d: optional dilation rate
        """
        super().__init__()
        padding = int((kSize - 1) / 2) * d
        self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=bias,
                              dilation=d, groups=groups)

    def forward(self, input):
        """
        :param input: input feature map
        :return: transformed feature map
        """

        output = self.conv(input)
        return output


class DilatedConv(nn.Module):
    """
    A single Dilated Convolution layer in the Consecutive Dilated Convolutions (CDC) module.
    """
    def __init__(self, dim, k, dilation=1, stride=1, drop_path=0.,
                 layer_scale_init_value=1e-6, expan_ratio=6):
        """
        :param dim: input dimension
        :param k: kernel size
        :param dilation: dilation rate
        :param drop_path: drop_path rate
        :param layer_scale_init_value:
        :param expan_ratio: inverted bottelneck residual
        """

        super().__init__()

        self.ddwconv = CDilated(dim, dim, kSize=k, stride=stride, groups=dim, d=dilation)
        self.bn1 = nn.BatchNorm2d(dim)

        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, expan_ratio * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(expan_ratio * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x

        x = self.ddwconv(x)
        x = self.bn1(x)

        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)

        return x


class FuseBlock(nn.Module):
    def __init__(self, num_feature, act, norm, kernel_size, num_res, scale=2):
        super(FuseBlock, self).__init__()

        self.scale = scale
        self.num = kernel_size * kernel_size

        self.depth_kernel = nn.Sequential(
                    common.ConvBNReLU2D(in_channels=num_feature, out_channels=num_feature // 2, kernel_size=1, act=act,
                                        norm=norm))

        self.guide_kernel = nn.Sequential(
                    common.ConvBNReLU2D(in_channels=3, out_channels=num_feature // 2, kernel_size=1, act=act, norm=norm),
                    )

        self.pix_shf = nn.PixelShuffle(upscale_factor=scale)
        self.scharr_x = nn.Parameter(torch.Tensor([[[[3, 0, -3], [10, 0, -10], [3, 0, -3]]]]))
        self.scharr_y = nn.Parameter(torch.Tensor([[[[3, 10, 3], [0, 0, 0], [-3, -10, -3]]]]))

    def forward(self, depth, guide):
        b, c, h, w = depth.size()
        depth_kernel = self.depth_kernel(depth)
        guide_kernel = self.guide_kernel(guide)
    #Scharr operation
        num_channels = depth_kernel.size(1)

        depth_kernel_scharr_x_list = []
        depth_kernel_scharr_y_list = []
        guide_kernel_scharr_x_list = []
        guide_kernel_scharr_y_list = []

        for i in range(num_channels):
            current_channel = depth_kernel[:, i:i + 1, :, :]

            depth_kernel_scharr_x = F.conv2d(current_channel, self.scharr_x, padding=1)
            depth_kernel_scharr_y = F.conv2d(current_channel, self.scharr_y, padding=1)
            depth_kernel_scharr_x_list.append(depth_kernel_scharr_x)
            depth_kernel_scharr_y_list.append(depth_kernel_scharr_y)

        depth_kernel_scharr_x = torch.cat(depth_kernel_scharr_x_list, dim=1)
        depth_kernel_scharr_y = torch.cat(depth_kernel_scharr_y_list, dim=1)
        depth_hor_pool = nn.AvgPool2d(kernel_size=(h, 1))(depth_kernel_scharr_x)
        depth_ver_pool = nn.AvgPool2d(kernel_size=(1, w))(depth_kernel_scharr_y)

        for i in range(num_channels):
            current_channel = guide_kernel[:, i:i + 1, :, :]

            guide_kernel_scharr_x = F.conv2d(current_channel, self.scharr_x, padding=1)
            guide_kernel_scharr_y = F.conv2d(current_channel, self.scharr_y, padding=1)
            guide_kernel_scharr_x_list.append(guide_kernel_scharr_x)
            guide_kernel_scharr_y_list.append(guide_kernel_scharr_y)

        guide_kernel_scharr_x = torch.cat(guide_kernel_scharr_x_list, dim=1)
        guide_kernel_scharr_y = torch.cat(guide_kernel_scharr_y_list, dim=1)
        guide_hor_pool = nn.AvgPool2d(kernel_size=(h, 1))(guide_kernel_scharr_x)
        guide_ver_pool = nn.AvgPool2d(kernel_size=(1, w))(guide_kernel_scharr_y)

        result_hor = torch.max(depth_hor_pool, guide_hor_pool)
        result_ver = torch.max(depth_ver_pool, guide_ver_pool)
        fuse_kernel = result_hor * result_ver

        abs_kernel = torch.abs(fuse_kernel)
        abs_kernel_sum = torch.sum(abs_kernel, dim=1, keepdim=True) + 1e-4
        abs_kernel_sum[abs_kernel_sum < 1.0] = 1.0
        fuse_kernel = fuse_kernel / abs_kernel_sum

        return fuse_kernel


class AvgPool(nn.Module):
   def __init__(self, ratio):
       super().__init__()
       self.pool = nn.ModuleList()
       for i in range(0, ratio):
           self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))

   def forward(self, x):
       for pool in self.pool:
           x = pool(x)

       return x


class LiteGfm(nn.Module):
    """
    LiteGfm
    """
    def __init__(self, in_chans=3, model='litegfm', height=320, width=1024,
                 global_block=[1, 1, 1], global_block_type=['AAG', 'AAG', 'AAG'],
                 drop_path_rate=0.2, layer_scale_init_value=1e-6, expan_ratio=6,
                 heads=[8, 8, 8], use_pos_embd_xca=[True, False, False], **kwargs):

        super().__init__()

        if model == 'litegfm':
            self.num_ch_enc = np.array([48, 80, 128])
            self.depth = [4, 4, 7]
            self.dims = [48, 80, 128]
            if height == 192 and width == 640:
                self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 2, 4, 6]]
            elif height == 320 and width == 1024:
                self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 2, 4, 10]]

        elif model == 'litegfm-small':
            self.num_ch_enc = np.array([32, 64, 128])
            self.depth = [4, 4, 7]
            self.dims = [32, 64, 128]
            if height == 192 and width == 640:
                self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 2, 4, 6]]
            elif height == 320 and width == 1024:
                self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 2, 4, 10]]

        for g in global_block_type:
            assert g in ['None', 'AAG']

        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem1 = nn.Sequential(
            Conv(in_chans, self.dims[0], kSize=3, stride=2, padding=1, bn_act=True),
            Conv(self.dims[0], self.dims[0], kSize=3, stride=1, padding=1, bn_act=True),
            Conv(self.dims[0], self.dims[0], kSize=3, stride=1, padding=1, bn_act=True),
        )

        self.stem2 = nn.Sequential(
            Conv(self.dims[0], self.dims[0], kSize=3, stride=2, padding=1, bn_act=False),
        )

        self.downsample_layers.append(stem1)

        self.input_downsample = nn.ModuleList()
        for i in range(1, 5):
            self.input_downsample.append(AvgPool(i))

        for i in range(2):
            downsample_layer = nn.Sequential(
                Conv(self.dims[i], self.dims[i+1], kSize=3, stride=2, padding=1, bn_act=False),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depth))]
        cur = 0
        for i in range(3):
            stage_blocks = []
            for j in range(self.depth[i]):
               if j > self.depth[i] - global_block[i] - 1:
                   if global_block_type[i] == 'AAG':
                       stage_blocks.append(FuseBlock(num_feature=self.dims[i], act='PReLU', norm=None,
                                                     kernel_size=3, num_res=2, scale=2))
                   else:
                       raise NotImplementedError
               else:
                   stage_blocks.append(DilatedConv(dim=self.dims[i], k=3, drop_path=dp_rates[cur + j],
                                                    layer_scale_init_value=layer_scale_init_value,
                                                    expan_ratio=expan_ratio))

            self.stages.append(nn.Sequential(*stage_blocks))
            cur += self.depth[i]

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        elif isinstance(m, (LayerNorm, nn.LayerNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        features = []
        x = (x - 0.45) / 0.225
        x_down = []

        for i in range(0, 4):
            x_result = self.input_downsample[i](x)
            if i > 0:
                x_down.append(x_result)

        tmp_x = []
        kernels = []
        x = self.downsample_layers[0](x)        #48
        x = self.stem2(x)                       #48
        get2 = x
        for s in range(len(self.stages[0])-1):
            x = self.stages[0][s](x)
        fuse_kernel = self.stages[0][-1](get2, x_down[0])  # USE RUSULT OF DOWNSAMPLE to filtering
        kernels.append(fuse_kernel)
        tmp_x.append(x)
        features.append(x)

        for i in range(1, 3):
            x = self.downsample_layers[i](x)
            get2 = x
            tmp_x = [x]
            for s in range(len(self.stages[i])-1):
                x = self.stages[i][s](x)
            fuse_kernel = self.stages[i][-1](get2, x_down[i])
            kernels.append(fuse_kernel)
            tmp_x.append(x)
            features.append(x)

        return features, kernels

    def forward(self, x):
        x, kernels = self.forward_features(x)

        return x, kernels
