# Motion Transformer (MTR): https://arxiv.org/abs/2209.13508
# Published at NeurIPS 2022
# Written by Shaoshuai Shi 
# All Rights Reserved

import argparse
import datetime
import glob
import os
import re
import time
from pathlib import Path

import numpy as np
import torch
import sys
sys.path.append('MTR')
from mtr.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
from mtr.models import model as model_utils
from mtr.utils import common_utils


def parse_config():
    cfg_file = 'MTR/tools/cfgs/waymo/mtr+100_percent_data.yaml'
    cktp = '.ckpt'

    cfg_from_yaml_file(cfg_file, cfg)

    np.random.seed(1024)

    return cfg


def transform_trajs_to_center_coords(obj_trajs, center_xyz, center_heading, heading_index, rot_vel_index=None):
        """
        Args:
            obj_trajs (num_objects, num_timestamps, num_attrs):
                first three values of num_attrs are [x, y, z] or [x, y]
            center_xyz (num_center_objects, 3 or 2): [x, y, z] or [x, y]
            center_heading (num_center_objects):
            heading_index: the index of heading angle in the num_attr-axis of obj_trajs
        """
        num_objects, num_timestamps, num_attrs = obj_trajs.shape
        num_center_objects = center_xyz.shape[0]
        assert center_xyz.shape[0] == center_heading.shape[0]
        assert center_xyz.shape[1] in [3, 2]

        obj_trajs = obj_trajs.clone().view(1, num_objects, num_timestamps, num_attrs).repeat(num_center_objects, 1, 1, 1)
        obj_trajs[:, :, :, 0:center_xyz.shape[1]] -= center_xyz[:, None, None, :]
        obj_trajs[:, :, :, 0:2] = common_utils.rotate_points_along_z(
            points=obj_trajs[:, :, :, 0:2].view(num_center_objects, -1, 2),
            angle=-center_heading
        ).view(num_center_objects, num_objects, num_timestamps, 2)

        obj_trajs[:, :, :, heading_index] -= center_heading[:, None, None]

        # rotate direction of velocity
        if rot_vel_index is not None:
            assert len(rot_vel_index) == 2
            obj_trajs[:, :, :, rot_vel_index] = common_utils.rotate_points_along_z(
                points=obj_trajs[:, :, :, rot_vel_index].view(num_center_objects, -1, 2),
                angle=-center_heading
            ).view(num_center_objects, num_objects, num_timestamps, 2)

        return obj_trajs


def generate_centered_trajs_for_agents(center_objects, obj_trajs_past, obj_types, center_indices, sdc_index, timestamps, obj_trajs_future):
    """[summary]

    Args:
        center_objects (num_center_objects, 10): [cx, cy, cz, dx, dy, dz, heading, vel_x, vel_y, valid]
        obj_trajs_past (num_objects, num_timestamps, 10): [cx, cy, cz, dx, dy, dz, heading, vel_x, vel_y, valid]
        obj_types (num_objects):
        center_indices (num_center_objects): the index of center objects in obj_trajs_past
        centered_valid_time_indices (num_center_objects), the last valid time index of center objects
        timestamps ([type]): [description]
        obj_trajs_future (num_objects, num_future_timestamps, 10): [cx, cy, cz, dx, dy, dz, heading, vel_x, vel_y, valid]
    Returns:
        ret_obj_trajs (num_center_objects, num_objects, num_timestamps, num_attrs):
        ret_obj_valid_mask (num_center_objects, num_objects, num_timestamps):
        ret_obj_trajs_future (num_center_objects, num_objects, num_timestamps_future, 4):  [x, y, vx, vy]
        ret_obj_valid_mask_future (num_center_objects, num_objects, num_timestamps_future):
    """
    # n_c, 10
    # n, t, 10
    assert obj_trajs_past.shape[-1] == 10
    assert center_objects.shape[-1] == 10
    num_center_objects = center_objects.shape[0]
    num_objects, num_timestamps, box_dim = obj_trajs_past.shape
    # transform to cpu torch tensor
    center_objects = torch.from_numpy(center_objects).float()
    obj_trajs_past = torch.from_numpy(obj_trajs_past).float()
    timestamps = torch.from_numpy(timestamps)

    # transform coordinates to the centered objects
    obj_trajs = transform_trajs_to_center_coords(
        obj_trajs=obj_trajs_past,
        center_xyz=center_objects[:, 0:3],
        center_heading=center_objects[:, 6],
        heading_index=6, rot_vel_index=[7, 8]
    )

    ## generate the attributes for each object
    object_onehot_mask = torch.zeros((num_center_objects, num_objects, num_timestamps, 5))
    object_onehot_mask[:, obj_types == 'TYPE_VEHICLE', :, 0] = 1
    object_onehot_mask[:, obj_types == 'TYPE_PEDESTRAIN', :, 1] = 1  # TODO: CHECK THIS TYPO
    object_onehot_mask[:, obj_types == 'TYPE_CYCLIST', :, 2] = 1
    object_onehot_mask[torch.arange(num_center_objects), center_indices, :, 3] = 1
    # object_onehot_mask[:, sdc_index, :, 4] = 1


    object_time_embedding = torch.zeros((num_center_objects, num_objects, num_timestamps, num_timestamps + 1))
    object_time_embedding[:, :, torch.arange(num_timestamps), torch.arange(num_timestamps)] = 1
    object_time_embedding[:, :, torch.arange(num_timestamps), -1] = timestamps

    object_heading_embedding = torch.zeros((num_center_objects, num_objects, num_timestamps, 2))
    object_heading_embedding[:, :, :, 0] = np.sin(obj_trajs[:, :, :, 6])
    object_heading_embedding[:, :, :, 1] = np.cos(obj_trajs[:, :, :, 6])

    vel = obj_trajs[:, :, :, 7:9]  # (num_centered_objects, num_objects, num_timestamps, 2)
    vel_pre = torch.roll(vel, shifts=1, dims=2)
    acce = (vel - vel_pre) / 0.1  # (num_centered_objects, num_objects, num_timestamps, 2)
    acce[:, :, 0, :] = acce[:, :, 1, :]

    ret_obj_trajs = torch.cat((
        obj_trajs[:, :, :, 0:6], 
        object_onehot_mask,
        object_time_embedding, 
        object_heading_embedding,
        obj_trajs[:, :, :, 7:9], 
        acce,
    ), dim=-1)

    ret_obj_valid_mask = obj_trajs[:, :, :, -1]  # (num_center_obejcts, num_objects, num_timestamps)  # TODO: CHECK THIS, 20220322
    ret_obj_trajs[ret_obj_valid_mask == 0] = 0

    return ret_obj_trajs.cuda(), (ret_obj_valid_mask > 0).cuda()

