# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import numpy as np
import cv2
import re
import copy

from PIL import Image
from kornia.utils import create_meshgrid

import math

img2mse = lambda x, y: torch.mean((x - y) ** 2)
img2mse_w = lambda x, y, w: torch.mean(w * ((x - y) ** 2))
mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device))


def load_ckpt(network, ckpt_file, key_prefix, strict=True, freeze=False, load_costreg=False, update_casmvs=False, freezeExceptTimephi=False, update_z=False):
    ckpt_dict = torch.load(ckpt_file)

    if "state_dict" in ckpt_dict.keys():
        ckpt_dict = ckpt_dict["state_dict"]

    state_dict = {}
    for key, val in ckpt_dict.items():
        if freezeExceptTimephi:
            if 'style_var' in key or 'style_mean' in key:
                continue

        if key_prefix in key:
            state_dict[key[len(key_prefix) + 1 :]] = val
            if load_costreg:
                for l in range(3):
                    if f'cost_reg_{l}' in key:
                        tmp_key = [f'cost_reg_{l}_B'] + key[len(key_prefix) + 1 :].split('.')[1:]
                        tmp_key = ".".join(tmp_key)
                        state_dict[tmp_key] = val
    network.load_state_dict(state_dict, strict)
    
    if freeze:
        name_in_ckpt = []
        for key, val in ckpt_dict.items():
            if freezeExceptTimephi:
                if 'phi' in key or 'style_var' in key or 'style_mean' in key:
                    continue
            name_in_ckpt.append(key[len(key_prefix) + 1 :])

        for name, param in network.named_parameters():
            if name in name_in_ckpt:
                param.requires_grad = False
                
    if update_z:
        if key_prefix == 'renderer':
            for name, param in network.named_parameters():
                if name == 'z' or 'mapping' in name or 'Special' in name or 'special' in name or 'style_rgb_fc' in name:#or 'style' in name or 'content' in name or 'generate_' in name or 'special' in name or 'cos' in name or 'sin' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        else:
            for name, param in network.named_parameters():
                param.requires_grad = False
            

    if update_casmvs: # freeze models except casmvs
        if key_prefix == 'renderer':
            for name, param in network.named_parameters():
                param.requires_grad = False


def init_log(log, keys):
    for key in keys:
        log[key] = torch.tensor([0.0], dtype=float)
    return log

def SL1loss_with_weight(input, target, weight, reduction="mean", beta=1.0):
    # ref: https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/smooth_l1_loss.py
    # loss = (weight * SL1(input,target)).mean() (if reduce="mean")
    n = torch.abs(input - target)
    cond = n < beta
    loss = torch.where(cond, 0.5 * (n**2) / beta, n - 0.5 * beta)
    loss *= weight
    if reduction == "mean":
        loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
    elif reduction == "sum":
        loss = loss.sum()

    return loss

class SL1Loss(nn.Module):
    def __init__(self, levels=3, have_confi=False):
        super(SL1Loss, self).__init__()
        self.levels = levels
        self.loss = nn.SmoothL1Loss(reduction="mean")
        self.loss_ray = nn.SmoothL1Loss(reduction="none")
        self.have_confi = have_confi

    def forward(self, inputs, targets, confi=None):
        loss = 0
        if isinstance(inputs, dict):
            for l in range(self.levels):
                depth_pred_l = inputs[f"level_{l}"]
                V = depth_pred_l.shape[1]

                depth_gt_l = targets[f"level_{l}"]
                depth_gt_l = depth_gt_l[:, :V]
                mask_l = depth_gt_l > 0

                if self.have_confi:
                    assert confi != None
                    loss = loss + SL1loss_with_weight(
                        depth_pred_l[mask_l], depth_gt_l[mask_l], confi[f"level_{l}"][mask_l], reduction="mean"
                    ) * 2 ** (1 - l)
                else:
                    loss = loss + self.loss(
                        depth_pred_l[mask_l], depth_gt_l[mask_l]
                    ) * 2 ** (1 - l)
        else:
            mask = targets > 0
            if self.have_confi:
                assert confi != None
                loss = loss + (SL1loss_with_weight(inputs, targets, confi, reduction="none") * mask).sum() / len(mask)
            else:
                loss = loss + (self.loss_ray(inputs, targets) * mask).sum() / len(mask)

        return loss

def gradient_x(img):
    # return img[:, :, :-1, :] - img[:, :, 1:, :]
    return img[:, :, :, :-1] - img[:, :, :, 1:]

def gradient_y(img):
    # return img[:, :-1, :, :] - img[:, 1:, :, :]
    return img[:, :, :-1, :] - img[:, :, 1:, :]

def depth_smoothness(depth, img, lambda_wt=1.0):
    # ref: https://github.com/Boese0601/RC-MVSNet/blob/c8dfe7c6c3ebadd47d70087aac776a268a478fae/losses/modules.py#L56
    """Computes image-aware depth smoothness loss."""
    # print('depth: {} img: {}'.format(depth.shape, img.shape))
    depth_dx = gradient_x(depth)
    depth_dy = gradient_y(depth)
    image_dx = gradient_x(img)
    image_dy = gradient_y(img)
    weights_x = torch.exp(-(lambda_wt * torch.mean(torch.abs(image_dx), 3, keepdim=True)))
    weights_y = torch.exp(-(lambda_wt * torch.mean(torch.abs(image_dy), 3, keepdim=True)))
    # print('depth_dx: {} weights_x: {}'.format(depth_dx.shape, weights_x.shape))
    # print('depth_dy: {} weights_y: {}'.format(depth_dy.shape, weights_y.shape))
    smoothness_x = depth_dx * weights_x
    smoothness_y = depth_dy * weights_y
    return torch.mean(torch.abs(smoothness_x)) + torch.mean(torch.abs(smoothness_y))

