# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import glob
import numpy as np
from PIL import Image

from utils.utils import get_nearest_pose_ids
import pickle
import torch

def normalize(v):
    return v / np.linalg.norm(v)


def average_poses(poses):
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(y_, z))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(z, x)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg


def center_poses(poses, blender2opencv):
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)

    # convert to homogeneous coordinate for faster computation
    # by simply adding 0, 0, 0, 1 as the last row
    pose_avg_homo[:3] = pose_avg
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)

    # (N_images, 4, 4) homogeneous coordinate
    poses_homo = np.concatenate([poses, last_row], 1)

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    poses_centered = poses_centered @ blender2opencv
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv


class LF_Dataset(Dataset):
    def __init__(self, root_dir, split, nb_views=3, levels=1, img_wh=None, downSample=1.0, max_len=-1, near=4, test_mode="near", kwargs={}, args=None, scene="None"):
        """
        spheric_poses: whether the images are taken in a spheric inward-facing manner
                       default: False (forward-facing)
        val_num: number of val images (used for multigpu training, validate same image for all gpus)
        """
        # self.args = args
        self.root_dir = root_dir
        self.town = self.root_dir.split('/')[-2].split('_')[0]
        self.split = split
        assert self.split in ['train', 'val', 'test', 'finetune'], \
            'split must be either "train", "val" or "test"!'
        self.kwargs = kwargs
        self.args = args
        self.scene = scene
        self.sequential = False

        self.img_wh = img_wh
        self.downSample = downSample
        self.scale_factor = 1.0 #/ 200
        self.max_len = max_len
        if img_wh is not None:
            assert img_wh[0]*downSample % 32 == 0 and img_wh[1]*downSample % 32 == 0, \
                'img_wh must both be multiples of 32!'
        self.near = near
        self.build_metas()
        self.n_views = nb_views
        self.levels = levels  # FPN levels
        self.build_proj_mats()
        self.define_transforms()
        print(f'==> image down scale: {self.downSample}')

        self.white_back = False
        self.test_mode = test_mode

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def find_files(self, dir, exts):
        if os.path.isdir(dir):
            # types should be ['*.png', '*.jpg']
            files_grabbed = []
            for ext in exts:
                files_grabbed.extend(glob.glob(os.path.join(dir, ext)))
            if len(files_grabbed) > 0:
                files_grabbed = sorted(files_grabbed)
            return files_grabbed
        else:
            return []

    def build_metas(self):
        self.metas = []
        if self.split == 'train':# or self.split == 'val':
            self.scenes = [ i.split('/')[-1] for i in glob.glob(f'{self.root_dir}/*') ][:-1] #['africa', 'basket', 'ship', 'statue']
        else:
            self.scenes = [ i.split('/')[-1] for i in glob.glob(f'{self.root_dir}/*') ][:] #['africa', 'basket', 'ship', 'statue']
            # self.scenes = [ i.split('/')[-1] for i in glob.glob(f'{self.root_dir}/*') ][-1:] #['torch']
        if self.scene != "None":
            self.scenes = [self.scene]
        print(self.scenes)
        self.id_list = {}
        for scene_id in self.scenes:
            data_path = os.path.join(self.root_dir, scene_id, 'all/rgb')
            imgs_file = self.find_files(data_path, ['*.jpg'])
            view_list = [ i.split('/')[-1].split('.')[0] for i in imgs_file]
            self.id_list[scene_id] = view_list

        if self.scene != "None":
            # meta_filename = f'./configs/lf_metas_near{self.near}_{self.scene}.pickle'
            # meta_filename = f'./configs/lf_metas_fixInputViews_{self.scene}.pickle'
            out = 8
            meta_filename = f'./configs/lf_metas_interpolate{out}_{self.scene}.pickle'
            if self.sequential:
                meta_filename = f'./configs/lf_metas_sequential_{self.scene}.pickle'
        else:
            # meta_filename = f'./configs/lf_metas_near{self.near}_{self.split}.pickle'
            meta_filename = f'./configs/lf_metas_near{self.near}_train.pickle'
            # meta_filename = f'./configs/lf_metas_near{self.near}_statue.pickle'
        if not os.path.isfile(meta_filename):
            self.id_list = {}
            for scene_id in self.scenes:
                data_path = os.path.join(self.root_dir, scene_id, 'all/rgb')
                imgs_file = self.find_files(data_path, ['*.jpg'])
                num_viewpoint = len(imgs_file)
                view_list = [ i.split('/')[-1].split('.')[0] for i in imgs_file]
                self.id_list[scene_id] = view_list
                ## interpolate begin ##
                for i, view in enumerate(view_list):
                    target_view = view
                    if i-out < 0:
                        src_views = view_list[i+out:(i+out)+2] # i+5, i+5+1
                        src_views += view_list[-(out-i)-1:-(out-i)]
                    elif (i+1)+out >= num_viewpoint:
                        src_views = view_list[out-(num_viewpoint-(i+1)):out-(num_viewpoint-(i+1))+2]
                        src_views += view_list[(i-out):(i-out)+1] # i-5
                    else:
                        src_views = view_list[i+out:(i+out)+2] # i+5, i+5+1
                        src_views += view_list[(i-out):(i-out)+1] # i-5
                    assert len(src_views) == 3, f"{len(src_views)},{i},{i-out},{i+1+out}"
        
                    self.metas += [(scene_id, target_view, src_views)]
                ## interpolate end ##

                ## fix input views begin ##
                # src_ids = [0,3,6]
                # src_views = [view_list[src_id] for src_id in src_ids]
                # for i, view in enumerate(view_list):
                #     if view not in src_views:
                #         self.metas += [(scene_id, view, src_views)]
                
                # if self.sequential:
                #     for i, view in enumerate(view_list):
                #         if i not in [0,3,6]:
                #             if i < 6:
                #                 src_ids = [0,3,6]
                #                 src_views = [view_list[src_id] for src_id in src_ids]
                #             elif i == 7:
                #                 src_ids = [3,5,6]
                #                 src_views = [view_list[src_id] for src_id in src_ids]
                #             elif i > 7 and i != (len(view_list)-1):
                #                 src_ids = [i-3,i-2,i-1]
                #                 src_views = [view_list[src_id] for src_id in src_ids]
                #             elif i == (len(view_list)-1):
                #                 src_ids = [0,3,1]
                #                 src_views = [view_list[src_id] for src_id in src_ids]

                #             self.metas += [(scene_id, view, src_views)]

                ## fix input views end ##

                ## original begin ##
                # for i, view in enumerate(view_list):
                #     target_view = view
                #     if i-self.near//2 < 0:
                #         src_views = view_list[i+1:(i+1)+self.near//2]
                #         src_views += view_list[0:i]
                #         src_views += view_list[-(self.near//2-i):]
                #     elif (i+1)+self.near//2 > num_viewpoint:
                #         src_views = view_list[i-self.near//2:i]
                #         src_views += view_list[i+1:num_viewpoint]
                #         src_views += view_list[0:self.near//2-(num_viewpoint-(i+1))]
                #     else:
                #         src_views = view_list[i-self.near//2:i]
                #         src_views += view_list[i+1:(i+1)+self.near//2]
                #     # print(target_view,src_views,len(src_views))
        
                #     self.metas += [(scene_id, target_view, src_views)]
                ## original end ##
    
            with open(f'{meta_filename}', 'wb') as f:
                pickle.dump(self.metas, f)
        else:
            with open(f'{meta_filename}', 'rb') as f:
                self.metas = pickle.load(f)
        # print(self.metas)

        self.build_remap()


    def build_remap(self):
        self.remap = {}
        for scene in self.scenes:
            self.remap[scene] = {}
            for i, item in enumerate(self.id_list[scene]):
                self.remap[scene][item] = i

    def build_proj_mats(self):
        
        def parse_txt(filename):
            assert os.path.isfile(filename)
            nums = open(filename).read().split()
            return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)

        proj_mats, intrinsics, world2cams, cam2worlds = {}, {}, {}, {}
        for scene in self.scenes:
            proj_mats[scene], intrinsics[scene], world2cams[scene], cam2worlds[scene] = {}, {}, {}, {}
            cur_proj_mats, cur_intrinsics, cur_world2cams, cur_cam2worlds = [], [], [], []
            for i, view in enumerate(self.id_list[scene]):
                intr_path = os.path.join(self.root_dir, scene, 'all/intrinsics', f'{view}.txt')
                intrinsic = parse_txt(intr_path)[:3, :3]

                extr_path = os.path.join(self.root_dir, scene, 'all/pose', f'{view}.txt')
                c2w = parse_txt(extr_path)
                # c2w[2, 3] = c2w[2, 3]+20
                w2c = np.linalg.inv(c2w) # extrinstic
                # w2c[2, 3] = -w2c[2, 3]
                w2c[:3, 3] *= self.scale_factor #set=1
                # c2w = np.linalg.inv(w2c)

                intrinsic[:2] = intrinsic[:2] * self.downSample
                cur_intrinsics += [intrinsic.copy()]

                # multiply intrinsics and extrinsics to get projection matrix
                proj_mat_l = np.eye(4)
                intrinsic[:2] = intrinsic[:2] / 4 #??????是不是要變feature img size
                proj_mat_l[:3, :4] = intrinsic @ w2c[:3, :4]

                far = w2c[2, 3]
                ### better: 1e-10, far*15
                # near_far = np.array([1e-4, far]) * self.scale_factor
                # near_far = np.array([1e-4, far*2]) * self.scale_factor
                # near_far = np.array([0.05, far/2]) * self.scale_factor
                # near_far = np.array([1e-10, far*15]) * self.scale_factor
                near_far = np.array([1e-10, 14]) * self.scale_factor
                # near_far = np.array([1e-10, 20]) * self.scale_factor
                
                # near_far = np.array([1e-10, 20.]) * self.scale_factor
                # near_far = np.array([1., 20.]) * self.scale_factor
                
                # near_far = np.array([far*1, far*15]) * self.scale_factor
                # near_far = np.array([far, far*6]) * self.scale_factor
                # near_far = np.array([far, far*4]) * self.scale_factor
                # near_far = np.array([20., 300.]) * self.scale_factor
                cur_proj_mats += [(proj_mat_l, near_far)]
                cur_world2cams += [w2c]
                cur_cam2worlds += [c2w]

            proj_mats[scene] , intrinsics[scene] = np.stack(cur_proj_mats), np.stack(cur_intrinsics)
            world2cams[scene], cam2worlds[scene] = np.stack(cur_world2cams), np.stack(cur_cam2worlds)
            
        self.proj_mats, self.intrinsics, self.world2cams, self.cam2worlds = proj_mats, intrinsics, world2cams, cam2worlds

    def __len__(self):
        return len(self.metas) if self.max_len <= 0 else self.max_len

    
    def __getitem__(self, idx):
        sample = {}
        scene_id, target_view, src_views = self.metas[idx]
        if self.split == 'train':
            ids = torch.randperm(self.near)[:self.n_views]
            view_ids = [src_views[i] for i in ids] + [target_view]
        elif self.split == 'val':
            # # _index_mat = self.remap[scene_id][target_view]
            # # target_view = self.id_list[scene_id][_index_mat+5]
            # # view_ids = [src_views[i] for i in range(self.n_views)] + [target_view]
            # _index_mat = self.remap[scene_id][src_views[1]]
            # new_index = _index_mat+10 if _index_mat+10<len(self.remap[scene_id]) else _index_mat-10
            # src_views[1] = self.id_list[scene_id][new_index]
            # _index_mat = self.remap[scene_id][src_views[2]]
            # new_index = _index_mat+15 if _index_mat+15<len(self.remap[scene_id]) else _index_mat-15
            # src_views[2] = self.id_list[scene_id][new_index]
            # _index_mat = self.remap[scene_id][target_view]
            # target_view = self.id_list[scene_id][_index_mat+5]
            # view_ids = [src_views[i] for i in range(self.n_views)] + [target_view]
            # _index_mat = self.remap[scene_id][target_view]
            # target_view = self.id_list[scene_id][_index_mat+5]
            # view_ids = [src_views[i] for i in range(self.n_views)] + [target_view]
            # view_ids = [src_views[1] , src_views[0]] + [target_view]
            view_ids = [src_views[i] for i in range(self.n_views)] + [target_view]
        elif self.split == 'test' or self.split == 'finetune':
            view_list = self.id_list[scene_id]
            if self.test_mode == 'sameInput':
                fix_views = [view_list[0], view_list[10]] #view_list[0:self.n_views]
                view_ids = fix_views + [target_view]
                # scene_id = 1
            elif self.test_mode == 'near':
                view_ids = [src_views[i] for i in range(self.n_views)] + [target_view]
                # view_ids = [src_views[1],target_view] + [src_views[0]]
                # view_ids = [src_views[1], src_views[0]] + [src_views[0]]
                # view_ids = [src_views[0], src_views[1]] + [src_views[1]]
            elif self.test_mode == 'far':
                _index_mat = self.remap[scene_id][target_view]
                target_view = self.id_list[scene_id][_index_mat+5]
                # view_ids = [src_views[i] for i in range(self.n_views)] + [target_view]
                # view_ids = [src_views[1] , src_views[0]] + [target_view]
                view_ids = [src_views[i] for i in range(self.n_views)] + [target_view]
            elif self.test_mode == 'seq':
                s = []
                for i in range(self.n_views,0,-1):
                    s += [target_view - 5*i if target_view - 5*i >= 0 else target_view - 5*i + 360]
                view_ids = s + [target_view]
            else:
                assert len(view_ids) != 0, \
                    print("WRONG test_mode")

        affine_mat, affine_mat_inv = [], []
        imgs, depths_h, depths, depths_aug = [], [], [], []
        proj_mats, intrinsics, w2cs, c2ws, near_fars = [], [], [], [], []  # record proj mats between views
        # print(view_ids)
        for i, vid in enumerate(view_ids):

            data_path = os.path.join(self.root_dir, scene_id, 'all/rgb', f'{vid}.jpg')
            if self.sequential:
                vid_num = self.remap[scene_id][vid]
                if (i != len(view_ids)-1) and (vid_num not in [0,3,6]): # not target view
                    if vid_num >= 7:
                        cor_vid = vid_num - 3
                    else:
                        if vid_num in [1,2]: cor_vid = vid_num-1
                        elif vid_num in [4,5]: cor_vid = vid_num-2
                    data_path = os.path.join('./logs/lf_data/africa/Generalizable_ver0/nb_3-sequential/evaluation', f'{0:08d}_{cor_vid:02d}_novel.png')
                    while(1):
                        import time
                        if len(glob.glob(data_path)) != 1:
                            time.sleep(35)
                            print("wait")
                        else:
                            break

            assert len(glob.glob(data_path)) == 1, \
                f'[data path wrong] data_path: {data_path}, glob:{glob.glob(data_path)}'
            img_filename = glob.glob(data_path)[0]
            
            img = Image.open(img_filename).convert('RGB')
            # img_wh = np.round(np.array(img.size) * self.downSample).astype('int')
            img_wh = np.round(np.array([1280, 800]) * self.downSample).astype('int')
            # print("image size:",img_wh) #(768, 480)
            img = img.resize(img_wh, Image.BILINEAR)
            img = self.transform(img)
            imgs += [img]

            w, h = img_wh
            depths_h.append(np.zeros([h, w]))
            depths.append(np.zeros([h // 4, w // 4]))
            depths_aug.append(np.zeros([h // 4, w // 4]))

            index_mat = self.remap[scene_id][vid]
            proj_mat_ls, near_far = self.proj_mats[scene_id][index_mat]
            intrinsics.append(self.intrinsics[scene_id][index_mat])
            w2cs.append(self.world2cams[scene_id][index_mat])
            c2ws.append(self.cam2worlds[scene_id][index_mat])

            # affine_mat.append(proj_mat_ls)
            # affine_mat_inv.append(np.linalg.inv(proj_mat_ls))
            # if i == 0:  # reference view
            #     ref_proj_inv = np.linalg.inv(proj_mat_ls)
            #     proj_mats += [np.eye(4)]
            # else:
            #     proj_mats += [proj_mat_ls @ ref_proj_inv]
            aff = []
            aff_inv = []
            for l in range(3):
                proj_mat_l = np.eye(4)
                intrinsic_temp = self.intrinsics[scene_id][index_mat].copy()
                intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
                proj_mat_l[:3, :4] = intrinsic_temp @ self.world2cams[scene_id][index_mat][:3, :4]
                aff.append(proj_mat_l.copy())
                aff_inv.append(np.linalg.inv(proj_mat_l))
            aff = np.stack(aff, axis=-1)
            aff_inv = np.stack(aff_inv, axis=-1)

            affine_mat.append(aff)
            affine_mat_inv.append(aff_inv)

            near_fars.append(near_far)

        imgs = torch.stack(imgs).float()
        # if self.split == 'train':
        #     imgs = colorjitter(imgs, 1.0+(torch.rand((4,))*2-1.0)*0.5)
        # imgs = F.normalize(imgs,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        depths = np.stack(depths)
        depths_h = np.stack(depths_h)
        depths_aug = np.stack(depths_aug)

        # proj_mats = np.stack(proj_mats)[:, :3]
        affine_mat, affine_mat_inv = np.stack(affine_mat), np.stack(affine_mat_inv)
        intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(w2cs), np.stack(c2ws), np.stack(near_fars)
        view_ids_all = [target_view] + list(src_views) if type(src_views[0]) is not list else [j for sub in src_views for j in sub]
        c2ws_all = []
        for v in view_ids_all:
            idx = self.remap[scene_id][v]
            c2ws_all.append(self.cam2worlds[scene_id][idx])
        c2ws_all = np.stack(c2ws_all)

        closest_idxs = []
        for pose in c2ws[:-1]:
            closest_idxs.append(
                get_nearest_pose_ids(
                    pose,
                    ref_poses=c2ws[:-1],
                    num_select=self.n_views,
                    angular_dist_method="dist",
                )
            )
        closest_idxs = np.stack(closest_idxs, axis=0)

        sample['images'] = imgs  # (V, H, W, 3)
        sample["depths"] = depths
        sample["depths_h"] = depths_h
        sample["depths_aug"] = depths_aug
        sample['w2cs'] = w2cs.astype(np.float32)  # (V, 4, 4)
        sample['c2ws'] = c2ws.astype(np.float32)  # (V, 4, 4)
        sample['near_fars'] = near_fars.astype(np.float32)
        # sample['proj_mats'] = proj_mats.astype(np.float32)
        sample['intrinsics'] = intrinsics.astype(np.float32)  # (V, 3, 3)
        # sample['view_ids'] = np.array(view_ids)
        sample['affine_mats'] = affine_mat
        sample['affine_mats_inv'] = affine_mat_inv
        sample['c2ws_all'] = c2ws_all.astype(np.float32)
        sample["closest_idxs"] = closest_idxs

        del affine_mat, affine_mat_inv, imgs, depths_h, depths, depths_aug, proj_mats, intrinsics, w2cs, c2ws, near_fars
        
        return sample