import logging
import os

import numpy as np

from .CountingDataset import CountingDataset
from ..DatasetSplits import DatasetSplits

LOGGER = logging.getLogger()


class CARPKDataset(CountingDataset):

    def __init__(self, base_dir, plot_dir=None, in_memory=True, split=DatasetSplits.TEST, **kwargs):
        super().__init__(base_dir, 5, split=split, img_sub_dir='Images', plot_dir=plot_dir, in_memory=in_memory, **kwargs)

    def _get_file_names(self):
        if self.split != DatasetSplits.TRAIN and self.split != DatasetSplits.TEST:
            raise ValueError(f'Invalid split: {self.split}')
        with open(os.path.join(self.root, 'ImageSets', f'{self.split.name.lower()}.txt'), 'r') as f:
            return [l[:-1] + '.png' for l in f.readlines()]

    def _get_labels(self):
        self.annotations = {}
        labels = {}
        for im_id in self.file_names:
            with open(os.path.join(self.root, 'Annotations', f'{im_id.rsplit(".", 1)[0]}.txt'), 'r') as f:
                img_ann = {'box_examples_coordinates': [], 'points': []}
                for obj in f.readlines():
                    split_obj = obj.split(' ')
                    box = np.array(split_obj[:4], dtype=int).reshape(2, 2)
                    img_ann['box_examples_coordinates'].append(box)
                    img_ann['points'].append(box[0] + box[1] / 2)
                img_ann['box_examples_coordinates'] = np.stack(img_ann['box_examples_coordinates'])
                img_ann['points'] = np.stack(img_ann['points'])

                self.annotations[im_id] = img_ann
                labels[im_id] = len(img_ann['points'])
        return labels

    def __getitem__(self, item):
        raise NotImplementedError(f'This ')