def self_supervision_loss(
    loss_fn,
    rays_pixs,
    rendered_depth,
    depth_map,
    rays_gt_rgb,
    unpre_imgs,
    rendered_rgb,
    intrinsics,
    c2ws,
    w2cs,
    confi=None
):
    loss = 0
    target_points = torch.stack(
        [rays_pixs[1], rays_pixs[0], torch.ones(rays_pixs[0].shape[0]).cuda()], dim=-1
    )
    target_points = rendered_depth.view(-1, 1) * (
        target_points @ torch.inverse(intrinsics[0, -1]).t()
    )
    target_points = target_points @ c2ws[0, -1][:3, :3].t() + c2ws[0, -1][:3, 3]

    rgb_mask = (rendered_rgb - rays_gt_rgb).abs().mean(dim=-1) < 0.02

    for v in range(len(w2cs[0]) - 1):
        points_v = target_points @ w2cs[0, v][:3, :3].t() + w2cs[0, v][:3, 3]
        points_v = points_v @ intrinsics[0, v].t()
        z_pred = points_v[:, -1].clone()
        points_v = points_v[:, :2] / points_v[:, -1:]

        points_unit = points_v.clone()
        H, W = depth_map["level_0"].shape[-2:]
        points_unit[:, 0] = points_unit[:, 0] / W
        points_unit[:, 1] = points_unit[:, 1] / H
        grid = 2 * points_unit - 1

        warped_rgbs = F.grid_sample(
            unpre_imgs[:, v],
            grid.view(1, -1, 1, 2),
            align_corners=True,
            mode="bilinear",
            padding_mode="zeros",
        ).squeeze()
        photo_mask = (warped_rgbs.t() - rays_gt_rgb).abs().mean(dim=-1) < 0.02

        pixel_coor = points_v.round().long()
        k = 5
        pixel_coor[:, 0] = pixel_coor[:, 0].clip(k // 2, W - (k // 2) - 1)
        pixel_coor[:, 1] = pixel_coor[:, 1].clip(2, H - (k // 2) - 1)
        lower_b = pixel_coor - (k // 2)
        higher_b = pixel_coor + (k // 2)

        ind_h = (
            lower_b[:, 1:] * torch.arange(k - 1, -1, -1).view(1, -1).cuda()
            + higher_b[:, 1:] * torch.arange(0, k).view(1, -1).cuda()
        ) // (k - 1)
        ind_w = (
            lower_b[:, 0:1] * torch.arange(k - 1, -1, -1).view(1, -1).cuda()
            + higher_b[:, 0:1] * torch.arange(0, k).view(1, -1).cuda()
        ) // (k - 1)

        patches_h = torch.gather(
            unpre_imgs[:, v].mean(dim=1).expand(ind_h.shape[0], -1, -1),
            1,
            ind_h.unsqueeze(-1).expand(-1, -1, W),
        )
        patches = torch.gather(patches_h, 2, ind_w.unsqueeze(1).expand(-1, k, -1))
        ent_mask = patches.view(-1, k * k).std(dim=-1) > 0.05

        for l in range(3):
            depth = F.grid_sample(
                depth_map[f"level_{l}"][:, v : v + 1],
                grid.view(1, -1, 1, 2),
                align_corners=True,
                mode="bilinear",
                padding_mode="zeros",
            ).squeeze()
            in_mask = (grid > -1.0) * (grid < 1.0)
            in_mask = (in_mask[..., 0] * in_mask[..., 1]).float()
            if confi != None:
                confi_l = F.grid_sample(
                    confi[f"level_{l}"][:, v : v + 1],
                    grid.view(1, -1, 1, 2),
                    align_corners=True,
                    mode="bilinear",
                    padding_mode="zeros",
                ).squeeze()
                loss = loss + loss_fn(
                    depth, z_pred * in_mask * photo_mask * ent_mask * rgb_mask, confi_l
                ) * 2 ** (1 - l)
            else:
                loss = loss + loss_fn(
                    depth, z_pred * in_mask * photo_mask * ent_mask * rgb_mask
                ) * 2 ** (1 - l)
    loss = loss / (len(w2cs[0]) - 1)

    return loss

def unified_focal_loss(prob_volume, depth_values, interval, depth_gt, mask, weight, gamma, alpha):
    depth_gt_volume = depth_gt.unsqueeze(1).expand_as(depth_values)  # (b, d, h, w)

    gt_index_volume = ((depth_values <= depth_gt_volume) * ((depth_values + interval) > depth_gt_volume))

    gt_unity_index_volume = torch.zeros_like(prob_volume, requires_grad=False)
    gt_unity_index_volume[gt_index_volume] = 1.0 - (depth_gt_volume[gt_index_volume] - depth_values[gt_index_volume]) / interval

    gt_unity, _ = torch.max(gt_unity_index_volume, dim=1, keepdim=True)
    gt_unity = torch.where(gt_unity > 0.0, gt_unity, torch.ones_like(gt_unity))  # (b, 1, h, w)
    pos_weight = (sigmoid((gt_unity - prob_volume).abs() / gt_unity, base=5) - 0.5) * 4 + 1  # [1, 3]
    neg_weight = (sigmoid(prob_volume / gt_unity, base=5) - 0.5) * 2  # [0, 1]
    focal_weight = pos_weight.pow(gamma) * (gt_unity_index_volume > 0.0).float() + alpha * neg_weight.pow(gamma) * (
            gt_unity_index_volume <= 0.0).float()

    mask = mask.unsqueeze(1).expand_as(depth_values).float()
    loss = (F.binary_cross_entropy_with_logits(prob_volume, gt_unity_index_volume, reduction="none") * focal_weight * mask).sum() / mask.sum()
    loss = loss * weight
    return loss

def sigmoid(x, base=2.71828):
    return 1 / (1 + torch.pow(base, -x))

def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET, use_ori_cmap=False):
    if type(depth) is not np.ndarray:
        depth = depth.cpu().numpy()

    x = np.nan_to_num(depth)  # change nan to 0
    if minmax is None:
        mi = np.min(x[x > 0])  # get minimum positive depth (ignore background)
        ma = np.max(x)
    else:
        mi, ma = minmax

    x = (x - mi) / (ma - mi + 1e-8)  # normalize to 0~1
    x = (255 * x).astype(np.uint8)

    if not use_ori_cmap:
        x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
    else:
        x_ = Image.fromarray(x) # ->(1, H, W)
    x_ = T.ToTensor()(x_)  # (3, H, W)

    return x_, [mi, ma]


def abs_error(depth_pred, depth_gt, mask):
    depth_pred, depth_gt = depth_pred[mask], depth_gt[mask]
    err = depth_pred - depth_gt
    return np.abs(err) if type(depth_pred) is np.ndarray else err.abs()


def acc_threshold(depth_pred, depth_gt, mask, threshold):
    errors = abs_error(depth_pred, depth_gt, mask)
    acc_mask = errors < threshold
    return (
        acc_mask.astype("float") if type(depth_pred) is np.ndarray else acc_mask.float()
    )


# Ray helpers
def get_rays(
    H,
    W,
    intrinsics_target,
    c2w_target,
    chunk=-1,
    chunk_id=-1,
    train=True,
    train_batch_size=-1,
    mask=None,
    points_2d=None,
    train_patch=False,
    whole_img=False
):
    if points_2d == None:
        if train:
            if whole_img:
                ys, xs = torch.meshgrid(
                    torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)
                )  # pytorch's meshgrid has indexing='ij'
                ys, xs = ys.cuda().reshape(-1), xs.cuda().reshape(-1)
                if chunk > 0:
                    ys, xs = (
                        ys[chunk_id * chunk : (chunk_id + 1) * chunk],
                        xs[chunk_id * chunk : (chunk_id + 1) * chunk],
                    )

            elif train_patch:
                width = int(train_batch_size**0.5)
                ### not use: since it will have decimal number
                # x_center, y_center = (
                #     torch.randint(int(0 + width//2), int(W - width//2), (1,)).float().cuda(),
                #     torch.randint(int(0 + width//2), int(H - width//2), (1,)).float().cuda(),
                # )
                # x_center, y_center = x_center[0], y_center[0]
                # ys, xs = torch.meshgrid(
                #     torch.linspace(y_center - width//2, y_center + width//2, width), torch.linspace(x_center - width//2, x_center + width//2, width)
                # )
                x_topleft, y_topleft = (
                    torch.randint(0, int(W - (width-1)), (1,)).float().cuda(),
                    torch.randint(0, int(H - (width-1)), (1,)).float().cuda(),
                )
                x_topleft, y_topleft = x_topleft[0], y_topleft[0]
                ys, xs = torch.meshgrid(
                    torch.linspace(y_topleft, y_topleft + (width-1), width), torch.linspace(x_topleft, x_topleft + (width-1), width)
                )
                ys, xs = ys.cuda().reshape(-1), xs.cuda().reshape(-1)
            else:
                if mask is None:
                    xs, ys = (
                        torch.randint(0, W, (train_batch_size,)).float().cuda(),
                        torch.randint(0, H, (train_batch_size,)).float().cuda(),
                    )
                else:  
                    ## directly sample points from mask(feasible) points
                    mask_points_coord = mask.nonzero()
                    if len(mask_points_coord) >= train_batch_size:
                        rand_idx = torch.randperm(len(mask_points_coord))[:train_batch_size].cuda()
                    else:
                        rand_idx = torch.randperm(len(mask_points_coord)).cuda()
                        rand_idx_2 = torch.randint(len(mask_points_coord),(train_batch_size-len(mask_points_coord),))
                        rand_idx = torch.cat((rand_idx,rand_idx_2))
                    ys_xs = mask_points_coord[rand_idx]
                    ys, xs = ys_xs[:,0], ys_xs[:,1]

                    ## Sample 8 times more points to get mask points as much as possible
                    # xs, ys = (
                    #     torch.randint(0, W, (8 * train_batch_size,)).float().cuda(),
                    #     torch.randint(0, H, (8 * train_batch_size,)).float().cuda(),
                    # )
                    # masked_points = mask[ys.long(), xs.long()]
                    # xs_, ys_ = xs[~masked_points], ys[~masked_points]
                    # xs, ys = xs[masked_points], ys[masked_points]
                    # xs, ys = torch.cat([xs, xs_]), torch.cat([ys, ys_])
                    # xs, ys = xs[:train_batch_size], ys[:train_batch_size]
        else:
            ys, xs = torch.meshgrid(
                torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)
            )  # pytorch's meshgrid has indexing='ij'
            ys, xs = ys.cuda().reshape(-1), xs.cuda().reshape(-1)
            if chunk > 0:
                ys, xs = (
                    ys[chunk_id * chunk : (chunk_id + 1) * chunk],
                    xs[chunk_id * chunk : (chunk_id + 1) * chunk],
                )
    else:
        ys, xs = points_2d[0], points_2d[1]

    dirs = torch.stack(
        [
            (xs - intrinsics_target[0, 2]) / intrinsics_target[0, 0],
            (ys - intrinsics_target[1, 2]) / intrinsics_target[1, 1],
            torch.ones_like(xs),
        ],
        -1,
    )  # use 1 instead of -1

    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_dir = (
        dirs @ c2w_target[:3, :3].t()
    )  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    rays_orig = c2w_target[:3, -1].clone().reshape(1, 3).expand(rays_dir.shape[0], -1)

    rays_pixs = torch.stack((ys, xs))  # row col

    return rays_orig, rays_dir, rays_pixs


def conver_to_ndc(ray_pts, w2c_ref, intrinsics_ref, W_H, depth_values):
    nb_rays, nb_samples = ray_pts.shape[:2]
    ray_pts = ray_pts.reshape(-1, 3)

    R = w2c_ref[:3, :3]  # (3, 3)
    T = w2c_ref[:3, 3:]  # (3, 1)
    ray_pts = torch.matmul(ray_pts, R.t()) + T.reshape(1, 3)

    ray_pts_ndc = ray_pts @ intrinsics_ref.t()
    ray_pts_ndc[:, :2] = ray_pts_ndc[:, :2] / (
        ray_pts_ndc[:, -1:] * W_H.reshape(1, 2)
    )  # normalize x,y to 0~1

    grid = ray_pts_ndc[None, None, :, :2] * 2 - 1
    near = F.grid_sample(
        depth_values[:, :1],
        grid,
        align_corners=True,
        mode="bilinear",
        padding_mode="border",
    ).squeeze()
    far = F.grid_sample(
        depth_values[:, -1:],
        grid,
        align_corners=True,
        mode="bilinear",
        padding_mode="border",
    ).squeeze()
    ray_pts_ndc[:, 2] = (ray_pts_ndc[:, 2] - near) / (far - near)  # normalize z to 0~1
    if (far-near)[(far-near)<1e-4].sum() > 0:
        print((far-near)[(far-near)<1e-4])

    ray_pts_ndc = ray_pts_ndc.view(nb_rays, nb_samples, 3)

    return ray_pts_ndc

def aggreate_inputs_near_far(input_depths, w2cs, Ks):

    return target_near_far

def get_sample_points(
    nb_coarse,
    nb_fine,
    near,
    far,
    rays_o,
    rays_d,
    nb_views,
    w2cs,
    intrinsics,
    depth_values,
    W_H,
    with_noise=False,
    gene_mask="None",
    rays_gt_depth=None,
    target_space=False,
    specify_depth=None,
):
    device = rays_o.device
    nb_rays = rays_o.shape[0]

    with torch.no_grad():
        t_vals = torch.linspace(0.0, 1.0, steps=nb_coarse).view(1, nb_coarse).to(device)
        pts_depth = near * (1.0 - t_vals) + far * (t_vals)
        pts_depth = pts_depth.expand([nb_rays, nb_coarse])
        ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1)

        ## Counting the number of source views for which the points are valid
        valid_points = torch.zeros([nb_rays, nb_coarse]).to(device)
        for idx in range(nb_views):
            w2c_ref, intrinsic_ref = w2cs[0, idx], intrinsics[0, idx]
            ray_pts_ndc = conver_to_ndc(
                ray_pts,
                w2c_ref,
                intrinsic_ref,
                W_H,
                depth_values=depth_values[f"level_0"][:, idx],
            )
            valid_points += (
                ((ray_pts_ndc >= 0) & (ray_pts_ndc <= 1)).sum(dim=-1) == 3
            ).float()

        ## Creating a distribution based on the counted values and sample more points
        if nb_fine > 0:
            point_distr = torch.distributions.categorical.Categorical(
                logits=valid_points
            )
            t_vals = (
                point_distr.sample([nb_fine]).t()
                - torch.rand([nb_rays, nb_fine]).cuda()
            ) / (nb_coarse - 1)
            pts_depth_fine = near * (1.0 - t_vals) + far * (t_vals)

            pts_depth = torch.cat([pts_depth, pts_depth_fine], dim=-1)
            pts_depth, _ = torch.sort(pts_depth)

        if with_noise:  ## Add noise to sample points during training
            # get intervals between samples
            mids = 0.5 * (pts_depth[..., 1:] + pts_depth[..., :-1])
            upper = torch.cat([mids, pts_depth[..., -1:]], -1)
            lower = torch.cat([pts_depth[..., :1], mids], -1)
            # stratified samples in those intervals
            t_rand = torch.rand(pts_depth.shape, device=device)
            pts_depth = lower + (upper - lower) * t_rand

        ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1)

        if gene_mask == "one_pt":
            pts_depth = rays_gt_depth.unsqueeze(-1).repeat(1,pts_depth.shape[1])
            ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1)
        if specify_depth != None:
            pts_depth = specify_depth.unsqueeze(-1).repeat(1,pts_depth.shape[1])
            ray_pts = rays_o.unsqueeze(1) + pts_depth.unsqueeze(-1) * rays_d.unsqueeze(1)

        ray_pts_ndc = {"level_0": [], "level_1": [], "level_2": []}
        if target_space:
            assert target_space == False, "target space ndc not done!"
            # w2c_ref, intrinsic_ref = w2cs[0, -1], intrinsics[0, -1]
            # for l in range(3):
            #     target_neat_far_l = aggreate_inputs_near_far(depth_values[f"level_{l}"], w2cs, intrinsics)
            #     print(depth_values[f"level_{l}"][:, idx].shape)
            #     ray_pts_ndc[f"level_{l}"].append(
            #         conver_to_ndc(
            #             ray_pts,
            #             w2c_ref,
            #             intrinsic_ref,
            #             W_H,
            #             depth_values=target_neat_far_l,
            #         )
            #     )
            # for l in range(3):
            #     ray_pts_ndc[f"level_{l}"] = torch.stack(ray_pts_ndc[f"level_{l}"], dim=2)
            #     print(ray_pts_ndc[f"level_{l}"].shape)
            # a

        else:
            for idx in range(nb_views):
                w2c_ref, intrinsic_ref = w2cs[0, idx], intrinsics[0, idx]
                for l in range(3):
                    ray_pts_ndc[f"level_{l}"].append(
                        conver_to_ndc(
                            ray_pts,
                            w2c_ref,
                            intrinsic_ref,
                            W_H,
                            depth_values=depth_values[f"level_{l}"][:, idx],
                        )
                    )
            for l in range(3):
                ray_pts_ndc[f"level_{l}"] = torch.stack(ray_pts_ndc[f"level_{l}"], dim=2)

        return pts_depth, ray_pts, ray_pts_ndc

