from typing import Tuple, Union

import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import plotting

Color = Union[str, Tuple[float]]
# https://stuff.mit.edu/afs/athena/contrib/tex-contrib/beamer/pgf-1.01/doc/generic/pgf/version-for-tex4ht/en/pgfmanualse21.html


def get_airplane_coordinates() -> np.ndarray:
    x = np.array([[0.45, -0.000808621],
                  [0.45, -0.00282241 ],
                  [0.449083, -0.00667069 ],
                  [0.441203, -0.0125362 ],
                  [0.411877, -0.0271983 ],
                  [0.37082, -0.0369155 ],
                  [0.319134, -0.0418603 ],
                  [0.274049, -0.0438776 ],
                  [0.138235, -0.0438776 ],
                  [-0.218623, -0.401629 ],
                  [-0.229255, -0.406395],
                  [-0.25858, -0.404564],
                  [-0.271409, -0.399612],
                  [-0.141277, -0.192329],
                  [-0.0895918, -0.0438776],
                  [-0.298717, -0.0447948],
                  [-0.390543, -0.033981],
                  [-0.501063, -0.168871],
                  [-0.54212, -0.168871],
                  [-0.508946, -0.0125362],
                  [-0.55, -0.00172586],
                  [-0.55, 0.00172241],
                  [-0.508946, 0.01235],
                  [-0.54212, 0.168871],
                  [-0.501063, 0.168871],
                  [-0.390543, 0.033981],
                  [-0.298717, 0.0446121],
                  [-0.0895918, 0.0436948],
                  [-0.141277, 0.192329],
                  [-0.271409, 0.399429],
                  [-0.25858, 0.404378],
                  [-0.229255, 0.406395],
                  [-0.218623, 0.401447],
                  [0.138235, 0.0436948],
                  [0.274049, 0.0436948],
                  [0.319134, 0.0416741],
                  [0.37082, 0.0369121],
                  [0.411877, 0.0270121],
                  [0.441203, 0.01235],
                  [0.449083, 0.00648793],
                  [0.45, 0.00263966]])
    return x


def _list_pgfplotto(ax: matplotlib.axes.Axes,
                    points: np.ndarray,
                    color: Color):
    lw = .75
    nr = points.shape[0]
    for idx in range(1, nr):
        # idx = 1
        xs = points[idx - 1:idx + 1, 0].tolist()
        ys = points[idx - 1:idx + 1, 1].tolist()
        ax.plot(xs, ys, color=color, linewidth=lw)


def add_airplane_at(ax: matplotlib.axes.Axes,
                    direction: float,
                    center: np.ndarray,
                    scale: float,
                    color: Color):

    # direction = 3 * np.pi / 2
    theta = direction
    rotation_matrix = np.array([[+1 * np.cos(theta), -1 * np.sin(theta)],
                                [+1 * np.sin(theta), +1 * np.cos(theta)]])

    airplane_coordinates = get_airplane_coordinates()
    airplane_coordinates = (rotation_matrix @ airplane_coordinates.T).T
    airplane_coordinates = airplane_coordinates * scale + center
    _list_pgfplotto(ax, airplane_coordinates, color)


if __name__ == "__main__":
    airplane_coordinates = get_airplane_coordinates()

    fig, axs = plotting.wrapped_subplot(1, 1)
    ax = axs[0, 0]

    color = "k"
    scale = .1
    center = np.zeros((1, 2))
    direction = np.random.uniform(0, 2 * np.pi)
    add_airplane_at(ax, direction, center, scale, color)
    # ax.plot([0, 1], [0, 1])
    # ax.plot(x[:, 0], x[:, 1])
