import numpy as np
import matplotlib.pyplot as plt
from ripser import ripser
from persim import plot_diagrams
from matplotlib.patches import Circle
from scipy.spatial import distance_matrix
import networkx as nx

n_points = 30  # Number of points to generate
n_petals = 3  # Number of petals
radius_multiplier = 2  # Controls the size of the petals
theta = np.linspace(0, 2 * np.pi, n_points)
r = radius_multiplier * np.cos(n_petals * theta)  # Flower-shaped radius
x = r * np.cos(theta)
y = r * np.sin(theta)
data = np.column_stack([x, y])

def plot_filtration(data, r, ax, show_complex=True):
    ax.set_aspect('equal', 'box')
    
    ax.scatter(data[:, 0], data[:, 1], c='black', s=50)
    
    for point in data:
        circle = Circle(point, r, color='aqua', alpha=0.5)
        ax.add_patch(circle)

    if show_complex:
        dist_matrix = distance_matrix(data, data)
        G = nx.Graph()
        
        for i in range(len(data)):
            for j in range(i + 1, len(data)):
                if dist_matrix[i, j] <= 2 * r:  # Within twice the radius
                    G.add_edge(i, j)
        
        for edge in G.edges():
            i, j = edge
            ax.plot([data[i, 0], data[j, 0]], [data[i, 1], data[j, 1]], 'k-')

    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')
    ax.set_title([])

radii = [0, 0.2, 0.3]

fig, axes = plt.subplots(1, len(radii), figsize=(12, 7))

for i, r in enumerate(radii):
    plot_filtration(data, r, axes[i])
    axes[i].set_title(f'r = {r}')

plt.tight_layout()
plt.savefig("pd_fil.png")
plt.close()

diagrams = ripser(data)['dgms']

plt.figure(figsize=(8, 3))
plt.subplot(1, 2, 1)
plot_diagrams(diagrams, show=True)
# plt.title("Persistence Diagram")

# plt.subplot(1, 2, 2)
# for i, (start, end) in enumerate(diagrams[1]):
#     plt.plot([start, end], [i, i], 'b-')
plt.subplot(1, 2, 2)
for i, (start, end) in enumerate(diagrams[0]):  # H0 (connected components)
    plt.plot([start, end], [i, i], 'b-', label=r'$\beta_0$' if i == 0 else "")

for i, (start, end) in enumerate(diagrams[1]):  # H1 (loops)
    plt.plot([start, end], [i + len(diagrams[0]), i + len(diagrams[0])], 'r-', label=r'$\beta_1$' if i == 0 else "")

plt.xlabel("Time")
plt.ylabel("Barcodes")
plt.legend()
plt.tight_layout()
plt.savefig("pd.png")
plt.close()
