# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import cv2
import numpy as np
from PIL import Image

from utils.utils import read_pfm, get_nearest_pose_ids
import glob
import json
import pickle
from pyquaternion import Quaternion
import random
from random import shuffle


def normalize(v):
    return v / np.linalg.norm(v)


def average_poses(poses):
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(y_, z))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(z, x)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg

def center_poses(poses, blender2opencv):
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)

    # convert to homogeneous coordinate for faster computation
    # by simply adding 0, 0, 0, 1 as the last row
    pose_avg_homo[:3] = pose_avg
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)

    # (N_images, 4, 4) homogeneous coordinate
    poses_homo = np.concatenate([poses, last_row], 1)

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    # poses_centered = poses_centered @ blender2opencv
    poses_centered = np.concatenate([poses_centered[:, :, 0:1], -poses_centered[:, :, 1:2], -poses_centered[:, :, 2:3], poses_centered[:, :, 3:4]], 2)
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv

class waymo_Dataset(Dataset):
    def __init__(
        self,
        root_dir=None,
        table_root_dir=None,
        split='val',
        nb_views=None,
        downSample=1.0,
        max_len=-1,
        scene="None",
        use_far_view=False,
        ithaca_all=None,
        use_two_cam=False,
        need_style_img=True,
        need_style_label=True,
        src_specify='all',
        ref_specify='all',
        cam_diff_weather=None,
        read_lidar=False,
        camfile='colmap',
        input_phi_to_test=False,
        style_dataset='waymo',
        pretrain_dataset='ithaca',
        specify_file="None",
        n_output_views=1,
        to_calculate_consistency=False,
        update_z=False,
        styleSame=False,
    ):
        self.root_dir = root_dir 
        self.table_root_dir = table_root_dir
        self.split = split
        self.scene = scene
        self.use_far_view = use_far_view
        self.use_two_cam = use_two_cam
        self.need_style_img = need_style_img
        self.need_style_label = need_style_label
        self.src_specify = src_specify
        self.ref_specify = ref_specify
        self.cam_diff_weather = cam_diff_weather
        self.camfile = camfile
        self.style_dataset = style_dataset
        self.specify_file = specify_file

        self.downSample = downSample
        self.scale_factor = 1.0 / 1500
        self.interval_scale = 1.06
        self.max_len = max_len
        self.nb_views = nb_views
        # self.img_wh = (int(960 * self.downSample), int(640 * self.downSample))
        self.img_wh = (int(128*6 * self.downSample), int(128*4 * self.downSample)) #(768,512)

        
        self.blender2opencv = np.array(
            [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
        )

        self.domain2label = {'night':0, 'sunny':1, 'rain':2, 'cloud':3, 'snow':4}

        self.build_metas()
        self.define_transforms()

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def build_metas(self):

        metas_style, metas_style_label = [], []
        if self.split == 'train':
            day_imgs_name = sorted(glob.glob('./waymo/sunny/Day/*'))
            dawn_imgs_name = sorted(glob.glob('./waymo/sunny/Dawn/Dusk/*'))
            night_imgs_name = sorted(glob.glob('./waymo/sunny/Night/*'))
        else:
            day_imgs_name = sorted(glob.glob('./waymo/val/sunny/Day/*'))
            dawn_imgs_name = sorted(glob.glob('./waymo/val/sunny/Dawn/Dusk/*'))
            night_imgs_name = sorted(glob.glob('./waymo/val/sunny/Night/*'))
        
        waymo_idx_file = f"./data/waymo/waymo_estimate_t_{self.split}.pickle"
        print(waymo_idx_file)
        if not os.path.isfile(waymo_idx_file):
            day_imgs_idx = random.sample(range(len(day_imgs_name)), 250)
            dawn_imgs_idx = random.sample(range(len(dawn_imgs_name)), 250)
            night_imgs_idx = random.sample(range(len(night_imgs_name)), 250)
                                    
            with open(waymo_idx_file, 'wb') as f:
                pickle.dump({'day':day_imgs_idx,'dawn':dawn_imgs_idx,'night':night_imgs_idx}, f)
        else:
            with open(waymo_idx_file, 'rb') as f:
                waymo_idx = pickle.load(f)
            day_imgs_idx = waymo_idx['day']
            dawn_imgs_idx = waymo_idx['dawn']
            night_imgs_idx = waymo_idx['night']
        metas_style = [day_imgs_name[x] for x in day_imgs_idx] + [dawn_imgs_name[x] for x in dawn_imgs_idx] + [night_imgs_name[x] for x in night_imgs_idx]
        metas_style_label = [self.domain2label['sunny'] for _ in range(len(day_imgs_idx))] + [self.domain2label['sunny'] for _ in range(len(dawn_imgs_idx))] + [self.domain2label['night'] for _ in range(len(night_imgs_idx))]
            
        self.metas_style = metas_style
        self.metas_style_label = metas_style_label

    def __len__(self):
        return len(self.metas_style) if self.max_len <= 0 else self.max_len

    def __getitem__(self, idx):

        sample = {}

        style_path = self.metas_style[idx]
        img = Image.open(style_path).convert("RGB") # (1920, 1208)
        w, h = self.img_wh
        if img.size != (w, h):
            img = img.resize((w, h), Image.BICUBIC)
        style_img = self.transform(img)

        sample["style_path"] = style_path
        sample["style_img"] = style_img
        sample["style_img_label"] = self.metas_style_label[idx]

        return sample
