# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from torch.utils.data import Dataset
from torchvision import transforms as T

import os
import cv2
import numpy as np
from PIL import Image

from utils.utils import read_pfm, get_nearest_pose_ids
import glob
import json
import pickle
from pyquaternion import Quaternion
import random

import torchvision
torchvision.set_video_backend("video_reader")

def normalize(v):
    return v / np.linalg.norm(v)


def average_poses(poses):
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(y_, z))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(z, x)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg

def center_poses(poses, blender2opencv):
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)

    # convert to homogeneous coordinate for faster computation
    # by simply adding 0, 0, 0, 1 as the last row
    pose_avg_homo[:3] = pose_avg
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)

    # (N_images, 4, 4) homogeneous coordinate
    poses_homo = np.concatenate([poses, last_row], 1)

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    # poses_centered = poses_centered @ blender2opencv
    poses_centered = np.concatenate([poses_centered[:, :, 0:1], -poses_centered[:, :, 1:2], -poses_centered[:, :, 2:3], poses_centered[:, :, 3:4]], 2)
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, np.linalg.inv(pose_avg_homo) @ blender2opencv

class timeLapse_Dataset(Dataset):
    def __init__(
        self,
        root_dir,
        split,
        nb_views,
        downSample=1.0,
        max_len=-1,
        scene="None",
        need_style_img=False,
        src_specify='all',
        ref_specify='all',
        input_phi_to_test=False,
        style_dataset='ithaca',
        save_video_frames=False,
    ):
        self.root_dir = root_dir 
        self.split = split
        self.scene = scene
        self.need_style_img = need_style_img
        self.src_specify = src_specify
        self.ref_specify = ref_specify
        self.input_phi_to_test = input_phi_to_test
        self.style_dataset = style_dataset

        self.downSample = downSample
        self.scale_factor = 1.0 / 1500
        self.interval_scale = 1.06
        self.max_len = max_len
        self.nb_views = nb_views
        # self.img_wh = (int(960 * self.downSample), int(640 * self.downSample))
        self.img_wh = (int(128*6 * self.downSample), int(128*4 * self.downSample)) #(768,512)
        
        self.blender2opencv = np.array(
            [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
        )

        if save_video_frames:
            self.video_to_frames()

        self.build_metas()
        self.define_transforms()

    def video_to_frames(self):
        import cv2
        vidcap = cv2.VideoCapture(self.root_dir)
        time_skips = float(10000)
        start_time_ms = 0#50000
        vidcap.set(cv2.CAP_PROP_POS_MSEC, start_time_ms)

        success,image = vidcap.read()
        count = 0
        dir_name = './videos_h264_imgs/' + self.root_dir.split('/')[-1].split('.')[0]
        print(self.root_dir.split('/')[-1].split('.')[0])
        os.makedirs(dir_name,exist_ok=True)

        with open(f'{dir_name}/time.txt','w') as f:
            f.write(f"start: {start_time_ms//1000}s\n")
            f.write(f"interval: {time_skips//1000}s\n")

        while success:
            # image = image[0:0+720,160:160+940]
            cv2.imwrite(f"{dir_name}/frame{count}.jpg", image)     
            vidcap.set(cv2.CAP_PROP_POS_MSEC, start_time_ms+(count*time_skips))
            success,image = vidcap.read()
            print('Read a new frame: ', success)
            count += 1
        a

    def define_transforms(self):
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def build_metas(self):

        self.scans = sorted(glob.glob(f'{self.root_dir}/*'))
        self.metas = sorted(glob.glob(f'{self.root_dir}/*'))
        if self.input_phi_to_test:
            only_one = f'{self.root_dir}/frame0.jpg'
            self.metas = [only_one for _ in range(16)]
        print(self.split,"meta",self.src_specify,len(self.metas))

        if self.need_style_img:
            metas_style, metas_style_label = [], []
            metas_style = sorted(glob.glob(f'{self.root_dir}/*')) 
            
            self.metas_style = metas_style
            if self.input_phi_to_test:
                only_one = f'{self.root_dir}/frame0.jpg'
                self.metas_style = [only_one for _ in range(16)]
            print(self.split,"meta style:",self.ref_specify,len(self.metas_style))

    def read_depth(self, sampleToken=None):
        return 

    def __len__(self):
        return len(self.metas) if self.max_len <= 0 else self.max_len

    def __getitem__(self, idx):
        if self.split == "train" and self.scene == "None":
            # noisy_factor = float(np.random.choice([1.0, 0.5], 1))
            noisy_factor = 1.0
        else:
            noisy_factor = 1.0

        affine_mats, affine_mats_inv = [], []
        imgs, depths, depths_h, depths_aug = [], [], [], []
        intrinsics, w2cs, c2ws, near_fars = [], [], [], []

        w, h = self.img_wh
        w, h = int(w * noisy_factor), int(h * noisy_factor)

        img_filename = self.metas[idx]
        img = Image.open(img_filename).convert("RGB") # (1280, 720)
        W, H = img.size
        if W>=1000 or H>=1000:
            w, h = W//2, H//2
        else:
            w, h = W, H
        if img.size != (w, h):
            img = img.resize((w, h), Image.BICUBIC)
        img = self.transform(img)
        imgs += [img]


        depths_h.append(np.zeros([h, w]))
        depths.append(np.zeros([h // 4, w // 4]))
        depths_aug.append(np.zeros([h // 4, w // 4]))


        imgs = np.stack(imgs)
        depths, depths_h, depths_aug = np.stack(depths), np.stack(depths_h), np.stack(depths_aug)

        closest_idxs = []
        second_closest_idxs = []

        sample = {}
        sample["images"] = imgs
        sample["depths"] = depths
        sample["depths_h"] = depths_h
        sample["depths_aug"] = depths_aug
        sample["w2cs"] = w2cs
        sample["c2ws"] = c2ws
        sample["near_fars"] = near_fars
        sample["intrinsics"] = intrinsics
        sample["affine_mats"] = affine_mats
        sample["affine_mats_inv"] = affine_mats_inv
        sample["closest_idxs"] = closest_idxs
        sample["second_closest_idxs"] = second_closest_idxs
        if self.need_style_img:
            style_path = self.metas_style[idx]
            img = Image.open(style_path).convert("RGB") # (1920, 1208)
            W, H = img.size
            if img.size != (w, h):
                img = img.resize((w, h), Image.BICUBIC)
            style_img = self.transform(img)
            sample["style_img"] = style_img
            sample["style_img_label"] = 0

        return sample
