import torch

from utils.mpi.homography_sampler import HomographySample
from utils.mpi.rendering_utils import transform_G_xyz, sample_pdf, gather_pixel_by_pxpy


def render(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW, use_alpha=False, is_bg_depth_inf=False):
    if not use_alpha:
        imgs_syn, depth_syn, blend_weights, weights = plane_volume_rendering(
            rgb_BS3HW,
            sigma_BS1HW,
            xyz_BS3HW,
            is_bg_depth_inf
        )
    else:
        imgs_syn, weights = alpha_composition(sigma_BS1HW, rgb_BS3HW)
        depth_syn, _ = alpha_composition(sigma_BS1HW, xyz_BS3HW[:, :, 2:])
        # No rgb blending with alpha composition
        blend_weights = torch.cumprod(1 - sigma_BS1HW + 1e-6, dim=1)
        # blend_weights = torch.zeros_like(rgb_BS3HW).cuda()
    return imgs_syn, depth_syn, blend_weights, weights


def alpha_composition(alpha_BK1HW, value_BKCHW):
    """
    composition equation from 'Single-View View Synthesis with Multiplane Images'
    K is the number of planes, k=0 means the nearest plane, k=K-1 means the farthest plane
    :param alpha_BK1HW: alpha at each of the K planes
    :param value_BKCHW: rgb/disparity at each of the K planes
    :return:
    """
    B, K, _, H, W = alpha_BK1HW.size()
    alpha_comp_cumprod = torch.cumprod(1 - alpha_BK1HW, dim=1)  # BxKx1xHxW

    preserve_ratio = torch.cat((torch.ones((B, 1, 1, H, W), dtype=alpha_BK1HW.dtype, device=alpha_BK1HW.device),
                                alpha_comp_cumprod[:, 0:K-1, :, :, :]), dim=1)  # BxKx1xHxW
    weights = alpha_BK1HW * preserve_ratio  # BxKx1xHxW
    value_composed = torch.sum(value_BKCHW * weights, dim=1, keepdim=False)  # Bx3xHxW

    return value_composed, weights


def plane_volume_rendering(rgb_BS3HW, sigma_BS1HW, xyz_BS3HW, is_bg_depth_inf):
    B, S, _, H, W = sigma_BS1HW.size()

    xyz_diff_BS3HW = xyz_BS3HW[:, 1:, :, :, :] - xyz_BS3HW[:, 0:-1, :, :, :]  # Bx(S-1)x3xHxW
    xyz_dist_BS1HW = torch.norm(xyz_diff_BS3HW, dim=2, keepdim=True)  # Bx(S-1)x1xHxW

    xyz_dist_BS1HW = torch.cat((xyz_dist_BS1HW,
                                torch.full((B, 1, 1, H, W),
                                           fill_value=1e3,
                                           dtype=xyz_BS3HW.dtype,
                                           device=xyz_BS3HW.device)),
                               dim=1)  # BxSx3xHxW
    transparency = torch.exp(-sigma_BS1HW * xyz_dist_BS1HW)  # BxSx1xHxW
    alpha = 1 - transparency # BxSx1xHxW

    # add small eps to avoid zero transparency_acc
    # pytorch.cumprod is like: [a, b, c] -> [a, a*b, a*b*c], we need to modify it to [1, a, a*b]
    transparency_acc = torch.cumprod(transparency + 1e-6, dim=1)  # BxSx1xHxW
    transparency_acc = torch.cat((torch.ones((B, 1, 1, H, W), dtype=transparency.dtype, device=transparency.device),
                                  transparency_acc[:, 0:-1, :, :, :]),
                                 dim=1)  # BxSx1xHxW

    weights = transparency_acc * alpha  # BxSx1xHxW
    rgb_out, depth_out = weighted_sum_mpi(rgb_BS3HW, xyz_BS3HW, weights, is_bg_depth_inf)

    return rgb_out, depth_out, transparency_acc, weights


def weighted_sum_mpi(rgb_BS3HW, xyz_BS3HW, weights, is_bg_depth_inf):
    weights_sum = torch.sum(weights, dim=1, keepdim=False)  # Bx1xHxW
    rgb_out = torch.sum(weights * rgb_BS3HW, dim=1, keepdim=False)  # Bx3xHxW

    if is_bg_depth_inf:
        # for dtu dataset, set large depth if weight_sum is small
        depth_out = torch.sum(weights * xyz_BS3HW[:, :, 2:, :, :], dim=1, keepdim=False) \
                    + (1 - weights_sum) * 1000
    else:
        depth_out = torch.sum(weights * xyz_BS3HW[:, :, 2:, :, :], dim=1, keepdim=False) \
                    / (weights_sum + 1e-5)  # Bx1xHxW

    return rgb_out, depth_out


