import logging
import os
import sys
import functools
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
import scipy.spatial
import matplotlib.pyplot as plt
import matplotlib.patches
import matplotlib
import seaborn

import mpl_toolkits.mplot3d
import mpl_toolkits.mplot3d.art3d
import mpl_toolkits

import tools

if False:
    # palette = seaborn.color_palette("colorblind", 20)
    palette = seaborn.color_palette("dark", 12)
    seaborn.palplot(palette)

USE_HARDCODED_COLORS = True
COLORS = ("b", "g", "r", "c", "m", "y")
ALPHAS = np.linspace(0, 1, 6)
SEABORN_PALETTE = "colorblind"


Color = Union[str, Tuple[float, float, float]]
Plotlims = Tuple[float, float]
FigAx = Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]

logger = logging.getLogger(__name__)


def _vec(x: np.ndarray) -> np.ndarray:
    return np.reshape(x, (-1, 1))


def _repeat_to_length_n(x: list, n: int) -> list:
    lenx = len(x)
    times_to_repeat = int(np.ceil(n / lenx))
    repeated = x * times_to_repeat
    return repeated[:n]


def get_default_interactive_backend() -> str:
    assert sys.platform == 'darwin', "Need to generalize this if not running on OSX" \

    return 'MacOSX'


def plot_list_of_v_forms(
    ax: matplotlib.axes.Axes,
    plist: List[np.ndarray],
    alpha: float,
    colors: List[Color]) -> None:
    # assert len(plist) > 0, "Probably do not want to call this with empty plist"
    num_colors = len(plist)
    colors = get_palette_of_length(num_colors)
    xlim, ylim = get_plotlims_from_list_of_v_forms(plist)
    for idx, v_form in enumerate(plist):
        convex_hull_plot(ax, v_form, xlim, ylim, colors[idx], alpha)


def _close_rays_with_bounds(
    v_repr: np.ndarray, xlim: Tuple[float, float], ylim: Tuple[float, float]
) -> np.ndarray:
    """ Very wasteful (unnecessary V -> H -> V) completion of
    rays to polytopes.
    This should be okay, since for now it is only to support plotting
    which is inherently limited to a low number of dimensions.
    """
    assert 3 == v_repr.shape[1], "Assuming 2d for now"

    lower = _vec(np.array([xlim[0], ylim[0]]))
    upper = _vec(np.array([xlim[1], ylim[1]]))

    n = len(upper)
    a_ineq = np.vstack([+1 * np.eye(n), -1 * np.eye(n)])
    b_ineq = np.vstack([+1 * upper, -1 * lower])

    to_augment_ineq = np.hstack([b_ineq, -1 * a_ineq])

    if 0 == v_repr.size:
        v_repr_bounded = v_repr
    else:
        is_rational = True
        h_repr = tools.v_to_h(v_repr, None, is_rational)
        h_repr_augmented = np.vstack([h_repr, to_augment_ineq])
        h_repr_augmented_lin = np.empty((0, h_repr_augmented.shape[1]))
        v_repr_bounded = tools.h_to_v(h_repr_augmented, h_repr_augmented_lin, is_rational)
        assert (0 == v_repr_bounded.size) or np.all(1 == v_repr_bounded[:, 0])
    return v_repr_bounded


def xxyyzz_grid(pred_func: Callable,
                x_min: float,
                x_max: float,
                y_min: float,
                y_max: float) -> Tuple[np.ndarray,
                                       np.ndarray,
                                       np.ndarray]:
    h = 0.01
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    pred_func_arg = np.c_[xx.ravel(), yy.ravel()]
    z = pred_func(pred_func_arg)
    zz = z.reshape(xx.shape)
    return xx, yy, zz


def _hacky_binary_contourf(
    ax: matplotlib.axes.Axes,
    xx: np.ndarray,
    yy: np.ndarray,
    zz: np.ndarray,
    color: Color,
    alpha: float,
) -> None:
    zz = zz.astype(float)
    zz[0 == zz] = +9999.9
    zz = zz + (np.random.rand(*zz.shape) - 0.5) * 0.01
    ax.contourf(xx, yy, zz, levels=[0, 2], colors=[color], alpha=alpha)