def gene_vis_mask(depth, c2ws, w2cs, K, H, W):
    # depth: (nb_views+1, H, W)
    # 1: visible

    with torch.no_grad():
        # unproject novel view pixel points to 3d world space by depth GT
        c2w_tgt, K_tgt = c2ws[-1], K[-1]
        depth_tgt = depth[-1]

        _xs, _ys = torch.arange(0, W).cuda(), torch.arange(0, H).cuda()
        xs, ys = torch.meshgrid(_xs, _ys)
        dirs = torch.stack(
            [
                (xs - K_tgt[0, 2]) / K_tgt[0, 0],
                (ys - K_tgt[1, 2]) / K_tgt[1, 1],
                torch.ones_like(xs),
            ],
            -1,
        )
        rays_d_tgt = (
            dirs @ c2w_tgt[:3, :3].t()
        )  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
        rays_d_tgt = rays_d_tgt.permute(1, 0, 2)
        rays_o_tgt = c2w_tgt[:3, -1].clone().reshape(1, 1, 3).repeat(H, W, 1)

        tgt_pts_3d = rays_o_tgt + depth_tgt.unsqueeze(-1) * rays_d_tgt # (H,W,3)
        
        # compute distance between tgt_pts_3d and cam_o of input views in 3d world coord.
        vis_mask_all = [] # 1: vis
        for v in range(len(c2ws)-1):
            c2w_v, depth_v = c2ws[v], depth[v]
            cam_o_v = c2w_v[:3, -1].clone().reshape(1, 1, 3).repeat(H, W, 1)

            dist_v = torch.linalg.norm(tgt_pts_3d-cam_o_v, dim=2)
            vis_mask_all.append((dist_v-depth_v)<0.2)

        vis_mask_all = torch.stack(vis_mask_all)
        vis_mask = torch.sum(vis_mask_all, dim=0)
        vis_mask = vis_mask > 0

    return vis_mask