def get_xyz_from_depth(meshgrid_homo,
                       depth,
                       K_inv):
    """

    :param meshgrid_homo: 3xHxW
    :param depth: Bx1xHxW
    :param K_inv: Bx3x3
    :return:
    """
    H, W = meshgrid_homo.size(1), meshgrid_homo.size(2)
    B, _, H_d, W_d = depth.size()
    assert H==H_d, W==W_d

    # 3xHxW -> Bx3xHxW
    meshgrid_src_homo = meshgrid_homo.unsqueeze(0).repeat(B, 1, 1, 1)
    meshgrid_src_homo_B3N = meshgrid_src_homo.reshape(B, 3, -1)
    xyz_src = torch.matmul(K_inv, meshgrid_src_homo_B3N)  # Bx3xHW
    xyz_src = xyz_src.reshape(B, 3, H, W) * depth  # Bx3xHxW

    return xyz_src


def disparity_consistency_src_to_tgt(meshgrid_homo, K_src_inv, disparity_src,
                                     G_tgt_src, K_tgt, disparity_tgt):
    """

    :param xyz_src_B3N: Bx3xN
    :param G_tgt_src: Bx4x4
    :param K_tgt: Bx3x3
    :param disparity_tgt: Bx1xHxW
    :return:
    """
    B, _, H, W = disparity_src.size()
    depth_src = torch.reciprocal(disparity_src)
    xyz_src_B3N = get_xyz_from_depth(meshgrid_homo, depth_src, K_src_inv).view(B, 3, H*W)

    xyz_tgt_B3N = transform_G_xyz(G_tgt_src, xyz_src_B3N, is_return_homo=False)
    K_xyz_tgt_B3N = torch.matmul(K_tgt, xyz_tgt_B3N)
    pxpy_tgt_B2N = K_xyz_tgt_B3N[:, 0:2, :] / K_xyz_tgt_B3N[:, 2:, :]  # Bx2xN

    pxpy_tgt_mask = torch.logical_and(
        torch.logical_and(pxpy_tgt_B2N[:, 0:1, :] >= 0,
                          pxpy_tgt_B2N[:, 0:1, :] <= W - 1),
        torch.logical_and(pxpy_tgt_B2N[:, 1:2, :] >= 0,
                          pxpy_tgt_B2N[:, 1:2, :] <= H - 1)
    )  # B1N

    disparity_src = torch.reciprocal(xyz_tgt_B3N[:, 2:, :])  # Bx1xN
    disparity_tgt = gather_pixel_by_pxpy(disparity_tgt, pxpy_tgt_B2N)  # Bx1xN

    depth_diff = torch.abs(disparity_src - disparity_tgt)
    return torch.mean(depth_diff[pxpy_tgt_mask])


def get_src_xyz_from_plane_disparity(meshgrid_src_homo,
                                     mpi_disparity_src,
                                     K_src_inv):
    """

    :param meshgrid_src_homo: 3xHxW
    :param mpi_disparity_src: BxS
    :param K_src_inv: Bx3x3
    :return:
    """
    B, S = mpi_disparity_src.size()
    H, W = meshgrid_src_homo.size(1), meshgrid_src_homo.size(2)
    mpi_depth_src = torch.reciprocal(mpi_disparity_src)  # BxS

    K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).reshape(B * S, 3, 3)

    # 3xHxW -> BxSx3xHxW
    meshgrid_src_homo = meshgrid_src_homo.unsqueeze(0).unsqueeze(1).repeat(B, S, 1, 1, 1)
    meshgrid_src_homo_Bs3N = meshgrid_src_homo.reshape(B * S, 3, -1)
    xyz_src = torch.matmul(K_src_inv_Bs33, meshgrid_src_homo_Bs3N)  # BSx3xHW
    xyz_src = xyz_src.reshape(B, S, 3, H * W) * mpi_depth_src.unsqueeze(2).unsqueeze(3)  # BxSx3xHW
    xyz_src_BS3HW = xyz_src.reshape(B, S, 3, H, W)

    return xyz_src_BS3HW


def get_tgt_xyz_from_plane_disparity(xyz_src_BS3HW,
                                     G_tgt_src):
    """

    :param xyz_src_BS3HW: BxSx3xHxW
    :param G_tgt_src: Bx4x4
    :return:
    """
    B, S, _, H, W = xyz_src_BS3HW.size()
    G_tgt_src_Bs33 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).reshape(B*S, 4, 4)
    xyz_tgt = transform_G_xyz(G_tgt_src_Bs33, xyz_src_BS3HW.reshape(B*S, 3, H*W))  # Bsx3xHW
    xyz_tgt_BS3HW = xyz_tgt.reshape(B, S, 3, H, W)  # BxSx3xHxW
    return xyz_tgt_BS3HW


