# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import random
import os
import json_tricks as json
from collections import OrderedDict

import cv2
import numpy as np
import torch
from scipy.io import loadmat, savemat
from skimage import io

from utils.transforms import get_affine_transform
from utils.transforms import affine_transform
from utils.transforms import fliplr_joints

from dataset.JointsDataset import JointsDataset


logger = logging.getLogger(__name__)


class PennActionDatasetSeg(JointsDataset):
    def __init__(self, root):
        self.root = root
        self.num_videos = 2326

        self.db = self._get_db()
        self.length = len(self.db)  # the real dataset length used in ttt

        logger.info('=> load {} samples'.format(len(self.db)))

    def _get_db(self):
        
        frame_path = os.path.join(self.root, 'frames')
        label_path = os.path.join(self.root, 'labels')

        gt_db = []
        for i in range(self.num_videos):

            label = loadmat(os.path.join(
                label_path, '{:04d}.mat'.format(i + 1)))

            nframes = label['nframes'].item()
            for j in range(nframes):
                is_last_frame = j == nframes - 1

                image_name = os.path.join(frame_path,
                    '{:04d}'.format(i + 1), '{:06d}.jpg'.format(j + 1))
                
                # there are mistakes in dataset, two videos are missing bbox in last frame
                if j >= label['bbox'].shape[0]:
                    bbox = label['bbox'][-1].astype(np.float)
                else:
                    bbox = label['bbox'][j].astype(np.float)
                gt_db.append(
                    {
                        'image': image_name,
                        'bbox': bbox,
                        'frame_i': j,
                        'is_last_frame': is_last_frame,
                    }
                )
        return gt_db


    def __len__(self):
        return len(self.db)

    def __getitem__(self, idx):
        db_rec = copy.deepcopy(self.db[idx])

        # do batching
        frame_i = db_rec['frame_i']
        is_last_frame = db_rec['is_last_frame']  # is_last_frame doesn't care what frame_i really is

        image_file = db_rec['image']
        bbox = db_rec['bbox']
        filename = db_rec['filename'] if 'filename' in db_rec else ''

        image = cv2.imread(image_file)
        h, w, _ = image.shape
        bbox[0] = max(bbox[0] - 50, 0)
        bbox[1] = max(bbox[1] - 50, 0)
        bbox[2] = min(bbox[2] + 50, w)
        bbox[3] = min(bbox[3] + 50, h)

        return image_file, bbox