def get_rays_pts(
    H,
    W,
    c2ws,
    w2cs,
    intrinsics,
    near_fars,
    depth_values,
    nb_coarse,
    nb_fine,
    nb_views,
    chunk=-1,
    chunk_idx=-1,
    train=False,
    train_batch_size=-1,
    target_img=None,
    target_depth=None,
    points_2d=None,
    gene_mask="None",
    train_patch=False,
    train_vis_novel=False,
    depth_gt=None,
    output_mask=False,
    style_novel=None,
    inputWhole_cs_loss=False,
    pred_depth_map=None,
    new_c2ws=None
):
    other_output = {}
    if train:
        if target_depth.sum() > 0:
            depth_mask = target_depth > 0
        else:
            depth_mask = None

        if train_vis_novel:
            vis_mask = gene_vis_mask(depth_gt, c2ws[0], w2cs[0], intrinsics[0], H, W)
            depth_mask = depth_mask * vis_mask

        if output_mask:
            return depth_mask

    else:
        if output_mask:
            depth_mask = target_depth > 0
            vis_mask = gene_vis_mask(depth_gt, c2ws[0], w2cs[0], intrinsics[0], H, W)
            depth_mask = depth_mask * vis_mask
            return depth_mask

        depth_mask = None

    if inputWhole_cs_loss:
        rays_orig, rays_dir, rays_pixs = get_rays(
            H,
            W,
            intrinsics[0, 0],
            c2ws[0, 0],
            chunk=chunk,
            chunk_id=chunk_idx,
            train=train,
            train_batch_size=train_batch_size,
            mask=depth_mask,
            points_2d=points_2d,
            train_patch=train_patch,
            whole_img=True
        )
    else:
        input_cws = new_c2ws  if new_c2ws is not None else c2ws[0, -1]

        rays_orig, rays_dir, rays_pixs = get_rays(
            H,
            W,
            intrinsics[0, -1],
            input_cws,
            chunk=chunk,
            chunk_id=chunk_idx,
            train=train,
            train_batch_size=train_batch_size,
            mask=depth_mask,
            points_2d=points_2d,
            train_patch=train_patch,
        )

    ## Extracting ground truth color and depth of target view
    if train:
        rays_pixs_int = rays_pixs.long()
        rays_gt_rgb = target_img[:, rays_pixs_int[0], rays_pixs_int[1]].permute(1, 0)
        rays_gt_depth = target_depth[rays_pixs_int[0], rays_pixs_int[1]]
        if style_novel != None:
            rays_pseudo_style_rgb = style_novel[:, rays_pixs_int[0], rays_pixs_int[1]].permute(1, 0)
            other_output["rays_pseudo_style_rgb"] = rays_pseudo_style_rgb
        if inputWhole_cs_loss:
            assert pred_depth_map != None
            rays_pred_depth = pred_depth_map[rays_pixs_int[0], rays_pixs_int[1]]

    elif gene_mask != "None":
        rays_pixs_int = rays_pixs.long()
        rays_gt_rgb = None
        rays_gt_depth = target_depth[rays_pixs_int[0], rays_pixs_int[1]]
    else:
        rays_gt_rgb = None
        rays_gt_depth = None

    # travel along the rays
    near, far = near_fars[0, -1, 0], near_fars[0, -1, 1]  ## near/far of the target view
    W_H = torch.tensor([W - 1, H - 1]).cuda()
    pts_depth, ray_pts, ray_pts_ndc = get_sample_points(
        nb_coarse,
        nb_fine,
        near,
        far,
        rays_orig,
        rays_dir,
        nb_views,
        w2cs,
        intrinsics,
        depth_values,
        W_H,
        with_noise=train,
        gene_mask=gene_mask,
        rays_gt_depth=rays_gt_depth,
        specify_depth=rays_pred_depth if inputWhole_cs_loss else None,
    )

    return (
        pts_depth,
        ray_pts,
        ray_pts_ndc,
        rays_dir,
        rays_gt_rgb,
        rays_gt_depth,
        rays_pixs,
        rays_orig,
        other_output,
    )


