import time
from collections import defaultdict
from typing import Any, Dict, List

import jax
import jax.numpy as jnp
import numpy as np
import wandb


class EpisodeLogger:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self._batches: Dict[int, List[Dict[str, float]]] = defaultdict(list)
        self._seen: Dict[int, int] = defaultdict(int)  # how many repeats processed for this step
        self._last_time = None

    def log_step(self, step: int, info: Dict[str, Any]):
        self._seen[step] += 1

        # Must have a valid returned_episode mask and at least one finished env
        re = info.get("returned_episode", None)
        if re is None:
            # invalid repeat; do not append a batch
            return self._maybe_close_step(step)

        re_np = np.asarray(re)
        if not np.all(np.isfinite(re_np)):
            return self._maybe_close_step(step)

        num_finished = float(re_np.sum())  # float is fine; we don't cast to int
        if not (num_finished > 0):
            # no finished episodes -> treat as invalid for logging purposes
            return self._maybe_close_step(step)

        # Build flat dict; if any key is missing/invalid -> skip this repeat
        flat: Dict[str, float] = {}
        valid = True
        mask = re_np > 0

        for k, v in info.items():
            v_np = np.asarray(v)
            if v_np.ndim == 0:
                if not np.isfinite(v_np):
                    valid = False
                    break
                flat[k] = float(v_np)
            else:
                # average only over finished envs; require at least one value and all finite
                try:
                    vals = v_np[mask]
                except Exception:
                    # If shapes are strange/broadcasted, bail out as invalid
                    valid = False
                    break
                if vals.size == 0 or not np.all(np.isfinite(vals)):
                    valid = False
                    break
                flat[k] = float(vals.mean())

        if valid:
            self._batches[step].append(flat)

        self._maybe_close_step(step)

    def _maybe_close_step(self, step: int):
        """
        If we've seen all repeats for this step, either flush (if all valid)
        or drop (if any were invalid / missing).
        """
        if self._seen[step] < self.config["NUM_REPEATS"]:
            return  # wait for remaining repeats

        if len(self._batches[step]) == self.config["NUM_REPEATS"]:
            self._flush(step)
        self._batches.pop(step, None)
        self._seen.pop(step, None)

    def _flush(self, step: int):
        reps = self._batches.pop(step)
        self._seen.pop(step, None)
        agg: Dict[str, float] = {}

        # mean-aggregate every key seen (all repeats are valid by construction)
        keys = set().union(*[r.keys() for r in reps])
        for k in keys:
            vals = [r[k] for r in reps if k in r]
            agg[k] = float(np.mean(vals))

        # compute stealth_ratio if possible (guarded)
        if "mean_delta_sq" in agg and "returned_episode_returns" in agg:
            den = agg["mean_delta_sq"]
            if den > 0:
                agg["stealth_ratio"] = -agg["returned_episode_returns"] / max(den, 1e-8)

        # compute SPS
        now = time.time()
        if self._last_time is not None:
            dt = now - self._last_time
            steps = (
                self.config["NUM_STEPS"]
                * self.config["NUM_ENVS"]
                * self.config["NUM_REPEATS"]
            )
            if dt > 0:
                agg["sps"] = steps / dt
        else:
            print('Logging started')
        self._last_time = now

        # send to wandb
        wandb.log(agg, step=step)
