import os
import numpy as np
import torch.utils.data as torch_data
import lib.utils.calibration as calibration
import lib.utils.kitti_utils as kitti_utils
from PIL import Image


class KittiDataset(torch_data.Dataset):
    def __init__(self, root_dir, split='train', lidar_mode=0):
        #lidar_mode 0: defaut lidar; lidar_mode 1: pseudo lidar; lidar_mode 2: pseudo lidar with depth loss
        self.split = split
        is_test = self.split == 'test'
        self.imageset_dir = os.path.join(root_dir, 'object', 'testing' if is_test else 'training')

        split_dir = os.path.join(root_dir, 'object', split + '.txt')
        self.image_idx_list = [x.strip() for x in open(split_dir).readlines()]
        self.num_sample = self.image_idx_list.__len__()
        
        if lidar_mode == 0:
            velodyne_name = 'velodyne'
        elif lidar_mode == 1:
            velodyne_name = 'velodyne_close'
        elif lidar_mode == 2:
            velodyne_name = 'pseudo-lidar_velodyne_sparse'
        elif lidar_mode == 3:
            velodyne_name = 'pseudo-lidar_velodyne_close_sparse'
        elif lidar_mode == 4:
            velodyne_name = 'pseudo-lidar_velodyne_dl_sparse'
        elif lidar_mode == 5:
            velodyne_name = 'pseudo-lidar_velodyne_dl_close_sparse'
            
        label_name = 'label_2' if lidar_mode in [0, 2, 4] else 'label_2_close'
        
        self.image_dir = os.path.join(self.imageset_dir, 'image_2')
        self.lidar_dir = os.path.join(self.imageset_dir, velodyne_name)
        self.calib_dir = os.path.join(self.imageset_dir, 'calib')
        self.label_dir = os.path.join(self.imageset_dir, label_name)
        self.plane_dir = os.path.join(self.imageset_dir, 'planes')

    def get_image(self, idx):
        assert False, 'DO NOT USE cv2 NOW, AVOID DEADLOCK'
        import cv2
        # cv2.setNumThreads(0)  # for solving deadlock when switching epoch
        img_file = os.path.join(self.image_dir, '%06d.png' % idx)
        assert os.path.exists(img_file)
        return cv2.imread(img_file)  # (H, W, 3) BGR mode

    def get_image_shape(self, idx):
        img_file = os.path.join(self.image_dir, '%06d.png' % idx)
        assert os.path.exists(img_file)
        im = Image.open(img_file)
        width, height = im.size
        return height, width, 3

    def get_lidar(self, idx):
        lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx)
        assert os.path.exists(lidar_file)
        return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)

    def get_calib(self, idx):
        calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx)
        assert os.path.exists(calib_file)
        return calibration.Calibration(calib_file)

    def get_label(self, idx):
        label_file = os.path.join(self.label_dir, '%06d.txt' % idx)
        assert os.path.exists(label_file)
        return kitti_utils.get_objects_from_label(label_file)

    def get_road_plane(self, idx):
        plane_file = os.path.join(self.plane_dir, '%06d.txt' % idx)
        with open(plane_file, 'r') as f:
            lines = f.readlines()
        lines = [float(i) for i in lines[3].split()]
        plane = np.asarray(lines)

        # Ensure normal is always facing up, this is in the rectified camera coordinate
        if plane[1] > 0:
            plane = -plane

        norm = np.linalg.norm(plane[0:3])
        plane = plane / norm
        return plane

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, item):
        raise NotImplementedError
