import plotly.express as px
import plotly.graph_objects as go
import os


def plot_latent_space_3d(umap_data, labels, dataset_name: str, method: str = "UMAP"):
    """
    Plot 3D latnet space using plotly.
    """
    
    fig = px.scatter_3d(
        x=umap_data[:, 0], 
        y=umap_data[:, 1], 
        z=umap_data[:, 2], 
        color=labels.astype(str), 
        labels={'x': f'{method}-1', 'y': f'{method}-2', 'z': f'{method}-3', 'color': 'Class'},
        title=f"3D {method} Visualization",
        height=700,
        width=700,
    )
    class_labels = set(labels)
    class_labels.remove("Boundary Graphs")
    class_labels.discard("Our Boundary Graphs")
    class_labels = list(class_labels)
    fig.update_traces(marker=dict(size=2, opacity=0.7))
    fig.update_layout(legend=dict(itemsizing='constant'))
    base_svg = f"./plots/svg/{dataset_name}"
    base_jpg = f"./plots/jpg/{dataset_name}"
    os.makedirs(base_svg, exist_ok=True)
    os.makedirs(base_jpg, exist_ok=True)
    fig.write_image(f"{base_svg}/{method}_{class_labels[0]}_{class_labels[1]}_3D.svg", format="svg")
    fig.write_image(f"{base_jpg}/{method}_{class_labels[0]}_{class_labels[1]}_3D.jpg", format="jpg")
    # fig.show()


def plot_latent_space_2d(umap_data, labels, dataset_name: str, method: str = "UMAP"):
    """
    Plot 3D UMAP using plotly.    
    """

    fig = px.scatter(
        x=umap_data[:, 0], 
        y=umap_data[:, 1], 
        color=labels.astype(str), 
        labels={'x': f'{method}-1', 'y': f'{method}-2', 'color': 'Class'},
        title=f"2D {method} Visualization",
        height=700,
        width=700,
    )
    class_labels = set(labels)
    class_labels.remove("Boundary Graphs")
    class_labels.discard("Our Boundary Graphs")
    class_labels = list(class_labels)
    fig.update_traces(marker=dict(size=3, opacity=0.7))
    fig.update_layout(legend=dict(itemsizing='constant'))
    base_svg = f"./plots/svg/{dataset_name}"
    base_jpg = f"./plots/jpg/{dataset_name}"
    os.makedirs(base_svg, exist_ok=True)
    os.makedirs(base_jpg, exist_ok=True)
    fig.write_image(f"{base_svg}/{method}_{class_labels[0]}_{class_labels[1]}_2D.svg", format="svg")
    fig.write_image(f"{base_jpg}/{method}_{class_labels[0]}_{class_labels[1]}_2D.jpg", format="jpg")
    # fig.show()
