# encoding: utf-8
"""
@author:  sherlock
@contact: sherlockliao01@gmail.com
"""

import copy
import logging
import os
import pdb

from tabulate import tabulate
from termcolor import colored

logger = logging.getLogger(__name__)


class Dataset(object):
    """An abstract class representing a Dataset.
    This is the base class for ``ImageDataset`` and ``VideoDataset``.

    Args:
        train (list): contains tuples of (img_path(s), pid, camid).
        query (list): contains tuples of (img_path(s), pid, camid).
        gallery (list): contains tuples of (img_path(s), pid, camid).
        transform: transform function.
        mode (str): 'train', 'query' or 'gallery'.
        combineall (bool): combines train, query and gallery in a
            dataset for training.
        verbose (bool): show information.
    """
    _junk_pids = []  # contains useless person IDs, e.g. background, false detections

    def __init__(self, train, query, gallery, transform=None, mode='train',
                 combineall=False, verbose=True, **kwargs):
        self.train = train
        self.query = query
        self.gallery = gallery
        self.transform = transform
        self.mode = mode
        self.combineall = combineall
        self.verbose = verbose

        if self.mode == 'train':
            self.num_train_pids = self.get_num_pids(self.train)
            self.num_train_gids = self.get_num_gids(self.train)
            self.num_train_cams = self.get_num_cams(self.train)

        if self.combineall:
            self.combine_all()

        if self.mode == 'train':
            self.data = self.train
        elif self.mode == 'query':
            self.data = self.query
        elif self.mode == 'gallery':
            self.data = self.gallery
        else:
            raise ValueError('Invalid mode. Got {}, but expected to be '
                             'one of [train | query | gallery]'.format(self.mode))

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        return len(self.data)

    def __radd__(self, other):
        """Supports sum([dataset1, dataset2, dataset3])."""
        if other == 0:
            return self
        else:
            return self.__add__(other)

    def parse_data(self, data):
        """Parses data list and returns the number of person IDs
        and the number of camera views.
        Args:
            data (list): contains tuples of (img_path(s), gid, pid, camid)
        """
        # pdb.set_trace()
        gids = set()
        pids = set()
        cams = set()
        for info in data:
            gids.add(info[1])
            # pdb.set_trace()
            if isinstance(info[2], str): # training data
                pid_list = info[2][info[2].index('[')+1:info[2].index(']')].split(',')
                for x in pid_list:
                    if x[0] == '\'':
                        x = x[1:-1]
                    if x != '-1':
                        pids.add(x)
            if isinstance(info[2], list):
                pid_list = info[2]
                for x in pid_list:
                    if x != '-1':
                        pids.add(x)
            cams.add(info[3])
        return len(gids), len(pids), len(cams)

    def get_num_gids(self, data):
        """Returns the number of training group identities."""
        return self.parse_data(data)[0]

    def get_num_pids(self, data):
        """Returns the number of training person identities."""
        return self.parse_data(data)[1]

    def get_num_cams(self, data):
        """Returns the number of training cameras."""
        return self.parse_data(data)[2]

    def show_summary(self):
        """Shows dataset statistics."""
        pass

    def combine_all(self):
        """Combines train, query and gallery in a dataset for training."""
        combined = copy.deepcopy(self.train)

        def _combine_data(data):
            for img_path, pid, camid in data:
                if pid in self._junk_pids:
                    continue
                pid = self.dataset_name + "_test_" + str(pid)
                camid = self.dataset_name + "_test_" + str(camid)
                combined.append((img_path, pid, camid))

        _combine_data(self.query)
        _combine_data(self.gallery)

        self.train = combined
        self.num_train_pids = self.get_num_pids(self.train)

    def check_before_run(self, required_files):
        """Checks if required files exist before going deeper.
        Args:
            required_files (str or list): string file name(s).
        """
        if isinstance(required_files, str):
            required_files = [required_files]

        for fpath in required_files:
            if not os.path.exists(fpath):
                raise RuntimeError('"{}" is not found'.format(fpath))


class ImageDataset(Dataset):
    """A base class representing ImageDataset.
    All other image datasets should subclass it.
    ``__getitem__`` returns an image given index.
    It will return ``img``, ``pid``, ``camid`` and ``img_path``
    where ``img`` has shape (channel, height, width). As a result,
    data in each batch has shape (batch_size, channel, height, width).
    """

    def __init__(self, train, query, gallery, **kwargs):
        super(ImageDataset, self).__init__(train, query, gallery, **kwargs)

    def show_train(self):
        num_train_gids, num_train_pids, num_train_cams = self.parse_data(self.train)

        headers = ['subset', '# gids', '# pids', '# images', '# cameras']
        csv_results = [['train', num_train_gids, num_train_pids, len(self.train), num_train_cams]]

        # tabulate it
        table = tabulate(
            csv_results,
            tablefmt="pipe",
            headers=headers,
            numalign="left",
        )
        logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))

    def show_test(self):
        # pdb.set_trace()
        num_query_gids, num_query_pids, num_query_cams = self.parse_data(self.query)
        num_gallery_gids, num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)

        headers = ['subset', '# gids', '# pids', '# images', '# cameras']
        csv_results = [
            ['query', num_query_gids, num_query_pids, len(self.query), num_query_cams],
            ['gallery', num_gallery_gids, num_gallery_pids, len(self.gallery), num_gallery_cams],
        ]

        # tabulate it
        table = tabulate(
            csv_results,
            tablefmt="pipe",
            headers=headers,
            numalign="left",
        )
        logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
