from typing import List, Set, Tuple

import torch

from .reranker import Reranker
from .universal_bandit_optimizer import UniversalBanditOptimizer
from .universal_contextual_bandit import UniversalContextualBandit


class DynamicArmContextualBandit(UniversalContextualBandit):
    """
    A contextual bandit that simulates a dynamic arm environment.

    A predefined set of arms is masked (unavailable) for an initial portion
    of the training process. During this phase, the policy is prevented from
    selecting these arms by zeroing out their probabilities and re-normalizing.
    """

    def __init__(
        self,
        # All arguments from the parent class
        num_arms: int,
        embedding_dim: int,
        device: torch.device,
        optimizer: UniversalBanditOptimizer,
        # New parameters for dynamic arm logic
        masked_arm_indices: Set[int],
        unmask_progress_point: float,
        **kwargs,
    ):
        """
        Initializes the dynamic arm bandit.

        Args:
            masked_arm_indices: A set of integer indices for the arms that
                                should be unavailable at the start.
            unmask_progress_point: A float between 0.0 and 1.0. The arms will
                                   remain masked until the training progress
                                   (current_step / total_steps) exceeds this value.
            *kwargs: All other arguments for the parent UniversalContextualBandit.
        """
        super().__init__(
            num_arms=num_arms,
            embedding_dim=embedding_dim,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        if not (0.0 <= unmask_progress_point <= 1.0):
            raise ValueError(
                "unmask_progress_point must be between 0.0 and 1.0"
            )

        self.unmask_progress_point = unmask_progress_point

        # Create a boolean mask tensor. `True` for available, `False` for masked.
        self.availability_mask = torch.ones(
            self.num_arms, device=self.device, dtype=torch.bool
        )

        indices_tensor = torch.tensor(
            list(masked_arm_indices), device=self.device, dtype=torch.long
        )
        self.availability_mask[indices_tensor] = False

    def train_batch(
        self,
        contexts: torch.Tensor,
        correct_arms_batch: List[List[int]],
        arm_embeddings: torch.Tensor,
        true_embeddings: torch.Tensor,
        # Additional parameters to control the dynamic state
        step_count: int,
        total_steps: int,
        **kwargs,
    ) -> Tuple:  # Return signature matches the parent class
        """
        Performs one training step with dynamic arm masking.

        The core logic overrides the sampling step of the parent method.
        """
        # Determine if we are in the masked phase of training
        progress = step_count / max(total_steps, 1)
        is_masked_phase = progress < self.unmask_progress_point

        with torch.no_grad():
            # Step 1: Compute policy probabilities as usual
            policy_probs = self.compute_policy_probabilities(
                contexts, arm_embeddings
            )

            # Step 2: Apply the mask if in the masked phase
            if is_masked_phase:
                # Zero out the probabilities for unavailable arms
                # The mask is broadcasted across the batch dimension
                masked_probs = policy_probs * self.availability_mask

                # Re-normalize the probabilities so each row sums to 1 again
                row_sums = masked_probs.sum(dim=1, keepdim=True)
                # Add a small epsilon to prevent division by zero
                normalized_probs = masked_probs / (row_sums + 1e-9)

                # Sample actions from the modified, safe probabilities
                sampled_actions = self.sample_actions(normalized_probs)
            else:
                # If not in the masked phase, perform standard sampling
                sampled_actions = self.sample_actions(policy_probs)

            # Step 3: Compute rewards and regrets with the final chosen actions
            rewards = self.compute_rewards(sampled_actions, correct_arms_batch)
            regrets = self.compute_regret(
                contexts, sampled_actions, correct_arms_batch, true_embeddings
            )

            ips_weights_scalar, dense_coeffs, ips_stats = (
                self._compute_ips_weights_and_stats(
                    rewards, policy_probs, sampled_actions
                )
            )

        # Compute gradients (dense by default)
        if self.use_dense:
            gradients = self.compute_ips_gradients_dense(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                coeffs=dense_coeffs,
            )
        else:
            gradients = self.compute_ips_gradients(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                ips_weights=ips_weights_scalar,
            )

        # ! also disable gradient stats
        # Gradient stats
        grad_row_norms = torch.norm(gradients, dim=1)
        grad_stats = {
            "grad_mean": float(grad_row_norms.mean().item()),
            "grad_max": float(grad_row_norms.max().item()),
            "grad_frob": float(torch.norm(gradients).item()),
        }

        updated_embeddings = self.update_embeddings(
            arm_embeddings, gradients, kwargs.get("current_lr")
        )

        return (
            updated_embeddings,
            rewards,
            regrets,
            policy_probs,
            grad_stats,
            ips_stats,
        )