def compute_heading_xy(xy: np.ndarray) -> np.ndarray:

    x, y = xy[:, 0], xy[:, 1]
    N = len(x)

    dx_mid = x[2:] - x[:-2]
    dy_mid = y[2:] - y[:-2]
    heading_mid = np.arctan2(dy_mid, dx_mid)

    heading_first = np.arctan2(y[1] - y[0], x[1] - x[0])
    heading_last  = np.arctan2(y[-1] - y[-2], x[-1] - x[-2])

    heading = np.empty(N)
    heading[1:-1] = heading_mid
    heading[-1] = heading_last
    heading[0] = heading_first

    return heading


def compute_speed_xy(xy: np.ndarray, DT=0.1):

    x, y = xy[:, 0], xy[:, 1]
    N = len(x)

    vx_mid = (x[2:] - x[:-2]) / (2 * DT)
    vy_mid = (y[2:] - y[:-2]) / (2 * DT)

    vx_first = (x[1] - x[0]) / DT
    vy_first = (y[1] - y[0]) / DT

    vx_last  = (x[-1] - x[-2]) / DT
    vy_last  = (y[-1] - y[-2]) / DT

    vx = np.empty(N)
    vy = np.empty(N)

    vx[1:-1], vy[1:-1] = vx_mid, vy_mid
    vx[-1], vy[-1]     = vx_last, vy_last
    vx[0],  vy[0]      = vx_first, vy_first

    v = np.hypot(vx, vy)                     # √(vx² + vy²)

    return vx, vy, v


def deal_pred_input(pred_traj_data):
    x = pred_traj_data[:, 0]
    y = pred_traj_data[:, 1]
    z = np.zeros_like(x)
    length = np.zeros_like(x) + 4.5
    width = np.zeros_like(x) + 2.0
    height = np.zeros_like(x) + 1.8
    heading = compute_heading_xy(pred_traj_data)
    # print(f'pred heading: {heading}')
    vx, vy, _ = compute_speed_xy(pred_traj_data)
    valid = np.ones_like(x)
    obj_trajs_past = np.stack((x, y, z, length, width, height, heading, vx, vy, valid), axis=-1)
    center_objects = obj_trajs_past[-1]
    obj_types = np.array('TYPE_VEHICLE').reshape(1)
    center_indices = np.array([0]).astype(np.int32)
    sdc_index = None
    timestamps = np.arange(len(x)).astype(np.float32) * 0.1
    obj_trajs_future = None

    return center_objects[None], obj_trajs_past[None], obj_types, center_indices, sdc_index, timestamps, obj_trajs_future


def gt_2_ego(gt_xy, yaw):
    theta = torch.tensor(-yaw, dtype=torch.float32)
    gt      = torch.from_numpy(gt_xy)            # shape [T, 2]

    origin  = gt[0]
    rel_gt  = gt - origin

    R = torch.tensor([[ torch.cos(theta), torch.sin(theta)],
                    [ -torch.sin(theta),  torch.cos(theta)]])  # shape [2,2]

    gt_local = torch.matmul(rel_gt, R)             # shape [T, 2]，列 0=Z，列 1=X
    gt_local[:, [0,1]] = gt_local[:, [1, 0]]
    gt_local[:, 0] = -gt_local[:, 0]
    gt = gt_local.numpy()   # X in x, Z in y

    return gt


def deal_gt_input(gt_traj_data):
    x = gt_traj_data[:, 0]
    y = gt_traj_data[:, 1]
    z = np.zeros_like(x)
    length = np.zeros_like(x) + 4.5
    width = np.zeros_like(x) + 2.0
    height = np.zeros_like(x) + 1.8
    heading = compute_heading_xy(gt_traj_data)
    vx, vy, _ = compute_speed_xy(gt_traj_data)
    # print(f'gt heading: {heading}')
    valid = np.ones_like(x)
    obj_trajs_past = np.stack((x, y, z, length, width, height, heading, vx, vy, valid), axis=-1)
    center_objects = obj_trajs_past[-1]
    obj_types = np.array('TYPE_VEHICLE').reshape(1)
    center_indices = np.array([0]).astype(np.int32)
    sdc_index = None
    timestamps = np.arange(len(x)).astype(np.float32) * 0.1
    obj_trajs_future = None

    return center_objects[None], obj_trajs_past[None], obj_types, center_indices, sdc_index, timestamps, obj_trajs_future