# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

import torch
from torch.utils.data import ConcatDataset, WeightedRandomSampler
import numpy as np

from data.llff import LLFF_Dataset
from data.dtu import DTU_Dataset
from data.nerf import NeRF_Dataset
from data.lf_data import LF_Dataset
from data.ithaca import ithaca_Dataset
from data.time_lapse import timeLapse_Dataset
from data.tt import tt_Dataset
# from data.tt_colmap import tt_Dataset
from data.photoTourism import photoTourism_Dataset
from data.waymo_ref import waymo_Dataset

def get_training_dataset(args, downsample=1.0, only_dtu=False, only_llff=False, only_ithaca=False, use_far_view=False, llffdownsample=1.0, ithaca_all=None):
    if only_dtu:
        train_datasets = [
            DTU_Dataset(
                original_root_dir=args.dtu_path,
                preprocessed_root_dir=args.dtu_pre_path,
                split="train",
                max_len=-1,
                downSample=downsample,
                nb_views=args.nb_views,
                use_far_view=use_far_view
            )
            # LLFF_Dataset(
            #     root_dir=args.llff_path,
            #     split="train",
            #     max_len=-1,
            #     downSample=downsample,
            #     nb_views=args.nb_views,
            #     imgs_folder_name="images_4",
            # )
        ]

        weights = [1.0]

        train_weights_samples = []
        for dataset, weight in zip(train_datasets, weights):
            num_samples = len(dataset)
            weight_each_sample = weight / num_samples
            train_weights_samples.extend([weight_each_sample] * num_samples)

        train_dataset = ConcatDataset(train_datasets)
        train_weights = torch.from_numpy(np.array(train_weights_samples))
        train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
    
    elif only_llff:
        train_datasets = [
            LLFF_Dataset(
                root_dir=args.ibrnet1_path,
                split="train",
                max_len=-1,
                downSample=downsample*llffdownsample,
                nb_views=args.nb_views,
                imgs_folder_name="images",
                use_far_view=use_far_view
            ),
            LLFF_Dataset(
                root_dir=args.ibrnet2_path,
                split="train",
                max_len=-1,
                downSample=downsample*llffdownsample,
                nb_views=args.nb_views,
                imgs_folder_name="images",
                use_far_view=use_far_view
            ),
            LLFF_Dataset(
                root_dir=args.llff_path,
                split="train",
                max_len=-1,
                downSample=downsample*llffdownsample,
                nb_views=args.nb_views,
                imgs_folder_name="images_4",
                use_far_view=use_far_view
            ),
        ]
        weights = [0.22*2, 0.12*2, 0.16*2]

        train_weights_samples = []
        for dataset, weight in zip(train_datasets, weights):
            num_samples = len(dataset)
            weight_each_sample = weight / num_samples
            train_weights_samples.extend([weight_each_sample] * num_samples)

        train_dataset = ConcatDataset(train_datasets)
        train_weights = torch.from_numpy(np.array(train_weights_samples))
        train_sampler = WeightedRandomSampler(train_weights, len(train_weights))

    elif only_ithaca:
        train_datasets = [
            ithaca_Dataset(
                root_dir=args.ithaca_path,
                table_root_dir=args.ithaca_label_path,
                split="train",
                max_len=-1,
                downSample=downsample,
                nb_views=args.nb_views,
                use_far_view=use_far_view,
                ithaca_all=ithaca_all,
                use_two_cam=args.ithaca_use_two_cams,
                need_style_img=args.geonerfMDMM,
                need_style_label=args.geonerfMDMM,
                src_specify=args.src_specify,
                ref_specify=args.ref_specify,
                cam_diff_weather=args.cam_diff_weather,
                read_lidar=args.read_lidar,
                camfile=args.camfile,
                input_phi_to_test=args.input_phi_to_test,
                style_dataset=args.style_dataset,
                specify_file=args.specify_file,
                n_output_views=args.n_output_views,
                pretrain_dataset=args.pretrain_dataset,
                to_calculate_consistency=args.to_calculate_consistency,
                update_z=args.update_z,
                styleSame=args.styleSame,
                to_calculate_FID=args.to_calculate_FID,
            )
        ]

        weights = [1.0]

        train_weights_samples = []
        for dataset, weight in zip(train_datasets, weights):
            num_samples = len(dataset)
            weight_each_sample = weight / num_samples
            train_weights_samples.extend([weight_each_sample] * num_samples)

        train_dataset = ConcatDataset(train_datasets)
        train_weights = torch.from_numpy(np.array(train_weights_samples))
        train_sampler = WeightedRandomSampler(train_weights, len(train_weights))
    else:
        train_datasets = [
            DTU_Dataset(
                original_root_dir=args.dtu_path,
                preprocessed_root_dir=args.dtu_pre_path,
                split="train",
                max_len=-1,
                downSample=downsample,
                nb_views=args.nb_views,
                use_far_view=use_far_view
            ),
            LLFF_Dataset(
                root_dir=args.ibrnet1_path,
                split="train",
                max_len=-1,
                downSample=downsample*llffdownsample,
                nb_views=args.nb_views,
                imgs_folder_name="images",
                use_far_view=use_far_view
            ),
            LLFF_Dataset(
                root_dir=args.ibrnet2_path,
                split="train",
                max_len=-1,
                downSample=downsample*llffdownsample,
                nb_views=args.nb_views,
                imgs_folder_name="images",
                use_far_view=use_far_view
            ),
            LLFF_Dataset(
                root_dir=args.llff_path,
                split="train",
                max_len=-1,
                downSample=downsample*llffdownsample,
                nb_views=args.nb_views,
                imgs_folder_name="images_4",
                use_far_view=use_far_view
            ),
        ]
        weights = [0.5, 0.22, 0.12, 0.16]

        train_weights_samples = []
        for dataset, weight in zip(train_datasets, weights):
            num_samples = len(dataset)
            weight_each_sample = weight / num_samples
            train_weights_samples.extend([weight_each_sample] * num_samples)

        train_dataset = ConcatDataset(train_datasets)
        train_weights = torch.from_numpy(np.array(train_weights_samples))
        train_sampler = WeightedRandomSampler(train_weights, len(train_weights))

        # train_dataset = LF_Dataset(
        #                     root_dir=args.lf_path,
        #                     split="train",
        #                     max_len=-1,
        #                     downSample=downsample,
        #                     nb_views=args.nb_views,
        #                     scene=args.scene,
        #                 )
        # train_sampler = None
    

    return train_dataset, train_sampler


