from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os

import numpy as np
import torch
from PIL import Image
from pycocotools.coco import COCO
from torch.utils.data import Dataset


class coco(Dataset):
    def __init__(self, img_dir, anno_file, transforms=None, benchmark='BDD100K'):
        if benchmark is 'BDD100K':
            self.classes = ('__background__', "bike", "bus", "car", "motor", "person", "rider", "traffic light", "traffic sign", "train", "truck")
        else:
            self.classes = ('__background__', 'car',)

        self.img_dir = img_dir
        self.anno_file = anno_file
        self.transforms = transforms

        # load COCO API
        self.COCO = COCO(self.anno_file)

        with open(self.anno_file) as f:
            self.anno = json.load(f)

        cats = self.COCO.loadCats(self.COCO.getCatIds())

        self.num_classes = len(self.classes)
        self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
        self._class_to_coco_cat_id = dict(list(zip([c['name'] for c in cats],
                                                   self.COCO.getCatIds())))

        self.coco_cat_id_to_class_ind = dict([(self._class_to_coco_cat_id[cls],
                                               self._class_to_ind[cls])
                                              for cls in self.classes[1:]])
        
    def __len__(self):
        return len(self.anno['images'])

    def __getitem__(self, idx):
        a = self.anno['images'][idx]
        image_idx = a['id']
        file_name = a['file_name']
        img_path = os.path.join(self.img_dir, file_name)
        image = Image.open(img_path).convert('RGB')
#         image = Image.open(img_path)

        width = a['width']
        height = a['height']

        annIds = self.COCO.getAnnIds(imgIds=image_idx, iscrowd=None)
        objs = self.COCO.loadAnns(annIds)

        # Sanitize bboxes -- some are invalid
        valid_objs = []
        for obj in objs:
            x1 = np.max((0, obj['bbox'][0]))
            y1 = np.max((0, obj['bbox'][1]))
            x2 = np.min((width - 1, x1 + np.max((0, obj['bbox'][2] - 1))))
            y2 = np.min((height - 1, y1 + np.max((0, obj['bbox'][3] - 1))))
            if obj['area'] > 0 and x2 > x1 and y2 > y1:
                obj['clean_bbox'] = [x1, y1, x2, y2]
                valid_objs.append(obj)
        objs = valid_objs
        num_objs = len(objs)

        boxes = np.zeros((num_objs, 4), dtype=np.float32)
        gt_classes = np.zeros((num_objs), dtype=np.int32)

        iscrowd = []
        for ix, obj in enumerate(objs):
            cls = self.coco_cat_id_to_class_ind[obj['category_id']]
            boxes[ix, :] = obj['clean_bbox']
            gt_classes[ix] = cls
            iscrowd.append(int(obj["iscrowd"]))

        # convert everything into a torch.Tensor
        image_id = torch.tensor([image_idx])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        gt_classes = torch.as_tensor(gt_classes, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int32)

        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {"boxes": boxes, "labels": gt_classes, "image_id": image_id, "area": area, "iscrowd": iscrowd}

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))

    @property
    def class_to_coco_cat_id(self):
        return self._class_to_coco_cat_id


