import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from Scripts.Useful import scale_state


def Plot_Vortices_on_Trajectories(*args):
    max_x = []
    max_y = []
    min_x = []
    min_y = []
    x_y_tuple = ()
    labels = []
    for arg in args:
        max_x.append(max(arg[0]))
        max_y.append(max(arg[1]))
        min_x.append(min(arg[0]))
        min_y.append(min(arg[1]))
        x_y_tuple += (arg[0], arg[1])
        labels += [arg[2]]
    max_x = max(max_x)
    max_y = max(max_y)
    min_x = min(min_x)
    min_y = min(min_y)

    x, y = np.meshgrid(np.linspace(min_x - 2 * np.pi, max_x + 2 * np.pi, int(30 * (max_x - min_x))),
                       np.linspace(min_y - 2 * np.pi, max_y + 2 * np.pi, int(30 * (max_y - min_y))))
    TGV_vort = np.cos(x) * np.cos(y)
    fig, ax = plt.subplots()
    im = ax.pcolormesh(x, y, TGV_vort, cmap='RdBu', shading='gouraud')
    cax = fig.add_axes([0.95, 0.1, 0.02, 0.8])
    fig.colorbar(im, cax=cax, orientation='vertical')
    ax.plot(*x_y_tuple)
    ax.set_xlabel('x - axis')
    ax.set_ylabel('y - axis')
    ax.legend(labels)
    ax.set_title("Trajectories in TGV")
    ax.axis('equal')
    ax.set_xlim(-np.pi / 2, 2.5 * np.pi)
    fig.set_figwidth(2.0)
    return fig, ax


def Plot_Policy(actor, critic, env, scaler, x_con, y_con):
    x, y = np.meshgrid(np.linspace(-np.pi / 2, 3 * np.pi / 2, int(30 * (2 * np.pi))),
                       np.linspace(-np.pi / 2, np.pi / 2, int(30 * (np.pi))))
    TGV_vort = np.cos(x) * np.cos(y)
    x2, y2 = np.meshgrid(np.linspace(-np.pi / 2, 3 * np.pi / 2, x_con), np.linspace(-np.pi / 2, np.pi / 2, y_con))
    states = np.zeros((x2.shape[0], x2.shape[1], env.observation_space.shape[0]))
    dist = []
    actions = np.zeros((x2.shape[0], x2.shape[1]))
    conc = np.zeros((x2.shape[0], x2.shape[1]))
    value = np.zeros((x2.shape[0], x2.shape[1]))
    for i in range(x2.shape[0]):
        dist.append([])
        for j in range(x2.shape[1]):
            states[i, j] = env.pos_to_state(x2[i, j], y2[i, j])
            dist[i].append(actor(scale_state(states[i, j], scaler)))
            if env.type == 'Surf':
                env.particle.current_x = x2[i, j]
                env.particle.current_y = y2[i, j]
                tau = dist[i][j].mean().numpy().reshape(-1)
                # if tau[0] <= 0.0:
                #     tau = np.array([0.0])
                # if tau[0] >= 10.0:
                #     tau = np.array([10.0])
                actions[i, j] = np.float64(
                    env.particle.surf(tau, np.array(env.target_dir)[:, np.newaxis]))
                conc[i, j] = dist[i][j].stddev().numpy().reshape(-1)
            else:
                actions[i, j] = dist[i][j].loc.numpy().reshape(-1)
                conc[i, j] = dist[i][j].concentration.numpy().reshape(-1)
            value[i, j] = critic(scale_state(states[i, j], scaler))

    fig1, axs1 = plt.subplots()
    im = axs1.pcolormesh(x, y, TGV_vort, cmap='RdBu', shading='gouraud')
    axs1.quiver(x2, y2, np.cos(actions), np.sin(actions))
    axs1.set_ylim(-np.pi / 2, np.pi / 2)
    fig1.set_figheight(3.2)
    axs1.set_title("Policy Representation")
    axs1.axis('equal')

    fig2, axs2 = plt.subplots()
    im = axs2.pcolormesh(x2, y2, value, cmap='plasma')
    cax = fig2.add_axes([0.95, 0.14, 0.02, 0.72])
    cax.set_title("v(s)", y=1.0)
    fig2.colorbar(im, cax=cax, orientation='vertical')
    axs2.set_ylim(-np.pi / 2, np.pi / 2)
    fig2.set_figheight(3.2)
    axs2.set_title("Value Function")
    axs2.axis('equal')

    fig3, axs3 = plt.subplots()
    im = axs3.pcolormesh(x2, y2, conc, cmap='viridis')
    cax = fig3.add_axes([0.95, 0.14, 0.02, 0.72])
    if env.type == 'Surf':
        cax.set_title("STD", y=1.0)
        fig3.colorbar(im, cax=cax, orientation='vertical')
        axs3.set_ylim(-np.pi / 2, np.pi / 2)
        fig3.set_figheight(3.2)
        axs3.set_title("Policy Standard Deviation")
        axs3.axis('equal')
    else:
        cax.set_title("conc", y=1.0)
        fig3.colorbar(im, cax=cax, orientation='vertical')
        axs3.set_ylim(-np.pi / 2, np.pi / 2)
        fig3.set_figheight(3.2)
        axs3.set_title("Policy Concentration")
        axs3.axis('equal')

    mean = conc.mean()
    if mean < 0.1:
        done = True
    else:
        done = False
    return [fig1, fig2, fig3], [axs1, axs2, axs3], done