def _add_class_to_bruteforced_decision_boundary_plot(
    ax: matplotlib.axes.Axes,
    pred_func: Callable,
    x: np.ndarray,
    y: np.ndarray,
    color: Color,
    ind: int,
) -> None:
    s = 2
    alpha = 0.4
    xlims, ylims = get_plotlims_from_pointcloud(x)
    x_min, x_max = xlims
    y_min, y_max = ylims
    xx, yy, zz = xxyyzz_grid(pred_func, x_min, x_max, y_min, y_max)
    _hacky_binary_contourf(ax, xx, yy, zz, color, alpha)
    rows = (y == ind).flatten()
    ax.scatter(x[rows, 0], x[rows, 1], color=color, s=s)


def get_palette_of_length(num_colors: int) -> list:
    if USE_HARDCODED_COLORS:
        colors = _repeat_to_length_n(COLORS, num_colors)
    else:
        colors = seaborn.color_palette(SEABORN_PALETTE, num_colors)
    return colors


def pred_func_unwrapped(x: np.ndarray,
                        model) -> np.ndarray:
    pred_fun_in = torch.from_numpy(x).type(torch.FloatTensor)
    y = model(pred_fun_in)
    softmaxy = torch.nn.functional.softmax(y, dim=1)
    pred_func_out = softmaxy[:, 0].detach().numpy()
    return pred_func_out


def bruteforced_prob_contour_plot(x: np.ndarray,
                                  y: np.ndarray,
                                  model,
                                  include_contours: bool,
                                  include_axes: bool,
                                  plot_scale: float
) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
    # lw = 0.75

    pred_func = functools.partial(pred_func_unwrapped, model=model)

    cmap = matplotlib.cm.get_cmap("Greys")
    s = 1
    alpha = 0.4
    xlims, ylims = get_plotlims_from_pointcloud(x)
    x_min, x_max = xlims
    y_min, y_max = ylims

    xx, yy, zz = xxyyzz_grid(pred_func, x_min, x_max, y_min, y_max)

    num_colors = 2
    colors = get_palette_of_length(num_colors)

    color0 = colors[0]
    color1 = colors[1]

    rows0 = (0 == y)
    rows1 = (1 == y)

    fig, axs = wrapped_subplot(1, 1, plot_scale)
    ax = axs[0, 0]

    ax.scatter(x[rows0, 0], x[rows0, 1], color=color0, s=s)
    ax.scatter(x[rows1, 0], x[rows1, 1], color=color1, s=s)
    if include_contours:
        ax.contourf(xx, yy, zz, alpha=alpha, cmap=cmap)
    else:
        ax.contourf(xx, yy, zz, alpha=0, cmap=cmap)

    if include_axes:
        _style_plot(ax)

    if not include_axes:
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    return fig, ax


def bruteforced_decision_boundary_plot(x: np.ndarray,
                                       y: np.ndarray,
                                       classifiers: List[Callable],
                                       plot_scale: float
) -> Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
    # lw = 0.75
    fig, axs = wrapped_subplot(1, 1, plot_scale)
    # num_classifiers = len(classifiers)

    num_colors = len(classifiers)
    colors = get_palette_of_length(num_colors)

    ax = axs[0, 0]
    for class_idx, classifier in enumerate(classifiers):
        color = colors[class_idx]
        _add_class_to_bruteforced_decision_boundary_plot(
            ax, classifier, x, y, color, class_idx
        )
    _style_plot(ax)
    return fig, ax


def convex_hull_plot_simple(v_repr: np.ndarray) -> FigAx:
    fig, axs = wrapped_subplot(1, 1)
    ax = axs[0, 0]
    factor = .3
    xlim, ylim = get_plotlims_from_v_form(v_repr, factor)
    convex_hull_plot(ax, v_repr, xlim, ylim)
    return fig, ax


def list_of_convex_hull_plot_simple(lv: List[np.ndarray]) -> FigAx:
    fig, axs = wrapped_subplot(1, 1)
    xlim, ylim = get_plotlims_from_list_of_v_forms(lv)

    lv_nonempty = [x for x in lv if x.size > 0]
    num_v_reprs = len(lv_nonempty)

    palette_name = "dark"
    palette = seaborn.color_palette(palette_name, num_v_reprs)

    alpha = 0.4
    ax = axs[0, 0]
    for idx, v_repr in enumerate(lv_nonempty):
        # idx = 0
        color = palette[idx]
        convex_hull_plot(ax, v_repr, xlim, ylim, color=color, alpha=alpha)

    return fig, ax