def normal_vect(vect, dim=-1):
    return vect / (torch.sqrt(torch.sum(vect**2, dim=dim, keepdim=True)) + 1e-7)


def interpolate_3D(feats, pts_ndc):
    H, W = pts_ndc.shape[-3:-1]
    grid = pts_ndc.view(-1, 1, H, W, 3) * 2 - 1.0  # [1 1 H W 3] (x,y,z)
    features = (
        F.grid_sample(
            feats, grid, align_corners=True, mode="bilinear", padding_mode="border"
        )[:, :, 0]
        .permute(2, 3, 0, 1)
        .squeeze()
    )

    return features


def interpolate_2D(feats, imgs, pts_ndc):
    H, W = pts_ndc.shape[-3:-1]
    grid = pts_ndc[..., :2].view(-1, H, W, 2) * 2 - 1.0  # [1 H W 2] (x,y)
    features = (
        F.grid_sample(
            feats, grid, align_corners=True, mode="bilinear", padding_mode="border"
        )
        .permute(2, 3, 1, 0)
        .squeeze()
    )
    images = (
        F.grid_sample(
            imgs, grid, align_corners=True, mode="bilinear", padding_mode="border"
        )
        .permute(2, 3, 1, 0)
        .squeeze()
    )
    with torch.no_grad():
        in_mask = (grid > -1.0) * (grid < 1.0)
        in_mask = (in_mask[..., 0] * in_mask[..., 1]).float().permute(1, 2, 0)

    return features, images, in_mask


