# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import cv2
import numpy as np
import glob
from PIL import Image

import random
import pickle

from utils.utils import read_pfm, get_nearest_pose_ids

class tt_Dataset(Dataset):
    def __init__(
        self,
        root_dir,
        split,
        nb_views,
        downSample=1.0,
        max_len=-1,
        scene="None",
        use_far_view=False,
        need_style_img=False,
        need_style_label=False,
        src_specify='all',
        ref_specify='all',
        style_dataset='ithaca',
        n_output_views=1,
        input_phi_to_test=False,
        to_calculate_consistency=False,
        far_consistency=False,
    ):
        self.root_dir = root_dir
        self.split = split
        self.scene = scene # advanced/Auditorium
        self.use_far_view = use_far_view

        self.downSample = downSample
        self.scale_factor = 1.0 #/ 200
        self.interval_scale = 1.0
        self.max_len = max_len
        self.nb_views = nb_views

        self.need_style_img = need_style_img
        self.need_style_label = need_style_label
        self.src_specify = src_specify
        self.ref_specify = ref_specify
        self.style_dataset = style_dataset
        self.n_output_views = n_output_views
        self.domain2label = {'night':0, 'sunny':1, 'rain':2, 'cloud':3, 'snow':4}
        self.input_phi_to_test = input_phi_to_test
        self.to_calculate_consistency = to_calculate_consistency
        self.far_consistency = far_consistency
        if self.far_consistency: assert self.to_calculate_consistency == True

        self.build_metas()
        if not self.to_calculate_consistency:
            self.build_proj_mats()
        self.define_transforms()

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def build_metas(self):
        self.metas = []
        assert self.scene != "None"
        self.scans = [self.scene]

        self.id_list = []

        pair_file = f"{self.root_dir}/{self.scene}/pair.txt"
        if self.to_calculate_consistency:
            for scan in self.scans:
                tgt_view = []
                with open(pair_file) as f:
                    num_viewpoint = int(f.readline())
                    self.id_list = np.arange(num_viewpoint)
                    self.build_proj_mats()
                    # for v in [x for x in range(184,187)]:#[x for x in range(162,170)]+[x for x in range(184,194)]:#range(num_viewpoint):
                    for v in range(0, num_viewpoint, 2):
                        # ref_view = int(f.readline().rstrip())
                        # _ = [int(x) for x in f.readline().rstrip().split()[1::2]]
                        ref_view = v
                        src_views = get_nearest_pose_ids(
                                        self.cam2worlds[ref_view],
                                        ref_poses=self.cam2worlds,
                                        num_select=len(self.cam2worlds[:-1]),
                                        angular_dist_method="dist",
                                        tar_id=ref_view,
                                    ).tolist()
                        if self.far_consistency:
                            if ref_view not in tgt_view and src_views[5] not in tgt_view:
                                self.metas += [(scan, ref_view, src_views[1:5]+src_views[6:])]
                                self.metas += [(scan, src_views[5], src_views[1:5]+src_views[6:])]
                                tgt_view += [ref_view, src_views[5]]
                                if len(tgt_view)//2 >= 15:
                                    break
                        else:
                            if self.scene == 'advanced/Temple':
                                if ref_view not in tgt_view and src_views[0] not in tgt_view:
                                    self.metas += [(scan, ref_view, src_views[1:])]
                                    self.metas += [(scan, src_views[0], src_views[1:])]
                                    tgt_view += [ref_view, src_views[0]]
                                if len(tgt_view)//2 >= 10:
                                    break
                            else:
                                if ref_view not in tgt_view and src_views[1] not in tgt_view:
                                    self.metas += [(scan, ref_view, src_views[:1]+src_views[2:])]
                                    self.metas += [(scan, src_views[1], src_views[:1]+src_views[2:])]
                                    tgt_view += [ref_view, src_views[1]]
                                    if len(tgt_view)//2 >= 15:
                                        break
                # self.metas = self.metas[2*9:]

        else:
            for scan in self.scans:
                with open(pair_file) as f:
                    num_viewpoint = int(f.readline())
                    for v in range(num_viewpoint):
                        if v % 15 != 0:
                            continue
                        ref_view = int(f.readline().rstrip())
                        src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
                        if self.n_output_views == 1:
                            self.metas += [(scan, ref_view, src_views)]
                            self.id_list.append([ref_view] + src_views)
                        # elif self.n_output_views == 2:
                        #     self.metas += [(scan, ref_view, src_views[0:2]+src_views[3:])]
                        #     self.metas += [(scan, src_views[2], src_views[0:2]+src_views[3:])]
                        #     self.id_list.append([ref_view] + src_views)
                        # else:
                        #     assert self.n_output_views in [1, 2]
            
            if self.input_phi_to_test:
                only_one = self.metas[1]
                self.metas = [only_one for _ in range(16)]

        self.id_list = np.unique(self.id_list)
        self.build_remap()

        print(self.split,"meta",self.src_specify,len(self.metas))

        if self.need_style_img and self.split == "train":
            metas_style, metas_style_label = [], []
            if self.style_dataset :
                if self.split == 'train':
                    day_imgs_name = sorted(glob.glob('./waymo/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./waymo/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./waymo/sunny/Night/*'))
                else:
                    day_imgs_name = sorted(glob.glob('./waymo/val/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./waymo/val/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./waymo/val/sunny/Night/*'))
                
                waymo_idx_file = f"./data/waymo/waymo_style_idx_{self.split}.pickle" # MDMM use ithaca to train
                with open(waymo_idx_file, 'rb') as f:
                    waymo_idx = pickle.load(f)

                day_imgs_idx = waymo_idx['day']
                dawn_imgs_idx = waymo_idx['dawn']
                night_imgs_idx = waymo_idx['night']
                metas_style = [day_imgs_name[x] for x in day_imgs_idx] + [dawn_imgs_name[x] for x in dawn_imgs_idx] + [night_imgs_name[x] for x in night_imgs_idx]
                metas_style_label = [self.domain2label['sunny'] for _ in range(len(day_imgs_idx))] + [self.domain2label['sunny'] for _ in range(len(dawn_imgs_idx))] + [self.domain2label['night'] for _ in range(len(night_imgs_idx))]

            self.metas_style = metas_style
            if self.need_style_label:
                self.metas_style_label = metas_style_label
            
            print(self.split,self.style_dataset,"meta style:",self.ref_specify,len(self.metas_style))

    def build_proj_mats(self):
        near_fars, intrinsics, world2cams, cam2worlds = [], [], [], []
        for vid in self.id_list:
            proj_mat_filename = os.path.join(
                self.root_dir, self.scene, f"cams/{vid:08d}_cam.txt"
            )
            intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename)
            # intrinsic[:2] /= 2
            extrinsic[:3, 3] *= self.scale_factor

            intrinsic[:2] = intrinsic[:2] * self.downSample
            intrinsics += [intrinsic.copy()]

            near_fars += [near_far]
            world2cams += [extrinsic]
            cam2worlds += [np.linalg.inv(extrinsic)]

        self.near_fars, self.intrinsics = np.stack(near_fars), np.stack(intrinsics)
        self.world2cams, self.cam2worlds = np.stack(world2cams), np.stack(cam2worlds)

    def read_cam_file(self, filename):
        with open(filename) as f:
            lines = [line.rstrip() for line in f.readlines()]
        # extrinsics: line [1,5), 4x4 matrix
        extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ")
        extrinsics = extrinsics.reshape((4, 4))
        # intrinsics: line [7-10), 3x3 matrix
        intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ")
        intrinsics = intrinsics.reshape((3, 3))
        # depth_min & depth_interval: line 11
        # depth_min, depth_max = lines[11].split()
        depth_min, depth_interval, n_depth, depth_max = lines[11].split()
        depth_min, depth_max = float(depth_min), float(depth_max)
        # depth_min, depth_max = float(depth_min), float(depth_max)+0.5
        # depth_max = depth_min + float(depth_interval) * int(n_depth)
        # depth_min = float(depth_min) * self.scale_factor
        # depth_max = depth_min + float(depth_interval) * 256 * self.interval_scale * self.scale_factor

        return intrinsics, extrinsics, [depth_min, depth_max]

    def build_remap(self):
        self.remap = np.zeros(np.max(self.id_list) + 1).astype("int")
        for i, item in enumerate(self.id_list):
            self.remap[item] = i

    def __len__(self):
        return len(self.metas) if self.max_len <= 0 else self.max_len

    def __getitem__(self, idx):
        if self.split == "train" and self.scene == "None":
            noisy_factor = float(np.random.choice([1.0, 0.5], 1))
            close_views = int(np.random.choice([3, 4, 5], 1))
        else:
            noisy_factor = 1.0
            close_views = 5

        scan, target_view, _src_views = self.metas[idx]
        print(f'target_view: {target_view}, _src_views: {_src_views[:self.nb_views]}')
        if not self.to_calculate_consistency:
            _src_views_idx = get_nearest_pose_ids(
                            self.cam2worlds[self.remap[target_view]],
                            ref_poses=self.cam2worlds,
                            num_select=len(self.cam2worlds[:-1]),
                            angular_dist_method="dist",
                            tar_id=target_view,
                        )
            _src_views_idx = np.array(_src_views_idx)
            _src_views = self.id_list[_src_views_idx].tolist()
        view_ids = _src_views[:self.nb_views] + [target_view]

        affine_mats, affine_mats_inv = [], []
        imgs, depths_h, depths_aug = [], [], []
        depths = []
        intrinsics, w2cs, c2ws, near_fars = [], [], [], []

        for vid in view_ids:
            img_filename = os.path.join(self.root_dir, self.scene, f"images/{vid:08d}.jpg")
            
            img = Image.open(img_filename)
            img_wh = np.round(
                np.array(img.size) * self.downSample * noisy_factor
            ).astype("int")
            w, h = img_wh
            # try:
            #     img = Image.open(img_filename.replace('.jph', '_resize.jpg'))
            # except:
            img = img.resize(img_wh, Image.BICUBIC)
            img = self.transform(img)
            imgs += [img]

            index_mat = self.remap[vid]

            intrinsic = self.intrinsics[index_mat].copy()
            intrinsic[:2] = intrinsic[:2] *noisy_factor
            intrinsics.append(intrinsic)

            w2c = self.world2cams[index_mat]
            w2cs.append(w2c)
            c2ws.append(self.cam2worlds[index_mat])

            aff = []
            aff_inv = []
            for l in range(3):
                proj_mat_l = np.eye(4)
                intrinsic_temp = intrinsic.copy()
                intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
                proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
                aff.append(proj_mat_l.copy())
                aff_inv.append(np.linalg.inv(proj_mat_l))
            aff = np.stack(aff, axis=-1)
            aff_inv = np.stack(aff_inv, axis=-1)

            affine_mats.append(aff)
            affine_mats_inv.append(aff_inv)

            near_far = self.near_fars[index_mat]

            depths_h.append(np.zeros([h, w]))
            depths.append(np.zeros([h // 4, w // 4]))
            depths_aug.append(np.zeros([h // 4, w // 4]))

            near_fars.append(near_far)

        imgs = np.stack(imgs)
        depths = np.stack(depths)
        depths_h, depths_aug = np.stack(depths_h), np.stack(depths_aug)
        affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv)
        intrinsics = np.stack(intrinsics)
        w2cs = np.stack(w2cs)
        c2ws = np.stack(c2ws)
        near_fars = np.stack(near_fars)

        closest_idxs = []
        for pose in c2ws[:-1]:
            closest_idxs.append(
                get_nearest_pose_ids(
                    pose,
                    ref_poses=c2ws[:-1],
                    num_select=close_views,
                    angular_dist_method="dist",
                )
            )
        if len(c2ws[:-1]) > 5:
            tgt_num_select = close_views
        else:
            tgt_num_select = len(c2ws[:-1])-1
        closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1],
                ref_poses=c2ws[:],
                num_select=tgt_num_select,
                angular_dist_method="dist",
            )
        )
        closest_idxs = np.stack(closest_idxs, axis=0)

        second_closest_idxs = []
        for pose in c2ws[:-1]:
            second_closest_idxs.append(
                get_nearest_pose_ids(
                    pose,
                    ref_poses=c2ws[:-1],
                    num_select=close_views,
                    angular_dist_method="dist",
                    second_close_step=2
                )
            )
        second_closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1],
                ref_poses=c2ws[:],
                num_select=tgt_num_select,
                angular_dist_method="dist",
                second_close_step=2
            )
        )
        second_closest_idxs = np.stack(second_closest_idxs, axis=0)

        assert imgs.shape[0] == self.nb_views+1
        assert w2cs.shape[0] == self.nb_views+1
        assert intrinsics.shape[0] == self.nb_views+1

        sample = {}
        sample["images"] = imgs
        sample["depths"] = depths
        sample["depths_h"] = depths_h
        sample["depths_aug"] = depths_aug
        sample["w2cs"] = w2cs
        sample["c2ws"] = c2ws
        sample["near_fars"] = near_fars
        sample["intrinsics"] = intrinsics
        sample["affine_mats"] = affine_mats
        sample["affine_mats_inv"] = affine_mats_inv
        sample["closest_idxs"] = closest_idxs
        sample["second_closest_idxs"] = second_closest_idxs
        sample["view_ids"] = np.array(view_ids)
        sample["input_img_label"] = self.domain2label['sunny']
        if self.need_style_img and self.split == "train":
            style_path = self.metas_style[idx]
            img = Image.open(style_path).convert("RGB") # (1920, 1208)
            W, H = img.size
            if img.size != (w, h):
                img = img.resize((w, h), Image.BICUBIC)
            style_img = self.transform(img)
            sample["style_img"] = style_img

            if self.need_style_label:
                sample["style_img_label"] = self.metas_style_label[idx]

        return sample
