from typing import Callable, Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from utils.plots import subplots


def density_plots(
    df: pd.DataFrame,
    samples: dict[str, pd.DataFrame],
    pdfs: dict[str, Callable[[str, np.ndarray], np.ndarray]],
    *,
    N_qs: int = 20,
    qs_include_extremes: bool = True,
    hist_kwargs: dict[str, Any] = dict(),
    contour_kwargs: dict[str, Any] = dict()
) -> tuple[plt.Figure, Any]:
    df = df[[
        col for col, dtype in df.dtypes.items()
        if pd.api.types.is_numeric_dtype(dtype)
    ]]

    fig, axes_mat = subplots(n=df.shape[1] ** 2, squeeze=True)
    for j, (axes, colj) in enumerate(zip(axes_mat, df.columns)):
        for i, (ax, coli) in enumerate(zip(axes, df.columns)):
            if i == j:
                ax.hist(df[coli], label='data', density=True, **hist_kwargs)
                qs = np.linspace(0, 1, N_qs + 2 * (1 - qs_include_extremes))
                if not qs_include_extremes:
                    qs = qs[1:-1]
                qs = df[coli].quantile(qs).values

                # Plot densities
                for k, (label, pdf) in enumerate(pdfs.items(), 1):
                    ax.plot(qs, pdf(coli, qs), label=label, color=f'C{k}')

                if i == 0:  # only in the first one
                    ax.legend()
            else:
                sns.kdeplot(
                    data=df, x=coli, y=colj,
                    color='C0', label='data', ax=ax
                )
                for k, (label, sample) in enumerate(samples.items(), 1):
                    sns.kdeplot(
                        data=sample, x=coli, y=colj,
                        color=f'C{k}', label=label, ax=ax,
                        legend=True,
                        **contour_kwargs
                    )

    return fig, axes_mat
