from __future__ import absolute_import, division

import numpy as np

from common.camera import world_to_camera, normalize_screen_coordinates


def create_2d_data(data_path, dataset):
    keypoints = np.load(data_path, allow_pickle=True)
    keypoints = keypoints['positions_2d'].item()

    for subject in keypoints.keys():
        for action in keypoints[subject]:
            for cam_idx, kps in enumerate(keypoints[subject][action]):
                # Normalize camera frame
                cam = dataset.cameras()[subject][cam_idx]
                kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
                if 'gt' in data_path:
                    joints = [0,1,2,3,4,5,6,7,8,10,11,12,13,14,15,16]       # Remove 10th joints (Neck/Nose)
                    kps = kps[:,joints,:]
                keypoints[subject][action][cam_idx] = kps

    return keypoints


def read_3d_data(dataset):
    for subject in dataset.subjects():
        for action in dataset[subject].keys():
            anim = dataset[subject][action]

            positions_3d = []
            for cam in anim['cameras']:
                pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation'])
                # pos_3d[:, :] -= pos_3d[:, :1]  # keep this, remove at model training.
                positions_3d.append(pos_3d)
            anim['positions_3d'] = positions_3d

    return dataset


def fetch(subjects, dataset, keypoints, action_filter=None, stride=1, parse_3d_poses=True):
    out_poses_3d = []
    out_poses_2d = []
    out_actions = []
    out_cam = []

    tot_num = 0
    for subject in subjects:
        for action in keypoints[subject].keys():
            if action_filter is not None:
                found = False
                for a in action_filter:
                    # if action.startswith(a):
                    if action.split(' ')[0] == a:
                        found = True
                        break
                if not found:
                    continue

            poses_2d = keypoints[subject][action]            
            for i in range(len(poses_2d)):  # Iterate across cameras
                tot_num += poses_2d[i].shape[0]
                out_poses_2d.append(poses_2d[i])
                out_actions.append([action.split(' ')[0]] * poses_2d[i].shape[0])

            if parse_3d_poses and 'positions_3d' in dataset[subject][action]:
                poses_3d = dataset[subject][action]['positions_3d']
                assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
                for i in range(len(poses_3d)):  # Iterate across cameras
                    out_poses_3d.append(poses_3d[i])
                    cam = dataset[subject][action]['cameras'][i]['intrinsic']
                    out_cam.append([cam] * poses_3d[i].shape[0])

    if len(out_poses_3d) == 0:
        out_poses_3d = None

    if stride > 1:
        # Downsample as requested
        for i in range(len(out_poses_2d)):
            out_poses_2d[i] = out_poses_2d[i][::stride]
            out_actions[i] = out_actions[i][::stride]
            if out_poses_3d is not None:
                out_poses_3d[i] = out_poses_3d[i][::stride]

    return out_poses_3d, out_poses_2d, out_actions, out_cam

def fetch_tr_val(subjects, dataset, keypoints, action_filter=None, stride=1, parse_3d_poses=True):
    out_poses_3d = []
    out_poses_2d = []
    out_actions = []
    out_cam = []

    out_poses_3d_val = []
    out_poses_2d_val = []
    out_actions_val = []
    out_cam_val = []

    for subject in subjects:
        for action in keypoints[subject].keys():
            if action_filter is not None:
                found = False
                for a in action_filter:
                    # if action.startswith(a):
                    if action.split(' ')[0] == a:
                        found = True
                        break
                if not found:
                    continue

            poses_2d = keypoints[subject][action]
            for i in range(len(poses_2d)):  # Iterate across cameras
                tr_size = int(poses_2d[i].shape[0] * 0.9)
                val_size = poses_2d[i].shape[0] - tr_size

                out_poses_2d.append(poses_2d[i][:tr_size])
                out_poses_2d_val.append(poses_2d[i][tr_size:])
                out_actions.append([action.split(' ')[0]] * tr_size)
                out_actions_val.append([action.split(' ')[0]] * val_size)

            if parse_3d_poses and 'positions_3d' in dataset[subject][action]:
                poses_3d = dataset[subject][action]['positions_3d']
                assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
                for i in range(len(poses_3d)):  # Iterate across cameras
                    tr_size = int(poses_2d[i].shape[0] * 0.9)
                    val_size = poses_2d[i].shape[0] - tr_size

                    out_poses_3d.append(poses_3d[i][:tr_size])
                    out_poses_3d_val.append(poses_3d[i][tr_size:])                    
                    cam = dataset[subject][action]['cameras'][i]['intrinsic']
                    out_cam.append([cam] * tr_size)
                    out_cam_val.append([cam] * val_size)

    if len(out_poses_3d) == 0:
        out_poses_3d = None

    if stride > 1:
        # Downsample as requested
        for i in range(len(out_poses_2d)):
            out_poses_2d[i] = out_poses_2d[i][::stride]
            out_actions[i] = out_actions[i][::stride]
            if out_poses_3d is not None:
                out_poses_3d[i] = out_poses_3d[i][::stride]

    return out_poses_3d, out_poses_2d, out_actions, out_cam, out_poses_3d_val, out_poses_2d_val, out_actions_val, out_cam_val