def render_tgt_rgb_depth(H_sampler: HomographySample,
                         mpi_rgb_src,
                         mpi_sigma_src,
                         mpi_disparity_src,
                         xyz_tgt_BS3HW,
                         G_tgt_src,
                         K_src_inv, K_tgt,
                         use_alpha=False,
                         is_bg_depth_inf=False):
    """
    :param H_sampler:
    :param mpi_rgb_src: BxSx3xHxW
    :param mpi_sigma_src: BxSx1xHxW
    :param mpi_disparity_src: BxS
    :param xyz_tgt_BS3HW: BxSx3xHxW
    :param G_tgt_src: Bx4x4
    :param K_src_inv: Bx3x3
    :param K_tgt: Bx3x3
    :return:
    """
    B, S, _, H, W = mpi_rgb_src.size()
    mpi_depth_src = torch.reciprocal(mpi_disparity_src)  # BxS

    # note that here we concat the mpi_src with xyz_tgt, because H_sampler will sample them for tgt frame
    # mpi_src is the same in whatever frame, but xyz has to be in tgt frame
    mpi_xyz_src = torch.cat((mpi_rgb_src, mpi_sigma_src, xyz_tgt_BS3HW), dim=2)  # BxSx(3+1+3)xHxW

    # homography warping of mpi_src into tgt frame
    G_tgt_src_Bs44 = G_tgt_src.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 4, 4)  # Bsx4x4
    K_src_inv_Bs33 = K_src_inv.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3)  # Bsx3x3
    K_tgt_Bs33 = K_tgt.unsqueeze(1).repeat(1, S, 1, 1).contiguous().reshape(B*S, 3, 3)  # Bsx3x3

    # BsxCxHxW, BsxHxW
    tgt_mpi_xyz_BsCHW, tgt_mask_BsHW = H_sampler.sample(mpi_xyz_src.view(B*S, 7, H, W),
                                                        mpi_depth_src.view(B*S),
                                                        G_tgt_src_Bs44,
                                                        K_src_inv_Bs33,
                                                        K_tgt_Bs33)

    # mpi composition
    tgt_mpi_xyz = tgt_mpi_xyz_BsCHW.view(B, S, 7, H, W)
    tgt_rgb_BS3HW = tgt_mpi_xyz[:, :, 0:3, :, :]
    tgt_sigma_BS1HW = tgt_mpi_xyz[:, :, 3:4, :, :]
    tgt_xyz_BS3HW = tgt_mpi_xyz[:, :, 4:, :, :]

    tgt_mask_BSHW = tgt_mask_BsHW.view(B, S, H, W)
    tgt_mask_BSHW = torch.where(tgt_mask_BSHW,
                                torch.ones((B, S, H, W), dtype=torch.float32, device=mpi_rgb_src.device),
                                torch.zeros((B, S, H, W), dtype=torch.float32, device=mpi_rgb_src.device))

    # Bx3xHxW, Bx1xHxW, Bx1xHxW
    tgt_z_BS1HW = tgt_xyz_BS3HW[:, :, -1:]
    tgt_sigma_BS1HW = torch.where(tgt_z_BS1HW >= 0,
                                  tgt_sigma_BS1HW,
                                  torch.zeros_like(tgt_sigma_BS1HW, device=tgt_sigma_BS1HW.device))
    tgt_rgb_syn, tgt_depth_syn, _, _ = render(tgt_rgb_BS3HW, tgt_sigma_BS1HW, tgt_xyz_BS3HW,
                                              use_alpha=use_alpha,
                                              is_bg_depth_inf=is_bg_depth_inf)
    tgt_mask = torch.sum(tgt_mask_BSHW, dim=1, keepdim=True)  # Bx1xHxW

    return tgt_rgb_syn, tgt_depth_syn, tgt_mask


def predict_mpi_coarse_to_fine(mpi_predictor, src_imgs, xyz_src_BS3HW_coarse,
                               disparity_coarse_src, S_fine, is_bg_depth_inf):
    if S_fine > 0:
        with torch.no_grad():
            # predict coarse mpi
            mpi_coarse_src_list = mpi_predictor(src_imgs, disparity_coarse_src)  # BxS_coarsex4xHxW
            mpi_coarse_rgb_src = mpi_coarse_src_list[0][:, :, 0:3, :, :]  # BxSx1xHxW
            mpi_coarse_sigma_src = mpi_coarse_src_list[0][:, :, 3:, :, :]  # BxSx1xHxW
            _, _, _, weights = plane_volume_rendering(
                mpi_coarse_rgb_src,
                mpi_coarse_sigma_src,
                xyz_src_BS3HW_coarse,
                is_bg_depth_inf
            )
            weights = weights.mean((2, 3, 4)).unsqueeze(1).unsqueeze(2)

            # sample fine disparity
            disparity_fine_src = sample_pdf(disparity_coarse_src.unsqueeze(1).unsqueeze(2), weights, S_fine)
            disparity_fine_src = disparity_fine_src.squeeze(2).squeeze(1)

            # assemble coarse and fine disparity
            disparity_all_src = torch.cat((disparity_coarse_src, disparity_fine_src), dim=1) # Bx(S_coarse + S_fine)
            disparity_all_src, _ = torch.sort(disparity_all_src, dim=1, descending=True)
        mpi_all_src_list = mpi_predictor(src_imgs, disparity_all_src)  # BxS_coarsex4xHxW
        return mpi_all_src_list, disparity_all_src
    else:
        mpi_coarse_src_list = mpi_predictor(src_imgs, disparity_coarse_src)  # BxS_coarsex4xHxW
        return mpi_coarse_src_list, disparity_coarse_src