def read_pfm(filename):
    file = open(filename, "rb")
    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().decode("utf-8").rstrip()
    if header == "PF":
        color = True
    elif header == "Pf":
        color = False
    else:
        raise Exception("Not a PFM file.")

    dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8"))
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception("Malformed PFM header.")

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = "<"
        scale = -scale
    else:
        endian = ">"  # big-endian

    data = np.fromfile(file, endian + "f")
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    file.close()
    return data, scale


def homo_warp(src_feat, proj_mat, depth_values, src_grid=None, pad=0):
    if src_grid == None:
        B, C, H, W = src_feat.shape
        device = src_feat.device

        if pad > 0:
            H_pad, W_pad = H + pad * 2, W + pad * 2
        else:
            H_pad, W_pad = H, W

        if depth_values.dim() != 4:
            depth_values = depth_values[..., None, None].repeat(1, 1, H_pad, W_pad)
        D = depth_values.shape[1]

        R = proj_mat[:, :, :3]  # (B, 3, 3)
        T = proj_mat[:, :, 3:]  # (B, 3, 1)
        # create grid from the ref frame
        ref_grid = create_meshgrid(
            H_pad, W_pad, normalized_coordinates=False, device=device
        )  # (1, H, W, 2)
        if pad > 0:
            ref_grid -= pad

        ref_grid = ref_grid.permute(0, 3, 1, 2)  # (1, 2, H, W)
        ref_grid = ref_grid.reshape(1, 2, W_pad * H_pad)  # (1, 2, H*W)
        ref_grid = ref_grid.expand(B, -1, -1)  # (B, 2, H*W)
        ref_grid = torch.cat(
            (ref_grid, torch.ones_like(ref_grid[:, :1])), 1
        )  # (B, 3, H*W)
        ref_grid_d = ref_grid.repeat(1, 1, D)  # (B, 3, D*H*W)
        src_grid_d = R @ ref_grid_d + T / depth_values.reshape(B, 1, D * W_pad * H_pad)
        del ref_grid_d, ref_grid, proj_mat, R, T, depth_values  # release (GPU) memory

        src_grid = (
            src_grid_d[:, :2] / src_grid_d[:, 2:]
        )  # divide by depth (B, 2, D*H*W)
        del src_grid_d
        src_grid[:, 0] = src_grid[:, 0] / ((W - 1) / 2) - 1  # scale to -1~1
        src_grid[:, 1] = src_grid[:, 1] / ((H - 1) / 2) - 1  # scale to -1~1
        src_grid = src_grid.permute(0, 2, 1)  # (B, D*H*W, 2)
        src_grid = src_grid.view(B, D, W_pad, H_pad, 2)

    B, D, W_pad, H_pad = src_grid.shape[:4]
    warped_src_feat = F.grid_sample(
        src_feat,
        src_grid.view(B, D, W_pad * H_pad, 2),
        mode="bilinear",
        padding_mode="zeros",
        align_corners=True,
    )  # (B, C, D, H*W)
    warped_src_feat = warped_src_feat.view(B, -1, D, H_pad, W_pad)
    # src_grid = src_grid.view(B, 1, D, H_pad, W_pad, 2)
    return warped_src_feat, src_grid

