# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import glob
import numpy as np
from PIL import Image

import pickle

from utils.utils import get_nearest_pose_ids

def normalize(v):
    return v / np.linalg.norm(v)


def average_poses(poses):
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(y_, z))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(z, x)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg


def center_poses(poses, blender2opencv):
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)

    # convert to homogeneous coordinate for faster computation
    # by simply adding 0, 0, 0, 1 as the last row
    pose_avg_homo[:3] = pose_avg
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)

    # (N_images, 4, 4) homogeneous coordinate
    poses_homo = np.concatenate([poses, last_row], 1)

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    # poses_centered = poses_centered @ blender2opencv
    poses_centered = np.concatenate([poses_centered[:, :, 0:1], -poses_centered[:, :, 1:2], -poses_centered[:, :, 2:3], poses_centered[:, :, 3:4]], 2)
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv


class LLFF_Dataset(Dataset):
    def __init__(
        self,
        root_dir,
        split,
        nb_views,
        downSample=1.0,
        max_len=-1,
        scene="None",
        imgs_folder_name="images",
        use_far_view=False,
        need_style_img=False,
        need_style_label=False,
        src_specify='all',
        ref_specify='all',
        input_phi_to_test=False,
    ):
        self.root_dir = root_dir
        self.split = split
        self.nb_views = nb_views
        self.scene = scene
        self.imgs_folder_name = imgs_folder_name
        self.use_far_view = use_far_view

        self.need_style_img = need_style_img
        self.need_style_label = need_style_label
        self.src_specify = src_specify
        self.ref_specify = ref_specify
        self.domain2label = {'night':0, 'sunny':1, 'rain':2, 'cloud':3, 'snow':4}
        self.input_phi_to_test = input_phi_to_test

        self.downsample = downSample
        self.max_len = max_len
        # self.img_wh = (int(960 * self.downsample), int(720 * self.downsample))
        self.img_wh = (int(960 * self.downsample), int(640 * self.downsample))

        self.define_transforms()
        self.blender2opencv = np.array(
            [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
        )

        self.render_poses = []

        self.build_metas()

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def spherify_poses(self, poses, bds):
    
        p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
        
        rays_d = poses[:,:3,2:3]
        rays_o = poses[:,:3,3:4]

        def min_line_dist(rays_o, rays_d):
            A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
            b_i = -A_i @ rays_o
            pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
            return pt_mindist

        pt_mindist = min_line_dist(rays_o, rays_d)
        
        center = pt_mindist
        up = (poses[:,:3,3] - center).mean(0)

        vec0 = normalize(up)
        vec1 = normalize(np.cross([.1,.2,.3], vec0))
        vec2 = normalize(np.cross(vec0, vec1))
        pos = center
        c2w = np.stack([vec1, vec2, vec0, pos], 1)

        poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])

        rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
        
        sc = 1./rad
        poses_reset[:,:3,3] *= sc
        bds *= sc
        rad *= sc
        
        centroid = np.mean(poses_reset[:,:3,3], 0)
        zh = centroid[2]
        radcircle = np.sqrt(rad**2-zh**2)
        new_poses = []
        
        for th in np.linspace(0.,2.*np.pi, 120):

            camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
            up = np.array([0,0,-1.])

            vec2 = normalize(camorigin)
            vec0 = normalize(np.cross(vec2, up))
            vec1 = normalize(np.cross(vec2, vec0))
            pos = camorigin
            p = np.stack([vec0, vec1, vec2, pos], 1)

            new_poses.append(p)

        new_poses = np.stack(new_poses, 0)
        
        new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
        poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)

        # post processing for geonerf
        new_poses = np.concatenate(
                [new_poses[..., 1:2], -new_poses[..., :1], new_poses[..., 2:4]], -1
            )
        new_poses, _ = center_poses(new_poses, self.blender2opencv)
            # poses = poses @ self.blender2opencv

        
        return poses_reset, new_poses, bds

    def build_metas(self):
        if self.scene != "None":
            self.scans = [
                os.path.basename(scan_dir)
                for scan_dir in sorted(
                    glob.glob(os.path.join(self.root_dir, self.scene))
                )
            ]
        else:
            self.scans = [
                os.path.basename(scan_dir)
                for scan_dir in sorted(glob.glob(os.path.join(self.root_dir, "*")))
            ]

        if self.input_phi_to_test:
            only_one = self.scans[2]
            self.scans = [only_one for _ in range(16)]
            
        self.meta = []
        self.image_paths = {}
        self.near_far = {}
        self.id_list = {}
        self.closest_idxs = {}
        self.c2ws = {}
        self.w2cs = {}
        self.intrinsics = {}
        self.affine_mats = {}
        self.affine_mats_inv = {}
        for scan in self.scans:
            self.image_paths[scan] = sorted(
                glob.glob(os.path.join(self.root_dir, scan, self.imgs_folder_name, "*"))
            )
            assert len(self.image_paths[scan]) > 0
            poses_bounds = np.load(
                os.path.join(self.root_dir, scan, "poses_bounds.npy")
            )  # (N_images, 17)
            poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
            bounds = poses_bounds[:, -2:]  # (N_images, 2)

            # Step 1: rescale focal length according to training resolution
            H, W, focal = poses[0, :, -1]  # original intrinsics, same for all images

            focal = [focal * self.img_wh[0] / W, focal * self.img_wh[1] / H]

            #render new pose ====== test====
            _, self.render_poses, _  = self.spherify_poses(poses, bounds)
            # Step 2: correct poses
            poses = np.concatenate(
                [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1
            )
            poses, _ = center_poses(poses, self.blender2opencv)
            # poses = poses @ self.blender2opencv

            # Step 3: correct scale so that the nearest depth is at a little more than 1.0
            near_original = bounds.min()
            scale_factor = near_original * 0.75  # 0.75 is the default parameter
            bounds /= scale_factor
            poses[..., 3] /= scale_factor

            self.near_far[scan] = bounds.astype('float32')

            num_viewpoint = len(self.image_paths[scan])

            if self.split == 'train':
                val_ids = [idx for idx in range(0, num_viewpoint, 8)]
            else:
                import torch
                val_ids = torch.load('configs/lists/pairs.th')[f'{scan}_val']
                input_ids = torch.load('configs/lists/pairs.th')[f'{scan}_train']
                
                if self.input_phi_to_test:
                    val_ids = val_ids[:1]

            w, h = self.img_wh

            self.id_list[scan] = []
            self.closest_idxs[scan] = []
            self.c2ws[scan] = []
            self.w2cs[scan] = []
            self.intrinsics[scan] = []
            self.affine_mats[scan] = []
            self.affine_mats_inv[scan] = []
            for idx in range(num_viewpoint):
                if (
                    (self.split == "val" and idx in val_ids)
                    or (
                        self.split == "train"
                        and self.scene != "None"
                        and idx not in val_ids
                    )
                    or (self.split == "train" and self.scene == "None")
                ):
                    self.meta.append({"scan": scan, "target_idx": idx})

                if self.split == 'train':
                    view_ids = get_nearest_pose_ids(
                        poses[idx, :, :],
                        ref_poses=poses[..., :],
                        num_select=10 + 1,
                        angular_dist_method="dist",
                    )
                else:
                    view_ids_unconvert = get_nearest_pose_ids(
                        poses[idx, :, :],
                        ref_poses=poses[input_ids, :, :],
                        num_select=10 + 1,
                        angular_dist_method="abs",
                    )
                    view_ids = np.array(input_ids)[view_ids_unconvert].tolist()

                self.id_list[scan].append(view_ids)

                # closest_idxs = []
                # source_views = view_ids[1:]
                # for vid in source_views:
                #     closest_idxs.append(
                #         get_nearest_pose_ids(
                #             poses[vid, :, :],
                #             ref_poses=poses[source_views],
                #             num_select=5,
                #             angular_dist_method="dist",
                #         )
                #     )
                # self.closest_idxs[scan].append(np.stack(closest_idxs, axis=0)) # didn't use

                c2w = np.eye(4).astype('float32')
                c2w[:3] = poses[idx]
                w2c = np.linalg.inv(c2w)
                self.c2ws[scan].append(c2w)
                self.w2cs[scan].append(w2c)

                intrinsic = np.array([[focal[0], 0, w / 2], [0, focal[1], h / 2], [0, 0, 1]]).astype('float32')
                self.intrinsics[scan].append(intrinsic)
        
        if self.need_style_img:
            metas_style, metas_style_label = [], []
            if self.split == 'train':
                day_imgs_name = sorted(glob.glob('./waymo/sunny/Day/*'))
                dawn_imgs_name = sorted(glob.glob('./waymo/sunny/Dawn/Dusk/*'))
                night_imgs_name = sorted(glob.glob('./waymo/sunny/Night/*'))
            else:
                day_imgs_name = sorted(glob.glob('./waymo/val/sunny/Day/*'))
                dawn_imgs_name = sorted(glob.glob('./waymo/val/sunny/Dawn/Dusk/*'))
                night_imgs_name = sorted(glob.glob('./waymo/val/sunny/Night/*'))
            
            waymo_idx_file = f"./data/waymo/waymo_style_idx_{self.split}.pickle" # MDMM use ithaca to train
            with open(waymo_idx_file, 'rb') as f:
                waymo_idx = pickle.load(f)

            day_imgs_idx = waymo_idx['day']
            dawn_imgs_idx = waymo_idx['dawn']
            night_imgs_idx = waymo_idx['night']
            metas_style = [day_imgs_name[x] for x in day_imgs_idx] + [dawn_imgs_name[x] for x in dawn_imgs_idx] + [night_imgs_name[x] for x in night_imgs_idx]
            metas_style_label = [self.domain2label['sunny'] for _ in range(len(day_imgs_idx))] + [self.domain2label['sunny'] for _ in range(len(dawn_imgs_idx))] + [self.domain2label['night'] for _ in range(len(night_imgs_idx))]

            self.metas_style = metas_style
            if self.need_style_label:
                self.metas_style_label = metas_style_label
            
            print(self.split,"meta style:",self.ref_specify,len(self.metas_style))


    def __len__(self):
        return len(self.meta) if self.max_len <= 0 else self.max_len

    def __getitem__(self, idx):
        if self.split == "train" and self.scene == "None":
            noisy_factor = float(np.random.choice([1.0, 0.75, 0.5], 1))
            close_views = int(np.random.choice([3, 4, 5], 1))
        else:
            noisy_factor = 1.0
            close_views = 5
        
        if self.use_far_view and self.split == 'train': 
            sample_list = []
            sample_num = 2
        else:
            sample_num = 1

        scan = self.meta[idx]["scan"]
        target_idx = self.meta[idx]["target_idx"]

        view_ids = self.id_list[scan][target_idx]
        target_view = view_ids[0]
        src_views = view_ids[1:]
        if self.split == "train" and self.scene == "None":
            import torch
            ids = torch.randperm(10)[:self.nb_views]
        else:
            ids = np.arange(self.nb_views)
        src_views = [src_views[i] for i in ids]
        view_ids = [vid for vid in src_views] + [target_view]

        # closest_idxs = self.closest_idxs[scan][target_idx][:, :close_views]
        for s_num in range(sample_num):
            if (self.use_far_view and s_num == 1) or (self.use_far_view and self.split == 'val'):
                src_views, near_target = view_ids[:-1], view_ids[-1]
                far_tgt = get_nearest_pose_ids(
                                self.c2ws[scan][near_target],
                                ref_poses=np.array(self.c2ws[scan]),
                                num_select=20,
                                th=15,
                                angular_dist_method="dist",
                            )
                view_ids = src_views + [far_tgt]

            imgs, depths, depths_h, depths_aug = [], [], [], []
            intrinsics, w2cs, c2ws, near_fars = [], [], [], []
            affine_mats, affine_mats_inv = [], []

            w, h = self.img_wh
            w, h = int(w * noisy_factor), int(h * noisy_factor)

            for vid in view_ids:
                img_filename = self.image_paths[scan][vid]
                img = Image.open(img_filename).convert("RGB")
                if img.size != (w, h):
                    img = img.resize((w, h), Image.BICUBIC)
                img = self.transform(img)
                imgs.append(img)

                intrinsic = self.intrinsics[scan][vid].copy()
                intrinsic[:2] = intrinsic[:2] * noisy_factor
                intrinsics.append(intrinsic)

                w2c = self.w2cs[scan][vid]
                w2cs.append(w2c)
                c2ws.append(self.c2ws[scan][vid])

                aff = []
                aff_inv = []
                for l in range(3):
                    proj_mat_l = np.eye(4)
                    intrinsic_temp = intrinsic.copy()
                    intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
                    proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
                    aff.append(proj_mat_l.copy())
                    aff_inv.append(np.linalg.inv(proj_mat_l))
                aff = np.stack(aff, axis=-1)
                aff_inv = np.stack(aff_inv, axis=-1)

                affine_mats.append(aff)
                affine_mats_inv.append(aff_inv)

                near_fars.append(self.near_far[scan][vid])

                depths_h.append(np.zeros([h, w]))
                depths.append(np.zeros([h // 4, w // 4]))
                depths_aug.append(np.zeros([h // 4, w // 4]))

            imgs = np.stack(imgs)
            depths = np.stack(depths)
            depths_h = np.stack(depths_h)
            depths_aug = np.stack(depths_aug)
            affine_mats = np.stack(affine_mats)
            affine_mats_inv = np.stack(affine_mats_inv)
            intrinsics = np.stack(intrinsics)
            w2cs = np.stack(w2cs)
            c2ws = np.stack(c2ws)
            near_fars = np.stack(near_fars)

            if len(c2ws[:-1]) > 5:
                tgt_num_select = close_views
            elif len(c2ws[:-1]) == 2:
                tgt_num_select = 2
            else:
                tgt_num_select = len(c2ws[:-1])-1
            closest_idxs = []
            for pose in c2ws[:-1]:
                closest_idxs.append(
                    get_nearest_pose_ids(
                        pose, ref_poses=c2ws[:-1], num_select=tgt_num_select, angular_dist_method="dist"
                    )
                )
            closest_idxs.append(
                get_nearest_pose_ids(
                    c2ws[-1], ref_poses=c2ws[:], num_select=tgt_num_select, angular_dist_method="dist"
                )
            )
            closest_idxs = np.stack(closest_idxs, axis=0)

            second_closest_idxs = []

            sample = {}
            sample["images"] = imgs
            sample["depths"] = depths
            sample["depths_h"] = depths_h
            sample["depths_aug"] = depths_aug
            sample["w2cs"] = w2cs
            sample["c2ws"] = c2ws
            sample["near_fars"] = near_fars
            sample["affine_mats"] = affine_mats
            sample["affine_mats_inv"] = affine_mats_inv
            sample["intrinsics"] = intrinsics
            sample["closest_idxs"] = closest_idxs
            sample["second_closest_idxs"] = second_closest_idxs
            sample["input_img_label"] = self.domain2label['sunny']
            if self.need_style_img:
                style_path = self.metas_style[idx]
                img = Image.open(style_path).convert("RGB") # (1920, 1208)
                W, H = img.size
                if img.size != (w, h):
                    img = img.resize((w, h), Image.BICUBIC)
                style_img = self.transform(img)
                sample["style_img"] = style_img

                if self.need_style_label:
                    sample["style_img_label"] = self.metas_style_label[idx]


            if self.use_far_view and self.split == 'train': sample_list.append(sample)
            
        if self.use_far_view and self.split == 'train':
            return sample_list
        else:
            return sample
