from pathlib import Path

import torch
from torch_geometric.data import Data

from ltsgns_mp.recording.visualizations.plotly_visualizations import visualize_trajectory as visualize_trajectory_plotly
from ltsgns_mp.util.own_types import ConfigDict


class GraphVisualizer:

    def __init__(self, visualization_config: ConfigDict):
        self.config: ConfigDict = visualization_config

        self._backend = visualization_config.backend
        if self._backend == "plotly":
            self._vis_fn = visualize_trajectory_plotly
        else:
            raise ValueError(f"Unknown backend {self._backend}.")

    def visualize_trajectory(self,
                             eval_traj: Data,
                             predicted_traj: torch.Tensor,
                             ):
        """
        Visualize a trajectory of graphs.
        Args:
            eval_traj: Data object containing the ground truth trajectory, the context and optional point cloud data.
            Has the anchor graph at the current timestep with the edge information
            predicted_traj: Tensor of shape (num_timesteps, num_nodes, dim_world) containing the predicted positions of the
            mesh nodes.
        Returns: The path to a gif file or the visualization as a plotly Figure depending on the output type config.

        """
        fig = self._vis_fn(eval_traj=eval_traj,
                           predicted_traj=predicted_traj,
                           visualization_config=self.config,
                           )
        return fig

    def save_animation(self, animation, save_path: str | Path, filename):
        import plotly.graph_objects as go
        import plotly.offline as pyo
        from matplotlib import animation as plt_anim

        if isinstance(save_path, str):
            save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        file_path = save_path / filename

        if isinstance(animation, go.Figure):  # plotly
            file_path = str(file_path)
            if file_path.endswith(".gif"):
                # replace with .html
                file_path = file_path[:-4] + ".html"
            pyo.plot(animation, filename=file_path, auto_open=False)
        elif isinstance(animation, plt_anim.FuncAnimation):  # matplotlib
            animation.save(file_path, writer='pillow',
                           fps=self.config.fps,
                           dpi=self.config.matplotlib.dpi
                           )

        else:
            raise ValueError(f"Unknown return type of visualization function: {type(animation)}")

    @property
    def visualization_config(self) -> ConfigDict:
        return self.config
