import torch
import numpy as np
import matplotlib.pyplot as plt

dpi = 100

def plot_dgm(dgm, size=5, filename='', save=False):
    """
    Plots a persistence diagram using Matplotlib.

    Input
    ----------
    dgm : list of tuples of floats, each tuple is a point, i.e., (birth, death) in the diagram.
    """
    _inf = float('inf')

    finite_births = []
    finite_deaths = []
    inf_births = []

    for (b, d) in dgm:
        if d == _inf:
            inf_births.append(b)
        else:
            finite_births.append(b)
            finite_deaths.append(d)
    
    if not finite_births and not inf_births:
        print('No diagram to plot.')
        return
    
    fig = plt.figure(figsize=(4, 4), dpi=dpi)
    max_ = max(finite_births+finite_deaths+inf_births)
    start = .01
    plt.xlim(start, max_ * 1.1)
    plt.ylim(start, max_ * 1.1)
    # infinity_value = max_ * 1.05
    plt.scatter(finite_births, finite_deaths, c='steelblue', s=size, marker='o')
    plt.xlabel('Birth')
    plt.ylabel('Death')
    plt.title('Persistence Diagram')
    # plt.legend()
    if save:
        plt.savefig(f"fig/{filename}", dpi=300)
        plt.close(fig)
    else:
        plt.show()
    # plt.show()
 

def plot_pts(pts: torch.Tensor, filename='', save=False):
    pts_np = pts.detach().numpy()
    # Create a scatter plot
    plt.figure(figsize=(4, 4), dpi=dpi)
    # Set the x and y axis limits
    plt.xlim(-1, 1)
    plt.ylim(-1, 1)
    plt.scatter(pts_np[:, 0], pts_np[:, 1], alpha=0.5, s=15, marker='o')
    plt.gca().set_aspect('equal', adjustable='box') 
    # plt.title('Random 2D Points')
    # plt.xlabel('X')
    # plt.ylabel('Y')
    # plt.show()
    if save:
        plt.savefig(f"fig/{filename}", dpi=300)
        plt.close(fig)
    else:
        plt.show()

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import display, HTML

def create_animation(pts_logs, figsize=(4, 4), size=15, interval=200, repeat_delay=2000):
    # plt.figure(figsize=figsize)
    fig, ax = plt.subplots(figsize=figsize)
    ax = plt.axis([-1, 1, -1, 1])  # Set the range to [-1, 1] for both x and y axes
    x_all = np.stack([pts for pts in pts_logs], axis=0)

    def _update_plot(i):
        plt.clf()
        plt.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, s=size, marker='o')
        plt.axis(ax)
        return 1

    ani = animation.FuncAnimation(plt.gcf(), _update_plot, frames=x_all.shape[0], interval=interval, repeat_delay=repeat_delay)
    html = HTML(ani.to_jshtml())
    display(html)
    plt.close(fig)


def create_dgm_animation(dgm_logs, dgm_dim=1, figsize=(4.2, 4), size=15, interval=200, repeat_delay=2000):
    fig, ax = plt.subplots(figsize=figsize)
    # plt.figure(figsize=figsize, dpi=dpi)
    x_all = [dgms_[dgm_dim] for dgms_ in dgm_logs]
    
    # Handle cases where there might be empty diagrams
    max_ = max([np.max(x) for x in x_all if len(x) > 0], default=1)
    start = -.01
    def _update_plot(i):
        plt.clf()
        plt.xlim(start, max_ * 1.1)
        plt.ylim(start, max_ * 1.1)
        plt.plot([start, max_ * 1.1], [start, max_ * 1.1], color='tomato', linestyle='--')
        plt.xlabel('Birth')
        plt.ylabel('Death')
        plt.title('Persistence Diagram')
        plt.gca().set_aspect('equal', adjustable='box') 
        if x_all[i].size > 0:
            plt.scatter(x_all[i][:, 0], x_all[i][:, 1], alpha=0.5, s=size, marker='o')
    
    ani = animation.FuncAnimation(plt.gcf(), _update_plot, frames=len(x_all), interval=interval, repeat_delay=repeat_delay)
    html = HTML(ani.to_jshtml())
    display(html)
    plt.close(fig)


def save_animation(pts_logs, filename, indexes, save=False, figsize=(4, 4), size=15, dpi=300,title_fontsize=20):
    num_plots = len(indexes)
    fig, axes = plt.subplots(1, num_plots, figsize=(figsize[0]*num_plots, figsize[1]))
    
    for i, ax in enumerate(axes):
        pts = pts_logs[indexes[i]]
        ax.scatter(pts[:, 0], pts[:, 1], alpha=0.5, s=size, marker='o')
        ax.set_aspect('equal')
        ax.axis('off')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f't = {indexes[i]+1}', fontsize=title_fontsize)

    plt.tight_layout()
    if save:
        plt.savefig(f"fig/{filename}", dpi=300)
        plt.close(fig)
    else:
        plt.show()

        
def save_dgm_animation(dgm_logs, dgm_dim, filename, indexes, save=False, figsize=(4, 4), size=15, dpi=300, title_fontsize=12):
    num_plots = len(indexes)
    # fig, axes = plt.subplots(1, num_plots, figsize=(figsize[0]*num_plots, figsize[1]))
    fig, axes = plt.subplots(1, num_plots, figsize=(figsize[0]*num_plots, figsize[1]))
    # plt.figure(figsize=figsize, dpi=dpi)
    x_all = [dgms_[dgm_dim] for dgms_ in dgm_logs]
    
    # Handle cases where there might be empty diagrams
    max_ = max([np.max(x) for x in x_all if len(x) > 0], default=1)
    
    for i, ax in enumerate(axes):
        pts = dgm_logs[indexes[i]]
        # ax.scatter(pts[:, 0], pts[:, 1], alpha=0.5, s=size, marker='o')
        start = -.01
        ax.set_xlim(start, max_ * 1.1)  # Fix: Change ax.xlim to ax.set_xlim
        ax.set_ylim(start, max_ * 1.1)  # Fix: Change ax.ylim to ax.set_ylim
        ax.plot([start, max_ * 1.1], [start, max_ * 1.1], color='tomato', linestyle='--')
        if x_all[i].size > 0:
            ax.scatter(x_all[indexes[i]][:, 0], x_all[indexes[i]][:, 1], alpha=0.5, s=size, marker='o')
        ax.set_aspect('equal')
        # ax.axis('off')
        # ax.set_xticks([])
        # ax.set_yticks([])
        ax.set_title(f't = {indexes[i]+1}', fontsize=title_fontsize)
    
    plt.tight_layout()
    if save:
        plt.savefig(f"fig/{filename}", dpi=300)
        plt.close(fig)
    else:
        plt.show()