import torch
import torch.nn as nn
import math
import torch.nn.functional as F


def compute_depth_expectation(prob, depth_values):
    depth_values = depth_values.view(*depth_values.shape, 1, 1)
    depth = torch.sum(prob * depth_values, 1)
    return depth


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ConvBlock, self).__init__()

        if kernel_size == 3:
            self.conv = nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
            )
        elif kernel_size == 1:
            self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)

        self.nonlin = nn.ELU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.nonlin(out)
        return out


class ConvBlock_double(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ConvBlock_double, self).__init__()

        if kernel_size == 3:
            self.conv = nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
            )
        elif kernel_size == 1:
            self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)

        self.nonlin = nn.ELU(inplace=True)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1)
        self.nonlin_2 = nn.ELU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.nonlin(out)
        out = self.conv_2(out)
        out = self.nonlin_2(out)
        return out


class DecoderFeature(nn.Module):
    def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]):
        super(DecoderFeature, self).__init__()
        self.num_ch_dec = num_ch_dec
        self.feat_channels = feat_channels

        self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1)
        self.upconv_3_1 = ConvBlock_double(
            self.feat_channels[2] + self.num_ch_dec[3],
            self.num_ch_dec[3],
            kernel_size=1)

        self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3)
        self.upconv_2_1 = ConvBlock_double(
            self.feat_channels[1] + self.num_ch_dec[2],
            self.num_ch_dec[2],
            kernel_size=3)

        self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3)
        self.upconv_1_1 = ConvBlock_double(
            self.feat_channels[0] + self.num_ch_dec[1],
            self.num_ch_dec[1],
            kernel_size=3)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, ref_feature):
        x = ref_feature[3]

        x = self.upconv_3_0(x)
        x = torch.cat((self.upsample(x), ref_feature[2]), 1)
        x = self.upconv_3_1(x)

        x = self.upconv_2_0(x)
        x = torch.cat((self.upsample(x), ref_feature[1]), 1)
        x = self.upconv_2_1(x)

        x = self.upconv_1_0(x)
        x = torch.cat((self.upsample(x), ref_feature[0]), 1)
        x = self.upconv_1_1(x)
        return x


class UNet(nn.Module):
    def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'):
        super(UNet, self).__init__()
        basic_block = ConvBnReLU
        num_depth = 128

        self.conv0 = basic_block(inp_ch, num_depth)
        if channel_mode == 'v0':
            channels = [num_depth, num_depth // 2, num_depth // 4, num_depth // 8, num_depth // 8]
        elif channel_mode == 'v1':
            channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth]
        self.down_sample_times = down_sample_times
        for i in range(down_sample_times):
            setattr(
                self, 'conv_%d' % i,
                nn.Sequential(
                    basic_block(channels[i], channels[i + 1], stride=2),
                    basic_block(channels[i + 1], channels[i + 1])
                )
            )
        for i in range(down_sample_times - 1, -1, -1):
            setattr(self, 'deconv_%d' % i,
                    nn.Sequential(
                        nn.ConvTranspose2d(
                            channels[i + 1],
                            channels[i],
                            kernel_size=3,
                            padding=1,
                            output_padding=1,
                            stride=2,
                            bias=False),
                        nn.BatchNorm2d(channels[i]),
                        nn.ReLU(inplace=True)
                    )
                    )
            self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0)

    def forward(self, x):
        features = {}
        conv0 = self.conv0(x)
        x = conv0
        features[0] = conv0
        for i in range(self.down_sample_times):
            x = getattr(self, 'conv_%d' % i)(x)
            features[i + 1] = x
        for i in range(self.down_sample_times - 1, -1, -1):
            x = features[i] + getattr(self, 'deconv_%d' % i)(x)
        x = self.prob(x)
        return x


class ConvBnReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=pad,
            bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)), inplace=True)


class HourglassDecoder(nn.Module):
    def __init__(self, cfg):
        super(HourglassDecoder, self).__init__()
        self.inchannels = cfg.model.decode_head.in_channels  # [256, 512, 1024, 2048]
        self.decoder_channels = cfg.model.decode_head.decoder_channel  # [64, 64, 128, 256]
        self.min_val = cfg.data_basic.depth_normalize[0]
        self.max_val = cfg.data_basic.depth_normalize[1]

        self.num_ch_dec = self.decoder_channels  # [64, 64, 128, 256]
        self.num_depth_regressor_anchor = 512
        self.feat_channels = self.inchannels
        unet_in_channel = self.num_ch_dec[1]
        unet_out_channel = 256

        self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec)
        self.conv_out_2 = UNet(inp_ch=unet_in_channel,
                               output_chal=unet_out_channel + 1,
                               down_sample_times=3,
                               channel_mode='v0',
                               )

        self.depth_regressor_2 = nn.Sequential(
            nn.Conv2d(unet_out_channel,
                      self.num_depth_regressor_anchor,
                      kernel_size=3,
                      padding=1,
                      ),
            nn.BatchNorm2d(self.num_depth_regressor_anchor),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                self.num_depth_regressor_anchor,
                self.num_depth_regressor_anchor,
                kernel_size=1,
            )
        )
        self.residual_channel = 16
        self.conv_up_2 = nn.Sequential(
            nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1),
            nn.BatchNorm2d(self.residual_channel),
            nn.ReLU(),
            nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
            nn.Upsample(scale_factor=4),
            nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(self.residual_channel, 1, 1, padding=0),
        )

    def get_bins(self, bins_num):
        depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda')
        depth_bins_vec = torch.exp(depth_bins_vec)
        return depth_bins_vec

    def register_depth_expectation_anchor(self, bins_num, B):
        depth_bins_vec = self.get_bins(bins_num)
        depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
        self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)

    def upsample(self, x, scale_factor=2):
        return F.interpolate(x, scale_factor=scale_factor, mode='nearest')

    def regress_depth_2(self, feature_map_d):
        prob = self.depth_regressor_2(feature_map_d).softmax(dim=1)
        B = prob.shape[0]
        if "depth_expectation_anchor" not in self._buffers:
            self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
        d = compute_depth_expectation(
            prob,
            self.depth_expectation_anchor[:B, ...]
        ).unsqueeze(1)
        return d

    def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
        y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
                               torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
        meshgrid = torch.stack((x, y))
        meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
        return meshgrid

    def forward(self, features_mono, **kwargs):
        '''
        trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4]
        inv_intrinsic_pool: list of inverse intrinsic matrix.
        features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...].
        '''
        outputs = {}
        # get encoder feature of the reference view
        ref_feat = features_mono

        feature_map_mono = self.decoder_mono(ref_feat)
        feature_map_mono_pred = self.conv_out_2(feature_map_mono)
        confidence_map_2 = feature_map_mono_pred[:, -1:, :, :]
        feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :]

        depth_pred_2 = self.regress_depth_2(feature_map_d_2)

        B, _, H, W = depth_pred_2.shape

        meshgrid = self.create_mesh_grid(H, W, B)

        depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \
                          self.conv_up_2(
                              torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1)
                          )
        confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4)

        outputs = dict(
            prediction=depth_pred_mono,
            confidence=confidence_map_mono,
            pred_logit=None,
        )
        return outputs
