# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import cv2
import numpy as np
from PIL import Image

from utils.utils import read_pfm, get_nearest_pose_ids
import glob
import json
import pickle
from pyquaternion import Quaternion
import random
from random import shuffle


def normalize(v):
    return v / np.linalg.norm(v)


def average_poses(poses):
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(y_, z))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(z, x)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg

def center_poses(poses, blender2opencv):
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)

    # convert to homogeneous coordinate for faster computation
    # by simply adding 0, 0, 0, 1 as the last row
    pose_avg_homo[:3] = pose_avg
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)

    # (N_images, 4, 4) homogeneous coordinate
    poses_homo = np.concatenate([poses, last_row], 1)

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    # poses_centered = poses_centered @ blender2opencv
    poses_centered = np.concatenate([poses_centered[:, :, 0:1], -poses_centered[:, :, 1:2], -poses_centered[:, :, 2:3], poses_centered[:, :, 3:4]], 2)
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv

class ithaca_Dataset(Dataset):
    def __init__(
        self,
        root_dir,
        table_root_dir,
        split,
        nb_views,
        downSample=1.0,
        max_len=-1,
        scene="None",
        use_far_view=False,
        ithaca_all=None,
        use_two_cam=False,
        need_style_img=False,
        need_style_label=False,
        src_specify='all',
        ref_specify='all',
        cam_diff_weather=None,
        read_lidar=False,
        camfile='colmap',
        input_phi_to_test=False,
        style_dataset='ithaca',
        pretrain_dataset='ithaca',
        specify_file="None",
        n_output_views=1,
        to_calculate_consistency=False,
        to_calculate_FID=False,
        update_z=False,
        styleSame=False,
    ):
        self.root_dir = root_dir 
        self.table_root_dir = table_root_dir
        self.split = split
        self.scene = scene
        self.use_far_view = use_far_view
        self.use_two_cam = use_two_cam
        self.need_style_img = need_style_img
        self.need_style_label = need_style_label
        self.src_specify = src_specify
        self.ref_specify = ref_specify
        self.cam_diff_weather = cam_diff_weather
        self.camfile = camfile
        self.input_phi_to_test = input_phi_to_test
        self.style_dataset = style_dataset
        self.pretrain_dataset = pretrain_dataset
        self.specify_file = specify_file
        self.n_output_views = n_output_views
        self.to_calculate_consistency = to_calculate_consistency
        self.update_z = update_z
        self.styleSame = styleSame
        self.to_calculate_FID = to_calculate_FID

        if self.to_calculate_consistency:
            self.n_output_views = 2

        self.downSample = downSample
        self.scale_factor = 1.0 / 1500
        self.interval_scale = 1.06
        self.max_len = max_len
        self.nb_views = nb_views
        # self.img_wh = (int(960 * self.downSample), int(640 * self.downSample))
        self.img_wh = (int(128*6 * self.downSample), int(128*4 * self.downSample)) #(768,512)

        self.ithaca_all = ithaca_all
        
        if self.use_two_cam: assert self.nb_views == 2
        self.blender2opencv = np.array(
            [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
        )

        if self.pretrain_dataset == 'waymo':
            self.domain2label = {'Day':0,'Dawn/Dusk':1,'Night':2}
        else:
            self.domain2label = {'night':0, 'sunny':1, 'rain':2, 'cloud':3, 'snow':4}

        self.time2weather()
        self.read_lidar = read_lidar

        self.build_metas()
        self.build_proj_mats()
        self.define_transforms()

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def time2weather(self):

        self.time2weather_map = {'01-16-2022': 'cloud', '01-17-2022': 'snow', '01-17-2022b': 'snow', '01-17-2022c': 'snow', '01-20-2022': 'sunny', '01-23-2022': 'night', '02-01-2022': 'night', '02-03-2022': 'snow','02-04-2022': 'snow', '02-04-2022b': 'snow', '02-11-2022': 'sunny', '02-17-2022': 'cloud', '02-17-2022b': 'night', '02-21-2022': 'sunny', '02-22-2022': 'rain', '02-24-2022': 'night', '02-25-2022': 'snow', '03-03-2022b': 'night', '08-20-2021': 'sunny', '10-04-2021': 'rain', '10-08-2021': 'sunny', '11-19-2021': 'cloud', '11-22-2021': 'cloud', '11-23-2021': 'cloud', '11-29-2021': 'snow', '11-30-2021': 'cloud', '12-01-2021': 'cloud', '12-02-2021': 'night', '12-03-2021': 'sunny', '12-06-2021': 'cloud', '12-07-2021': 'cloud', '12-08-2021': 'rain', '12-09-2021': 'night', '12-13-2021': 'sunny', '12-14-2021': 'sunny', '12-15-2021': 'rain', '12-16-2021': 'cloud', '12-18-2021': 'rain', '12-19-2021': 'rain', '12-19-2021b': 'cloud'}
        
        return

    def build_metas(self):
        self.metas = []
        if self.use_two_cam:
            self.scans = []
            dates = sorted(glob.glob(f'{self.root_dir}/*')) 
            timeIdx2date = {}
            weatherCnt = {'night':2, 'sunny':2, 'rain':2, 'cloud':2, 'snow':2}
            cnt = 10
            train_date = ['12-14-2021', '10-08-2021', '12-19-2021b', '12-01-2021', '11-30-2021', '02-03-2022', '11-29-2021', '02-04-2022b', '02-22-2022', '12-09-2021']
            if self.src_specify == 'all':
                src_weather = ['sunny', 'rain', 'snow', 'cloud', 'night']
            elif self.src_specify == 'sunny+night':
                src_weather = ['sunny', 'night']
            else:
                src_weather = [self.src_specify]

            if self.specify_file != "None":
                
                date = '/'.join(self.specify_file.split('/')[:-2])
                time = self.specify_file.split('/')[-1]
                print(date,time)
                self.scans = [{'date':date, 'time':time}]

            else:
                scan_filename = f'./data/ithaca/scans/{self.src_specify}_{self.split}.pickle'
                if self.update_z:
                    scan_filename = f'./data/ithaca/scans/{self.src_specify}_val.pickle'

                if not os.path.isfile(scan_filename):
                    for i, date in enumerate(dates):
                        if 'poses_bounds' in date or '02-17-2022b' in date: continue
                        only_date = date.split('/')[-1] # ex: 08-20-2021

                        if self.src_specify == 'all':
                            if only_date not in train_date: continue
                        else:
                            if self.time2weather_map[only_date] not in src_weather: continue

                        date_path = os.path.join(self.root_dir, date, "cam0")
                        times = glob.glob(f'{date_path}/*')
                        times = [x.split('/')[-1] for x in times]
                        times.sort()
                        for i, time in enumerate(times):
                            if self.split == 'train' and i>=len(times)-15: continue
                            elif self.split == 'val' and i<len(times)-15: continue
                            notadd = False
                            for vid in range(3):
                                scene_path = os.path.join(self.root_dir, date, f"cam{vid}", time)
                                if not os.path.isfile(scene_path):
                                    notadd = True
                                    break
                            if not notadd:
                                self.scans.append({'date':date, 'time':time})
                    
                    if self.split == 'val':
                        shuffle(self.scans)

                    with open(scan_filename, 'wb') as f:
                        pickle.dump(self.scans, f)
                
                else:
                    with open(scan_filename, 'rb') as f:
                        self.scans = pickle.load(f)
                
                if self.split == 'val':
                    if self.to_calculate_FID:
                        self.scans = self.scans[:100]
                    else:
                        self.scans = self.scans[:15]
            
            if self.update_z:
                only_one = self.scans[0]
                self.scans = [only_one for _ in range(300)]

            if self.input_phi_to_test:
                only_one = self.scans[7]
                self.scans = [only_one for _ in range(16)]
                    
            if self.cam_diff_weather == 'colmap':
                if self.split == 'val':
                    # self.scans = [{'date':['01-16-2022','01-20-2022'], 'time':['1642366393014409.png','1642713642816302.png']}]
                    self.scans = [{'date':['01-16-2022','12-19-2021b'], 'time':['1642366393014409.png','1639938265215199.png']}]

        self.id_list = []

        if self.use_two_cam:
            meta_filename = f'./configs/ithaca_{self.split}_2cams.pickle'
        else:
            meta_filename = f'./configs/ithaca_{self.split}.pickle'
        self.metas, self.metas_label = [], []
        if not os.path.isfile(meta_filename):
            for idx in range(len(self.scans)):
                if self.use_two_cam:
                    date_time = self.scans[idx]
                    if self.split == 'train':
                        src_views = [0, 2]; tgt_view = 1
                        self.metas += [(date_time, tgt_view, src_views)]
                        domain = self.time2weather_map[date_time["date"].split('/')[-1]]
                        self.metas_label += [self.domain2label[domain]]
                    else:
                        if self.cam_diff_weather == 'colmap':
                            src_views = [0, 6]; tgt_view = 1
                            self.metas += [(date_time, tgt_view, src_views)]
                            self.metas_label += [self.domain2label[self.time2weather_map[date_time["date"][0]]]]
                            self.id_list.append([tgt_view] + src_views)

                            src_views = [2, 4]; tgt_view = 1
                            self.metas += [(date_time, tgt_view, src_views)]
                            self.metas_label += [self.domain2label[self.time2weather_map[date_time["date"][0]]]]
                            self.id_list.append([tgt_view] + src_views)

                            src_views = [0, 2]; tgt_view = 1
                            self.metas += [(date_time, tgt_view, src_views)]
                            self.metas_label += [self.domain2label[self.time2weather_map[date_time["date"][0]]]]
                            self.id_list.append([tgt_view] + src_views)

                            src_views = [4, 6]; tgt_view = 1
                            self.metas += [(date_time, tgt_view, src_views)]
                            self.metas_label += [self.domain2label[self.time2weather_map[date_time["date"][0]]]]
                            self.id_list.append([tgt_view] + src_views)
                        
                        else:
                            if self.n_output_views == 1:
                                src_views = [0, 2]; tgt_view = 1
                                self.metas += [(date_time, tgt_view, src_views)]
                                domain = self.time2weather_map[date_time["date"].split('/')[-1]]
                                self.metas_label += [self.domain2label[domain]]
                            elif self.n_output_views == 2:
                                assert self.to_calculate_consistency == True
                                for t in range(4,6):
                                    src_views = [0, 2]; tgt_view = t
                                    self.metas += [(date_time, tgt_view, src_views)]
                                    domain = self.time2weather_map[date_time["date"].split('/')[-1]]
                                    self.metas_label += [self.domain2label[domain]]
                            elif self.n_output_views == 3:
                                for x in range(3):
                                    # if x == 0: src_views = [0, 1]; tgt_view = 2
                                    # elif x == 1: src_views = [0, 2]; tgt_view = 1
                                    # elif x == 2: src_views = [2, 1]; tgt_view = 0
                                    if x == 0: src_views = [0, 2]; tgt_view = 2
                                    elif x == 1: src_views = [0, 2]; tgt_view = 1
                                    elif x == 2: src_views = [0, 2]; tgt_view = 0
                                    self.metas += [(date_time, tgt_view, src_views)]
                                    domain = self.time2weather_map[date_time["date"].split('/')[-1]]
                                    self.metas_label += [self.domain2label[domain]]
                    
                    self.id_list.append([x for x in range(3)])

            print(self.split,"meta",self.src_specify,len(self.metas))
                
            # with open(f'{meta_filename}', 'wb') as f:
            #     pickle.dump(self.metas, f)
        else:
            with open(f'{meta_filename}', 'rb') as f:
                self.metas = pickle.load(f)

        if self.need_style_img:
            metas_style, metas_style_label = [], []
            if self.style_dataset == 'ithaca':
                dates = sorted(glob.glob(f'{self.root_dir}/*')) 
                if self.ref_specify == 'all':
                    mode_list = ['sunny', 'rain', 'snow', 'cloud', 'night']
                elif self.ref_specify == 'sunny+night':
                    mode_list = ['sunny', 'night']
                else:
                    mode_list = [self.ref_specify]
                for date in dates:
                    if 'poses_bounds' in date or '02-17-2022b' in date: continue
                    only_date = date.split('/')[-1]
                    if self.time2weather_map[only_date] in mode_list:
                        metas_style += glob.glob(f'{date}/cam0/*')
                        if self.need_style_label:
                            metas_style_label += [self.domain2label[self.time2weather_map[only_date]] for x in range(len(glob.glob(f'{date}/cam0/*')))]
            elif self.style_dataset == 'waymo':
                if self.split == 'train':
                    day_imgs_name = sorted(glob.glob('./waymo/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./waymo/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./waymo/sunny/Night/*'))
                else:
                    day_imgs_name = sorted(glob.glob('./waymo/val/sunny/Day/*'))
                    dawn_imgs_name = sorted(glob.glob('./waymo/val/sunny/Dawn/Dusk/*'))
                    night_imgs_name = sorted(glob.glob('./waymo/val/sunny/Night/*'))
                
                if self.to_calculate_FID:
                    waymo_idx_file = f"./data/waymo/waymo_calculateFID_{self.ref_specify}_{self.split}.pickle"
                    print(waymo_idx_file)
                    if not os.path.isfile(waymo_idx_file):
                        if self.ref_specify == 'day':
                            day_imgs_idx = random.sample(range(len(day_imgs_name)), 100)
                        elif self.ref_specify == 'dawn':
                            dawn_imgs_idx = random.sample(range(len(dawn_imgs_name)), 100)
                        elif self.ref_specify == 'night':
                            night_imgs_idx = random.sample(range(len(night_imgs_name)), 100)
                                                
                        with open(waymo_idx_file, 'wb') as f:
                            if self.ref_specify == 'day':
                                pickle.dump({'day':day_imgs_idx}, f)
                            elif self.ref_specify == 'dawn':
                                pickle.dump({'dawn':dawn_imgs_idx}, f)
                            elif self.ref_specify == 'night':
                                pickle.dump({'night':night_imgs_idx}, f)
                    else:
                        with open(waymo_idx_file, 'rb') as f:
                            waymo_idx = pickle.load(f)
                        if self.ref_specify == 'day':
                            day_imgs_idx = waymo_idx['day']
                        elif self.ref_specify == 'dawn':
                            dawn_imgs_idx = waymo_idx['dawn']
                        elif self.ref_specify == 'night':
                            night_imgs_idx = waymo_idx['night']
                    if self.ref_specify == 'day':
                        metas_style = [day_imgs_name[x] for x in day_imgs_idx]
                        metas_style_label = [self.domain2label['sunny'] for _ in range(len(day_imgs_idx))]
                    elif self.ref_specify == 'dawn':
                        metas_style = [dawn_imgs_name[x] for x in dawn_imgs_idx]
                        metas_style_label = [self.domain2label['sunny'] for _ in range(len(dawn_imgs_idx))]
                    elif self.ref_specify == 'night':
                        metas_style = [night_imgs_name[x] for x in night_imgs_idx]
                        metas_style_label = [self.domain2label['night'] for _ in range(len(night_imgs_idx))]
                
                else:
                    waymo_idx_file = f"./data/waymo/waymo_style_idx_{self.split}.pickle" # MDMM use ithaca to train
                    print(waymo_idx_file)
                    if not os.path.isfile(waymo_idx_file):
                        day_imgs_idx = random.sample(range(len(day_imgs_name)), 100)
                        dawn_imgs_idx = random.sample(range(len(dawn_imgs_name)), 100)
                        night_imgs_idx = random.sample(range(len(night_imgs_name)), 100)
                                                
                        with open(waymo_idx_file, 'wb') as f:
                            pickle.dump({'day':day_imgs_idx,'dawn':dawn_imgs_idx,'night':night_imgs_idx}, f)
                    else:
                        with open(waymo_idx_file, 'rb') as f:
                            waymo_idx = pickle.load(f)
                        day_imgs_idx = waymo_idx['day']
                        dawn_imgs_idx = waymo_idx['dawn']
                        night_imgs_idx = waymo_idx['night']
                    metas_style = [day_imgs_name[x] for x in day_imgs_idx] + [dawn_imgs_name[x] for x in dawn_imgs_idx] + [night_imgs_name[x] for x in night_imgs_idx]
                    metas_style_label = [self.domain2label['sunny'] for _ in range(len(day_imgs_idx))] + [self.domain2label['sunny'] for _ in range(len(dawn_imgs_idx))] + [self.domain2label['night'] for _ in range(len(night_imgs_idx))]
                    
            style_label_map = {}
            for i, s in enumerate(metas_style):
                style_label_map[s] = i

            if self.split == 'train':
                self.metas_style = metas_style
                if self.need_style_label:
                    self.metas_style_label = metas_style_label
            else:
                if not os.path.isfile("./data/waymo/val_metas_style_rand.pickle"):
                    shuffle(metas_style)
                    with open("./data/waymo/val_metas_style_rand.pickle", 'wb') as f:
                        pickle.dump(metas_style, f)
                else:
                    with open("./data/waymo/val_metas_style_rand.pickle", 'rb') as f:
                        metas_style = pickle.load(f)
                       

                self.metas_style, self.metas_style_label = {}, {}
                for i, (_scene, _, _) in enumerate(self.metas):
                    if self.cam_diff_weather == 'colmap':
                        scene = _scene["date"][0]+'_'+_scene["time"][0]
                    else:
                        scene = _scene["date"]+'_'+_scene["time"]
                    self.metas_style[scene] = metas_style[i]
                    if self.need_style_label:
                        # self.metas_style_label[scene] = metas_style_label[i]
                        self.metas_style_label[scene] = metas_style_label[style_label_map[metas_style[i]]]
            
            print(self.split,self.style_dataset,"meta style:",self.ref_specify,len(self.metas_style))

            if self.update_z or self.styleSame:
                print("style use input images")
                _scan, _, _src_views = self.metas[0]
                if self.split == 'train':
                    self.metas_style = [os.path.join(self.root_dir, _scan["date"], f"cam{_src_views[0]}", _scan["time"]) for _ in range(300)]
                    self.metas_style_label = self.metas_label
                elif self.split == 'val':
                    self.metas_style, self.metas_style_label = {}, {}
                    for i, (_scene, _, _) in enumerate(self.metas):
                        scene = _scene["date"]+'_'+_scene["time"]
                        self.metas_style[scene] = os.path.join(self.root_dir, _scan["date"], f"cam{_src_views[0]}", _scan["time"])
                        self.metas_style_label[scene] = self.metas_label[i]

            ## check
            if self.split == 'train':
                min_len = min(len(self.metas),len(self.metas_style))
                self.metas = self.metas[:min_len]
                self.metas_label = self.metas_label[:min_len]
                self.metas_style = self.metas_style[:min_len]
                self.metas_style_label = self.metas_style_label[:min_len]

        self.id_list = np.unique(self.id_list)
        self.build_remap()

    def build_proj_mats(self):
        near_fars, intrinsics, world2cams, cam2worlds = [], [], [], []
        
        if self.use_two_cam:
            self.near_fars, self.intrinsics, self.world2cams, self.cam2worlds = {}, {}, {}, {}
            if self.camfile == 'colmap':
                if self.cam_diff_weather == 'colmap':
                    # poses_bounds = np.load(os.path.join(self.root_dir, f"poses_bounds_0116_0120.npy"))  # (N_cams=4, 17)
                    poses_bounds = np.load(os.path.join(self.root_dir, f"poses_bounds_0116_1219b.npy"))  # (N_cams=4, 17)
                else:
                    poses_bounds = np.load(os.path.join(self.root_dir, f"poses_bounds.npy"))  # (N_cams=4, 17)
                poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
                bounds = poses_bounds[:, -2:]  # (N_images, 2)
                # Step 1: rescale focal length according to training resolution
                H, W, focal = poses[0, :, -1]  # original intrinsics, same for all images
                
                # Step 2: correct poses
                poses = np.concatenate(
                    [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1
                ) # to blender coord (N_images, 3, 4)
                poses, _ = center_poses(poses, self.blender2opencv)

                # Step 3: correct scale so that the nearest depth is at a little more than 1.0
                near_original = bounds.min()
                scale_factor = near_original * 0.75  # 0.75 is the default parameter
                bounds /= scale_factor
                poses[..., 3] /= scale_factor

                near_fars = bounds.tolist()
                for _c2w in poses:
                    c2w = np.eye(4)
                    c2w[:3] = _c2w
                    w2c = np.linalg.inv(c2w)
                    cam2worlds += [c2w.copy()]
                    world2cams += [w2c.copy()]

                    intrinsic = np.array([[focal, 0, W / 2], [0, focal, H / 2], [0, 0, 1]])
                    intrinsics += [intrinsic.copy()]
        
        self.near_fars, self.intrinsics = np.stack(near_fars).astype('float32'), np.stack(intrinsics).astype('float32')
        self.world2cams, self.cam2worlds = np.stack(world2cams).astype('float32'), np.stack(cam2worlds).astype('float32')

    def read_depth(self, sampleToken=None):
        self.ithaca_all.render_pointcloud_in_image(sampleToken, camera_channel='cam0', pointsensor_channel='LIDAR_TOP', out_path='./d.png')
        return 

    def build_remap(self):
        self.remap = {}
        for i, item in enumerate(self.id_list):
            self.remap[item] = i

    def __len__(self):
        return len(self.metas) if self.max_len <= 0 else self.max_len

    def __getitem__(self, idx):
        if self.split == "train" and self.scene == "None":
            # noisy_factor = float(np.random.choice([1.0, 0.5], 1))
            noisy_factor = 1.0
            close_views = int(np.random.choice([3, 4, 5], 1))
        else:
            noisy_factor = 1.0
            close_views = 5

        scan, target_view, _src_views = self.metas[idx]
        if self.split == "train" and self.scene == "None":
            src_views = _src_views
            view_ids = src_views[:self.nb_views] + [target_view]
        else:
            src_views = _src_views
            view_ids = src_views[:self.nb_views] + [target_view]

        affine_mats, affine_mats_inv = [], []
        imgs, depths, depths_h, depths_aug = [], [], [], []
        intrinsics, w2cs, c2ws, near_fars = [], [], [], []

        w, h = self.img_wh
        w, h = int(w * noisy_factor), int(h * noisy_factor)

        for i, vid in enumerate(view_ids):
            ori_vid = vid
            if self.to_calculate_consistency and ori_vid >= 4:
                vid = 1

            if self.use_two_cam:
                if self.cam_diff_weather == 'colmap':
                    if vid < 4:
                        img_filename = os.path.join(self.root_dir, scan["date"][0], f"cam{vid}", scan["time"][0])
                    else:
                        img_filename = os.path.join(self.root_dir, scan["date"][1], f"cam{vid%4}", scan["time"][1])
                else:
                    img_filename = os.path.join(self.root_dir, scan["date"], f"cam{vid}", scan["time"])

            img = Image.open(img_filename).convert("RGB") # (1920, 1208)
            W, H = img.size
            if img.size != (w, h):
                img = img.resize((w, h), Image.BICUBIC)
            img = self.transform(img)
            imgs += [img]

            index_mat = self.remap[vid]

            intrinsic = self.intrinsics[index_mat].copy()
            intrinsic[0] = intrinsic[0] * w / W
            intrinsic[1] = intrinsic[1] * h / H
            intrinsics.append(intrinsic)

            if self.to_calculate_consistency and ori_vid >= 4:
                i_l, i_r, i_m = self.remap[0], self.remap[2], self.remap[1]
                c2w_l, c2w_r, c2w_m = self.cam2worlds[i_l], self.cam2worlds[i_r], self.cam2worlds[i_m]
                if ori_vid == 4:
                    o_new = (c2w_l[:3, -1] + c2w_m[:3, -1]) / 2
                elif ori_vid == 5:
                    o_new = (c2w_m[:3, -1] + c2w_r[:3, -1]) / 2
                c2w = c2w_m.copy()
                c2w[:3, -1] = o_new
                w2c = np.linalg.inv(c2w)
                w2cs.append(w2c)
                c2ws.append(c2w)
            else:
                w2c = self.world2cams[index_mat]
                w2cs.append(w2c)
                c2ws.append(self.cam2worlds[index_mat])

            aff = []
            aff_inv = []
            for l in range(3):
                proj_mat_l = np.eye(4)
                intrinsic_temp = intrinsic.copy()
                intrinsic_temp[:2] = intrinsic_temp[:2] / (2**l)
                proj_mat_l[:3, :4] = intrinsic_temp @ w2c[:3, :4]
                aff.append(proj_mat_l.copy())
                aff_inv.append(np.linalg.inv(proj_mat_l))
            aff = np.stack(aff, axis=-1)
            aff_inv = np.stack(aff_inv, axis=-1)

            affine_mats.append(aff)
            affine_mats_inv.append(aff_inv)

            near_fars.append(self.near_fars[index_mat])

            depths_h.append(np.zeros([h, w]))
            depths.append(np.zeros([h // 4, w // 4]))
            depths_aug.append(np.zeros([h // 4, w // 4]))


        imgs = np.stack(imgs)
        depths, depths_h, depths_aug = np.stack(depths), np.stack(depths_h), np.stack(depths_aug)
        affine_mats, affine_mats_inv = np.stack(affine_mats), np.stack(affine_mats_inv)
        intrinsics = np.stack(intrinsics)
        w2cs = np.stack(w2cs)
        c2ws = np.stack(c2ws)
        near_fars = np.stack(near_fars)

        closest_idxs = []
        must_select_2 = True if self.nb_views == 2 else False
        for pose in c2ws[:-1]:
            closest_idxs.append(
                get_nearest_pose_ids(
                    pose,
                    ref_poses=c2ws[:-1],
                    num_select=close_views,
                    must_select_2=must_select_2,
                    angular_dist_method="dist",
                )
            )
        closest_idxs.append(
            get_nearest_pose_ids(
                c2ws[-1],
                ref_poses=c2ws[:],
                num_select=len(c2ws[:-1])-1,
                must_select_2=must_select_2,
                angular_dist_method="dist",
            )
        )
        closest_idxs = np.stack(closest_idxs, axis=0)

        second_closest_idxs = []

        assert imgs.shape[0] == self.nb_views+1
        assert w2cs.shape[0] == self.nb_views+1
        assert intrinsics.shape[0] == self.nb_views+1

        sample = {}
        sample["images"] = imgs
        sample["depths"] = depths
        sample["depths_h"] = depths_h
        sample["depths_aug"] = depths_aug
        sample["w2cs"] = w2cs
        sample["c2ws"] = c2ws
        sample["near_fars"] = near_fars
        sample["intrinsics"] = intrinsics
        sample["affine_mats"] = affine_mats
        sample["affine_mats_inv"] = affine_mats_inv
        sample["closest_idxs"] = closest_idxs
        sample["second_closest_idxs"] = second_closest_idxs
        sample["view_ids"] = np.array(view_ids)
        if self.need_style_img:
            if self.split == 'train':
                style_path = self.metas_style[idx]
            else:
                if self.cam_diff_weather == 'colmap':
                    scene = scan["date"][0]+'_'+scan["time"][0]
                else:
                    scene = scan["date"]+'_'+scan["time"]
                style_path = self.metas_style[scene]
            img = Image.open(style_path).convert("RGB") # (1920, 1208)
            W, H = img.size
            if img.size != (w, h):
                img = img.resize((w, h), Image.BICUBIC)
            style_img = self.transform(img)
            sample["style_img"] = style_img

            if self.need_style_label:
                if self.split == 'train':
                    sample["style_img_label"] = self.metas_style_label[idx]
                else:
                    if self.cam_diff_weather == 'colmap':
                        scene = scan["date"][0]+'_'+scan["time"][0]
                    else:
                        scene = scan["date"]+'_'+scan["time"]
                    sample["style_img_label"] = self.metas_style_label[scene]
                sample["input_img_label"] = self.metas_label[idx]

        return sample