def convex_hull_plot_simple_vectorised(
    lv_repr: List[np.ndarray]
) -> Tuple[matplotlib.figure.Figure, np.array]:
    color = "blue"
    num_v_repr = len(lv_repr)
    xlim, ylim = get_plotlims_from_list_of_v_forms(lv_repr)
    assert np.all(np.isfinite(xlim)) and np.all(np.isfinite(ylim))
    fig, axs = wrapped_subplot(1, num_v_repr)
    # for ax, v_repr in zip(axs, lv_repr):
    for idx, v_repr in enumerate(lv_repr):
        ax = axs[0, idx]
        convex_hull_plot(ax, v_repr, xlim, ylim, color=color)
    return fig, axs


def convex_hull_plot(
    ax: matplotlib.axes.Axes,
    v_repr: np.ndarray,
    xlim: Tuple[float, float],
    ylim: Tuple[float, float],
    color: str = "grey",
    alpha: float = 0.2,
) -> None:
    v_repr_closed = _close_rays_with_bounds(v_repr, xlim, ylim)
    points_to_plot = v_repr_closed[:, 1:]
    _convex_hull_plot_kernel(ax, points_to_plot, color, alpha)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    # _style_plot(ax)


def _matrix_rank_0(x: np.ndarray) -> int:
    """ Matrix rank calculation that naturally handles the empty case """
    if 0 == x.size:
        mr = 0
    else:
        mr = np.linalg.matrix_rank(x)
    return mr


def _convex_hull_plot_kernel(
    ax: matplotlib.axes.Axes,
    points: np.ndarray,
    color: str = "grey",
    alpha: float = 0.2,
) -> None:
    """
    Plots a polytope (no rays), assumes (without checking) that the
    points are distinct [only used in the treatment of 1 and 2 point
    sets, though]
    """
    num_points, point_dim = points.shape
    assert 2 == point_dim

    points_rank = _matrix_rank_0(points)
    if num_points >= 3 and (points_rank >= 2):
        hull = scipy.spatial.ConvexHull(points)

        # point_alpha = 0.5
        point_alpha = 0.0
        ax.plot(points[:, 0], points[:, 1], ".", color=color, alpha=point_alpha)
        centre = np.mean(points, 0)
        pts = []
        for pt in points[hull.simplices]:
            pts.append(pt[0].tolist())
            pts.append(pt[1].tolist())

        pts.sort(key=lambda p: np.arctan2(p[1] - centre[1], p[0] - centre[0]))
        pts.insert(len(pts), pts[0])
        pts_array = (np.array(pts) - centre) + centre

        poly = matplotlib.patches.Polygon(pts_array,
                                          facecolor=color,
                                          edgecolor=None,
                                          alpha=alpha)
        ax.add_patch(poly)
    elif points.shape[0] > 1:
        ax.plot(points[:, 0], points[:, 1], lw=4, color=color, alpha=alpha)
    else:
        ax.plot(points[:, 0], points[:, 1], ".", lw=5, color=color, alpha=alpha)


def wrapped_subplot(subplot_rows: int,
                    subplot_cols: int,
                    plot_scale: float = 3.0) -> Tuple[matplotlib.figure.Figure, np.ndarray]:
    fig, axs = plt.subplots(
        subplot_rows,
        subplot_cols,
        figsize=(subplot_cols * plot_scale, subplot_rows * plot_scale),
    )
    axs = np.reshape(axs, (subplot_rows, subplot_cols))
    return fig, axs


def _style_plot(ax) -> None:
    major_axis_grid_color = "black"
    # major_axis_grid_color = "grey"

    axline_lw = 0.50
    major_grid_lw = .50
    minor_grid_lw = .25

    do_minor_ticks = False
    if do_minor_ticks:
        ax.minorticks_on()
        ax.grid(which="minor", color="grey", linewidth=minor_grid_lw)

    ax.grid(which="major", color=major_axis_grid_color, linewidth=major_grid_lw)
    ax.axhline(y=0, color=major_axis_grid_color, linewidth=axline_lw)
    ax.axvline(x=0, color=major_axis_grid_color, linewidth=axline_lw)


def _nudge_out_limits(lim: Plotlims, factor: float) -> Plotlims:
    # Add a bit in each direction to better visualise around
    # the edge of the plot
    assert 0 <= factor <= 0.5, "factor should be a small positive number"
    limdiff = lim[1] - lim[0]
    nudge_out_by = np.nan_to_num(limdiff * factor)

    return lim[0] - nudge_out_by, lim[1] + nudge_out_by


