import os
from pathlib import Path

import plotly

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.recording.loggers.abstract_logger import AbstractLogger
from ltsgns_mp.recording.visualizations.plotly_visualizations import get_frame_duration
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import *
from ltsgns_mp.recording.visualizations.graph_visualizer import GraphVisualizer


def visualize_trajectories(recorded_values, graph_visualizer: GraphVisualizer,
                           iteration: int, save_animation: bool = False,
                           save_path: str | Path = None):
    if keys.VISUALIZATIONS in recorded_values:
        vis_dict = recorded_values[keys.VISUALIZATIONS]
        for task_name, task_dict in vis_dict.items():
            if "to_visualize" not in task_dict:
                continue
            data_dict = task_dict["to_visualize"]
            eval_traj = data_dict["eval_traj"]
            predicted_traj = data_dict["predicted_traj"]
            animation_name = "trajectory_prediction"
            animation = graph_visualizer.visualize_trajectory(eval_traj=eval_traj,
                                                              predicted_traj=predicted_traj,
                                                              )
            task_dict[animation_name] = animation

            # delete to_visualize key as this is not needed anymore and contains large data
            del task_dict["to_visualize"]


class VisualizationLogger(AbstractLogger):
    """
    Creates checkpoints of the algorithm at a given frequency (in iterations)
    """

    def __init__(self, config: ConfigDict, algorithm: AbstractAlgorithm):
        super().__init__(config=config, algorithm=algorithm)

        env_config = config.env
        self.graph_visualizer = GraphVisualizer(visualization_config=env_config.visualization)
        self._vis_config = config.recorder.visualizations
        if self._vis_config.save_on_disk:
            self._vis_path = os.path.join(self._recording_directory, keys.VISUALIZATIONS)
            os.makedirs(self._vis_path, exist_ok=True)

    def log_iteration(self, recorded_values: ValueDict, iteration: int
                      ) -> None:
        """
        Calls the internal save_checkpoint() method of the algorithm with the current iteration
        Args:
            recorded_values: A dictionary of previously recorded values
            iteration: The current iteration of the algorithm
        Returns:

        """
        self._writer.info(f"Logging visualizations for iteration '{iteration}'")
        graph_visualizer = self.graph_visualizer
        self.visualize_trajectories(recorded_values, graph_visualizer, iteration)

    def visualize_trajectories(self, recorded_values, graph_visualizer: GraphVisualizer,
                               iteration: int):
        if keys.VISUALIZATIONS in recorded_values:
            if self._vis_config.save_on_disk:
                # create folder for this step
                step_path = os.path.join(self._vis_path, f"iteration_{iteration:03}")
                os.makedirs(step_path, exist_ok=True)

            vis_dict = recorded_values[keys.VISUALIZATIONS]
            for context_size_str, context_dict in vis_dict.items():
                for task_name, task_dict in context_dict.items():
                    if "to_visualize" not in task_dict:
                        continue
                    data_dict = task_dict["to_visualize"]
                    eval_traj = data_dict["eval_traj"]
                    predicted_traj = data_dict["predicted_traj"]
                    animation_name = "trajectory_prediction"
                    animation = graph_visualizer.visualize_trajectory(eval_traj=eval_traj,
                                                                      predicted_traj=predicted_traj,
                                                                      )
                    task_dict[animation_name] = animation

                    # delete to_visualize key as this is not needed anymore and contains large data
                    del task_dict["to_visualize"]

                    if self._vis_config.save_on_disk:
                        # right now only implemented for plotly html
                        if isinstance(animation, plotly.graph_objs._figure.Figure):
                            from plotly.io import to_html
                            frame_duration = get_frame_duration(self._config.env.visualization)
                            html_content = to_html(animation, include_plotlyjs="cdn", auto_play=True,
                                                   animation_opts={"frame": {"duration": frame_duration,
                                                                             "redraw": True  # must be true
                                                                             },
                                                                   "fromcurrent": True,
                                                                   "transition": {"duration": frame_duration},
                                                                   })
                            # create task folder if not exists
                            task_path = os.path.join(step_path, task_name)
                            os.makedirs(task_path, exist_ok=True)
                            with open(os.path.join(task_path, f"{context_size_str}.html"), "w") as f:
                                f.write(html_content)

    def finalize(self) -> None:
        pass
