# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import cv2
import numpy as np
import glob
from PIL import Image

import random
import pickle

from utils.utils import read_pfm, get_nearest_pose_ids

def normalize(v):
    return v / np.linalg.norm(v)


def average_poses(poses):
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(y_, z))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(z, x)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg

def center_poses(poses, blender2opencv):
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)

    # convert to homogeneous coordinate for faster computation
    # by simply adding 0, 0, 0, 1 as the last row
    pose_avg_homo[:3] = pose_avg
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)

    # (N_images, 4, 4) homogeneous coordinate
    poses_homo = np.concatenate([poses, last_row], 1)

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    # poses_centered = poses_centered @ blender2opencv
    poses_centered = np.concatenate([poses_centered[:, :, 0:1], -poses_centered[:, :, 1:2], -poses_centered[:, :, 2:3], poses_centered[:, :, 3:4]], 2)
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv

class tt_Dataset(Dataset):
    def __init__(
        self,
        root_dir,
        split,
        nb_views,
        downSample=1.0,
        max_len=-1,
        scene="None",
        use_far_view=False,
        need_style_img=False,
        need_style_label=False,
        src_specify='all',
        ref_specify='all',
        style_dataset='ithaca',
        n_output_views=1,
        input_phi_to_test=False,
    ):
        self.root_dir = root_dir
        self.split = split
        self.scene = scene # advanced/Auditorium
        self.use_far_view = use_far_view

        self.downSample = downSample
        # self.scale_factor = 1.0 #/ 200
        # self.interval_scale = 1.0
        self.max_len = max_len
        self.nb_views = nb_views

        self.need_style_img = need_style_img
        self.need_style_label = need_style_label
        self.src_specify = src_specify
        self.ref_specify = ref_specify
        self.style_dataset = style_dataset
        self.n_output_views = n_output_views
        self.domain2label = {'night':0, 'sunny':1, 'rain':2, 'cloud':3, 'snow':4}
        self.input_phi_to_test = input_phi_to_test

        self.blender2opencv = np.array(
            [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
        )

        self.build_metas()
        self.define_transforms()

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def build_metas(self):
        self.metas = []
        assert self.scene != "None"
        self.scans = [self.scene]

        self.build_proj_mats()
        self.id_list = []

        for scan in self.scans:
            num_viewpoint = len(self.near_fars)
            for ref_view in range(num_viewpoint):
                src_views = get_nearest_pose_ids(
                                self.cam2worlds[ref_view, :, :],
                                ref_poses=self.cam2worlds[..., :],
                                num_select=10 + 1,
                                angular_dist_method="dist",
                            )[1:].tolist()
                self.metas += [(scan, ref_view, src_views)]
                self.id_list.append([ref_view] + src_views)
        
        if self.input_phi_to_test:
            only_one = self.metas[1]
            self.metas = [only_one for _ in range(16)]

        self.id_list = np.unique(self.id_list)
        self.build_remap()

        print(self.split,"meta",self.src_specify,len(self.metas))

        if self.need_style_img:
            metas_style, metas_style_label = [], []
            if self.style_dataset == 'waymo':
                if self.split == 'train':
                    day_imgs_name = sorted(glob.glob('./waymo/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./waymo/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./waymo/sunny/Night/*'))
                else:
                    day_imgs_name = sorted(glob.glob('./waymo/val/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./waymo/val/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./waymo/val/sunny/Night/*'))
                
                waymo_idx_file = f"./data/waymo/waymo_style_idx_{self.split}.pickle" # MDMM use ithaca to train
                with open(waymo_idx_file, 'rb') as f:
                    waymo_idx = pickle.load(f)

                day_imgs_idx = waymo_idx['day']
                dawn_imgs_idx = waymo_idx['dawn']
                night_imgs_idx = waymo_idx['night']
                metas_style = [day_imgs_name[x] for x in day_imgs_idx] + [dawn_imgs_name[x] for x in dawn_imgs_idx] + [night_imgs_name[x] for x in night_imgs_idx]
                metas_style_label = [self.domain2label['sunny'] for _ in range(len(day_imgs_idx))] + [self.domain2label['sunny'] for _ in range(len(dawn_imgs_idx))] + [self.domain2label['night'] for _ in range(len(night_imgs_idx))]

            self.metas_style = metas_style
            if self.need_style_label:
                self.metas_style_label = metas_style_label
            
            print(self.split,self.style_dataset,"meta style:",self.ref_specify,len(self.metas_style))

    def build_proj_mats(self):
        near_fars, intrinsics, world2cams, cam2worlds = [], [], [], []
        poses_bounds = np.load(os.path.join(self.root_dir, self.scene, f"poses_bounds.npy"))  # (N_cams=4, 17)
        poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
        bounds = poses_bounds[:, -2:]  # (N_images, 2)
        # Step 1: rescale focal length according to training resolution
        H, W, focal = poses[0, :, -1]  # original intrinsics, same for all images
        
        # Step 2: correct poses
        poses = np.concatenate(
            [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1
        ) # to blender coord (N_images, 3, 4)
        poses, _ = center_poses(poses, self.blender2opencv)

        # Step 3: correct scale so that the nearest depth is at a little more than 1.0
        near_original = bounds.min()
        scale_factor = near_original * 0.75  # 0.75 is the default parameter
        bounds /= scale_factor
        poses[..., 3] /= scale_factor

        near_fars = bounds.tolist()
        for _c2w in poses:
            c2w = np.eye(4)
            c2w[:3] = _c2w
            w2c = np.linalg.inv(c2w)
            cam2worlds += [c2w.copy()]
            world2cams += [w2c.copy()]

            intrinsic = np.array([[focal, 0, W / 2], [0, focal, H / 2], [0, 0, 1]])
            intrinsic[:2] = intrinsic[:2] * self.downSample
            intrinsics += [intrinsic.copy()]

        self.near_fars, self.intrinsics = np.stack(near_fars).astype('float32'), np.stack(intrinsics).astype('float32')
        self.world2cams, self.cam2worlds = np.stack(world2cams).astype('float32'), np.stack(cam2worlds).astype('float32')

    def build_remap(self):
        self.remap = np.zeros(np.max(self.id_list) + 1).astype("int")
        for i, item in enumerate(self.id_list):
            self.remap[item] = i

    def __len__(self):
        return len(self.metas) if self.max_len <= 0 else self.max_len

    def __getitem__(self, idx):
        if self.split == "train" and self.scene == "None":
            noisy_factor = float(np.random.choice([1.0, 0.5], 1))
            close_views = int(np.random.choice([3, 4, 5], 1))
        else:
            noisy_factor = 1.0
            close_views = 5

        scan, target_view, _src_views = self.metas[idx]
        view_ids = _src_views[:self.nb_views] + [target_view]

        affine_mats, affine_mats_inv = [], []
        imgs, depths_h, depths_aug = [], [], []
        depths = []
        intrinsics, w2cs, c2ws, near_fars = [], [], [], []

        for vid in view_ids:
            img_filename = os.path.join(self.root_dir, self.scene, f"images/{vid:08d}.jpg")
            
            img = Image.open(img_filename)
            img_wh = np.round(
                np.array(img.size) * self.downSample * noisy_factor
            ).astype("int")
            w, h = img_wh
            img = img.resize(img_wh, Image.BICUBIC)
            img = self.transform(img)
            imgs += [img]

            index_mat = self.remap[vid]

            intrinsic = self.intrinsics[index_mat].copy()
            intrinsic[:2] = intrinsic[:2] *noisy_factor
            intrinsics.append(intrinsic)

            w2c = self.world2cams[index_mat]
            w2cs.append(w2c)
            c2ws.append(self.cam2worlds[index_mat])

            aff = []
            aff_inv = []
            for l in range(3):
                proj_mat_l = np.eye(4)
                intrinsic_temp = intrinsic.copy()
                intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
                proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
                aff.append(proj_mat_l.copy())
                aff_inv.append(np.linalg.inv(proj_mat_l))
            aff = np.stack(aff, axis=-1)
            aff_inv = np.stack(aff_inv, axis=-1)

            affine_mats.append(aff)
            affine_mats_inv.append(aff_inv)

            near_far = self.near_fars[index_mat]

            depths_h.append(np.zeros([h, w]))
            depths.append(np.zeros([h // 4, w // 4]))
            depths_aug.append(np.zeros([h // 4, w // 4]))

            near_fars.append(near_far)

        imgs = np.stack(imgs)
        depths = np.stack(depths)
        depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug)
        affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv)
        intrinsics = np.stack(intrinsics)
        w2cs = np.stack(w2cs)
        c2ws = np.stack(c2ws)
        near_fars = np.stack(near_fars)

        closest_idxs = []
        for pose in c2ws[:-1]:
            closest_idxs.append(
                get_nearest_pose_ids(
                    pose,
                    ref_poses=c2ws[:-1],
                    num_select=close_views,
                    angular_dist_method="dist",
                )
            )
        if len(c2ws[:-1]) >= 5:
            tgt_num_select = close_views
        else:
            tgt_num_select = len(c2ws[:-1])-1
        closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1],
                ref_poses=c2ws[:],
                num_select=tgt_num_select,
                angular_dist_method="dist",
            )
        )
        closest_idxs = np.stack(closest_idxs, axis=0)

        second_closest_idxs = []
        for pose in c2ws[:-1]:
            second_closest_idxs.append(
                get_nearest_pose_ids(
                    pose,
                    ref_poses=c2ws[:-1],
                    num_select=close_views,
                    angular_dist_method="dist",
                    second_close_step=2
                )
            )
        second_closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1],
                ref_poses=c2ws[:],
                num_select=tgt_num_select,
                angular_dist_method="dist",
                second_close_step=2
            )
        )
        second_closest_idxs = np.stack(second_closest_idxs, axis=0)

        assert imgs.shape[0] == self.nb_views+1
        assert w2cs.shape[0] == self.nb_views+1
        assert intrinsics.shape[0] == self.nb_views+1

        sample = {}
        sample["images"] = imgs
        sample["depths"] = depths
        sample["depths_h"] = depths_h
        sample["depths_aug"] = depths_aug
        sample["w2cs"] = w2cs
        sample["c2ws"] = c2ws
        sample["near_fars"] = near_fars
        sample["intrinsics"] = intrinsics
        sample["affine_mats"] = affine_mats
        sample["affine_mats_inv"] = affine_mats_inv
        sample["closest_idxs"] = closest_idxs
        sample["second_closest_idxs"] = second_closest_idxs
        if self.need_style_img:
            style_path = self.metas_style[idx]
            img = Image.open(style_path).convert("RGB") # (1920, 1208)
            W, H = img.size
            if img.size != (w, h):
                img = img.resize((w, h), Image.BICUBIC)
            style_img = self.transform(img)
            sample["style_img"] = style_img

            if self.need_style_label:
                sample["style_img_label"] = self.metas_style_label[idx]

        return sample