def get_plotlims_from_v_form(v: np.ndarray,
                             factor: float) -> Tuple[Plotlims, Plotlims]:
    if 0 == v.size:
        mins = (+1 * np.inf, +1 * np.inf)
        maxs = (-1 * np.inf, -1 * np.inf)
    else:
        # v = np.atleast_2d(v)
        mins = np.min(v[:, 1:], axis=0)
        maxs = np.max(v[:, 1:], axis=0)

    xlim_direct = (mins[0], maxs[0])
    ylim_direct = (mins[1], maxs[1])

    xlim = _nudge_out_limits(xlim_direct, factor)
    ylim = _nudge_out_limits(ylim_direct, factor)
    return xlim, ylim


def envelope_list_of_plot_limits(ll: List[tuple]) -> Tuple[Plotlims, Plotlims]:
    ll_np = np.array(ll)

    ll_np_x = ll_np[:, 0, :]
    ll_np_y = ll_np[:, 1, :]
    xlims = (np.min(ll_np_x[:, 0]), np.max(ll_np_x[:, 1]))
    ylims = (np.min(ll_np_y[:, 0]), np.max(ll_np_y[:, 1]))

    enveloped_limits = (xlims, ylims)
    return enveloped_limits


def get_plotlims_from_list_of_v_forms(
    lx: List[np.ndarray]
) -> Tuple[Plotlims, Plotlims]:
    individual_plotlims = [get_plotlims_from_v_form(v, 0.0) for v in lx]
    enveloped = envelope_list_of_plot_limits(individual_plotlims)
    plotlims = [_nudge_out_limits(pl, 0.25) for pl in enveloped]
    return plotlims


def get_plotlims_from_pointcloud(xy: np.ndarray) -> Tuple[Plotlims, Plotlims]:
    nr, nc = xy.shape
    assert nr > 0
    assert nc == 2

    mins = np.min(xy, axis=0)
    maxs = np.max(xy, axis=0)

    xlim_direct = (mins[0], maxs[0])
    ylim_direct = (mins[1], maxs[1])

    factor = 0.15
    xlim = _nudge_out_limits(xlim_direct, factor)
    ylim = _nudge_out_limits(ylim_direct, factor)
    return xlim, ylim


def initialise_pgf_plots(texsystem: str, font_family: str) -> None:
    plt.switch_backend("pgf")
    # https://matplotlib.org/users/customizing.html
    pgf_with_rc_fonts = {
        "pgf.texsystem": texsystem,
        "font.family": font_family,
        "font.serif": [],
        "text.usetex": True,
    }
    matplotlib.rcParams.update(pgf_with_rc_fonts)


class Faces:
    # from: https://stackoverflow.com/questions/53816211/plot-3d-connected-prism-matplotlib-based-on-vertices
    def __init__(self, tri, sig_dig=12, method="convexhull"):
        self.method = method
        self.tri = np.around(np.array(tri), sig_dig)
        self.grpinx = list(range(len(tri)))
        norms = np.around([self.norm(s) for s in self.tri], sig_dig)
        _, self.inv = np.unique(norms, return_inverse=True, axis=0)

    def norm(self, sq):
        cr = np.cross(sq[2] - sq[0], sq[1] - sq[0])
        return np.abs(cr / np.linalg.norm(cr))

    def isneighbor(self, tr1, tr2):
        a = np.concatenate((tr1, tr2), axis=0)
        return len(a) == len(np.unique(a, axis=0)) + 2

    def order(self, v):
        if len(v) <= 3:
            return v
        v = np.unique(v, axis=0)
        n = self.norm(v[:3])
        y = np.cross(n, v[1] - v[0])
        y = y / np.linalg.norm(y)
        c = np.dot(v, np.c_[v[1] - v[0], y])
        if self.method == "convexhull":
            h = scipy.spatial.ConvexHull(c)
            return v[h.vertices]
        else:
            mean = np.mean(c, axis=0)
            d = c - mean
            s = np.arctan2(d[:, 0], d[:, 1])
            return v[np.argsort(s)]

    def simplify(self):
        for i, tri1 in enumerate(self.tri):
            for j, tri2 in enumerate(self.tri):
                if j > i:
                    if self.isneighbor(tri1, tri2) and self.inv[i] == self.inv[j]:
                        self.grpinx[j] = self.grpinx[i]
        groups = []
        for i in np.unique(self.grpinx):
            u = self.tri[self.grpinx == i]
            u = np.concatenate([d for d in u])
            u = self.order(u)
            groups.append(u)
        return groups