def get_finetuning_dataset(args, downsample=1.0, use_far_view=False, ithaca_all=None):
    if args.dataset_name == "dtu":
        train_dataset = DTU_Dataset(
            original_root_dir=args.dtu_path,
            preprocessed_root_dir=args.dtu_pre_path,
            split="train",
            max_len=-1,
            downSample=downsample,
            nb_views=args.nb_views,
            scene=args.scene,
            use_far_view=use_far_view,
        )
    elif args.dataset_name == "llff":
        train_dataset = LLFF_Dataset(
            root_dir=args.llff_test_path if not args.llff_test_path is None else args.llff_path,
            split="train",
            max_len=-1,
            downSample=downsample,
            nb_views=args.nb_views,
            scene=args.scene,
            imgs_folder_name="images",
            use_far_view=use_far_view,
        )
    elif args.dataset_name == "nerf":
        if args.scene != "None":
            train_dataset = NeRF_Dataset(
                root_dir=args.nerf_path,
                split="train",
                max_len=-1,
                downSample=downsample,
                nb_views=args.nb_views,
                scene=args.scene,
                use_far_view=use_far_view,
            )
        else:
            nerf_scenes = ["chair", "drums", "ficus", "hotdog", "lego", "materials", "mic", "ship"]
            train_datasets = []
            for scene in nerf_scenes:
                train_datasets.append(
                    NeRF_Dataset(
                        root_dir=args.nerf_path,
                        split="train",
                        max_len=-1,
                        downSample=downsample,
                        nb_views=args.nb_views,
                        scene=scene,
                        use_far_view=use_far_view,
                    )
                )
            train_dataset = ConcatDataset(train_datasets)
    elif args.dataset_name == "lf_data":
        train_dataset = LF_Dataset(
            root_dir=args.lf_path,
            split="train",
            max_len=-1,
            downSample=downsample,
            nb_views=args.nb_views,
            scene=args.scene,
        )
    elif args.dataset_name == 'ithaca':
        train_dataset = ithaca_Dataset(
            root_dir=args.ithaca_path,
            table_root_dir=args.ithaca_label_path,
            split="train",
            max_len=-1,
            downSample=downsample,
            nb_views=args.nb_views,
            use_far_view=use_far_view,
            ithaca_all=ithaca_all,
            use_two_cam=args.ithaca_use_two_cams,
            need_style_img=args.geonerfMDMM,
            need_style_label=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            cam_diff_weather=args.cam_diff_weather,
            read_lidar=args.read_lidar,
            camfile=args.camfile,
            input_phi_to_test=args.input_phi_to_test,
            style_dataset=args.style_dataset,
            specify_file=args.specify_file,
            n_output_views=args.n_output_views,
            pretrain_dataset=args.pretrain_dataset,
            to_calculate_consistency=args.to_calculate_consistency,
            update_z=args.update_z,
            styleSame=args.styleSame,
            to_calculate_FID=args.to_calculate_FID,
        )
    elif args.dataset_name == 'timeLapse':
        train_dataset = timeLapse_Dataset(
            root_dir=args.timeLapse_path,
            split="train",
            max_len=-1,
            downSample=downsample,
            nb_views=args.nb_views,
            need_style_img=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            input_phi_to_test=args.input_phi_to_test,
            style_dataset=args.style_dataset,
            save_video_frames=args.save_video_frames,
        )
    elif args.dataset_name == 'tt':
        train_dataset = tt_Dataset(
            root_dir=args.tt_path,
            split="train",
            max_len=-1,
            scene=args.scene,
            downSample=downsample,
            nb_views=args.nb_views,
            use_far_view=use_far_view,
            need_style_img=args.geonerfMDMM,
            need_style_label=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            style_dataset=args.style_dataset,
            n_output_views=args.n_output_views,
            input_phi_to_test=args.input_phi_to_test,
            to_calculate_consistency=args.to_calculate_consistency,
            far_consistency=args.far_consistency,
        )
    elif args.dataset_name == 'photoTourism':
        train_dataset = photoTourism_Dataset(
            root_dir=args.photoTourism_path,
            split="train",
            max_len=-1,
            scene=args.scene,
            downSample=downsample,
            nb_views=args.nb_views,
            use_far_view=use_far_view,
            need_style_img=args.geonerfMDMM,
            need_style_label=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            style_dataset=args.style_dataset,
            n_output_views=args.n_output_views,
            input_phi_to_test=args.input_phi_to_test,
            to_calculate_consistency=args.to_calculate_consistency,
            far_consistency=args.far_consistency,
        )
    elif args.dataset_name == 'waymo_ref':
        train_dataset = waymo_Dataset(
            split="val",
            max_len=-1,
            scene=args.scene,
            downSample=downsample,
        )

    train_sampler = None

    return train_dataset, train_sampler

