import os
import torch
import datasets


def load_driving_stereo_dataset(data_dir, splits_dir, eval_split, height, width):
    if eval_split == "sunny" or eval_split == "foggy" or eval_split == "rainy" or eval_split == "cloudy":
        filenames = readlines(os.path.join(splits_dir, eval_split, 'test_files.txt'))
        dataset = datasets.DRIVINGSTEREO(data_dir, filenames, height, width,
                                        [0], 4, is_train=False, stereo_split=eval_split)
        return dataset

def load_waymo_da_dataset(data_dir, splits_dir, eval_split, height, width):
    tflist = './splits/waymo_da/tf_list.txt'
    with open(tflist, 'r') as f:
        filelist = f.readlines()
    data_list = []
    for file in filelist:
        data_list.append(*file.splitlines())

    dataset_list = []
    for weather in ['unknown_day', 'unknown_dusk']:
        t_file_dir = os.path.join(splits_dir, eval_split, weather)
        for data in data_list:
            file = data.split('.')[0]
            tar_folder = os.path.join(t_file_dir, file)
            if os.path.isdir(tar_folder):
                filenames = readlines(os.path.join(tar_folder, 'test_files.txt'))
                data_path = os.path.join(data_dir, weather, 'color_images', file)
                dt = datasets.WAYMO(data_path, filenames, height, width,
                                    [0], 4, is_train=False, stereo_split=eval_split)
                dataset_list.append(dt)
        dataset = torch.utils.data.ConcatDataset(dataset_list)
        dataset.K = dataset_list[0].K
    return dataset

def dataset_factory(data_name, data_dir, splits_dir, eval_split, height, width):
    dataset = {'DrivingStereo': load_driving_stereo_dataset, 'waymo_da': load_waymo_da_dataset}
    return dataset[data_name](data_dir, splits_dir, eval_split, height, width)
    
