import logging
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as grd
import numpy as np
import pandas as pd
import seaborn as sns
from typing import Optional

logger = logging.getLogger(__name__)


def plot_basic(
    ys, name, xs=None, ylabel=None, color: Optional[str] = "r", logy=False, ax=None
):
    if ax is None:
        fig = plt.figure(figsize=(10, 5))
        gs = grd.GridSpec(1, 1)
        ax = plt.subplot(gs[0])
        gs.update(left=0.14, bottom=0.14, top=0.96, right=0.96)
        ax.set_xlabel("epoch", fontsize=20)
        ax.tick_params(axis="both", labelsize=15)

    if ylabel:
        ax.set_ylabel(ylabel, fontsize=20)
    if logy:
        ax.set_yscale("log")

    if xs is None:
        xs = np.arange(len(ys))
    ax.plot(xs, ys, color=color, alpha=0.7, label=name)
    return ax, xs


def plot_param(param_history, param_names, fig_folder, prefix="sb"):
    ax = None
    for pname in param_names:
        ax, _ = plot_basic(ys=param_history[pname], name=pname, ax=ax)

    if fig_folder and param_names:
        plt.savefig(os.path.join(fig_folder, f"{prefix}_{param_names[0]}.pdf"))
        plt.close()


def plot_loss(param_history, fig_folder, prefix="sb", logy=True):
    loss = param_history["loss"]
    aloss = param_history["aloss"]
    ax, xs = plot_basic(ys=loss, name="loss", ylabel="loss", logy=logy)
    ax, xs = plot_basic(
        ys=aloss,
        xs=xs[: len(aloss)] + len(loss) - len(aloss),
        name="aloss",
        ylabel="loss",
        ax=ax,
        color=None,
    )

    if fig_folder:
        plt.savefig(os.path.join(fig_folder, f"loss_{prefix}.pdf"))
        plt.close()

    evidence_triplet = param_history["evidence"]

    evidence = [x for x, _, _ in evidence_triplet]
    evidence_5pct = [x for _, x, _ in evidence_triplet]
    evidence_95pct = [x for _, _, x in evidence_triplet]

    ax, xs = plot_basic(ys=evidence, name="evidence", ylabel="log_prob", logy=logy)
    ax.fill_between(
        xs,
        evidence_5pct,
        evidence_95pct,
        alpha=0.5,
    )

    ax2 = ax.twinx()
    tail = 25
    ax2, xs = plot_basic(
        xs=xs[tail:],
        ys=param_history["rms"][tail:],
        name="rms",
        ylabel="rms",
        ax=ax2,
        logy=logy,
        color="b",
    )

    if fig_folder:
        plt.savefig(os.path.join(fig_folder, f"evidence_{prefix}.pdf"))
        plt.close()

    # ax.plot(
    #     xc[(len(loss) - len(aloss) + 1) :],
    #     # np.abs(np.array(ph.param_history["raloss"])),
    #     param_history["raloss"],
    #     marker="^",
    #     color="b",
    # )
    # plt.savefig(os.path.join(fig_folder, f"loss_ave_rel_{name}.pdf"))


def plot_prediction(
    predicted_dist,
    data=None,
    name="cbnn",
    ax=None,
    fig_folder=None,
    label=None,
    xlabel="spatial coordinates",
    plot_envelope=True,
    logy=False,
    alpha_envelope=0.5,
    hatch=None,
):
    xtest, y_means, y_5pcts, y_95pcts = predicted_dist

    if ax is None:
        fig = plt.figure(figsize=(7, 7))
        gs = grd.GridSpec(1, 1)
        ax = plt.subplot(gs[0])
        gs.update(left=0.14, bottom=0.14, top=0.96, right=0.96)

    if data is not None:
        if isinstance(data, list):
            for xs0, ys0 in data:
                # xs0, ys0 = data
                ax.plot(xs0, ys0, "o", markersize=5, alpha=0.5)
        else:
            xs0, ys0 = data
            ax.plot(xs0, ys0, "o", markersize=5, alpha=0.5)

    if len(y_means.shape) > 1:
        ypack = [
            (y_means[:, k], y_5pcts[:, k], y_95pcts[:, k])
            for k in range(y_means.shape[1])
        ]
        if "dir" in name:
            ypack = ypack[:-1]
    else:
        ypack = [(y_means, y_5pcts, y_95pcts)]

    for ym, y_5pct, y_95pct in ypack:
        ax.plot(
            xtest,
            ym,
            label=label,
        )
        if plot_envelope:
            ax.fill_between(
                xtest,
                y_5pct,
                y_95pct,
                alpha=alpha_envelope,
                hatch=hatch,
            )

    plt.legend(loc="best", fontsize=16, frameon=False)
    plt.xlabel(xlabel, fontsize=20)
    if logy:
        ax.set_yscale("log")
    plt.tick_params(axis="both", labelsize=15)
    if fig_folder is not None:
        plt.savefig(os.path.join(fig_folder, f"prediction_{name}.pdf"))
    return ax


def plot_weights_dist(
    layers,
    ax=None,
    fig_folder=None,
    name="a",
    xlabel="scale magnitude",
):

    if ax is None:
        fig = plt.figure(figsize=(8, 8))

    dfs = []
    for j, layer in enumerate(layers):
        weight_loc, weight_scale, bias_loc, bias_scale = layer
        df = pd.DataFrame(weight_scale.clone().numpy().flatten(), columns=["data"])
        df["hue"] = f"{name}_{j}_weight"
        dfs += [df]
        df = pd.DataFrame(bias_scale.clone().numpy().flatten(), columns=["data"])
        df["hue"] = f"{name}_{j}_bias"
        dfs += [df]

    df = pd.concat(dfs).reset_index()
    xa, xb = [0.0, np.percentile(df.data.values, 99)]
    sns.histplot(
        data=df,
        hue="hue",
        x="data",
        stat="density",
        bins=np.linspace(xa, xb, 31),
        common_norm=False,
        common_bins=True,
        element="step",
    )
    if fig_folder is not None:
        plt.savefig(os.path.join(fig_folder, f"{name}_scales_dist.pdf"))

    return ax
