import matplotlib as mpl
import matplotlib.ticker as ticker
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from jax import device_get
import jax
from fairgym.envs.default_reward import known_loss_func_dict

mpl.rcParams.update(
    {
        #         "font.family": "serif",
        #         "font.weight": "bold",
        "mathtext.fontset": "cm",
        "mathtext.rm": "serif",
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    }
)

font = {"size": 13}

mpl.rc("font", **font)


def make_replicator_phase_plot(
    agent,  # Assumed to be already trained
    env,
    loss_name,
    disparity_name,
    env_seed,
    filename,
    res=4,
    options=None,
    num_runs=5,
    H=False,  # if H, just export data. No phase plot.
    return_truncated=True,
):

    # options used for environment
    if options is None:
        options = {}

    x = np.linspace(0.01, 0.99, res)  # q_1
    y = np.linspace(0.01, 0.99, res)  # q_2

    xx, yy = np.meshgrid(x, y)

    # Velocity vectors
    Vx = np.zeros((res, res))  # x-component (group 1)
    Vy = np.zeros((res, res))  # y-component (group 2)

    # Reward AT this point (reward may be 1 - combination of loss and diparity)
    R = np.zeros((res, res))

    # Loss at this point
    L = np.zeros((res, res))

    cumulative_loss = [np.zeros((res, res))] * num_runs

    # Disparity
    D = np.zeros((res, res))

    # https://eli.thegreenplace.net/2014/meshgrids-and-disambiguating-rows-and-columns-from-cartesian-coordinates/
    # Arrays indexed by row = y = group 2, column = x = group 1.

    for ix in tqdm(range(res)):  # group 1
        for iy in range(res):  # group 2

            # calculate _x_g, _y1_g given d1_g and d0_g
            pr_q = np.array([x[ix], y[iy]])
            next_pr_q = np.zeros((2,))

            observation, init_info = env.reset(
                seed=env_seed, return_info=True, options=options | {"pr_q": pr_q}
            )
            observation = device_get(observation)
            for agent_seed in range(num_runs):
                # let agent decide what non-deterministic policy to deploy in
                # this fake scenario. We will average ALL depicted quantities
                # (disparity, dynamics, etc) over this non-determinisim.
                action = agent.policy(observation, init_info, seed=agent_seed)
                if return_truncated:
                    observation, reward, terminated, truncated, next_info = env.step(
                        action
                    )
                else:
                    observation, reward, terminated, next_info = env.step(action)

                # what is reward for this state/action?
                R[iy, ix] += reward

                L[iy, ix] += known_loss_func_dict[loss_name](
                    next_info["prev_state"],
                    next_info["prev_action"],
                    next_info["prev_results"],
                    next_info["current_state"],
                )

                # collect mean episodic loss over H episodes starting from this
                # point and seed
                if H:
                    cumulative_loss[agent_seed][iy, ix] = L[iy, ix]
                    for h in range(H - 1):
                        action = agent.policy(observation, next_info, seed=agent_seed)
                        if return_truncated:
                            (
                                observation,
                                reward,
                                terminated,
                                truncated,
                                next_info,
                            ) = env.step(action)
                        else:
                            observation, reward, terminated, next_info = env.step(
                                action
                            )
                        cumulative_loss[agent_seed][iy, ix] += known_loss_func_dict[
                            loss_name
                        ](
                            next_info["prev_state"],
                            next_info["prev_action"],
                            next_info["prev_results"],
                            next_info["current_state"],
                        )

                    cumulative_loss[agent_seed][iy, ix] /= H

                # what is color value to display at this state/action?
                D[iy, ix] += known_loss_func_dict[disparity_name](
                    next_info["prev_state"],
                    next_info["prev_action"],
                    next_info["prev_results"],
                    next_info["current_state"],
                )

                # induced qualification rate
                next_pr_q += next_info["current_state"].pr_Y1

            # perform averaging over num_runs
            L[iy, ix] /= num_runs
            R[iy, ix] /= num_runs
            D[iy, ix] /= num_runs
            next_pr_q /= num_runs

            # what is velocity?
            Vx[iy, ix], Vy[iy, ix] = next_pr_q - pr_q

    # if storing episodic mean loss, don't generate phase plots
    if H:
        for agent_seed in range(num_runs):
            array_title = f"{filename}_episodic_mean_loss_H={H}_seed={agent_seed}.npy"
            print("saving", array_title)
            np.save(
                array_title,
                cumulative_loss[agent_seed],
            )
        return

    for (color_array, color_name, colormin, colormax, cmap) in [
        # (L, "Loss", np.round(np.min(L), 2), np.round(np.max(L), 2), "Oranges"),
        (L, "Loss", 0, 1, "Oranges"),
        (D, disparity_name, 0, np.max(D), "Blues"),
    ]:

        if colormin == colormax:
            colormax += 0.001

        fig = plt.figure()
        scale = 6

        fig.set_size_inches(scale, scale)

        ax = fig.add_subplot(1, 1, 1)
        ax.set_xlim([0.01, 0.99])
        ax.set_ylim([0.01, 0.99])
        ax.set_xlabel("Group 1 qualification rate $s_1$")
        ax.set_ylabel("Group 2 qualification rate $s_2$")
        ax.xaxis.set_major_locator(ticker.FixedLocator([0.1, 0.3, 0.5, 0.7, 0.9]))
        ax.yaxis.set_major_locator(ticker.FixedLocator([0.1, 0.3, 0.5, 0.7, 0.9]))

        ax.streamplot(x, y, Vx, Vy, color="black", linewidth=0.6, arrowsize=0.8)

        cs = ax.contourf(
            x,
            y,
            color_array,
            cmap=plt.get_cmap(cmap),
            levels=np.array(np.linspace(colormin, colormax, 9)),
            alpha=0.8,
        )

        fig.colorbar(
            cs,
            ax=ax,
            fraction=0.046,
            pad=0.04,
            ticks=[colormin, colormax],
            ticklocation="left",
        )
        ax.text(
            1.15,
            0.5,
            color_name,
            rotation=270,
            rotation_mode="anchor",
            horizontalalignment="center",
            verticalalignment="baseline",
            multialignment="center",
        )

        # plot equal qualification line
        ax.plot(
            [0.02, 0.98],
            [0.02, 0.98],
            color="black",
            linewidth=3.5,
        )

        # Set legend location for equal qualification line
        # https://stackoverflow.com/questions/4700614/how-to-put-the-legend-out-of-the-plot/43439132#43439132

        # plt.show()
        print("saving", f"{filename}_color={color_name}.pdf")
        plt.savefig(f"{filename}_color={color_name}.pdf", bbox_inches="tight")

    # if "loss" in color:
    #     with open(
    #         "color_" + color + "_" + filename.split(".")[0] + ".pickle", "wb"
    #     ) as f:
    #         pickle.dump([AR, BR], f)
