import math

import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import gym

import plotting
import path_config


def _get_screen(env):
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = screen.transpose(1, 2, 0)
    return screen


def get_cartpole_pixels(state: np.ndarray) -> np.ndarray:
    env = gym.make(env_name)

    env.seed(0)
    env.reset()
    env.env.state = state
    to_plot = _get_screen(env)
    env.close()
    plt.close("all")

    return to_plot


def add_x_line(ax: matplotlib.axes.Axes):
    x_line_y = -.20
    ax.plot((0, 0), (x_line_y - tick_len, x_line_y + tick_len), color="k", lw=lw)
    ax.plot((x, x), (x_line_y - tick_len, x_line_y + tick_len), color="k", lw=lw)

    ax.plot((0, x), (x_line_y, x_line_y), color="k", lw=lw)
    x_annot_y = x_line_y - x_bel
    frac = .5
    ax.text(x * frac, x_annot_y, x_str)


def add_xprime_line(ax: matplotlib.axes.Axes):
    xdot_y = -.40
    #
    # arrow_start = (x, xdot_y)
    # arrow_end = (x + tau * secs * xdot, xdot_y)
    ax.arrow(x=x,
             y=xdot_y,
             dx=tau * secs * xdot,
             dy=0,
             head_width=0.025,
             head_length=0.05, fc='k', ec='k', lw=lw)

    xprime_annot_y = xdot_y - x_bel
    ax.text(x + tau * secs * xdot / 2, xprime_annot_y, xdot_str)


def add_theta_arc(ax: matplotlib.axes.Axes):
    theta_arc_radius = .5
    theta_annot_rad = .6

    # Add the vertical tick
    ax.plot((x, x), (theta_arc_radius - tick_len, theta_arc_radius + tick_len),
            color="k", lw=lw)

    # Add the diagonal tick
    x0 = x + (theta_arc_radius - tick_len) * np.sin(theta)
    x1 = x + (theta_arc_radius + tick_len) * np.sin(theta)

    y0 = (theta_arc_radius - tick_len) * np.cos(theta)
    y1 = (theta_arc_radius + tick_len) * np.cos(theta)

    ax.plot((x0, x1), (y0, y1), color="k", lw=lw)

    theta1_deg = min(90 - theta * 360 / (2 * np.pi), 90)
    theta2_deg = max(90 - theta * 360 / (2 * np.pi), 90)

    e1 = matplotlib.patches.Arc((x, 0),
                                width=2 * theta_arc_radius,
                                height=2 * theta_arc_radius,
                                angle=0,
                                theta1=theta1_deg,
                                theta2=theta2_deg,
                                lw=lw)
    ax.add_patch(e1)

    frac = .5
    ax.text(x + theta_annot_rad * np.sin(theta * frac),
                theta_annot_rad * np.cos(theta * frac), theta_str)


def add_thetadot_arc(ax: matplotlib.axes.Axes):
    thetadot_annot_rad = 1.1

    theta_prime = theta + thetadot * tau * secs

    theta1_deg = min(90 - theta_prime * 360 / (2 * np.pi), 90 - theta * 360 / (2 * np.pi))
    theta2_deg = max(90 - theta_prime * 360 / (2 * np.pi), 90 - theta * 360 / (2 * np.pi))

    e2 = matplotlib.patches.Arc((x, 0),
                                width=2,
                                height=2,
                                angle=0,
                                theta1=theta1_deg,
                                theta2=theta2_deg,
                                lw=lw)
    ax.add_patch(e2)
    x1 = x + np.sin(theta_prime)
    y1 = np.cos(theta_prime)

    x0 = x + np.sin(theta)
    y0 = np.cos(theta)

    dx = .01 * (x1 - x0)
    dy = .01 * (y1 - y0)

    ax.arrow(x=x1,
             y=y1,
             dx=dx,
             dy=dy,
             head_width=0.025,
             head_length=0.05, fc='k', ec='k', lw=lw)

    annotate_theta = (theta + theta_prime) / 2
    ax.text(x + thetadot_annot_rad * np.sin(annotate_theta),
                thetadot_annot_rad * np.cos(annotate_theta),
            thetadot_str)


if __name__ == "__main__":
    fig_format = "pgf"
    # fig_format = "png"
    # save_plot = False
    save_plot = fig_format == "pgf"
    # matplotlib.rc('text', usetex=True)
    if save_plot:
        font_family = "serif"
        plotting.initialise_pgf_plots("pdflatex", font_family)

    env_name = "CartPole-v1"
    major_axis_grid_color = "black"
    # major_axis_grid_color = "grey"
    tau = 0.02
    secs = 10

    x_bel = .15

    lw = 1.50
    axline_lw = 0.50
    major_grid_lw = .50
    minor_grid_lw = .25
    # plot_scale = 3.0
    plot_scale = 2.5
    # plot_scale = 2.75

    upper = +600 / 250
    lower = -200 / 250
    left = -2.4
    right = +2.4
    tick_len = .025

    degrees = +40
    theta = degrees * 2 * math.pi / 360
    thetadot = +0.512501
    # thetadot = 0.512501
    state = np.array([-1.4, -2.0, theta, thetadot])
    to_plot = get_cartpole_pixels(state)

    x = state[0]
    xdot = state[1]
    theta = state[2]
    thetadot = state[3]

    # string_type = "values"
    # string_type = "sign"
    string_type = "none"

    if string_type == "values":
        x_str = "$x = {:.4f}$".format(x)
        xdot_str = "$\\dot{{x}} = {:.4f}$".format(xdot)
        theta_str = "$\\theta = {:.4f}$".format(theta)
        thetadot_str = "$\\dot{{\\theta}} = {:.4f}$".format(thetadot)
    elif string_type == "none":
        x_str = "$x$"
        xdot_str = "$\\dot{{x}}$"
        theta_str = r"$\theta$"
        thetadot_str = r"$\dot{\theta}$"
    elif string_type == "sign":
        x_str = "$x \\le 0$" if x <= 0 else "$x \\ge 0$"
        xdot_str = "$\\dot{{x}} \\le 0$" if xdot <= 0 else "$\\dot{{x}} \\ge 0$"
        theta_str = "$\\theta \\le 0$" if theta <= 0 else "$\\theta \\ge 0$"
        thetadot_str = "$\\dot{{\\theta}} \\le 0$" if thetadot <= 0 else "$\\dot{{\\theta}} \\ge 0$"
    else:
        raise ValueError("Wrong string type")

    fig, axs = plotting.wrapped_subplot(1, 1, plot_scale)

    ax = axs[0, 0]
    extent = [left, right, lower, upper]
    ax.imshow(to_plot, interpolation='none', extent=extent)

    add_theta_arc(ax)
    add_thetadot_arc(ax)
    add_x_line(ax)
    add_xprime_line(ax)

    ax.set_xlim((left, 0))
    ax.set_ylim((-.75, 1.25))
    ax.get_yaxis().set_visible(False)

    # fig.subplots_adjust(right=0.8)
    fig.tight_layout()
    if save_plot:
        paths = path_config.get_paths()
        ident = "cartpole_schematic_{}".format(string_type)
        filepath = paths["plots"]
        fig_path = plotting.smart_save_fig(fig, ident, fig_format, filepath)

    print("Figure written to '{}'".format(fig_path))
