import wandb
from omegaconf import OmegaConf

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.recording.loggers.abstract_logger import AbstractLogger
from ltsgns_mp.recording.loggers.logger_util.wandb_util import reset_wandb_env, wandbfy
from ltsgns_mp.recording.visualizations.plotly_visualizations import get_frame_duration
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import *


class CustomWAndBLogger(AbstractLogger):
    """
    Logs (some) recorded results using wandb.ai.
    """

    def __init__(self, config: ConfigDict, algorithm: AbstractAlgorithm):
        super().__init__(config=config, algorithm=algorithm)
        reset_wandb_env()

        wandb_params = config.recorder.wandb
        project_name = wandb_params.get("project_name")
        environment_name = config.env.name

        if wandb_params.get("task_name") is not None:
            project_name = project_name + "_" + wandb_params.get("task_name")
        elif environment_name is not None:
            project_name = project_name + "_" + environment_name
        else:
            # no further specification of the project, just use the initial project_name
            project_name = project_name

        groupname = wandb_params.get("group_name")
        if config.evaluation.eval_only:
            groupname = config.eval_name + "_" + groupname
        if wandb_params.dev:
            groupname = "DEV_" + groupname
        groupname = groupname[-127:]
        runname = wandb_params.get("run_name")[-127:]
        job_type = wandb_params.get("job_type")[-64:]

        tags = wandb_params.get("tags", [])
        if tags is None:
            tags = []
        if config.get("algorithm").get("name") is not None:
            tags.append(config.get("algorithm").get("name"))
        if config.get("env").get("name") is not None:
            tags.append(config.get("env").get("name"))

        entity = wandb_params.get("entity")

        start_method = wandb_params.get("start_method")
        settings = wandb.Settings(start_method=start_method) if start_method is not None else None

        self.wandb_logger = wandb.init(project=project_name,  # name of the whole project
                                       tags=tags,  # tags to search the runs by. Currently, contains algorithm name
                                       job_type=job_type,  # name of your experiment
                                       group=groupname,  # group of identical hyperparameters for different seeds
                                       name=runname,  # individual repetitions
                                       dir=self._recording_directory,  # local directory for wandb recording
                                       config=OmegaConf.to_container(config, resolve=True),  # full file config
                                       reinit=False,
                                       entity=entity,
                                       settings=settings
                                       )
        # look at weights
        # wandb.watch(algorithm.simulator.decoder, log='all')
        # wandb.watch(algorithm.simulator.processor, log='all')

    def log_iteration(self, recorded_values: ValueDict, iteration: int) -> None:
        """
        Parses and logs the given dict of recorder metrics to wandb.
        Args:
            recorded_values: A dictionary of previously recorded things
            iteration: The current iteration of the algorithm
        Returns:

        """
        wandb_log_dict = {}
        # log_scalars
        if keys.SCALARS in recorded_values:
            scalars_dict = recorded_values[keys.SCALARS]
            # category is train or eval
            for category_name, category_metrics in scalars_dict.items():
                for task_quantity_name, task_quantity_metrics in category_metrics.items():
                    if isinstance(task_quantity_metrics, dict):
                        for metric_name, metric_value in task_quantity_metrics.items():
                            wandb_log_dict[task_quantity_name + "/" + metric_name] = metric_value
                    else:
                        # if it is not nested, just log the metric
                        wandb_log_dict[category_name + "/" + task_quantity_name] = task_quantity_metrics
            wandb_log_dict['default/iteration'] = iteration

        # log visualizations and animations
        if self._config.recorder.wandb.log_visualizations and keys.VISUALIZATIONS in recorded_values:
            vis_dict = recorded_values[keys.VISUALIZATIONS]
            for task_name, task_figures in vis_dict.items():
                for vis_name, vis_figure in task_figures.items():
                    vis_figure = wandbfy(vis_figure, get_frame_duration(self._config.env.visualization))
                    wandb_log_dict[f"{task_name}/{vis_name}"] = vis_figure

            if keys.TRAIN in recorded_values[keys.VISUALIZATIONS]:
                train_vis_dict = recorded_values[keys.VISUALIZATIONS][keys.TRAIN]
                for vis_name, vis_figure in train_vis_dict.items():
                    vis_figure = wandbfy(vis_figure, get_frame_duration(self._config.env.visualization))
                    wandb_log_dict[f"train/{vis_name}"] = vis_figure

        if wandb_log_dict:  # logging dictionary is not empty
            wandb.log(wandb_log_dict, step=iteration)

        # delete visualizations to try to fix memory leak
        if keys.VISUALIZATIONS in recorded_values:
            del recorded_values[keys.VISUALIZATIONS]

    def finalize(self) -> None:
        """
        Properly close the wandb logger
        Returns:

        """
        wandb.finish()
