from typing import Tuple, Dict, List
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import numpy as np


def plot_heatmap(
    info: List[Tuple[Dict[str, float], float]],
    title_plot: str = "heatmap_mass_friction.png",
):
    list_of_mass = [params[0]["mass_coef"] for params in info]
    list_of_friction = [params[0]["friction_coef"] for params in info]

    distinct_mass = set(list_of_mass)
    distinct_mass = {round(mass, 2) for mass in distinct_mass}
    distinct_friction = set(list_of_friction)
    distinct_friction = {round(friction, 2) for friction in distinct_friction}

    nb_of_mass = len(distinct_mass)
    nb_of_friction = len(distinct_friction)

    rewards = [params[1] for params in info]

    heatmap = np.array(rewards).reshape(nb_of_mass, nb_of_friction)

    ax = sns.heatmap(heatmap, linewidth=0.5, vmin=-600, vmax=15000)
    ax.set_ylabel("Mass")
    ax.set_xlabel("Friction")

    ax.set_xticklabels(sorted(distinct_friction))
    ax.set_yticklabels(sorted(distinct_mass))
    plt.savefig(title_plot, dpi=600)
    del ax


def plot_surface(
    info: List[Tuple[Dict[str, float], float]],
    title_plot: str = "surface_mass_friction.png",
):
    list_of_mass = [params[0]["mass_coef"] for params in info]
    list_of_friction = [params[0]["friction_coef"] for params in info]

    rewards = [params[1] for params in info]
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    surf = ax.plot_trisurf(
        list_of_mass,
        list_of_friction,
        rewards,
        cmap=plt.cm.jet,
        linewidth=0.01,
        vmin=-600,
        vmax=15000,
    )
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.savefig(title_plot, dpi=600)
