# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import cv2
import numpy as np
import glob
from PIL import Image

import random
import pickle

from utils.utils import read_pfm, get_nearest_pose_ids

from utils.colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary
import pandas as pd

class photoTourism_Dataset(Dataset):
    def __init__(
        self,
        root_dir,
        split,
        nb_views,
        downSample=1.0,
        max_len=-1,
        scene="None",
        use_far_view=False,
        need_style_img=False,
        need_style_label=False,
        src_specify='all',
        ref_specify='all',
        style_dataset='ithaca',
        n_output_views=1,
        input_phi_to_test=False,
        to_calculate_consistency=False,
        far_consistency=False,
    ):
        self.root_dir = root_dir
        self.split = split
        self.scene = scene # advanced/Auditorium
        self.use_far_view = use_far_view

        self.downSample = downSample
        self.scale_factor = 1.0 #/ 200
        self.interval_scale = 1.0
        self.max_len = max_len
        self.nb_views = nb_views

        self.img_wh = (int(960 * self.downSample), int(640 * self.downSample))
        self.need_style_img = need_style_img
        self.need_style_label = need_style_label
        self.src_specify = src_specify
        self.ref_specify = ref_specify
        self.style_dataset = style_dataset
        self.n_output_views = n_output_views
        self.domain2label = {'night':0, 'sunny':1, 'rain':2, 'cloud':3, 'snow':4}
        self.input_phi_to_test = input_phi_to_test
        self.to_calculate_consistency = to_calculate_consistency
        self.far_consistency = far_consistency
        if self.far_consistency: assert self.to_calculate_consistency == True

        self.build_metas()
        self.define_transforms()

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def build_metas(self):
        # read all files in the tsv first (split to train and test later)
        tsv = glob.glob(os.path.join(self.root_dir, self.scene, '*.tsv'))[0]
        self.scene_name = os.path.basename(tsv)[:-4]
        self.files = pd.read_csv(tsv, sep='\t')
        self.files = self.files[~self.files['id'].isnull()] # remove data without id
        self.files.reset_index(inplace=True, drop=True)

        # Step 1. load image paths
        # Attention! The 'id' column in the tsv is BROKEN, don't use it!!!!
        # Instead, read the id from images.bin using image file name!
        imdata = read_images_binary(os.path.join(self.root_dir, self.scene, 'dense/sparse/images.bin'))
        img_path_to_id = {}
        for v in imdata.values():
            img_path_to_id[v.name] = v.id
        self.img_ids = []
        self.image_paths = {} # {id: filename}
        for filename in list(self.files['filename']):
            if filename in img_path_to_id:
                id_ = img_path_to_id[filename]
                self.image_paths[id_] = filename
                self.img_ids += [id_]
        
        self.build_remap()

        # Step 2: read and rescale camera intrinsics
        intrinsics = [] # K of id: self.instrinsics[self.remap[id]]
        camdata = read_cameras_binary(os.path.join(self.root_dir, self.scene, 'dense/sparse/cameras.bin'))
        for id_ in self.img_ids:
            K = np.zeros((3, 3), dtype=np.float32)
            cam = camdata[id_]
            img_w, img_h = int(cam.params[2]*2), int(cam.params[3]*2)
            img_w_, img_h_ = self.img_wh
            K[0, 0] = cam.params[0]*img_w_/img_w # fx
            K[1, 1] = cam.params[1]*img_h_/img_h # fy
            K[0, 2] = cam.params[2]*img_w_/img_w # cx
            K[1, 2] = cam.params[3]*img_h_/img_h # cy
            K[2, 2] = 1
            intrinsics += [K]
        self.intrinsics = np.stack(intrinsics).astype('float32')

        # Step 3: read c2w poses (of the images in tsv file only) and correct the order
        w2c_mats = []
        bottom = np.array([0, 0, 0, 1.]).reshape(1, 4)
        for id_ in self.img_ids:
            im = imdata[id_]
            R = im.qvec2rotmat()
            t = im.tvec.reshape(3, 1)
            w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)]
        w2c_mats = np.stack(w2c_mats, 0) # (N_images, 4, 4)
        self.cam2worlds = np.linalg.inv(w2c_mats).astype('float32') # (N_images, 4, 4)
        # Original poses has rotation in form "right down front", change to "right up back"
        # self.cam2worlds[..., 1:3] *= -1

        # Step 4: correct scale
        pts3d = read_points3d_binary(os.path.join(self.root_dir, self.scene, 'dense/sparse/points3D.bin'))
        self.xyz_world = np.array([pts3d[p_id].xyz for p_id in pts3d])
        xyz_world_h = np.concatenate([self.xyz_world, np.ones((len(self.xyz_world), 1))], -1)
        # Compute near and far bounds for each image individually
        near_fars = [] # near_fars of id: self.near_fars[self.remap[id]]
        for i, id_ in enumerate(self.img_ids):
            xyz_cam_i = (xyz_world_h @ w2c_mats[i].T)[:, :3] # xyz in the ith cam coordinate
            xyz_cam_i = xyz_cam_i[xyz_cam_i[:, 2]>0] # filter out points that lie behind the cam
            near_fars += [[np.percentile(xyz_cam_i[:, 2], 0.1), np.percentile(xyz_cam_i[:, 2], 99.9)]]
            # print([np.percentile(xyz_cam_i[:, 2], 0.1), np.percentile(xyz_cam_i[:, 2], 99.9)])
        self.near_fars = np.stack(near_fars).astype('float32')

        max_far = np.fromiter(self.near_fars[:,1], np.float32).max()
        # scale_factor = max_far/5 # so that the max far is scaled to 5
        scale_factor = 1
        self.cam2worlds[..., 3] /= scale_factor
        self.world2cams = np.linalg.inv(self.cam2worlds) # (N_images, 4, 4)
        self.near_fars /= scale_factor
        self.xyz_world /= scale_factor
        # self.poses_dict = {id_: self.poses[i] for i, id_ in enumerate(self.img_ids)}

        # Step 5. split the img_ids (the number of images is verfied to match that in the paper)
        # self.img_ids_train = [id_ for i, id_ in enumerate(self.img_ids) 
        #                             if self.files.loc[i, 'split']=='train']
        # self.img_ids_test = [id_ for i, id_ in enumerate(self.img_ids)
        #                             if self.files.loc[i, 'split']=='test']
        # self.img_names_test = [self.files.loc[i, 'filename'] for i, id_ in enumerate(self.img_ids)
        #                             if self.files.loc[i, 'split']=='test']
        # self.N_images_train = len(self.img_ids_train)
        # self.N_images_test = len(self.img_ids_test)

        self.metas = []
        assert self.scene != "None"
        self.scans = [self.scene]

        num_viewpoint = len(self.img_ids)
        for scan in self.scans:
            for v in range(num_viewpoint):
                if v % 50 != 0: continue
                ref_view = self.img_ids[v]
                src_views_idx = get_nearest_pose_ids(
                                self.cam2worlds[self.remap[ref_view]],
                                ref_poses=self.cam2worlds,
                                num_select=len(self.cam2worlds[:-1]),
                                angular_dist_method="dist",
                                tar_id=self.remap[ref_view],
                            )
                src_views = np.array(self.img_ids)[src_views_idx].tolist()
                self.metas += [(scan, ref_view, src_views)]
        
        if self.input_phi_to_test:
            only_one = self.metas[1]
            self.metas = [only_one for _ in range(16)]


        print(self.split,"meta",self.src_specify,len(self.metas))

        if self.need_style_img:
            metas_style, metas_style_label = [], []
            if self.style_dataset == 'waymo':
                if self.split == 'train':
                    day_imgs_name = sorted(glob.glob('./data/waymo/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./data/waymo/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./data/waymo/sunny/Night/*'))
                else:
                    day_imgs_name = sorted(glob.glob('./data/waymo/val/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./data/waymo/val/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./data/waymo/val/sunny/Night/*'))
                
                waymo_idx_file = f"./data/waymo/waymo_style_idx_{self.split}.pickle" # MDMM use ithaca to train
                with open(waymo_idx_file, 'rb') as f:
                    waymo_idx = pickle.load(f)

                day_imgs_idx = waymo_idx['day']
                dawn_imgs_idx = waymo_idx['dawn']
                night_imgs_idx = waymo_idx['night']
                metas_style = [day_imgs_name[x] for x in day_imgs_idx] + [dawn_imgs_name[x] for x in dawn_imgs_idx] + [night_imgs_name[x] for x in night_imgs_idx]
                metas_style_label = [self.domain2label['sunny'] for _ in range(len(day_imgs_idx))] + [self.domain2label['sunny'] for _ in range(len(dawn_imgs_idx))] + [self.domain2label['night'] for _ in range(len(night_imgs_idx))]

            self.metas_style = metas_style
            if self.need_style_label:
                self.metas_style_label = metas_style_label
            
            print(self.split,self.style_dataset,"meta style:",self.ref_specify,len(self.metas_style))

    def build_remap(self):
        self.remap = np.zeros(np.max(self.img_ids) + 1).astype("int")
        for i, item in enumerate(self.img_ids):
            self.remap[item] = i

    def __len__(self):
        return len(self.metas) if self.max_len <= 0 else self.max_len

    def __getitem__(self, idx):
        if self.split == "train" and self.scene == "None":
            noisy_factor = float(np.random.choice([1.0, 0.5], 1))
            close_views = int(np.random.choice([3, 4, 5], 1))
        else:
            noisy_factor = 1.0
            close_views = 5

        scan, target_view, src_views = self.metas[idx]
        view_ids = src_views[:self.nb_views] + [target_view]

        affine_mats, affine_mats_inv = [], []
        imgs, depths_h, depths_aug = [], [], []
        depths = []
        intrinsics, w2cs, c2ws, near_fars = [], [], [], []
        w, h = self.img_wh

        for vid in view_ids:
            img_filename = os.path.join(self.root_dir, scan, 'dense/images', self.image_paths[vid])
            
            img = Image.open(img_filename).convert('RGB')
            img = img.resize(self.img_wh, Image.BICUBIC)
            img = self.transform(img)
            imgs += [img]

            index_mat = self.remap[vid]

            intrinsic = self.intrinsics[index_mat].copy()
            intrinsic[:2] = intrinsic[:2] *noisy_factor
            intrinsics.append(intrinsic)

            w2c = self.world2cams[index_mat]
            w2cs.append(w2c)
            c2ws.append(self.cam2worlds[index_mat])

            aff = []
            aff_inv = []
            for l in range(3):
                proj_mat_l = np.eye(4)
                intrinsic_temp = intrinsic.copy()
                intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
                proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
                aff.append(proj_mat_l.copy())
                aff_inv.append(np.linalg.inv(proj_mat_l))
            aff = np.stack(aff, axis=-1)
            aff_inv = np.stack(aff_inv, axis=-1)

            affine_mats.append(aff)
            affine_mats_inv.append(aff_inv)

            near_far = self.near_fars[index_mat]

            depths_h.append(np.zeros([h, w]))
            depths.append(np.zeros([h // 4, w // 4]))
            depths_aug.append(np.zeros([h // 4, w // 4]))

            near_fars.append(near_far)

        imgs = np.stack(imgs)
        depths = np.stack(depths)
        depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug)
        affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv)
        intrinsics = np.stack(intrinsics)
        w2cs = np.stack(w2cs)
        c2ws = np.stack(c2ws)
        near_fars = np.stack(near_fars)

        if len(c2ws[:-1]) > 5:
            tgt_num_select = close_views
        elif len(c2ws[:-1]) == 2:
            tgt_num_select = 2
        else:
            tgt_num_select = len(c2ws[:-1])-1

        closest_idxs = []
        for pose in c2ws[:-1]:
            closest_idxs.append(
                get_nearest_pose_ids(
                    pose,
                    ref_poses=c2ws[:-1],
                    num_select=tgt_num_select,
                    angular_dist_method="dist",
                )
            )
        
        closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1],
                ref_poses=c2ws[:],
                num_select=tgt_num_select,
                angular_dist_method="dist",
            )
        )
        closest_idxs = np.stack(closest_idxs, axis=0)

        second_closest_idxs = []

        assert imgs.shape[0] == self.nb_views+1
        assert w2cs.shape[0] == self.nb_views+1
        assert intrinsics.shape[0] == self.nb_views+1

        sample = {}
        sample["images"] = imgs
        sample["depths"] = depths
        sample["depths_h"] = depths_h
        sample["depths_aug"] = depths_aug
        sample["w2cs"] = w2cs
        sample["c2ws"] = c2ws
        sample["near_fars"] = near_fars
        sample["intrinsics"] = intrinsics
        sample["affine_mats"] = affine_mats
        sample["affine_mats_inv"] = affine_mats_inv
        sample["closest_idxs"] = closest_idxs
        sample["second_closest_idxs"] = second_closest_idxs
        sample["view_ids"] = np.array(view_ids)
        sample["input_img_label"] = self.domain2label['sunny']
        if self.need_style_img:
            style_path = self.metas_style[idx]
            img = Image.open(style_path).convert("RGB") # (1920, 1208)
            W, H = img.size
            if img.size != (w, h):
                img = img.resize((w, h), Image.BICUBIC)
            style_img = self.transform(img)
            sample["style_img"] = style_img

            if self.need_style_label:
                sample["style_img_label"] = self.metas_style_label[idx]

        return sample