def get_validation_dataset(args, downsample=1.0, use_far_view=False, ithaca_all=None):
    if args.scene == "None":
        if args.finetune or args.eval:
            max_len = -1
        else:
            max_len = 2
    else:
        max_len = -1

    if "train" in args.dataset_name:
        if args.only_ithaca:
            val_dataset = ithaca_Dataset(
                root_dir=args.ithaca_path,
                table_root_dir=args.ithaca_label_path,
                split="val",
                max_len=max_len,
                downSample=downsample,
                nb_views=args.nb_views,
                use_far_view=use_far_view,
                ithaca_all=ithaca_all,
                use_two_cam=args.ithaca_use_two_cams,
                need_style_img=args.geonerfMDMM,
                need_style_label=args.geonerfMDMM,
                src_specify=args.src_specify,
                ref_specify=args.ref_specify,
                cam_diff_weather=args.cam_diff_weather,
                read_lidar=args.read_lidar,
                camfile=args.camfile,
                input_phi_to_test=args.input_phi_to_test,
                style_dataset=args.style_dataset,
                specify_file=args.specify_file,
                n_output_views=args.n_output_views,
                pretrain_dataset=args.pretrain_dataset,
                to_calculate_consistency=args.to_calculate_consistency,
                update_z=args.update_z,
                styleSame=args.styleSame,
                to_calculate_FID=args.to_calculate_FID,
            )
        elif args.only_llff:
            val_dataset = LLFF_Dataset(
                root_dir=args.llff_test_path if not args.llff_test_path is None else args.llff_path,
                split="val",
                max_len=max_len,
                downSample=downsample,
                nb_views=args.nb_views,
                scene=args.scene,
                imgs_folder_name="images",
                use_far_view=use_far_view,
            )
        else:
            val_dataset = DTU_Dataset(
                original_root_dir=args.dtu_path,
                preprocessed_root_dir=args.dtu_pre_path,
                split="val",
                max_len=max_len,
                downSample=downsample,
                nb_views=args.nb_views,
                scene=args.scene,
                use_far_view=use_far_view,
            )

    if args.dataset_name == "dtu": 
        val_dataset = DTU_Dataset(
            original_root_dir=args.dtu_path,
            preprocessed_root_dir=args.dtu_pre_path,
            split="val",
            max_len=max_len,
            downSample=downsample,
            nb_views=args.nb_views,
            scene=args.scene,
            use_far_view=use_far_view,
        )
    elif args.dataset_name == "llff":
        val_dataset = LLFF_Dataset(
            root_dir=args.llff_test_path if not args.llff_test_path is None else args.llff_path,
            split="val",
            max_len=max_len,
            downSample=downsample,
            nb_views=args.nb_views,
            scene=args.scene,
            imgs_folder_name="images",
            use_far_view=use_far_view,
            need_style_img=args.geonerfMDMM,
            need_style_label=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            input_phi_to_test=args.input_phi_to_test,
        )
    elif args.dataset_name == "nerf":
        if args.scene != "None":
            val_dataset = NeRF_Dataset(
                root_dir=args.nerf_path,
                split="val",
                max_len=max_len,
                downSample=downsample,
                nb_views=args.nb_views,
                scene=args.scene,
                use_far_view=use_far_view,
            )
        else:
            nerf_scenes = ["chair", "drums", "ficus", "hotdog", "lego", "materials", "mic", "ship"]
            val_datasets = []
            for scene in nerf_scenes:
                val_datasets.append(
                    NeRF_Dataset(
                        root_dir=args.nerf_path,
                        split="val",
                        max_len=max_len,
                        downSample=downsample,
                        nb_views=args.nb_views,
                        scene=scene,
                        use_far_view=use_far_view,
                    )
                )
            val_dataset = ConcatDataset(val_datasets)

    elif args.dataset_name == "lf_data":
        val_dataset = LF_Dataset(
            root_dir=args.lf_path,
            split="val",
            max_len=max_len,
            downSample=downsample,
            nb_views=args.nb_views,
            scene=args.scene,
        )
    elif args.dataset_name == 'ithaca':
        val_dataset = ithaca_Dataset(
            root_dir=args.ithaca_path,
            table_root_dir=args.ithaca_label_path,
            split="val",
            max_len=max_len,
            downSample=downsample,
            nb_views=args.nb_views,
            use_far_view=use_far_view,
            ithaca_all=ithaca_all,
            use_two_cam=args.ithaca_use_two_cams,
            need_style_img=args.geonerfMDMM,
            need_style_label=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            cam_diff_weather=args.cam_diff_weather,
            read_lidar=args.read_lidar,
            camfile=args.camfile,
            input_phi_to_test=args.input_phi_to_test,
            style_dataset=args.style_dataset,
            specify_file=args.specify_file,
            n_output_views=args.n_output_views,
            pretrain_dataset=args.pretrain_dataset,
            to_calculate_consistency=args.to_calculate_consistency,
            update_z=args.update_z,
            styleSame=args.styleSame,
            to_calculate_FID=args.to_calculate_FID,
        )
    elif args.dataset_name == 'timeLapse':
        val_dataset = timeLapse_Dataset(
            root_dir=args.timeLapse_path,
            split="val",
            max_len=max_len,
            downSample=downsample,
            nb_views=args.nb_views,
            need_style_img=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            input_phi_to_test=args.input_phi_to_test,
            style_dataset=args.style_dataset,
            save_video_frames=args.save_video_frames,
        )
    elif args.dataset_name == 'tt':
        val_dataset = tt_Dataset(
            root_dir=args.tt_path,
            split="val",
            max_len=max_len,
            scene=args.scene,
            downSample=downsample,
            nb_views=args.nb_views,
            use_far_view=use_far_view,
            need_style_img=args.geonerfMDMM,
            need_style_label=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            style_dataset=args.style_dataset,
            n_output_views=args.n_output_views,
            input_phi_to_test=args.input_phi_to_test,
            to_calculate_consistency=args.to_calculate_consistency,
            far_consistency=args.far_consistency,
        )
    elif args.dataset_name == 'photoTourism':
        val_dataset = photoTourism_Dataset(
            root_dir=args.photoTourism_path,
            split="val",
            max_len=max_len,
            scene=args.scene,
            downSample=downsample,
            nb_views=args.nb_views,
            use_far_view=use_far_view,
            need_style_img=args.geonerfMDMM,
            need_style_label=args.geonerfMDMM,
            src_specify=args.src_specify,
            ref_specify=args.ref_specify,
            style_dataset=args.style_dataset,
            n_output_views=args.n_output_views,
            input_phi_to_test=args.input_phi_to_test,
            to_calculate_consistency=args.to_calculate_consistency,
            far_consistency=args.far_consistency,
        )
    elif args.dataset_name == 'waymo_ref':
        val_dataset = waymo_Dataset(
            split="val",
            max_len=max_len,
            scene=args.scene,
            downSample=downsample,
        )

    return val_dataset
