
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import torch

from geocalib.perspective_fields import get_perspective_field
from geocalib.utils import rad2deg



def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True):

    n = len(imgs)
    if not isinstance(cmaps, (list, tuple)):
        cmaps = [cmaps] * n

    ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n
    figsize = [sum(ratios) * 4.5, 4.5]
    fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios})
    if n == 1:
        axs = [axs]
    for i, (img, ax) in enumerate(zip(imgs, axs)):
        ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
        ax.set_axis_off()
        if titles:
            ax.set_title(titles[i])
    fig.tight_layout(pad=pad)

    return fig


def plot_image_grid(
    imgs,
    titles=None,
    cmaps="gray",
    dpi=100,
    pad=0.5,
    fig=None,
    adaptive=True,
    figs=3.0,
    return_fig=False,
    set_lim=False,
) -> plt.Figure:

    nr, n = len(imgs), len(imgs[0])
    if not isinstance(cmaps, (list, tuple)):
        cmaps = [cmaps] * n

    if adaptive:
        ratios = [i.shape[1] / i.shape[0] for i in imgs[0]]  # W / H
    else:
        ratios = [4 / 3] * n

    figsize = [sum(ratios) * figs, nr * figs]
    if fig is None:
        fig, axs = plt.subplots(
            nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
        )
    else:
        axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
        fig.figure.set_size_inches(figsize)

    if nr == 1 and n == 1:
        axs = [[axs]]
    elif n == 1:
        axs = axs[:, None]
    elif nr == 1:
        axs = [axs]

    for j in range(nr):
        for i in range(n):
            ax = axs[j][i]
            ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
            ax.set_axis_off()
            if set_lim:
                ax.set_xlim([0, imgs[j][i].shape[1]])
                ax.set_ylim([imgs[j][i].shape[0], 0])
            if titles:
                ax.set_title(titles[j][i])
    if isinstance(fig, plt.Figure):
        fig.tight_layout(pad=pad)
    return (fig, axs) if return_fig else axs


def add_text(
    idx,
    text,
    pos=(0.01, 0.99),
    fs=15,
    color="w",
    lcolor="k",
    lwidth=4,
    ha="left",
    va="top",
    axes=None,
    **kwargs,
):

    if axes is None:
        axes = plt.gcf().axes

    ax = axes[idx]

    t = ax.text(
        *pos,
        text,
        fontsize=fs,
        ha=ha,
        va=va,
        color=color,
        transform=ax.transAxes,
        zorder=5,
        **kwargs,
    )
    if lcolor is not None:
        t.set_path_effects(
            [
                path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
                path_effects.Normal(),
            ]
        )
    return t


def plot_heatmaps(
    heatmaps,
    vmin=-1e-6,  # include negative zero
    vmax=None,
    cmap="Spectral",
    a=0.5,
    axes=None,
    contours_every=None,
    contour_style="solid",
    colorbar=False,
):

    if axes is None:
        axes = plt.gcf().axes
    artists = []

    for i in range(len(axes)):
        a_ = a if isinstance(a, float) else a[i]

        if isinstance(heatmaps[i], torch.Tensor):
            heatmaps[i] = heatmaps[i].cpu().numpy()

        alpha = a_
        # Plot the heatmap
        art = axes[i].imshow(
            heatmaps[i],
            alpha=alpha,
            vmin=vmin,
            vmax=vmax,
            cmap=cmap,
        )
        if colorbar:
            cmax = vmax or np.percentile(heatmaps[i], 99)
            art.set_clim(vmin, cmax)
            cbar = plt.colorbar(art, ax=axes[i])
            artists.append(cbar)

        artists.append(art)

        if contours_every is not None:
            # Add contour lines to the heatmap
            contour_data = np.arange(vmin, vmax + contours_every, contours_every)

            # Get the colormap colors for contour lines
            contour_colors = [
                plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level))
                for level in contour_data
            ]
            contours = axes[i].contour(
                heatmaps[i],
                levels=contour_data,
                linewidths=2,
                colors=contour_colors,
                linestyles=contour_style,
            )

            contours.set_clim(vmin, vmax)

            fmt = {
                level: f"{label}°"
                for level, label in zip(contour_data, contour_data.astype(int).astype(str))
            }
            t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white")

            for label in t:
                label.set_path_effects(
                    [
                        path_effects.Stroke(linewidth=1, foreground="k"),
                        path_effects.Normal(),
                    ]
                )
            artists.append(contours)

    return artists