##### Functions for view selection
TINY_NUMBER = 1e-5  # float32 only has 7 decimal digits precision

def angular_dist_between_2_vectors(vec1, vec2):
    vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER)
    vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER)
    angular_dists = np.arccos(
        np.clip(np.sum(vec1_unit * vec2_unit, axis=-1), -1.0, 1.0)
    )
    return angular_dists


def batched_angular_dist_rot_matrix(R1, R2):
    assert (
        R1.shape[-1] == 3
        and R2.shape[-1] == 3
        and R1.shape[-2] == 3
        and R2.shape[-2] == 3
    )
    return np.arccos(
        np.clip(
            (np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1)
            / 2.0,
            a_min=-1 + TINY_NUMBER,
            a_max=1 - TINY_NUMBER,
        )
    )


def get_nearest_pose_ids(
    tar_pose,
    ref_poses,
    num_select,
    must_select_2=False,
    tar_id=-1,
    angular_dist_method="dist",
    scene_center=(0, 0, 0),
    second_close_step=1,
    th=None
):
    num_cams = len(ref_poses)
    if num_cams != 2:
        num_select = min(num_select, num_cams - 1)
    if must_select_2:
        num_select = 2
    batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0)

    if angular_dist_method == "matrix":
        dists = batched_angular_dist_rot_matrix(
            batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3]
        )
    elif angular_dist_method == "vector":
        tar_cam_locs = batched_tar_pose[:, :3, 3]
        ref_cam_locs = ref_poses[:, :3, 3]
        scene_center = np.array(scene_center)[None, ...]
        tar_vectors = tar_cam_locs - scene_center
        ref_vectors = ref_cam_locs - scene_center
        dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors)
    elif angular_dist_method == "dist":
        tar_cam_locs = batched_tar_pose[:, :3, 3]
        ref_cam_locs = ref_poses[:, :3, 3]
        dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1)
    elif angular_dist_method == "abs":
        tar_cam_locs = batched_tar_pose[:, :3, 3]
        ref_cam_locs = ref_poses[:, :3, 3]
        dists = np.sum(np.abs(tar_cam_locs - ref_cam_locs), axis=1)
    else:
        raise Exception("unknown angular distance calculation method!")

    if tar_id >= 0:
        assert tar_id < num_cams
        dists[tar_id] = 1e3  # make sure not to select the target id itself

    sorted_ids = np.argsort(dists)[::second_close_step]
    selected_ids = sorted_ids[:num_select]
    if th != None:
        if th >= len(selected_ids): th = -1
        selected_ids = selected_ids[th]

    return selected_ids

def rgb2ycbcr(image):
    if not torch.is_tensor(image):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
    if len(image.shape) < 3 or image.shape[-3] != 3:
        raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))

    r: torch.Tensor = image[..., 0, :, :]
    g: torch.Tensor = image[..., 1, :, :]
    b: torch.Tensor = image[..., 2, :, :]

    delta = .5
    y: torch.Tensor = .299 * r + .587 * g + .114 * b
    cb: torch.Tensor = (b - y) * .564 + delta
    cr: torch.Tensor = (r - y) * .713 + delta
    return torch.stack((y, cb, cr), -3)

