from collections import deque
from typing import Dict, Tuple, Optional

import numpy as np

INT32_MAX = 2147483647
np.seterr(all="raise")


class LevelSamplerES:
    def __init__(
        self,
        seeds,
        obs_space,
        action_space,
        num_actors=1,
        strategy="random",
        max_score_coef=0.0,
        replay_schedule="fixed",
        score_transform="power",
        temperature=1.0,
        eps=0.05,
        rho=0.6,
        replay_prob=0.95,
        tao=0.1,
        alpha=1.0,
        staleness_coef=0.0,
        staleness_transform="power",
        staleness_temperature=1.0,
        sample_full_distribution=True,
        seed_buffer_size=960,
        seed_buffer_priority="replay_support",
        use_dense_rewards=False,
        tscl_window_size=0,
        gamma=0.999,
        partition_manager=None,
        partition_flush_interval=10,
    ):
        self.obs_space = obs_space
        self.action_space = action_space
        self.num_actors = num_actors

        self.strategy = strategy
        self.rho = float(rho)
        self.replay_prob = float(replay_prob)
        self.tao = float(tao)

        self.staleness_coef = float(staleness_coef)
        self.staleness_transform = staleness_transform
        self.staleness_temperature = float(staleness_temperature)

        self.sample_full_distribution = bool(sample_full_distribution)
        self.seed_buffer_priority = seed_buffer_priority

        self.seed_buffer_size = 960

        self._init_seed_index(seeds)

        self.running_sample_count = 0
        self.next_seed_index = 0  # For finite buffer sequential strategy

        if self.sample_full_distribution:

            self.capacity = int(self.seed_buffer_size)

            self.working_seed_buffer_size = 0
            self.working_seed_set = set()

            self._ever_warm = False

        self.use_dense_rewards = use_dense_rewards
        self.gamma = gamma
        self.grounded_values = None

        self.track_solvable = False
        self.seed_solvable = np.ones(self.seed_buffer_size, dtype=bool)

    def _init_seed_index(self, seeds):
        if seeds:
            arr = np.array(seeds, dtype=np.int64)
            self.seeds = arr
            self.seed2index = {int(seed): i for i, seed in enumerate(arr)}
        else:
            self.seeds = np.zeros(self.seed_buffer_size, dtype=np.int64) - 1
            self.seed2index = {}

    def seed_range(self):
        if not self.sample_full_distribution:
            if len(self.seeds) == 0:
                return (0, 0)
            return (int(self.seeds.min()), int(self.seeds.max()))
        else:
            return (0, INT32_MAX)

    @property
    def _proportion_filled(self):

        if self.sample_full_distribution:
            if self.capacity <= 0:
                return 0.0
            return self.working_seed_buffer_size / float(self.capacity)
        else:
            if len(self.seeds) == 0:
                return 0.0
            num_unseen = (self.unseen_seed_weights > 0).sum()
            return (len(self.seeds) - num_unseen) / float(len(self.seeds))

    @property
    def is_warm(self):
        if self.sample_full_distribution:
            return self._ever_warm
        else:
            return self._proportion_filled >= self.rho

    @property
    def requires_value_buffers(self):

        return False

    def _update_staleness(self, selected_idx: int):
        if self.staleness_coef > 0 and len(self.seed_staleness) > 0:
            self.seed_staleness += 1.0
            if 0 <= selected_idx < len(self.seed_staleness):
                self.seed_staleness[selected_idx] = 0.0

    def update_with_rollouts(self, rollouts, external_scores=None):

        return

    def after_update(self):

        return

    def observe_external_unseen_sample(self, seeds, solvable=None):

        for i, seed in enumerate(seeds):
            self.running_sample_count += 1
            seed = int(seed)

            if self.sample_full_distribution:
                if seed in self.working_seed_set:

                    idx = self.seed2index.get(seed, None)
                    # if idx is not None:
                    #     self._update_staleness(idx)
                    continue

                if self.capacity <= 0:
                    continue

                if self.working_seed_buffer_size >= self.capacity:
                    oldest_seed = int(self.seeds[0])
                    if oldest_seed != -1:
                        self.drop_seed(oldest_seed)

                # At this point working_seed_buffer_size <= capacity - 1
                idx = self.working_seed_buffer_size
                if idx < self.seed_buffer_size:
                    self.seeds[idx] = seed
                    self.seed2index[seed] = idx
                    # self.unseen_seed_weights[idx] = 0.0  # seen
                    self.working_seed_set.add(seed)
                    # if len(self.seed_staleness) > 0:
                    #     self.seed_staleness[idx] = 0.0

                    if solvable is not None:
                        if not self.track_solvable:
                            self.track_solvable = True
                        if i < len(solvable):
                            self.seed_solvable[idx] = bool(solvable[i])

                    self.working_seed_buffer_size = min(
                        self.capacity, self.working_seed_buffer_size + 1
                    )
            else:
                idx = self.seed2index.get(seed, None)
                if idx is not None:
                    self._update_staleness(idx)

    def sample_replay_level(self, update_staleness: bool = True) -> int:
        return self._sample_replay_level(update_staleness=update_staleness)

    def _sample_replay_level(self, update_staleness: bool = True) -> int:

        if self.sample_full_distribution:
            effective_n = self.working_seed_buffer_size
            # if self.capacity > 0:
            #     base_for_head = self.capacity
            # else:
            #     base_for_head = self.seed_buffer_size
        else:
            effective_n = len(self.seeds)
            base_for_head = self.seed_buffer_size

        if effective_n <= 0:
            if len(self.seeds) == 0:
                return -1
            return int(self.seeds[0])

        # head_limit = int(self.tao * base_for_head)
        # if head_limit <= 0:
        #     head_limit = 1
        # head_limit = min(head_limit, effective_n)

        head_limit = effective_n

        # candidate_indices = [
        #     i
        #     for i in range(head_limit)
        #     if self.unseen_seed_weights[i] == 0.0 and self.seeds[i] != -1
        # ]

        # if not candidate_indices:
        #     candidate_indices = [
        #         i
        #         for i in range(effective_n)
        #         if self.unseen_seed_weights[i] == 0.0 and self.seeds[i] != -1
        #     ]

        # if not candidate_indices:
        #     candidate_indices = list(range(effective_n))

        seed_idx = int(np.random.choice(head_limit))
        seed = int(self.seeds[seed_idx])

        # if update_staleness:
        #     self._update_staleness(seed_idx)

        return seed

    def _sample_replay_level_batch(self, batch_size=16, update_staleness: bool = True):
        """
        Uniform sample `batch_size` unique seeds from replay buffer head region.
        Guarantees:
        - uniform sampling
        - no index repetition within one batch
        - returns list[int] of length batch_size
        """

        # === Compute effective_n and head_limit ===
        if self.sample_full_distribution:
            effective_n = self.working_seed_buffer_size
        #     base_for_head = self.capacity if self.capacity > 0 else self.seed_buffer_size
        # else:
        #     effective_n = len(self.seeds)
        #     base_for_head = self.seed_buffer_size

        # if effective_n <= 0:
        #     return []

        # head_limit = int(self.tao * base_for_head)
        head_limit = effective_n
        # if head_limit <= 0:
        #     head_limit = 1
        # head_limit = min(head_limit, effective_n)

        # === Ensure enough seeds to sample from ===
        if head_limit < batch_size:
            head_limit = min(effective_n, batch_size)

        # === Uniform sample unique indices ===
        idxs = np.random.choice(head_limit, size=batch_size, replace=False)

        # === Convert to seed values ===
        seeds = [int(self.seeds[i]) for i in idxs if self.seeds[i] != -1]

        # === Optional staleness update ===
        # if update_staleness:
        #     for i in idxs:
        #         self._update_staleness(i)

        return seeds

    def _sample_unseen_level(self) -> int:

        if self.sample_full_distribution:
            seed = int(np.random.randint(1, INT32_MAX))
            while seed in self.working_seed_set:
                seed = int(np.random.randint(1, INT32_MAX))
            return seed
        else:
            # Finite buffer: sample an index based on unseen_seed_weights
            if len(self.seeds) == 0:
                return -1
            weights = self.unseen_seed_weights
            if weights.sum() <= 0:
                idx = int(np.random.choice(len(self.seeds)))
            else:
                probs = weights / weights.sum()
                idx = int(np.random.choice(len(self.seeds), p=probs))
            self._update_staleness(idx)
            return int(self.seeds[idx])

    def sample(self, strategy: Optional[str] = None) -> int:

        self.running_sample_count += 1

        if strategy is None:
            strategy = self.strategy

        if not self.sample_full_distribution:
            if strategy == "random":
                if len(self.seeds) == 0:
                    return -1
                idx = int(np.random.choice(len(self.seeds)))
                return int(self.seeds[idx])
            if strategy == "sequential":
                if len(self.seeds) == 0:
                    return -1
                idx = self.next_seed_index
                self.next_seed_index = (self.next_seed_index + 1) % len(self.seeds)
                return int(self.seeds[idx])

        if self.sample_replay_decision():
            return self._sample_replay_level()
        else:
            return self._sample_unseen_level()

    def sample_weights(self) -> np.ndarray:

        N = len(self.seeds)
        if N == 0:
            return np.zeros(0, dtype=np.float32)

        weights = np.zeros(N, dtype=np.float32)

        if self.sample_full_distribution:
            effective_n = self.working_seed_buffer_size
        else:
            effective_n = N

        if effective_n <= 0:
            return weights

        seen_indices = [
            i
            for i in range(effective_n)
            if self.unseen_seed_weights[i] == 0.0 and self.seeds[i] != -1
        ]

        if seen_indices:
            val = 1.0 / float(len(seen_indices))
            for i in seen_indices:
                weights[i] = val
        else:
            val = 1.0 / float(effective_n)
            for i in range(effective_n):
                weights[i] = val

        if self.staleness_coef > 0 and len(self.seed_staleness) == N:
            if self.staleness_transform == "softmax":
                s = np.exp(self.seed_staleness / max(self.staleness_temperature, 1e-6))
            else:
                s = (self.seed_staleness.clip(min=0.0) + 1e-3) ** (
                    1.0 / max(self.staleness_temperature, 1e-6)
                )
            mask = (self.unseen_seed_weights == 0.0).astype(np.float32)
            s = s * mask
            if s.sum() > 0:
                s /= s.sum()
                weights = (1.0 - self.staleness_coef) * weights + self.staleness_coef * s

        return weights

    def drop_seed(self, seed: int):
        seed = int(seed)

        if self.sample_full_distribution:
            if seed not in self.working_seed_set:
                return

            seed_idx = self.seed2index.get(seed, None)
            if seed_idx is None:
                return

            last_idx = self.working_seed_buffer_size - 1
            if last_idx < 0:
                return

            # Move everything forward starting from seed_idx
            for i in range(seed_idx, last_idx):
                s_next = int(self.seeds[i + 1])

                self.seeds[i] = s_next

                if s_next != -1:
                    self.seed2index[s_next] = i

            self.seeds[last_idx] = -1
            # self.seed_scores[last_idx] = 0.0
            # self.seed_staleness[last_idx] = 0.0
            # self.unseen_seed_weights[last_idx] = 1.0
            # self.seed_solvable[last_idx] = True

            self.working_seed_set.discard(seed)
            self.seed2index.pop(seed, None)

            self.working_seed_buffer_size = max(0, self.working_seed_buffer_size - 1)
        else:
            seed_idx = self.seed2index.get(seed, None)
            if seed_idx is None:
                return

    @property
    def solvable_mass(self) -> float:

        if self.track_solvable and len(self.seed_solvable) == len(self.seeds):
            w = self.sample_weights()
            return float((w * self.seed_solvable.astype(np.float32)).sum())
        else:
            return 1.0

    @property
    def max_score(self) -> float:

        return 0