def simple_convex_hull_plot3d(
    verts: np.ndarray
) -> Tuple[matplotlib.figure.Figure, mpl_toolkits.mplot3d.axes3d.Axes3D]:
    x = verts[:, 0]
    y = verts[:, 1]
    z = verts[:, 2]

    # compute the triangles that make up the convex hull of the data points
    hull = scipy.spatial.ConvexHull(verts)
    triangles = [verts[s] for s in hull.simplices]

    # combine co-planar triangles into a single face
    faces = Faces(triangles, sig_dig=1).simplify()
    fig = plt.figure()
    ax = mpl_toolkits.mplot3d.Axes3D(fig)
    pc = mpl_toolkits.mplot3d.art3d.Poly3DCollection(faces, facecolor="salmon", edgecolor="k", alpha=0.9)
    ax.add_collection3d(pc)

    ax.set_xlim(np.min(x), np.max(x))
    ax.set_ylim(np.min(y), np.max(y))
    ax.set_zlim(np.min(z), np.max(z))

    ax.dist = 10
    ax.azim = 30
    ax.elev = 10

    plt.show()
    return fig, ax


def _finitemin(x: np.ndarray) -> float:
    return np.min(x[np.isfinite(x)])


def _finitemax(x: np.ndarray) -> float:
    return np.max(x[np.isfinite(x)])


def _combine_possibly_infinite_plotlims(pl1: Plotlims,
                                        pl2: Plotlims) -> Plotlims:
    lower = np.array([pl1[0], pl2[0]])
    upper = np.array([pl1[1], pl2[1]])

    combined_lower = _finitemin(lower)
    combined_upper = _finitemax(upper)
    combined = (combined_lower, combined_upper)
    return combined


def plot_decomp(x: np.ndarray,
                xlims: Plotlims,
                ylims: Plotlims,
                vlists: List[list],
                plot_scale: float) -> FigAx:
    alpha = 0.9
    num_colors = len(vlists)
    colors = get_palette_of_length(num_colors)

    palettes = [seaborn.light_palette(x) for x in colors]
    fig, axs = wrapped_subplot(1, 1, plot_scale)
    ax = axs[0, 0]

    for idx, vl in enumerate(vlists):
        # idx = 0; vl = vlists[idx]
        assert type(vl) == list
        vl_v = [_ for _ in vl if _.shape[0] > 0]
        if 0 == len(vl_v):
            continue
        colors = palettes[idx]
        xlim, ylim = get_plotlims_from_list_of_v_forms(vl_v)
        for i, v_form in enumerate(vl_v):
            convex_hull_plot(ax, v_form, xlim, ylim, colors[i % len(colors)], alpha)

    point_xlims, point_ylims = get_plotlims_from_pointcloud(x)

    final_xlims = _combine_possibly_infinite_plotlims(point_xlims, xlims)
    final_ylims = _combine_possibly_infinite_plotlims(point_ylims, ylims)

    ax.set_xlim(*final_xlims)
    ax.set_ylim(*final_ylims)

    # turn_off_axis_and_grid = False
    turn_off_axis_and_grid = True
    if turn_off_axis_and_grid:
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    do_tight_layout = True
    if do_tight_layout:
        fig.tight_layout()
    return fig, ax


def finalize_plot() -> None:
    default_interactive_backend = get_default_interactive_backend()
    plt.switch_backend(default_interactive_backend)


def palette_example():
    num_color = 12
    palette = seaborn.color_palette("dark", num_color)

    fig, axs = wrapped_subplot(1, 1)
    ax = axs[0, 0]

    for idx in range(num_color):
        plot_x = np.ones((10, 1)) * idx
        plot_y = np.arange(10)
        # ax.plot(plot_x, plot_y, color=palette[idx])
        ax.scatter(plot_x, plot_y, color=palette[idx])


def smart_save_fig(fig: matplotlib.figure.Figure,
                   ident: str,
                   fig_format: str,
                   filepath: str) -> str:
    filename = "{}.{}".format(ident, fig_format)
    os.makedirs(filepath, exist_ok=True)
    fig_path = os.path.join(filepath, filename)
    fig.savefig(fig_path)
    return fig_path


if __name__ == "__main__":
    xyz = np.random.rand(5, 3)
    fig, ax = simple_convex_hull_plot3d(xyz)