def plot_horizon_lines(
    cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None
):

    if not isinstance(line_colors, list):
        line_colors = [line_colors] * len(cameras)

    if not isinstance(styles, list):
        styles = [styles] * len(cameras)

    fig = plt.gcf()
    ax = fig.gca() if ax is None else ax

    if isinstance(ax, plt.Axes):
        ax = [ax] * len(cameras)

    assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}"

    for i in range(len(cameras)):
        _, lat = get_perspective_field(cameras[i], gravities[i])
        # horizon line is zero level of the latitude field
        lat = lat[0, 0].cpu().numpy()
        contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i])
        for contour_line in contours.collections:
            contour_line.set_linestyle(styles[i])


def plot_vector_fields(
    vector_fields,
    cmap="lime",
    subsample=15,
    scale=None,
    lw=None,
    alphas=0.8,
    axes=None,
):

    if axes is None:
        axes = plt.gcf().axes

    vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields]

    artists = []

    H, W = vector_fields[0].shape[-2:]
    if scale is None:
        scale = subsample / min(H, W)

    if lw is None:
        lw = 0.1 / subsample

    if alphas is None:
        alphas = np.ones_like(vector_fields[0][0])
        alphas = np.stack([alphas] * len(vector_fields), 0)
    elif isinstance(alphas, float):
        alphas = np.ones_like(vector_fields[0][0]) * alphas
        alphas = np.stack([alphas] * len(vector_fields), 0)
    else:
        alphas = np.array(alphas)

    subsample = min(W, H) // subsample
    offset_x = ((W % subsample) + subsample) // 2

    samples_x = np.arange(offset_x, W, subsample)
    samples_y = np.arange(int(subsample * 0.9), H, subsample)

    x_grid, y_grid = np.meshgrid(samples_x, samples_y)

    for i in range(len(axes)):

        vector_field = vector_fields[i]

        a = alphas[i][samples_y][:, samples_x]
        x, y = vector_field[:, samples_y][:, :, samples_x]

        c = cmap
        if not isinstance(cmap, str):
            c = cmap[i][samples_y][:, samples_x].reshape(-1, 3)

        s = scale * min(H, W)
        arrows = axes[i].quiver(
            x_grid,
            y_grid,
            x,
            y,
            scale=s,
            scale_units="width" if H > W else "height",
            units="width" if H > W else "height",
            alpha=a,
            color=c,
            angles="xy",
            antialiased=True,
            width=lw,
            headaxislength=3.5,
            zorder=5,
        )

        artists.append(arrows)

    return artists


def plot_latitudes(
    latitude,
    is_radians=True,
    vmin=-90,
    vmax=90,
    cmap="seismic",
    contours_every=15,
    alpha=0.4,
    axes=None,
    **kwargs,
):

    if axes is None:
        axes = plt.gcf().axes

    assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}"
    lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude
    return plot_heatmaps(
        lat,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        a=alpha,
        axes=axes,
        contours_every=contours_every,
        **kwargs,
    )


def plot_perspective_fields(cameras, gravities, axes=None, **kwargs):

    if axes is None:
        axes = plt.gcf().axes

    assert len(axes) == len(cameras), f"{len(axes)}, {len(cameras)}"

    artists = []
    for i in range(len(axes)):
        up, lat = get_perspective_field(cameras[i], gravities[i])
        artists += plot_vector_fields([up[0]], axes=[axes[i]], **kwargs)
        artists += plot_latitudes([lat[0, 0]], axes=[axes[i]], **kwargs)

    return artists


def plot_confidences(
    confidence,
    as_log=True,
    vmin=-4,
    vmax=0,
    cmap="turbo",
    alpha=0.4,
    axes=None,
    **kwargs,
):

    if axes is None:
        axes = plt.gcf().axes

    assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}"

    if as_log:
        confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence]

    # normalize to [0, 1]
    confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence]
    return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs)


def save_plot(path, **kw):
    """Save the current figure without any white margin."""
    plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
