from matplotlib import pyplot as plt
import json
import os
import numpy as np
import pandas as pd


def plot_mse(sinrs, mse, title):
    fig, ax = plt.subplots()
    mse_db = np.array(mse)
    mse_db = 10 * np.log10(mse_db)
    plt.plot(sinrs, mse_db)
    ax.grid()
    ax.set_ylabel("MSE [dB]")
    ax.set_xlabel("SIR [dB]")
    ax.set_title(title)
    return fig


def plot_ber(sinrs, ber, title):
    fig, ax = plt.subplots()
    plt.plot(sinrs, ber)
    ax.grid()
    ax.set_ylabel("BER")
    ax.set_xlabel("SIR [dB]")
    ax.set_title(title)
    ax.set_yscale("log")
    ax.set_ylim([1e-5, 1])
    return fig


def plot_results(dataset_title, dataset_path, result_list):
    with open(os.path.join(os.environ["DATASETS"], dataset_path, "evals.json")) as f:
        eval_json = json.load(f)
    with open(os.path.join(os.environ["DATASETS"], dataset_path, "meta.json")) as f:
        sinrs = np.array(json.load(f)["sinr"])
    fig, ax = plt.subplots(1, 2, figsize=(16, 5))
    for result in result_list:
        title = result["title"]
        if "id" in result:
            idx = result["id"]
            assert eval_json[idx]["id"] == idx
            mse_db = np.array(eval_json[idx]["mse"])
            mse_db = 10 * np.log10(mse_db)
            ber = eval_json[idx]["ber"]
        elif "external" in result:
            filename_csv = os.path.join(os.environ["RESULTS"], result["external"] + ".csv")
            filename_pkl = os.path.join(os.environ["RESULTS"], result["external"] + ".pkl")
            if os.path.exists(filename_csv):
                data = np.loadtxt(filename_csv, delimiter=",", max_rows=2)
            elif os.path.exists(filename_pkl):
                data = pd.read_pickle(filename_pkl)
            mse_db = data[0]
            ber = data[1]
        ax[0].plot(sinrs, mse_db, label=title, marker=result["marker"])
        ax[1].plot(sinrs, ber, label=title, marker=result["marker"])
    ax[1].set_yscale("log")
    ax[0].grid()
    ax[1].grid()
    ax[0].set_xlabel("SIR [dB]")
    ax[0].set_ylabel("MSE [dB]")
    ax[1].set_ylim([1e-5, 1])
    ax[1].set_xlabel("SIR [dB]")
    ax[1].set_ylabel("BER")
    lo, hi = ax[0].get_ylim()
    if lo < -60:
        ax[0].set_ylim(-60, hi)
    handles, labels = ax[0].get_legend_handles_labels()
    rows = (len(handles) + 1) // 2
    fig.legend(handles, labels, loc='lower center', ncol=5, bbox_to_anchor=(0.5, 0.02 - 0.04 * rows))
    if dataset_title is not None:
        fig.suptitle(dataset_title)
