import os
from pathlib import Path
from typing import List

import jax
import jax.numpy as jnp

from qdax.logging.metric_loggers import Metric, MetricLogger
from qdax.types import Metrics


def log_accumulated_metrics(
    metrics: Metrics,
    metric_loggers: List[MetricLogger],
    current_timestep: int,
    last_timestep: int,
    metrics_subsample: int = 1,
) -> None:
    """Logs uniformly metrics on the interval [current_timestep, last_timestep)."""

    metrics = jax.tree_map(
        lambda x: jnp.nanmean(x, axis=1) if len(x.shape) == 2 else x,
        metrics,
    )
    metrics = jax.tree_map(
        lambda x: x.flatten()[::metrics_subsample].block_until_ready(), metrics
    )

    for metric_name, metric_value in metrics.items():
        x_values = jnp.flip(
            jnp.linspace(
                current_timestep,
                last_timestep,
                len(list(metrics.values())[0]),
                endpoint=False,
            )
        )
        for i in range(len(metric_value)):

            metric = Metric(
                name=metric_name,
                value=int(metric_value[i])
                if "int" in str(metric_value[i].dtype)
                else float(metric_value[i]),
                x_axis_value=int(x_values[i]),
            )

            for metric_logger in metric_loggers:
                metric_logger.log(metric)


def log_statistics(
    statistics: Metrics,
    current_timestep: int,
    metric_loggers: List[MetricLogger],
) -> None:
    """Log the values in the statistics object, if the key contains
    `per_skill` (meaning that there is one value per_skill) it iterates through the
    value and logs every item of the array.
    """
    for key, value in statistics.items():
        if "per_skill" in key:
            for i, v in enumerate(value):
                metric = Metric(
                    name=f"{key}_{i}",
                    value=float(v),
                    x_axis_value=int(current_timestep),
                )
                for metric_logger in metric_loggers:
                    metric_logger.log(metric)

        else:
            metric = Metric(
                name=f"{key}",
                value=float(value),
                x_axis_value=int(current_timestep),
            )
            for metric_logger in metric_loggers:
                metric_logger.log(metric)


def get_output_dir() -> str:
    """
    Returns the correct output directory for saving artifacts. This output directory is
    a base directory where artifacts and artifact sub-directories should be saved.
    For Ichor runs we need to use the premade mounted directory for the docker image at
    the given environment variable. For local runs the output_dir defaults to the
    working directory.
    """
    output_dir = os.environ.get("ICHOR_OUTPUT_DATASET")
    if output_dir is None:
        output_dir = str(Path.cwd())

    return output_dir