def alter_batch(batch, cycle=False, other=None, tgt_view_id=-1):

    new_batch = copy.deepcopy(batch)

    def new_close_idx(c2ws):
        closest_idxs = []
        for pose in c2ws[:-1]:
            closest_idxs.append(
                get_nearest_pose_ids(
                    pose, ref_poses=c2ws[:-1], num_select=5, angular_dist_method="dist"
                )
            )
        closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1], ref_poses=c2ws[:], num_select=len(c2ws[:-1])-1, angular_dist_method="dist"
            )
        )
        closest_idxs = np.stack(closest_idxs, axis=0)
        closest_idxs = torch.from_numpy(closest_idxs).unsqueeze(0)

        second_closest_idxs = []
        for pose in c2ws[:-1]:
            second_closest_idxs.append(
                get_nearest_pose_ids(
                    pose, ref_poses=c2ws[:-1], num_select=5, angular_dist_method="dist", second_close_step=2
                )
            )
        second_closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1], ref_poses=c2ws[:], num_select=len(c2ws[:-1])-1, angular_dist_method="dist", second_close_step=2
            )
        )
        second_closest_idxs = np.stack(second_closest_idxs, axis=0)
        second_closest_idxs = torch.from_numpy(second_closest_idxs).unsqueeze(0)

        return closest_idxs, second_closest_idxs

    if cycle:
        rendered_rgb = other["rendered_rgb"] # (bs, 3)
        rays_pixs = other["rays_pixs"].long() # (2, bs)
        points_2d_v0 = other["points_2d_v0"] # (2, bs)
        
        novel_img = copy.deepcopy(batch["images"][:, -1])
        transform_n = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        rendered_rgb = transform_n(rendered_rgb.permute(1,0).unsqueeze(-1)).squeeze()
        novel_img[0, :, rays_pixs[0], rays_pixs[1]] = rendered_rgb
        new_batch["images"][:, 0], new_batch["images"][:, -1] = novel_img, copy.deepcopy(batch["images"][:, 0])
        if isinstance(new_batch["depths"], dict):
            for l in range(3):
                new_batch["depths"][f"level_{l}"][:, 0], new_batch["depths"][f"level_{l}"][:, -1] = copy.deepcopy(batch["depths"][f"level_{l}"][:, -1]), copy.deepcopy(batch["depths"][f"level_{l}"][:, 0])
            new_batch["depths_h"][:, 0], new_batch["depths_h"][:, -1] = copy.deepcopy(batch["depths_h"][:, -1]), copy.deepcopy(batch["depths_h"][:, 0])
            new_batch["depths_aug"][:, 0], new_batch["depths_aug"][:, -1] = copy.deepcopy(batch["depths_aug"][:, -1]), copy.deepcopy(batch["depths_aug"][:, 0])
        for k in batch:
            if k not in ["images","depths","depths_h","depths_aug","closest_idxs","second_closest_idxs"]:
                new_batch[k][:, 0], new_batch[k][:, -1] = copy.deepcopy(batch[k][:, -1]), copy.deepcopy(batch[k][:, 0])
        
        new_batch["closest_idxs"], new_batch["second_closest_idxs"] = new_close_idx(new_batch["c2ws"][0].detach().cpu().numpy())

    elif tgt_view_id != -1:
        new_batch["images"][:, -1] = copy.deepcopy(batch["images"][:, tgt_view_id])
        if isinstance(new_batch["depths"], dict):
            for l in range(3):
                new_batch["depths"][f"level_{l}"][:, -1] = copy.deepcopy(batch["depths"][f"level_{l}"][:, tgt_view_id])
            new_batch["depths_h"][:, -1] = copy.deepcopy(batch["depths_h"][:, tgt_view_id])
            new_batch["depths_aug"][:, -1] = copy.deepcopy(batch["depths_aug"][:, tgt_view_id])
        for k in batch:
            if k not in ["images","depths","depths_h","depths_aug","closest_idxs","second_closest_idxs"]:
                new_batch[k][:, -1] = copy.deepcopy(batch[k][:, tgt_view_id])
        
        new_batch["closest_idxs"], new_batch["second_closest_idxs"] = new_close_idx(new_batch["c2ws"][0].detach().cpu().numpy())

    return new_batch

def inverse_sigmoid(x):
    x[x<=0.5] += 1e-5
    x[x>0.5] -= 1e-5
    return -torch.log((1 / x) - 1)

from torchvision import models
class VGG16_perceptual(torch.nn.Module):
        def __init__(self, requires_grad=False):
            super(VGG16_perceptual, self).__init__()
            vgg_pretrained_features = models.vgg16(pretrained=True).features
            self.slice1 = torch.nn.Sequential()
            self.slice2 = torch.nn.Sequential()
            self.slice3 = torch.nn.Sequential()
            self.slice4 = torch.nn.Sequential()
            for x in range(2):
                self.slice1.add_module(str(x), vgg_pretrained_features[x])
            for x in range(2, 4):
                self.slice2.add_module(str(x), vgg_pretrained_features[x])
            for x in range(4, 14):
                self.slice3.add_module(str(x), vgg_pretrained_features[x])
            for x in range(14, 21):
                self.slice4.add_module(str(x), vgg_pretrained_features[x])
            if not requires_grad:
                for param in self.parameters():
                    param.requires_grad = False

        def forward(self, X):
            h = self.slice1(X)
            h_relu1_1 = h
            h = self.slice2(h)
            h_relu1_2 = h
            h = self.slice3(h)
            h_relu3_2 = h
            h = self.slice4(h)
            h_relu4_2 = h
            return h_relu1_1, h_relu1_2, h_relu3_2, h_relu4_2

## from ComoGAN
from torch.nn import init
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """
    init_weights(net, init_type, init_gain=init_gain)
    return net

def delta_t_compute(t1, t2):
    if t1 < t2:
        tmp = t1
        t1 = t2
        t2 = tmp
    
    if (t1-t2) <= math.pi:
        delta_t = (t1-t2) / math.pi
    else:
        delta_t = (2*math.pi-(t1-t2)) / math.pi
    
    return delta_t # [0,1]
