from collections.abc import Sequence, Mapping

import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt

from bayesflow.utils import logging
from bayesflow.utils.dict_utils import dicts_to_arrays


def pairs_samples(
    samples: Mapping[str, np.ndarray] | np.ndarray = None,
    variable_keys: Sequence[str] = None,
    variable_names: Sequence[str] = None,
    height: float = 2.5,
    color: str | tuple = "#132a70",
    alpha: float = 0.9,
    label_fontsize: int = 14,
    tick_fontsize: int = 12,
    **kwargs,
) -> sns.PairGrid:
    """
    A more flexible pair plot function for multiple distributions based upon
    collected samples.

    Parameters
    ----------
    samples     : dict[str, Tensor], default: None
        Sample draws from any dataset
    variable_keys       : list or None, optional, default: None
       Select keys from the dictionary provided in samples.
       By default, select all keys.
    variable_names    : list or None, optional, default: None
        The parameter names for nice plot titles. Inferred if None
    height      : float, optional, default: 2.5
        The height of the pair plot
    color       : str, optional, default : '#8f2727'
        The color of the plot
    alpha       : float in [0, 1], optional, default: 0.9
        The opacity of the plot
    label_fontsize    : int, optional, default: 14
        The font size of the x and y-label texts (parameter names)
    tick_fontsize     : int, optional, default: 12
        The font size of the axis ticklabels
    **kwargs    : dict, optional
        Additional keyword arguments passed to the sns.PairGrid constructor
    """

    plot_data = dicts_to_arrays(
        estimates=samples,
        variable_keys=variable_keys,
        variable_names=variable_names,
    )

    g = _pairs_samples(
        plot_data=plot_data,
        height=height,
        color=color,
        alpha=alpha,
        label_fontsize=label_fontsize,
        tick_fontsize=tick_fontsize,
    )

    return g


def _pairs_samples(
    plot_data: dict,
    height: float = 2.5,
    color: str | tuple = "#132a70",
    color2: str | tuple = "gray",
    alpha: float = 0.9,
    label_fontsize: int = 14,
    tick_fontsize: int = 12,
    legend_fontsize: int = 14,
    **kwargs,
) -> sns.PairGrid:
    # internal version of pairs_samples creating the seaborn plot

    # Parameters
    # ----------
    # plot_data   : output of bayesflow.utils.dict_utils.dicts_to_arrays
    # other arguments are documented in pairs_samples

    estimates_shape = plot_data["estimates"].shape
    if len(estimates_shape) != 2:
        raise ValueError(
            f"Samples for a single distribution should be a matrix, but "
            f"your samples array has a shape of {estimates_shape}."
        )

    variable_names = plot_data["estimates"].variable_names

    # Convert samples to pd.DataFrame
    if plot_data["priors"] is not None:
        # differentiate posterior from prior draws
        # row bind posterior and prior draws
        samples = np.vstack((plot_data["priors"], plot_data["estimates"]))
        data_to_plot = pd.DataFrame(samples, columns=variable_names)

        # ensure that the source of the samples is stored
        source_prior = np.repeat("Prior", plot_data["priors"].shape[0])
        source_post = np.repeat("Posterior", plot_data["estimates"].shape[0])
        data_to_plot["_source"] = np.concatenate((source_prior, source_post))
        data_to_plot["_source"] = pd.Categorical(data_to_plot["_source"], categories=["Prior", "Posterior"])

        # initialize plot
        g = sns.PairGrid(
            data_to_plot,
            height=height,
            hue="_source",
            palette=[color2, color],
            **kwargs,
        )

        # ensures that color doesn't overwrite palette
        color = None

    else:
        # plot just the one set of distributions
        data_to_plot = pd.DataFrame(plot_data["estimates"], columns=variable_names)

        # initialize plot
        g = sns.PairGrid(data_to_plot, height=height, **kwargs)

    # add histograms + KDEs to the diagonal
    g.map_diag(
        histplot_twinx,
        fill=True,
        kde=True,
        color=color,
        alpha=alpha,
        stat="density",
        common_norm=False,
    )

    # add scatterplots to the upper diagonal
    g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)

    # add KDEs to the lower diagonal
    try:
        g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha)
    except Exception as e:
        logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
        g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)

    # need to add legend here such that colors are recognized
    if plot_data["priors"] is not None:
        g.add_legend(fontsize=legend_fontsize, loc="center right")
        g._legend.set_title(None)

    # Generate grids
    dim = g.axes.shape[0]
    for i in range(dim):
        for j in range(dim):
            g.axes[i, j].grid(alpha=0.5)
            g.axes[i, j].set_axisbelow(True)

    dim = g.axes.shape[0]
    for i in range(dim):
        # Modify tick sizes
        for j in range(i + 1):
            g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
            g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)

        # adjust font size of labels
        # the labels themselves remain the same as before, i.e., variable_names
        g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
        g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)

    # Return figure
    g.tight_layout()

    return g


# create a histogram plot on a twin y axis
# this ensures that the y scaling of the diagonal plots
# in independent of the y scaling of the off-diagonal plots
def histplot_twinx(x, **kwargs):
    # Create a twin axis
    ax2 = plt.gca().twinx()

    # create a histogram on the twin axis
    sns.histplot(x, **kwargs, ax=ax2)

    # make the twin axis invisible
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)
    ax2.set_ylabel("")
    ax2.set_yticks([])
    ax2.set_yticklabels([])

    return None
