import pickle
import matplotlib.pyplot as plt
import numpy as np
from shapely.geometry import Polygon
from alphashape import alphashape


def hexbin_plot(points, save_path_png=None, save_path_bin=None, shape=None):
    x, y, z = points

    # Remove top and bottom 1% of z-values as outliers
    z_threshold_hi = np.percentile(z, 99)
    z_threshold_lo = np.percentile(z, 1)
    mask = (z >= z_threshold_lo) & (z <= z_threshold_hi)  # Keep only values below threshold
    x, y, z = x[mask], y[mask], z[mask]

    # Create heightmap figure
    fig, ax = plt.subplots(figsize=(8, 6))

    # 2D heatmap where each pixel gets the max Z value in that bin
    hb = ax.hexbin(x, y, C=z, gridsize=1000, cmap='viridis', reduce_C_function=np.max)

    # Add colorbar
    cb = plt.colorbar(hb, ax=ax, label="Z")

    # Plot alphashape polygon
    if shape is not None:
        if isinstance(shape, Polygon):
            hull_x, hull_y = shape.exterior.xy
            ax.plot(hull_x, hull_y, 'r-')
        elif shape.geom_type == "MultiPolygon":
            for polygon in shape.geoms:
                hull_x, hull_y = polygon.exterior.xy
                ax.plot(hull_x, hull_y, 'r-')

    # Labels and title
    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    if save_path_bin is not None:
        # Save the figure object
        with open(save_path_bin, 'wb') as f:
            pickle.dump(fig, f)
    if save_path_png is not None:
        # Save figure
        plt.savefig(save_path_png, dpi=300)
    
    plt.close()

    return fig


def compute_alphashape(pointcloud, downsample=100, alpha=0.01, save_path=None):
    # pointcloud (3, n)
    xy = pointcloud[:2, :].T[::downsample]
    print(len(xy))
    concave_hull = alphashape(xy, alpha=alpha)
    if save_path is not None:
        with open(save_path, 'wb') as f:
            pickle.dump(concave_hull, f)

    return concave_hull