#!/usr/bin/env python

"""
A modified visualization function and get_lanes function
"""

import argparse
import os
import shutil
import sys
from collections import defaultdict
from typing import Dict, Optional

import matplotlib.animation as anim
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.interpolate as interp

from argoverse.map_representation.map_api import ArgoverseMap

_ZORDER = {"AGENT": 15, "AV": 10, "OTHERS": 5}


def get_batch_lane_direction(pos, city, am):
    if am is None:
        from argoverse.map_representation.map_api import ArgoverseMap
        am = ArgoverseMap()
    else:
        pass
    drct_conf = list()
    
    for ps, c in zip(pos, city):
        drct_conf.append(np.array([np.append(*am.get_lane_direction(p[:2], c)) for p in ps]))
    
    return drct_conf


def get_lane_direction(pos, city, am):
    if am is None:
        from argoverse.map_representation.map_api import ArgoverseMap
        am = ArgoverseMap()
    else:
        pass
    drct_conf = np.array([np.append(*am.get_lane_direction(p[:2], city)) for p in pos])
    
    return drct_conf


def interpolate_polyline(polyline: np.ndarray, num_points: int) -> np.ndarray:
    duplicates = []
    for i in range(1, len(polyline)):
        if np.allclose(polyline[i], polyline[i - 1]):
            duplicates.append(i)
    if polyline.shape[0] - len(duplicates) < 4:
        return polyline
    if duplicates:
        polyline = np.delete(polyline, duplicates, axis=0)
    tck, u = interp.splprep(polyline.T, s=0)
    u = np.linspace(0.0, 1.0, num_points)
    return np.column_stack(interp.splev(u, tck))


def get_lanes(df: pd.DataFrame, city_name: str, avm: Optional[ArgoverseMap] = None) -> list:
    
    # Get API for Argo Dataset map
    avm = ArgoverseMap() if avm is None else avm
    seq_lane_bbox = avm.city_halluc_bbox_table[city_name]
    seq_lane_props = avm.city_lane_centerlines_dict[city_name]
    
    x_min = min(df["X"])
    x_max = max(df["X"])
    y_min = min(df["Y"])
    y_max = max(df["Y"])

    lane_centerlines = []
    
    # Get lane centerlines which lie within the range of trajectories
    for lane_id, lane_props in seq_lane_props.items():

        lane_cl = lane_props.centerline

        if (
            np.min(lane_cl[:, 0]) < x_max
            and np.min(lane_cl[:, 1]) < y_max
            and np.max(lane_cl[:, 0]) > x_min
            and np.max(lane_cl[:, 1]) > y_min
        ):
            lane_centerlines.append(lane_cl)
    
    return lane_centerlines


def get_all_lanes(city_name: str, avm: Optional[ArgoverseMap] = None) -> list:
    
    # Get API for Argo Dataset map
    avm = ArgoverseMap() if avm is None else avm
    seq_lane_bbox = avm.city_halluc_bbox_table[city_name]
    seq_lane_props = avm.city_lane_centerlines_dict[city_name]

    lane_centerlines = [lane.centerline for lane in seq_lane_props.values()]
    
    return lane_centerlines

    
def visualize_trajectory(
    df: pd.DataFrame, lane_centerlines: Optional[np.ndarray] = None, show: bool = True, smoothen: bool = False
) -> None:

    # Seq data
    # time_list = np.sort(np.unique(df["TIMESTAMP"].values))
    city_name = df["CITY_NAME"].values[0]

    lane_centerlines = get_lanes(df, city_name) if lane_centerlines is None else lane_centerlines

    plt.figure(0, figsize=(8, 7))

    x_min = min(df["X"])
    x_max = max(df["X"])
    y_min = min(df["Y"])
    y_max = max(df["Y"])
    
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)


    for lane_cl in lane_centerlines:
        plt.plot(lane_cl[:, 0], lane_cl[:, 1], "--", color="grey", alpha=1, linewidth=1, zorder=0)
    frames = df.groupby("TRACK_ID")

    plt.xlabel("Map X")
    plt.ylabel("Map Y")

    color_dict = {"AGENT": "#d33e4c", "OTHERS": "#13d4f2", "AV": "#007672"}
    object_type_tracker: Dict[int, int] = defaultdict(int)

    # Plot all the tracks up till current frame
    for group_name, group_data in frames:
        object_type = group_data["OBJECT_TYPE"].values[0]

        cor_x = group_data["X"].values
        cor_y = group_data["Y"].values

        if smoothen:
            polyline = np.column_stack((cor_x, cor_y))
            num_points = cor_x.shape[0] * 3
            smooth_polyline = interpolate_polyline(polyline, num_points)
            cor_x = smooth_polyline[:, 0]
            cor_y = smooth_polyline[:, 1]

        plt.plot(
            cor_x,
            cor_y,
            "-",
            color=color_dict[object_type],
            label=object_type if not object_type_tracker[object_type] else "",
            alpha=1,
            linewidth=1,
            zorder=_ZORDER[object_type],
        )

        final_x = cor_x[-1]
        final_y = cor_y[-1]

        if object_type == "AGENT":
            marker_type = "o"
            marker_size = 7
        elif object_type == "OTHERS":
            marker_type = "o"
            marker_size = 7
        elif object_type == "AV":
            marker_type = "o"
            marker_size = 7

        plt.plot(
            final_x,
            final_y,
            marker_type,
            color=color_dict[object_type],
            label=object_type if not object_type_tracker[object_type] else "",
            alpha=1,
            markersize=marker_size,
            zorder=_ZORDER[object_type],
        )

        object_type_tracker[object_type] += 1

    red_star = mlines.Line2D([], [], color="red", marker="*", linestyle="None", markersize=7, label="Agent")
    green_circle = mlines.Line2D([], [], color="green", marker="o", linestyle="None", markersize=7, label="Others")
    black_triangle = mlines.Line2D([], [], color="black", marker="^", linestyle="None", markersize=7, label="AV")

    plt.axis("off")
    
    if show:
        plt.show()

        


