# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# Modified minimally to support external std-window scoring and uniform+gate baseline.

from collections import namedtuple, defaultdict, deque
import queue

import numpy as np
import torch

INT32_MAX = 2147483647
np.seterr(all="raise")


class LevelSamplerDrop:
    def __init__(
        self,
        seeds,
        obs_space,
        action_space,
        num_actors=1,
        strategy="random",
        max_score_coef=0.0,
        replay_schedule="fixed",
        score_transform="power",
        temperature=1.0,
        eps=0.05,
        rho=1.0,
        replay_prob=0.95,
        alpha=1.0,
        staleness_coef=0,
        staleness_transform="power",
        staleness_temperature=1.0,
        sample_full_distribution=True,
        seed_buffer_size=0,
        seed_buffer_priority="replay_support",
        use_dense_rewards=False,
        tscl_window_size=0,
        gamma=0.999,
    ):
        """
        Inputs:
            seeds: List of initial seeds.
            strategy: Sampling strategy (random, sequential, policy entropy, ..., std_window, uniform_mean_gate).
        """
        self.obs_space = obs_space
        self.action_space = action_space
        self.num_actors = num_actors
        self.strategy = strategy
        self.max_score_coef = max_score_coef
        self.replay_schedule = replay_schedule
        self.score_transform = score_transform
        self.temperature = temperature
        self.eps = eps
        self.rho = rho
        self.replay_prob = replay_prob
        self.alpha = alpha
        self.staleness_coef = staleness_coef
        self.staleness_transform = staleness_transform
        self.staleness_temperature = staleness_temperature
        self.gamma = gamma
        self.use_dense_rewards = use_dense_rewards

        # self.seed_buffer_size = seed_buffer_size if not seeds else len(seeds)
        self.seed_buffer_size = seed_buffer_size
        N = self.seed_buffer_size
        self._init_seed_index(seeds)

        self.unseen_seed_weights = np.array([1.0] * N)
        self.seed_scores = np.array([0.0] * N, dtype=np.float)
        self.partial_seed_scores = np.zeros((num_actors, N), dtype=np.float)
        self.partial_seed_max_scores = np.ones((num_actors, N), dtype=np.float) * float(
            "-inf"
        )
        self.partial_seed_steps = np.zeros((num_actors, N), dtype=np.int32)
        self.seed_staleness = np.array([0.0] * N, dtype=np.float)

        self.running_sample_count = 0
        self.next_seed_index = 0

        self.track_solvable = False

        self.grounded_values = None
        if str(self.strategy).startswith("grounded"):
            self.grounded_values = np.array([np.NINF] * N, dtype=np.float)

        # NEW: Sample count & early-stop flag for each seed
        self.seed_sample_counts = np.zeros(N, dtype=np.int32)
        self.seed_early_stopped = np.zeros(N, dtype=np.bool)

        self.sample_full_distribution = sample_full_distribution
        if self.sample_full_distribution:
            self.seed2actor = defaultdict(set)
            self.working_seed_buffer_size = 0
            self.seed_buffer_priority = seed_buffer_priority
            self.staging_seed_set = set()
            self.working_seed_set = set()

            self.seed2timestamp_buffer = {}
            self.partial_seed_scores_buffer = [{} for _ in range(num_actors)]
            self.partial_seed_max_scores_buffer = [{} for _ in range(num_actors)]
            self.partial_seed_steps_buffer = [{} for _ in range(num_actors)]

        if str(self.strategy).startswith("tscl"):
            self.tscl_window_size = tscl_window_size
            self.tscl_return_window = [
                deque(maxlen=self.tscl_window_size) for _ in range(N)
            ]
            self.tscl_episode_window = [
                deque(maxlen=self.tscl_window_size) for _ in range(N)
            ]
            self.unseen_seed_weights = np.zeros(N)  # force uniform over seen

    def seed_range(self):
        if not self.sample_full_distribution:
            return (int(min(self.seeds)), int(max(self.seeds)))
        else:
            return (0, INT32_MAX)

    def _init_seed_index(self, seeds):
        if seeds:
            self.seeds = np.array(seeds, dtype=np.int64)
            self.seed2index = {seed: i for i, seed in enumerate(seeds)}
        else:
            self.seeds = np.zeros(self.seed_buffer_size, dtype=np.int64) - 1
            self.seed2index = {}

    def _init_solvable_tracking(self):
        self.track_solvable = True
        self.staging_seed2solvable = {}
        self.seed_solvable = np.ones(self.seed_buffer_size, dtype=np.bool)

    @property
    def _proportion_filled(self):
        if self.sample_full_distribution:
            return self.working_seed_buffer_size / self.seed_buffer_size
        else:
            num_unseen = (self.unseen_seed_weights > 0).sum()
            proportion_seen = (len(self.seeds) - num_unseen) / len(self.seeds)
            return proportion_seen

    # === New: direct score setter used by std_window runner ===
    def set_score_direct(self, seed, score):
        """Write a score (e.g., std over a window) to a seed from outside.
        Marks it as seen so it participates in sampling."""
        seed_idx = self.seed2index.get(int(seed), -1)
        if seed_idx < 0:
            return False
        self.seed_scores[seed_idx] = float(score)
        self.unseen_seed_weights[seed_idx] = 0.0
        return True

    # === Original update entry preserved ===
    def update_with_rollouts(self, rollouts, external_scores=None):
        # Our new strategies are handled by the Runner externally.
        if self.strategy in ["std_window", "uniform_mean_gate"]:
            return
        if self.strategy in ["random", "off"]:
            return

        if self.strategy == "uniform":
            score_function = self._uniform
        elif self.strategy == "policy_entropy":
            score_function = self._average_entropy
        elif self.strategy == "least_confidence":
            score_function = self._average_least_confidence
        elif self.strategy == "min_margin":
            score_function = self._average_min_margin
        elif self.strategy == "gae":
            score_function = self._average_gae
        elif self.strategy == "value_l1":
            score_function = self._average_value_l1
        elif self.strategy == "signed_value_loss":
            score_function = self._average_signed_value_loss
        elif self.strategy == "positive_value_loss":
            score_function = self._average_positive_value_loss
        elif self.strategy == "grounded_signed_value_loss":
            score_function = self._average_grounded_signed_value_loss
        elif self.strategy == "grounded_positive_value_loss":
            score_function = self._average_grounded_positive_value_loss
        elif self.strategy == "one_step_td_error":
            score_function = self._one_step_td_error
        elif self.strategy == "alt_advantage_abs":
            score_function = self._average_alt_advantage_abs
        elif self.strategy == "tscl_window":
            score_function = self._tscl_window
        else:
            raise ValueError(f"Unsupported strategy, {self.strategy}")

        if external_scores is not None:
            score_function = self._average_external_score

        self._update_with_rollouts(
            rollouts, score_function, external_scores=external_scores
        )

    def update_seed_score(
        self, actor_index, seed, score, max_score, num_steps, running_mean=True
    ):
        if self.sample_full_distribution and seed in getattr(
            self, "staging_seed_set", set()
        ):
            score, seed_idx = self._partial_update_seed_score_buffer(
                actor_index, seed, score, num_steps, done=True, running_mean=running_mean
            )
        else:
            score, seed_idx = self._partial_update_seed_score(
                actor_index,
                seed,
                score,
                max_score,
                num_steps,
                done=True,
                running_mean=running_mean,
            )
        return score, seed_idx

    def _partial_update_seed_score(
        self, actor_index, seed, score, max_score, num_steps, done=False, running_mean=True
    ):
        seed_idx = self.seed2index.get(seed, -1)
        if seed_idx < 0:
            return 0, None
        partial_score = self.partial_seed_scores[actor_index][seed_idx]
        partial_max_score = self.partial_seed_max_scores[actor_index][seed_idx]
        partial_num_steps = self.partial_seed_steps[actor_index][seed_idx]

        running_num_steps = partial_num_steps + num_steps
        merged_score = partial_score + (score - partial_score) * num_steps
        if running_mean:
            merged_score = merged_score / float(running_num_steps)
        merged_max_score = max(partial_max_score, max_score)

        if done:
            self.partial_seed_scores[actor_index][seed_idx] = 0.0
            self.partial_seed_max_scores[actor_index][seed_idx] = float("-inf")
            self.partial_seed_steps[actor_index][seed_idx] = 0
            self.unseen_seed_weights[seed_idx] = 0.0
            old_score = self.seed_scores[seed_idx]
            total_score = (
                self.max_score_coef * merged_max_score
                + (1 - self.max_score_coef) * merged_score
            )
            self.seed_scores[seed_idx] = (
                1 - self.alpha
            ) * old_score + self.alpha * total_score
        else:
            self.partial_seed_scores[actor_index][seed_idx] = merged_score
            self.partial_seed_max_scores[actor_index][seed_idx] = merged_max_score
            self.partial_seed_steps[actor_index][seed_idx] = running_num_steps
        return merged_score, seed_idx

    @property
    def _next_buffer_index(self):
        if self._proportion_filled < 1.0:
            return getattr(self, "working_seed_buffer_size", 0)
        else:
            if getattr(self, "seed_buffer_priority", "replay_support") == "replay_support":
                return self.sample_weights().argmin()
            else:
                return self.seed_scores.argmin()

    def _partial_update_seed_score_buffer(
        self, actor_index, seed, score, num_steps, done=False, running_mean=True
    ):
        seed_idx = -1
        self.seed2actor[seed].add(actor_index)
        partial_score = self.partial_seed_scores_buffer[actor_index].get(seed, 0)
        partial_num_steps = self.partial_seed_steps_buffer[actor_index].get(seed, 0)

        running_num_steps = partial_num_steps + num_steps
        merged_score = partial_score + (score - partial_score) * num_steps
        if running_mean:
            merged_score = merged_score / float(running_num_steps)

        if done:
            seed_idx = self._next_buffer_index
            if (
                self.seed_scores[seed_idx] <= merged_score
                or self.unseen_seed_weights[seed_idx] > 0
            ):
                self.unseen_seed_weights[seed_idx] = 0.0
                getattr(self, "working_seed_set", set()).discard(self.seeds[seed_idx])
                getattr(self, "working_seed_set", set()).add(seed)
                self.seeds[seed_idx] = seed
                self.seed2index[seed] = seed_idx
                self.seed_scores[seed_idx] = merged_score
                self.partial_seed_scores[:, seed_idx] = 0.0
                self.partial_seed_steps[:, seed_idx] = 0
                if hasattr(self, "seed2timestamp_buffer"):
                    self.seed_staleness[seed_idx] = (
                        self.running_sample_count - self.seed2timestamp_buffer[seed]
                    )
                self.working_seed_buffer_size = min(
                    getattr(self, "working_seed_buffer_size", 0) + 1, self.seed_buffer_size
                )
                if self.track_solvable:
                    self.seed_solvable[seed_idx] = self.staging_seed2solvable.get(
                        seed, True
                    )

                # NEW: Reset sample count and early-stop flag when new seed is written to buffer
                self.seed_sample_counts[seed_idx] = 0
                self.seed_early_stopped[seed_idx] = False

            else:
                seed_idx = None

            for a in self.seed2actor[seed]:
                self.partial_seed_scores_buffer[a].pop(seed, None)
                self.partial_seed_steps_buffer[a].pop(seed, None)
            del self.seed2timestamp_buffer[seed]
            del self.seed2actor[seed]
            self.staging_seed_set.remove(seed)
            if self.track_solvable:
                del self.staging_seed2solvable[seed]
        else:
            self.partial_seed_scores_buffer[actor_index][seed] = merged_score
            self.partial_seed_steps_buffer[actor_index][seed] = running_num_steps
        return merged_score, seed_idx

    # ---- score functions (unchanged originals) ----
    def _uniform(self, **kwargs):
        return 1.0, 1.0

    def _average_entropy(self, **kwargs):
        episode_logits = kwargs["episode_logits"]
        num_actions = self.action_space.n
        max_entropy = -(1.0 / num_actions) * np.log(1.0 / num_actions) * num_actions
        scores = -torch.exp(episode_logits) * episode_logits.sum(-1) / max_entropy
        return scores.mean().item(), scores.max().item()

    def _average_least_confidence(self, **kwargs):
        episode_logits = kwargs["episode_logits"]
        scores = 1 - torch.exp(episode_logits.max(-1, keepdim=True)[0])
        return scores.mean().item(), scores.max().item()

    def _average_min_margin(self, **kwargs):
        episode_logits = kwargs["episode_logits"]
        top2_confidence = torch.exp(episode_logits.topk(2, dim=-1)[0])
        scores = top2_confidence[:, 0] - top2_confidence[:, 1]
        return 1 - scores.mean().item(), 1 - scores.min().item()

    def _average_gae(self, **kwargs):
        returns = kwargs["returns"]
        value_preds = kwargs["value_preds"]
        advantages = returns - value_preds
        return advantages.mean().item(), advantages.max().item()

    def _average_value_l1(self, **kwargs):
        returns = kwargs["returns"]
        value_preds = kwargs["value_preds"]
        abs_advantages = (returns - value_preds).abs()
        return abs_advantages.mean().item(), abs_advantages.max().item()

    def _average_signed_value_loss(self, **kwargs):
        returns = kwargs["returns"]
        value_preds = kwargs["value_preds"]
        advantages = returns - value_preds
        return advantages.mean().item(), advantages.max().item()

    def _average_positive_value_loss(self, **kwargs):
        returns = kwargs["returns"]
        value_preds = kwargs["value_preds"]
        clipped_advantages = (returns - value_preds).clamp(0)
        return clipped_advantages.mean().item(), clipped_advantages.max().item()

    def _average_grounded_signed_value_loss(self, **kwargs):
        seed = kwargs["seed"]
        actor_idx = kwargs["actor_index"]
        done = kwargs["done"]
        value_preds = kwargs["value_preds"]
        episode_logits = kwargs["episode_logits"]
        partial_steps = 0
        if (
            self.sample_full_distribution
            and seed in getattr(self, "partial_seed_steps_buffer", [{}])[actor_idx]
        ):
            partial_steps = self.partial_seed_steps_buffer[actor_idx][seed]
        elif seed in self.seed2index:
            partial_steps = self.partial_seed_steps[actor_idx][self.seed2index[seed]]
        new_steps = len(episode_logits)
        total_steps = partial_steps + new_steps
        grounded_value = kwargs.get("grounded_value", None)
        if done and grounded_value is not None:
            if self.use_dense_rewards:
                advantages = grounded_value - value_preds[0]
            else:
                advantages = grounded_value - value_preds
            mean_score = (total_steps / new_steps) * advantages.mean().item()
            max_score = advantages.max().item()
        else:
            mean_score, max_score = 0, 0
        return mean_score, max_score

    def _average_external_score(self, **kwargs):
        done = kwargs["done"]
        external_scores = kwargs["external_scores"]
        if done:
            mean_score = external_scores.item()
        else:
            mean_score = 0
        return mean_score, mean_score

    def _average_grounded_positive_value_loss(self, **kwargs):
        seed = kwargs["seed"]
        actor_idx = kwargs["actor_index"]
        done = kwargs["done"]
        value_preds = kwargs["value_preds"]
        episode_logits = kwargs["episode_logits"]
        partial_steps = 0
        if (
            self.sample_full_distribution
            and seed in getattr(self, "partial_seed_steps_buffer", [{}])[actor_idx]
        ):
            partial_steps = self.partial_seed_steps_buffer[actor_idx][seed]
        elif seed in self.seed2index:
            partial_steps = self.partial_seed_steps[actor_idx][self.seed2index[seed]]
        new_steps = len(episode_logits)
        total_steps = partial_steps + new_steps
        grounded_value = kwargs.get("grounded_value", None)

        if done and grounded_value is not None:
            if self.use_dense_rewards:
                advantages = grounded_value - value_preds[0]
            else:
                advantages = grounded_value - value_preds
            advantages = advantages.clamp(0)
            mean_score = (total_steps / new_steps) * advantages.mean().item()
            max_score = advantages.max().item()
        else:
            mean_score, max_score = 0, 0
        return mean_score, max_score

    def _one_step_td_error(self, **kwargs):
        rewards = kwargs["rewards"]
        value_preds = kwargs["value_preds"]
        max_t = len(rewards)
        if max_t > 1:
            td_errors = (
                rewards[:-1] + self.gamma * value_preds[1:max_t] - value_preds[: max_t - 1]
            ).abs()
        else:
            td_errors = rewards[0] - value_preds[0]
        return td_errors.mean().item(), td_errors.max().item()

    def _average_alt_advantage_abs(self, **kwargs):
        returns = kwargs["alt_returns"]
        value_preds = kwargs["value_preds"]
        abs_advantages = (returns - value_preds).abs()
        return abs_advantages.mean().item(), abs_advantages.max().item()

    def _tscl_window(self, **kwargs):
        rewards = kwargs["rewards"]
        seed = kwargs["seed"]
        seed_idx = self.seed2index.get(seed, -1)
        assert seed_idx >= 0
        episode_total_reward = rewards.sum().item()
        self.tscl_return_window[seed_idx].append(episode_total_reward)
        self.tscl_episode_window[seed_idx].append(self.running_sample_count)
        x = self.tscl_episode_window[seed]
        y = self.tscl_return_window[seed]
        A = np.vstack([x, np.ones(len(x))]).T
        c, _ = np.linalg.lstsq(A, y, rcond=None)[0]
        c = abs(c)
        return c, c

    @property
    def requires_value_buffers(self):
        return self.strategy in [
            "gae",
            "value_l1",
            "signed_value_loss",
            "positive_value_loss",
            "grounded_signed_value_loss",
            "grounded_positive_value_loss",
            "one_step_td_error",
            "alt_advantage_abs",
            "tscl_window",
        ]

    @property
    def _has_working_seed_buffer(self):
        return not self.sample_full_distribution or (
            self.sample_full_distribution and self.seed_buffer_size > 0
        )

    def _update_with_rollouts(self, rollouts, score_function, external_scores=None):
        if not self._has_working_seed_buffer:
            return

        level_seeds = rollouts.level_seeds
        policy_logits = rollouts.action_log_dist
        total_steps, num_actors = policy_logits.shape[:2]
        done = ~(rollouts.masks > 0)
        cliffhanger = ~(rollouts.cliffhanger_masks > 0)

        for actor_index in range(num_actors):
            start_t = 0
            done_steps = done[:, actor_index].nonzero()[:, 0]
            for t in done_steps:
                if not start_t < total_steps:
                    break
                if t == 0:  # previous cycle full update
                    continue
                seed_t = level_seeds[start_t, actor_index].item()
                score_function_kwargs = {}
                score_function_kwargs["actor_index"] = actor_index
                score_function_kwargs["done"] = True
                episode_logits = policy_logits[start_t:t, actor_index]
                score_function_kwargs["episode_logits"] = torch.log_softmax(
                    episode_logits, -1
                )
                score_function_kwargs["seed"] = seed_t
                if external_scores is not None:
                    score_function_kwargs["external_scores"] = external_scores[actor_index]
                if self.requires_value_buffers:
                    score_function_kwargs["returns"] = rollouts.returns[
                        start_t:t, actor_index
                    ]
                    if self.strategy == "alt_advantage_abs":
                        score_function_kwargs["alt_returns"] = rollouts.alt_returns[
                            start_t:t, actor_index
                        ]
                    score_function_kwargs["rewards"] = rollouts.rewards[
                        start_t:t, actor_index
                    ]
                    if rollouts.use_popart:
                        score_function_kwargs["value_preds"] = rollouts.denorm_value_preds[
                            start_t:t, actor_index
                        ]
                    else:
                        score_function_kwargs["value_preds"] = rollouts.value_preds[
                            start_t:t, actor_index
                        ]
                if not cliffhanger[t, actor_index]:
                    grounded_value = None
                    if self.grounded_values is not None:
                        seed_idx = self.seed2index.get(seed_t, None)
                        score_function_kwargs["seed_idx"] = seed_idx
                        grounded_value_ = rollouts.rewards[start_t:t].sum(0)[actor_index]
                        if seed_idx is not None:
                            grounded_value = max(
                                self.grounded_values[seed_idx], grounded_value_
                            )
                        else:
                            grounded_value = grounded_value_
                        score_function_kwargs["grounded_value"] = grounded_value
                    score, max_score = score_function(**score_function_kwargs)
                    num_steps = len(episode_logits)
                    _, seed_idx = self.update_seed_score(
                        actor_index,
                        seed_t,
                        score,
                        max_score,
                        num_steps,
                        running_mean=(external_scores is not None),
                    )
                    if (
                        seed_idx is not None
                        and self.grounded_values is not None
                        and grounded_value is not None
                    ):
                        self.grounded_values[seed_idx] = grounded_value
                start_t = t.item()

            if start_t < total_steps:
                seed_t = level_seeds[start_t, actor_index].item()
                score_function_kwargs = {}
                score_function_kwargs["actor_index"] = actor_index
                score_function_kwargs["done"] = False
                episode_logits = policy_logits[start_t:, actor_index]
                score_function_kwargs["episode_logits"] = torch.log_softmax(
                    episode_logits, -1
                )
                score_function_kwargs["seed"] = seed_t
                if external_scores is not None:
                    score_function_kwargs["external_scores"] = external_scores[actor_index]
                if self.requires_value_buffers:
                    score_function_kwargs["returns"] = rollouts.returns[
                        start_t:, actor_index
                    ]
                    if self.strategy == "alt_advantage_abs":
                        score_function_kwargs["alt_returns"] = rollouts.alt_returns[
                            start_t:, actor_index
                        ]
                    score_function_kwargs["rewards"] = rollouts.rewards[
                        start_t:, actor_index
                    ]
                    if rollouts.use_popart:
                        score_function_kwargs["value_preds"] = rollouts.denorm_value_preds[
                            start_t:t, actor_index
                        ]
                    else:
                        score_function_kwargs["value_preds"] = rollouts.value_preds[
                            start_t:, actor_index
                        ]
                score, max_score = score_function(**score_function_kwargs)
                num_steps = len(episode_logits)
                if self.sample_full_distribution and seed_t in getattr(
                    self, "staging_seed_set", set()
                ):
                    self._partial_update_seed_score_buffer(
                        actor_index,
                        seed_t,
                        score,
                        num_steps,
                        running_mean=(external_scores is not None),
                    )
                else:
                    self._partial_update_seed_score(
                        actor_index,
                        seed_t,
                        score,
                        max_score,
                        num_steps,
                        running_mean=(external_scores is not None),
                    )

    def after_update(self):
        if not self._has_working_seed_buffer:
            return
        for actor_index in range(self.partial_seed_scores.shape[0]):
            for seed_idx in range(self.partial_seed_scores.shape[1]):
                if self.partial_seed_scores[actor_index][seed_idx] != 0:
                    self.update_seed_score(
                        actor_index, self.seeds[seed_idx], 0, float("-inf"), 0
                    )
        self.partial_seed_scores.fill(0)
        self.partial_seed_steps.fill(0)
        if self.sample_full_distribution:
            for actor_index in range(self.num_actors):
                actor_staging_seeds = list(
                    self.partial_seed_scores_buffer[actor_index].keys()
                )
                for seed in actor_staging_seeds:
                    if self.partial_seed_scores_buffer[actor_index][seed] > 0:
                        self.update_seed_score(actor_index, seed, 0, float("-inf"), 0)

    def _update_staleness(self, selected_idx):
        if self.staleness_coef > 0:
            self.seed_staleness = self.seed_staleness + 1
            self.seed_staleness[selected_idx] = 0

    def sample_replay_decision(self):
        if self.sample_full_distribution:
            proportion_filled = self._proportion_filled
            if self.seed_buffer_size > 0:
                if self.replay_schedule == "fixed":
                    return (proportion_filled >= self.rho) and (
                        np.random.rand() < self.replay_prob
                    )
                else:
                    return (proportion_filled >= self.rho) and (
                        np.random.rand() < min(proportion_filled, self.replay_prob)
                    )
            else:
                return False
        elif self.replay_schedule == "fixed":
            proportion_seen = self._proportion_filled
            if proportion_seen >= self.rho:
                if np.random.rand() < self.replay_prob or not proportion_seen < 1.0:
                    return True
            return False
        else:
            proportion_seen = self._proportion_filled
            return (proportion_seen >= self.rho) and (np.random.rand() < proportion_seen)

    @property
    def is_warm(self):
        return self._proportion_filled >= self.rho

    def observe_external_unseen_sample(self, seeds, solvable=None):
        for i, seed in enumerate(seeds):
            self.running_sample_count += 1
            if not (seed in self.staging_seed_set or seed in self.working_seed_set):
                self.seed2timestamp_buffer[seed] = self.running_sample_count
                self.staging_seed_set.add(seed)

                if solvable is not None:
                    if not self.track_solvable:  # lazy init of solvable tracking
                        self._init_solvable_tracking()
                    self.staging_seed2solvable[seed] = solvable[i]
            else:
                seed_idx = self.seed2index.get(seed, None)
                if seed_idx is not None:
                    self._update_staleness(seed_idx)

    def sample_replay_level(self, update_staleness=True):
        return self._sample_replay_level(update_staleness=update_staleness)

    def _sample_replay_level(self, update_staleness=True):
        sample_weights = self.sample_weights()
        if np.isclose(np.sum(sample_weights), 0):
            sample_weights = np.ones_like(self.seeds, dtype=np.float) / len(self.seeds)
            sample_weights = sample_weights * (1 - self.unseen_seed_weights)
            sample_weights /= np.sum(sample_weights)
        elif np.sum(sample_weights, 0) != 1.0:
            sample_weights = sample_weights / np.sum(sample_weights, 0)
        seed_idx = np.random.choice(range(len(self.seeds)), 1, p=sample_weights)[0]
        seed = self.seeds[seed_idx]
        if update_staleness:
            self._update_staleness(seed_idx)
        return int(seed)

    def _sample_unseen_level(self):
        if self.sample_full_distribution:
            seed = int(np.random.randint(1, INT32_MAX))
            while seed in getattr(self, "staging_seed_set", set()) or seed in getattr(
                self, "working_seed_set", set()
            ):
                seed = int(np.random.randint(1, INT32_MAX))
            self.seed2timestamp_buffer[seed] = self.running_sample_count
            self.staging_seed_set.add(seed)
        else:
            sample_weights = self.unseen_seed_weights / self.unseen_seed_weights.sum()
            seed_idx = np.random.choice(range(len(self.seeds)), 1, p=sample_weights)[0]
            seed = self.seeds[seed_idx]
            self._update_staleness(seed_idx)
        return int(seed)

    def sample(self, strategy=None):
        if strategy == "full_distribution":
            raise ValueError(
                "One-off sampling via full_distribution strategy is not supported."
            )
        self.running_sample_count += 1
        if not strategy:
            strategy = self.strategy

        if not self.sample_full_distribution:
            if strategy == "random":
                seed_idx = np.random.choice(range((len(self.seeds))))
                seed = self.seeds[seed_idx]
                return int(seed)
            if strategy == "sequential":
                seed_idx = self.next_seed_index
                self.next_seed_index = (self.next_seed_index + 1) % len(self.seeds)
                seed = self.seeds[seed_idx]
                return int(seed)

        replay_decision = self.sample_replay_decision()
        if replay_decision:
            return self._sample_replay_level()
        else:
            return self._sample_unseen_level()

    def sample_weights(self):
        """
        Robust version:
        - Handles all-zero weights
        - Handles NaNs in scores/weights
        - Respects unseen_seed_weights masking
        """
        scores = np.array(self.seed_scores, dtype=np.float64)
        scores = np.nan_to_num(scores, nan=0.0)

        mask_seen = (1.0 - self.unseen_seed_weights).astype(
            np.float64
        )  # 1 for seen, 0 for unseen

        def _safe_normalize(x, mask=None):
            x = np.array(x, dtype=np.float64)
            x = np.nan_to_num(x, nan=0.0)
            if mask is not None:
                x = x * mask
            z = np.sum(x)
            if (not np.isfinite(z)) or z <= 0:
                if mask is not None:
                    m_sum = np.sum(mask)
                    if m_sum > 0:
                        return mask / m_sum
                n = len(x)
                if n == 0:
                    return x
                return np.ones(n, dtype=np.float64) / float(n)
            return x / z

        # --- base weights from scores ---
        weights = self._score_transform(self.score_transform, self.temperature, scores)
        weights = np.nan_to_num(weights, nan=0.0)

        # zero out unseen levels
        weights = weights * mask_seen
        weights = _safe_normalize(weights, mask_seen)

        # --- staleness mixing ---
        if self.staleness_coef > 0:
            staleness_scores = np.array(self.seed_staleness, dtype=np.float64)
            staleness_scores = np.nan_to_num(staleness_scores, nan=0.0)

            staleness_weights = self._score_transform(
                self.staleness_transform, self.staleness_temperature, staleness_scores
            )
            staleness_weights = np.nan_to_num(staleness_weights, nan=0.0)
            staleness_weights = staleness_weights * mask_seen
            staleness_weights = _safe_normalize(staleness_weights, mask_seen)

            weights = (
                1.0 - self.staleness_coef
            ) * weights + self.staleness_coef * staleness_weights
            weights = _safe_normalize(weights, mask_seen)

        return weights

    def _score_transform(self, transform, temperature, scores):
        if transform == "constant":
            weights = np.ones_like(scores)
        if transform == "max":
            weights = np.zeros_like(scores)
            scores = scores[:]
            scores[self.unseen_seed_weights > 0] = -float("inf")
            argmax = np.random.choice(np.flatnonzero(np.isclose(scores, scores.max())))
            weights[argmax] = 1.0
        elif transform == "eps_greedy":
            weights = np.zeros_like(scores)
            weights[scores.argmax()] = 1.0 - self.eps
            weights += self.eps / len(self.seeds)
        elif transform == "rank":
            temp = np.flip(scores.argsort())
            ranks = np.empty_like(temp)
            ranks[temp] = np.arange(len(temp)) + 1
            weights = 1 / ranks ** (1.0 / temperature)
        elif transform == "power":
            eps = 0 if self.staleness_coef > 0 else 1e-3
            weights = (np.array(scores).clip(0) + eps) ** (1.0 / temperature)
        elif transform == "softmax":
            weights = np.exp(np.array(scores) / temperature)
        elif transform == "match":
            weights = np.array([(1 - score) * score for score in scores])
            weights = weights ** (1.0 / temperature)
        elif transform == "match_rank":
            weights = np.array([(1 - score) * score for score in scores])
            temp = np.flip(weights.argsort())
            ranks = np.empty_like(temp)
            ranks[temp] = np.arange(len(temp)) + 1
            weights = 1 / ranks ** (1.0 / temperature)
        return weights

    def _drop_seed_idx(self, seed_idx):
        if seed_idx is None or seed_idx < 0 or seed_idx >= len(self.seeds):
            return

        seed = int(self.seeds[seed_idx])

        if self.sample_full_distribution:
            getattr(self, "working_seed_set", set()).discard(seed)

        self.seeds[seed_idx] = -1
        self.seed_scores[seed_idx] = 0.0
        self.seed_staleness[seed_idx] = 0.0
        self.unseen_seed_weights[seed_idx] = 0.0

        self.partial_seed_scores[:, seed_idx] = 0.0
        self.partial_seed_max_scores[:, seed_idx] = float("-inf")
        self.partial_seed_steps[:, seed_idx] = 0

        self.seed_sample_counts[seed_idx] = 0
        self.seed_early_stopped[seed_idx] = False

        if seed in self.seed2index:
            del self.seed2index[seed]

        if self.track_solvable and hasattr(self, "seed_solvable"):
            if 0 <= seed_idx < len(self.seed_solvable):
                self.seed_solvable[seed_idx] = False

    def mark_sample_and_maybe_drop(self, seed, reached_threshold, max_samples=100):

        seed_idx = self.seed2index.get(int(seed), -1)
        if seed_idx < 0:
            return False, False

        self.seed_sample_counts[seed_idx] += 1

        early_stop_triggered = False
        dropped = False

        if reached_threshold and not self.seed_early_stopped[seed_idx]:
            self.seed_early_stopped[seed_idx] = True
            early_stop_triggered = True

        if (
            self.seed_sample_counts[seed_idx] > max_samples
            and not self.seed_early_stopped[seed_idx]
        ):
            self._drop_seed_idx(seed_idx)
            dropped = True

        return early_stop_triggered, dropped

    @property
    def solvable_mass(self):
        if self.track_solvable:
            sample_weights = self.sample_weights()
            return np.sum(sample_weights[getattr(self, "seed_solvable", slice(None))])
        else:
            return 1.0

    @property
    def max_score(self):
        return max(self.seed_scores)
