import time

import jax.numpy as jnp
import numpy as np
import wandb
from craftax.craftax_classic.constants import Achievement

batch_logs = {}
log_times = []

def compute_score(info):
    achievements_value = []
    for achievement in Achievement:
        name = f"Achievements/{achievement.name.lower()}"
        achievements_value.append(info[name])
    # Geometric mean with an offset of 1%
    achievements_value = jnp.array(achievements_value)
    score= jnp.exp(jnp.mean(jnp.log(1 + achievements_value))) - 1.0
    return score

def create_log_dict(info, config):
    to_log = {
        "episode_return": info["returned_episode_returns"],
        "episode_length": info["returned_episode_lengths"],
        "score": compute_score(info),
    }

    sum_achievements = 0
    for k, v in info.items():
        if "achievements" in k.lower():
            to_log[k] = v
            sum_achievements += v / 100.0

    to_log["achievements"] = sum_achievements

    if config["GOAL_TYPE"] in ["SGRL", "SGRL_nopri"]:
        to_log["epsilon_mask"] = info["epsilon_mask"]
        to_log["feasible_action"] = info["feasible_action"]
    return to_log


def batch_log(update_step, log, config):
    update_step = int(update_step)
    if update_step not in batch_logs:
        batch_logs[update_step] = []

    batch_logs[update_step].append(log)

    if len(batch_logs[update_step]) == config["NUM_REPEATS"]:
        agg_logs = {}
        for key in batch_logs[update_step][0]:
            agg = []
            if key in ["goal_heatmap"]:
                agg = [batch_logs[update_step][0][key]]
            else:
                for i in range(config["NUM_REPEATS"]):
                    val = batch_logs[update_step][i][key]
                    if not jnp.isnan(val):
                        agg.append(val)

            if len(agg) > 0:
                if key in [
                    "episode_length",
                    "episode_return",
                    "exploration_bonus",
                    "e_mean",
                    "e_std",
                    "rnd_loss",
                ]:
                    agg_logs[key] = np.mean(agg)
                else:
                    agg_logs[key] = np.array(agg)

        log_times.append(time.time())

        if config["DEBUG"]:
            if len(log_times) == 1:
                print("Started logging")
            elif len(log_times) > 1:
                dt = log_times[-1] - log_times[-2]
                steps_between_updates = (
                    config["NUM_STEPS"] * config["NUM_ENVS"] * config["NUM_REPEATS"]
                )
                sps = steps_between_updates / dt
                agg_logs["sps"] = sps

        wandb.log(agg_logs)
