import glob
import re
import os.path as osp
from PIL import Image
from .bases import BaseImageDataset


class VeRi(BaseImageDataset):
    """
       VeRi-776
       Reference:
       Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016.

       URL:https://vehiclereid.github.io/VeRi/

       Dataset statistics:
       # identities: 776
       # images: 37778 (train) + 1678 (query) + 11579 (gallery)
       # cameras: 20
       """

    dataset_dir = 'VeRi'

    def __init__(self, root='../Dataset', verbose=True, **kwargs):
        super(VeRi, self).__init__()
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'image_train')
        self.query_dir = osp.join(self.dataset_dir, 'image_query')
        self.gallery_dir = osp.join(self.dataset_dir, 'image_test')

        keypoint_train_file = osp.join(self.dataset_dir, 'keypoint_train.txt')
        self.keypoint_train_file = keypoint_train_file if osp.exists(keypoint_train_file) else None
        keypoint_test_file = osp.join(self.dataset_dir, 'keypoint_test.txt')
        self.keypoint_test_file = keypoint_test_file if osp.exists(keypoint_test_file) else None

        self._check_before_run()

        train = self._process_dir(self.train_dir, relabel=True, keypoint=self.keypoint_train_file)
        query = self._process_dir(self.query_dir, relabel=False)
        gallery = self._process_dir(self.gallery_dir, relabel=False, keypoint=self.keypoint_test_file)

        if verbose:
            print("=> VeRi-776 loaded")
            self.print_dataset_statistics(train, query, gallery)

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir(self, dir_path, relabel=False, keypoint=None):
        img_paths = sorted(glob.glob(osp.join(dir_path, '*.jpg')))
        pattern = re.compile(r'([-\d]+)_c(\d+)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            pid_container.add(pid)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        if keypoint != None:
            fp = open(keypoint, "r")
            lines = fp.read().split('\n')

            keypoint = {}
            for i in range(len(lines)):
                line = lines[i].split(' ')
                keypoint[line[0].split('/')[-1]] = line[1:-1]

        dataset = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            assert 0 <= pid <= 776  # pid == 0 means background
            assert 1 <= camid <= 20
            camid -= 1  # index starts from 0
            if relabel: pid = pid2label[pid]
            
            img_name = img_path.split('/')[-1]
            kpt = [-1]*40
            if keypoint != None and img_name in keypoint.keys():
                width, height = Image.open(img_path).size
                line = keypoint[img_name]
                for idx, i in enumerate(line):
                    if int(i) > 0:
                        kpt[idx] = int(i)/width if (idx%2 == 0) else int(i)/height
            dataset.append((img_path, pid, kpt, camid, 1))

        return dataset

