# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""
The matplotlib plotter implementation for all the primitive tasks (in our case: lines and
dots)
"""
from typing import Any, Callable, Dict, List

import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3

import numpy as np

from .core import BasePlotter, BasePlotterTask


class Matplotlib2DPlotter(BasePlotter):
    _fig: plt.figure  # plt figure
    _ax: plt.axis  # plt axis
    # stores artist objects for each task (task name as the key)
    _artist_cache: Dict[str, Any]
    # callables for each task primitives
    _create_impl_callables: Dict[str, Callable]
    _update_impl_callables: Dict[str, Callable]

    def __init__(self, task: "BasePlotterTask") -> None:
        fig, ax = plt.subplots()
        self._fig = fig
        self._ax = ax
        self._artist_cache = {}

        self._create_impl_callables = {
            "Draw2DLines": self._lines_create_impl,
            "Draw2DDots": self._dots_create_impl,
            "Draw2DTrail": self._trail_create_impl,
        }
        self._update_impl_callables = {
            "Draw2DLines": self._lines_update_impl,
            "Draw2DDots": self._dots_update_impl,
            "Draw2DTrail": self._trail_update_impl,
        }
        self._init_lim()
        super().__init__(task)

    @property
    def ax(self):
        return self._ax

    @property
    def fig(self):
        return self._fig

    def show(self):
        plt.show()

    def _min(self, x, y):
        if x is None:
            return y
        if y is None:
            return x
        return min(x, y)

    def _max(self, x, y):
        if x is None:
            return y
        if y is None:
            return x
        return max(x, y)

    def _init_lim(self):
        self._curr_x_min = None
        self._curr_y_min = None
        self._curr_x_max = None
        self._curr_y_max = None

    def _update_lim(self, xs, ys):
        self._curr_x_min = self._min(np.min(xs), self._curr_x_min)
        self._curr_y_min = self._min(np.min(ys), self._curr_y_min)
        self._curr_x_max = self._max(np.max(xs), self._curr_x_max)
        self._curr_y_max = self._max(np.max(ys), self._curr_y_max)

    def _set_lim(self):
        if not (
            self._curr_x_min is None
            or self._curr_x_max is None
            or self._curr_y_min is None
            or self._curr_y_max is None
        ):
            self._ax.set_xlim(self._curr_x_min, self._curr_x_max)
            self._ax.set_ylim(self._curr_y_min, self._curr_y_max)
        self._init_lim()

    @staticmethod
    def _lines_extract_xy_impl(index, lines_task):
        return lines_task[index, :, 0], lines_task[index, :, 1]

    @staticmethod
    def _trail_extract_xy_impl(index, trail_task):
        return (trail_task[index : index + 2, 0], trail_task[index : index + 2, 1])

    def _lines_create_impl(self, lines_task):
        color = lines_task.color
        self._artist_cache[lines_task.task_name] = [
            self._ax.plot(
                *Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task),
                color=color,
                linewidth=lines_task.line_width,
                alpha=lines_task.alpha
            )[0]
            for i in range(len(lines_task))
        ]

    def _lines_update_impl(self, lines_task):
        lines_artists = self._artist_cache[lines_task.task_name]
        for i in range(len(lines_task)):
            artist = lines_artists[i]
            xs, ys = Matplotlib2DPlotter._lines_extract_xy_impl(i, lines_task)
            artist.set_data(xs, ys)
            if lines_task.influence_lim:
                self._update_lim(xs, ys)

    def _dots_create_impl(self, dots_task):
        color = dots_task.color
        self._artist_cache[dots_task.task_name] = self._ax.plot(
            dots_task[:, 0],
            dots_task[:, 1],
            c=color,
            linestyle="",
            marker=".",
            markersize=dots_task.marker_size,
            alpha=dots_task.alpha,
        )[0]

    def _dots_update_impl(self, dots_task):
        dots_artist = self._artist_cache[dots_task.task_name]
        dots_artist.set_data(dots_task[:, 0], dots_task[:, 1])
        if dots_task.influence_lim:
            self._update_lim(dots_task[:, 0], dots_task[:, 1])

    def _trail_create_impl(self, trail_task):
        color = trail_task.color
        trail_length = len(trail_task) - 1
        self._artist_cache[trail_task.task_name] = [
            self._ax.plot(
                *Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task),
                color=trail_task.color,
                linewidth=trail_task.line_width,
                alpha=trail_task.alpha * (1.0 - i / (trail_length - 1))
            )[0]
            for i in range(trail_length)
        ]

    def _trail_update_impl(self, trail_task):
        trails_artists = self._artist_cache[trail_task.task_name]
        for i in range(len(trail_task) - 1):
            artist = trails_artists[i]
            xs, ys = Matplotlib2DPlotter._trail_extract_xy_impl(i, trail_task)
            artist.set_data(xs, ys)
            if trail_task.influence_lim:
                self._update_lim(xs, ys)

    def _create_impl(self, task_list):
        for task in task_list:
            self._create_impl_callables[task.task_type](task)
        self._draw()

    def _update_impl(self, task_list):
        for task in task_list:
            self._update_impl_callables[task.task_type](task)
        self._draw()

    def _set_aspect_equal_2d(self, zero_centered=True):
        xlim = self._ax.get_xlim()
        ylim = self._ax.get_ylim()

        if not zero_centered:
            xmean = np.mean(xlim)
            ymean = np.mean(ylim)
        else:
            xmean = 0
            ymean = 0

        plot_radius = max(
            [
                abs(lim - mean_)
                for lims, mean_ in ((xlim, xmean), (ylim, ymean))
                for lim in lims
            ]
        )

        self._ax.set_xlim([xmean - plot_radius, xmean + plot_radius])
        self._ax.set_ylim([ymean - plot_radius, ymean + plot_radius])

    def _draw(self):
        self._set_lim()
        self._set_aspect_equal_2d()
        self._fig.canvas.draw()
        self._fig.canvas.flush_events()
        plt.pause(0.00001)


class Matplotlib3DPlotter(BasePlotter):
    _fig: plt.figure  # plt figure
    _ax: p3.Axes3D  # plt 3d axis
    # stores artist objects for each task (task name as the key)
    _artist_cache: Dict[str, Any]
    # callables for each task primitives
    _create_impl_callables: Dict[str, Callable]
    _update_impl_callables: Dict[str, Callable]

    def __init__(self, task: "BasePlotterTask") -> None:
        self._fig = plt.figure()
        self._ax = p3.Axes3D(self._fig)
        self._artist_cache = {}

        self._create_impl_callables = {
            "Draw3DLines": self._lines_create_impl,
            "Draw3DDots": self._dots_create_impl,
            "Draw3DTrail": self._trail_create_impl,
        }
        self._update_impl_callables = {
            "Draw3DLines": self._lines_update_impl,
            "Draw3DDots": self._dots_update_impl,
            "Draw3DTrail": self._trail_update_impl,
        }
        self._init_lim()
        super().__init__(task)

    @property
    def ax(self):
        return self._ax

    @property
    def fig(self):
        return self._fig

    def show(self):
        plt.show()

    def _min(self, x, y):
        if x is None:
            return y
        if y is None:
            return x
        return min(x, y)

    def _max(self, x, y):
        if x is None:
            return y
        if y is None:
            return x
        return max(x, y)

    def _init_lim(self):
        self._curr_x_min = None
        self._curr_y_min = None
        self._curr_z_min = None
        self._curr_x_max = None
        self._curr_y_max = None
        self._curr_z_max = None

    def _update_lim(self, xs, ys, zs):
        self._curr_x_min = self._min(np.min(xs), self._curr_x_min)
        self._curr_y_min = self._min(np.min(ys), self._curr_y_min)
        self._curr_z_min = self._min(np.min(zs), self._curr_z_min)
        self._curr_x_max = self._max(np.max(xs), self._curr_x_max)
        self._curr_y_max = self._max(np.max(ys), self._curr_y_max)
        self._curr_z_max = self._max(np.max(zs), self._curr_z_max)

    def _set_lim(self):
        if not (
            self._curr_x_min is None
            or self._curr_x_max is None
            or self._curr_y_min is None
            or self._curr_y_max is None
            or self._curr_z_min is None
            or self._curr_z_max is None
        ):
            self._ax.set_xlim3d(self._curr_x_min, self._curr_x_max)
            self._ax.set_ylim3d(self._curr_y_min, self._curr_y_max)
            self._ax.set_zlim3d(self._curr_z_min, self._curr_z_max)
        self._init_lim()

    @staticmethod
    def _lines_extract_xyz_impl(index, lines_task):
        return lines_task[index, :, 0], lines_task[index, :, 1], lines_task[index, :, 2]

    @staticmethod
    def _trail_extract_xyz_impl(index, trail_task):
        return (
            trail_task[index : index + 2, 0],
            trail_task[index : index + 2, 1],
            trail_task[index : index + 2, 2],
        )

    def _lines_create_impl(self, lines_task):
        color = lines_task.color
        self._artist_cache[lines_task.task_name] = [
            self._ax.plot(
                *Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task),
                color=color,
                linewidth=lines_task.line_width,
                alpha=lines_task.alpha
            )[0]
            for i in range(len(lines_task))
        ]

    def _lines_update_impl(self, lines_task):
        lines_artists = self._artist_cache[lines_task.task_name]
        for i in range(len(lines_task)):
            artist = lines_artists[i]
            xs, ys, zs = Matplotlib3DPlotter._lines_extract_xyz_impl(i, lines_task)
            artist.set_data(xs, ys)
            artist.set_3d_properties(zs)
            if lines_task.influence_lim:
                self._update_lim(xs, ys, zs)

    def _dots_create_impl(self, dots_task):
        color = dots_task.color
        self._artist_cache[dots_task.task_name] = self._ax.plot(
            dots_task[:, 0],
            dots_task[:, 1],
            dots_task[:, 2],
            c=color,
            linestyle="",
            marker=".",
            markersize=dots_task.marker_size,
            alpha=dots_task.alpha,
        )[0]

    def _dots_update_impl(self, dots_task):
        dots_artist = self._artist_cache[dots_task.task_name]
        dots_artist.set_data(dots_task[:, 0], dots_task[:, 1])
        dots_artist.set_3d_properties(dots_task[:, 2])
        if dots_task.influence_lim:
            self._update_lim(dots_task[:, 0], dots_task[:, 1], dots_task[:, 2])

    def _trail_create_impl(self, trail_task):
        color = trail_task.color
        trail_length = len(trail_task) - 1
        self._artist_cache[trail_task.task_name] = [
            self._ax.plot(
                *Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task),
                color=trail_task.color,
                linewidth=trail_task.line_width,
                alpha=trail_task.alpha * (1.0 - i / (trail_length - 1))
            )[0]
            for i in range(trail_length)
        ]

    def _trail_update_impl(self, trail_task):
        trails_artists = self._artist_cache[trail_task.task_name]
        for i in range(len(trail_task) - 1):
            artist = trails_artists[i]
            xs, ys, zs = Matplotlib3DPlotter._trail_extract_xyz_impl(i, trail_task)
            artist.set_data(xs, ys)
            artist.set_3d_properties(zs)
            if trail_task.influence_lim:
                self._update_lim(xs, ys, zs)

    def _create_impl(self, task_list):
        for task in task_list:
            self._create_impl_callables[task.task_type](task)
        self._draw()

    def _update_impl(self, task_list):
        for task in task_list:
            self._update_impl_callables[task.task_type](task)
        self._draw()

    def _set_aspect_equal_3d(self):
        xlim = self._ax.get_xlim3d()
        ylim = self._ax.get_ylim3d()
        zlim = self._ax.get_zlim3d()

        xmean = np.mean(xlim)
        ymean = np.mean(ylim)
        zmean = np.mean(zlim)

        plot_radius = max(
            [
                abs(lim - mean_)
                for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean))
                for lim in lims
            ]
        )

        self._ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
        self._ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
        self._ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])

    def _draw(self):
        self._set_lim()
        self._set_aspect_equal_3d()
        self._fig.canvas.draw()
        self._fig.canvas.flush_events()
        plt.pause(0.00001)
