import logging
import os
import json_tricks as json
from collections import OrderedDict

import numpy as np

from dataset.KeypointDataset import KeypointDataset
from utils.utils import save_object

logger = logging.getLogger(__name__)


class CategoryDataset(KeypointDataset):
    """
    Class for birds full body and body parts dataset.
    """
    def __init__(self, cfg, root, image_set, is_train, transform=None):
        super().__init__(cfg, root, image_set, is_train, transform)

        legend = cfg.DATASET.LEGEND
        assert isinstance(legend, list)
        assert len(legend) > 0

        self.dataset_joints = np.array(legend)
        self.num_joints = len(legend)
        self.db = self._get_db()
        self.use_headboxes = cfg.DATASET.HEADBOXES
        self.threshhold = cfg.TEST.PRED_THRESHOLD

#        self._update_name2label(id_container)

        file_name = os.path.join(self.root, 'annot', self.image_set+'.json')

        with open(file_name) as file:
            gt_file = json.load(file)

        self.jnt_visible = []
        self.pos_gt_src = []
        self.bbox_size = []
        self.headboxes = []

        for elem in gt_file:
            self.jnt_visible.append(elem['joints_vis'])
            self.pos_gt_src.append(elem['joints'])
            if 'bbox_max_side' in elem:
                self.bbox_size.append(elem['bbox_max_side'])
            else:
                self.bbox_size.append(max(cfg.MODEL.IMAGE_SIZE))
            if self.use_headboxes:
                self.headboxes.append(elem['headbox'])

        self.jnt_visible = np.array(self.jnt_visible)
        self.jnt_visible = self.jnt_visible.swapaxes(0,1)

        self.pos_gt_src = np.array(self.pos_gt_src)
        self.pos_gt_src = self.pos_gt_src.transpose((1,2,0))

        self.bbox_size = np.array(self.bbox_size)
        if self.use_headboxes:
            self.headboxes = np.array(self.headboxes)
            self.headboxes = self.headboxes.transpose((1,2,0))

        logger.info('=> load {} samples'.format(len(self.db)))

    def _get_db(self):
        # create train/val split
        file_name = os.path.join(
            self.root, 'annot', self.image_set+'.json'
        )
        with open(file_name) as anno_file:
            anno = json.load(anno_file)

        gt_db = []
#        id_container = set()
        for a in anno:
            image_name = a['image']

            #Read identity labels = folder names if not provided in annotations
#            if 'class' in a:
#                id_name = a['class']
#            else:
#                id_name = os.path.split(os.path.split(image_name)[0])[-1]
#            id_container.add(id_name)

            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
            joints_3d_vis = np.zeros((self.num_joints,  3), dtype=np.float)
#            if 'test' not in self.image_set:
            joints = np.array(a['joints'])
            joints[:, 0:2] = joints[:, 0:2] - 1
            joints_vis = np.array(a['joints_vis'])
            assert len(joints) == self.num_joints, \
                'joint num diff: {} vs {}'.format(len(joints),
                                                  self.num_joints)

            joints_3d[:, 0:2] = joints[:, 0:2]
            joints_3d_vis[:, 0] = joints_vis[:]
            joints_3d_vis[:, 1] = joints_vis[:]

            image_dir = 'images.zip@' if self.data_format == 'zip' else 'images'
            gt_db.append(
                {
                    'image': os.path.join(self.root, image_dir, image_name),
                    'joints_3d': joints_3d,
                    'joints_3d_vis': joints_3d_vis,
                    'filename': '',
                    'imgnum': 0,
                }
            )

        return gt_db


    def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
        preds = preds[:, :, 0:2]

        if output_dir:
            pred_file = os.path.join(output_dir, 'pred.pkl')
            save_object(preds, pred_file)

        pos_pred_src = np.transpose(preds, [1, 2, 0])

        uv_error = pos_pred_src - self.pos_gt_src
        uv_err = np.linalg.norm(uv_error, axis=1)

        if self.use_headboxes:
            headsizes = self.headboxes[1, :, :] - self.headboxes[0, :, :]
            headsizes = np.linalg.norm(headsizes, axis=0)
            headsizes *= 0.6
            scale = np.multiply(headsizes, np.ones((len(uv_err), 1)))
        else:
            scale = np.multiply(self.bbox_size, np.ones((len(uv_err), 1)))

        scaled_uv_err = np.divide(uv_err, scale)
        scaled_uv_err = np.multiply(scaled_uv_err, self.jnt_visible)
        jnt_count = np.sum(self.jnt_visible, axis=1)
        less_than_threshold = np.multiply((scaled_uv_err <= self.threshhold),
                                          self.jnt_visible)
        PCKh = np.divide(100.*np.sum(less_than_threshold, axis=1), jnt_count)

        jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)

        name_value = [(kp, PCKh[i]) for i, kp in enumerate(self.dataset_joints)]
        name_value += [ ('Mean', np.sum(PCKh * jnt_ratio)) ]
        name_value = OrderedDict(name_value)

        return name_value, name_value['Mean']
