import plotly.express as px
import plotly.graph_objects as go
from pandas import DataFrame

from src.utils.constants import ATOM_COLOR_MAP


def plot_occ_pointcloud(
    df_occs: DataFrame,
    color_column: str,
    title: str,
    df_atoms: DataFrame = None,
    marker_size: float = 5,
):
    fig = px.scatter_3d(
        df_occs,
        x="X",
        y="Y",
        z="Z",
        color=color_column,
        title=title,
        color_discrete_map=ATOM_COLOR_MAP,
        category_orders={color_column: list(ATOM_COLOR_MAP.keys())},
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[0, 1], autorange=False),
            yaxis=dict(range=[0, 1], autorange=False),
            zaxis=dict(range=[0, 1], autorange=False),
            aspectmode="manual",
            aspectratio=dict(x=1, y=1, z=1),
        ),
        legend=dict(title="Occ/Atom", x=0, y=1, traceorder="normal"),
    )
    fig.update_traces(marker=dict(size=marker_size))

    if df_atoms is not None:
        # Add additional points using go.Scatter3d
        for atom_type in df_atoms["atom_type"].unique():
            atom_df = df_atoms[df_atoms["atom_type"] == atom_type]
            atom_trace = go.Scatter3d(
                x=atom_df["X"],
                y=atom_df["Y"],
                z=atom_df["Z"],
                mode="markers",
                marker=dict(
                    size=10,
                    symbol="diamond-open",
                    color=ATOM_COLOR_MAP[atom_type],
                ),
                text=atom_df["atom_type"],
                name=f"{atom_type} Atom",
            )
            fig.add_trace(atom_trace)

    return fig


def plot_density_point_cloud(
    df_densities: DataFrame,
    color_column: str,
    title: str,
    df_atoms: DataFrame,
    marker_size: float = 5,
):
    fig = px.scatter_3d(
        df_densities,
        x="X",
        y="Y",
        z="Z",
        color=color_column,
        color_continuous_scale="Viridis",
        title=title,
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[0, 1], autorange=False),
            yaxis=dict(range=[0, 1], autorange=False),
            zaxis=dict(range=[0, 1], autorange=False),
            aspectmode="manual",
            aspectratio=dict(x=1, y=1, z=1),
        ),
        legend=dict(title="Dens", x=0, y=1, traceorder="normal"),
    )
    fig.update_traces(marker=dict(size=marker_size))

    if df_atoms is not None:
        # Add additional points using go.Scatter3d
        for atom_type in df_atoms["atom_type"].unique():
            atom_df = df_atoms[df_atoms["atom_type"] == atom_type]
            atom_trace = go.Scatter3d(
                x=atom_df["X"],
                y=atom_df["Y"],
                z=atom_df["Z"],
                mode="markers",
                marker=dict(
                    size=10,
                    symbol="diamond-open",
                    color=ATOM_COLOR_MAP[atom_type],
                ),
                text=atom_df["atom_type"],
                name=f"{atom_type} Atom",
            )
            fig.add_trace(atom_trace)

    return fig