def Plot_Learning_Curve(learning_episodes, average_return):
    fig, ax = plt.subplots()
    ax.plot(learning_episodes, average_return)
    ax.set_xlabel('Episode')
    ax.set_ylabel('Average Return')
    ax.set_title("Learning curve")
    return fig, ax


def Plot_Policy_Turb(actor, critic, env, scaler, x_con, y_con, time, surf=None):
    x, y = np.meshgrid(np.linspace(0.0, 2 * np.pi, int(30 * (2 * np.pi))),
                       np.linspace(0.0, 2 * np.pi, int(30 * (2 * np.pi))))
    vort_z = np.zeros((x.shape[0], x.shape[1]))
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            vort_z[i, j] = env.velf.Interpolate_Fields(x[i, j], y[i, j], time, env.velf.vort_z)

    x2, y2 = np.meshgrid(np.linspace(0.0, 2 * np.pi, x_con), np.linspace(0.0, 2 * np.pi, y_con))
    states = np.zeros((x2.shape[0], x2.shape[1], env.observation_space.shape[0]))
    actions = np.zeros((x2.shape[0], x2.shape[1]))
    if surf is None:
        dist = []
        conc = np.zeros((x2.shape[0], x2.shape[1]))
        value = np.zeros((x2.shape[0], x2.shape[1]))
        for i in range(x2.shape[0]):
            dist.append([])
            for j in range(x2.shape[1]):
                states[i, j] = env.pos_to_state(x2[i, j], y2[i, j], time)
                dist[i].append(actor(scale_state(states[i, j], scaler)))
                if env.type == 'Surf':
                    env.particle.current_x = x2[i, j]
                    env.particle.current_y = y2[i, j]
                    tau = dist[i][j].mean().numpy().reshape(-1)
                    # if tau[0] <= 0.0:
                    #     tau = np.array([0.0])
                    # if tau[0] >= 10.0:
                    #     tau = np.array([10.0])
                    actions[i, j] = np.float64(
                        env.particle.surf(tau, np.array(env.target_dir)[:, np.newaxis]))
                    conc[i, j] = dist[i][j].stddev().numpy().reshape(-1)
                else:
                    actions[i, j] = dist[i][j].loc.numpy().reshape(-1)
                    conc[i, j] = dist[i][j].concentration.numpy().reshape(-1)
                value[i, j] = critic(scale_state(states[i, j], scaler))
    else:
        for i in range(x2.shape[0]):
            for j in range(x2.shape[1]):
                env.particle.current_x = x2[i, j]
                env.particle.current_y = y2[i, j]
                env.particle.current_t = time
                actions[i, j] = np.float64(env.particle.surf(surf, np.array(env.target_dir)[:, np.newaxis]))

    figs = []
    axs = []

    fig1, axs1 = plt.subplots(figsize=(6, 6))
    im = axs1.pcolormesh(x, y, vort_z, cmap='viridis', shading='gouraud')
    axs1.quiver(x2, y2, np.cos(actions), np.sin(actions))
    axs1.set_ylim(0.0, 2 * np.pi)
    axs1.set_xlim(0.0, 2 * np.pi)
    axs1.set_title("Policy Representation")
    axs1.axis('equal')
    axs1.set_aspect('equal')
    figs.append(fig1)
    axs.append(axs1)

    done = False

    if surf is None:
        fig2, axs2 = plt.subplots()
        im = axs2.pcolormesh(x2, y2, value, cmap='plasma')
        cax = fig2.add_axes([0.95, 0.14, 0.02, 0.72])
        cax.set_title("v(s)", y=1.0)
        fig2.colorbar(im, cax=cax, orientation='vertical')
        axs2.set_ylim(0.0, 2 * np.pi)
        axs2.set_xlim(0.0, 2.0)
        axs2.set_title("Value Function")
        axs2.axis('equal')
        figs.append(fig2)
        axs.append(axs2)

        fig3, axs3 = plt.subplots()
        im = axs3.pcolormesh(x2, y2, conc, cmap='viridis')
        cax = fig3.add_axes([0.95, 0.14, 0.02, 0.72])
        if env.type == 'Surf':
            cax.set_title("STD", y=1.0)
            axs3.set_title("Policy Standard Deviation")
        else:
            cax.set_title("conc", y=1.0)
            axs3.set_title("Policy Concentration")

        fig3.colorbar(im, cax=cax, orientation='vertical')
        axs3.set_ylim(0.0, 2 * np.pi)
        axs3.set_xlim(0.0, 2 * np.pi)
        axs3.axis('equal')
        figs.append(fig3)
        axs.append(axs3)


        mean = conc.mean()
        if mean < 0.1:
            done = True
        else:
            done = False

    plt.close('all')

    return figs, axs, done