from typing import List, Tuple

import torch

from .reranker import Reranker
from .universal_bandit_optimizer import UniversalBanditOptimizer
from .universal_contextual_bandit import UniversalContextualBandit


class RerankingContextualBandit(UniversalContextualBandit):
    """
    A contextual bandit that uses a two-stage retrieve-and-rerank policy.
    Inherits from the base bandit and overrides the training step.
    """

    def __init__(
        self,
        # Takes all arguments of the parent class
        num_arms: int,
        embedding_dim: int,
        device: torch.device,
        optimizer: UniversalBanditOptimizer,
        # Plus the reranker-specific arguments
        reranker: Reranker,
        retrieval_k: int,
        **kwargs,
    ):
        super().__init__(
            num_arms=num_arms,
            embedding_dim=embedding_dim,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        self.reranker = reranker
        self.retrieval_k = retrieval_k

    def train_batch(
        self,
        contexts: torch.Tensor,
        correct_arms_batch: List[List[int]],
        arm_embeddings: torch.Tensor,
        true_embeddings: torch.Tensor,
        batch_query_indices: torch.Tensor,  # New: needs indices for reranker
        **kwargs,
    ) -> Tuple:  # Return type is the same as parent

        with torch.no_grad():
            # 1. Get policy probabilities from the base retrieval model
            policy_probs = self.compute_policy_probabilities(
                contexts, arm_embeddings
            )

            # 2. Retrieve k candidate arms based on policy
            candidate_indices = torch.multinomial(
                policy_probs, num_samples=self.retrieval_k
            )  # [B, k]

            # 3. Use the reranker to get the index of the best arm (from 0 to k-1)
            reranked_indices = self.reranker.rerank(
                batch_query_indices, candidate_indices
            )  # Shape [B, K]

            # For each b, we find out which arm k was 0 (the best one)
            best_arm_indices = torch.argmin(reranked_indices, dim=1).unsqueeze(
                -1
            )

            # 4. Map the reranked index back to the original arm index
            sampled_actions = torch.gather(
                candidate_indices, 1, best_arm_indices
            ).squeeze(-1)

            # 5. Compute rewards and regrets with the final chosen action
            rewards = self.compute_rewards(sampled_actions, correct_arms_batch)
            regrets = self.compute_regret(
                contexts, sampled_actions, correct_arms_batch, true_embeddings
            )

            ips_weights_scalar, dense_coeffs, ips_stats = (
                self._compute_ips_weights_and_stats(
                    rewards, policy_probs, sampled_actions
                )
            )

        # Compute gradients (dense by default)
        if self.use_dense:
            gradients = self.compute_ips_gradients_dense(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                coeffs=dense_coeffs,
            )
        else:
            gradients = self.compute_ips_gradients(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                ips_weights=ips_weights_scalar,
            )

        # ! also disable gradient stats
        # Gradient stats
        grad_row_norms = torch.norm(gradients, dim=1)
        grad_stats = {
            "grad_mean": float(grad_row_norms.mean().item()),
            "grad_max": float(grad_row_norms.max().item()),
            "grad_frob": float(torch.norm(gradients).item()),
        }

        updated_embeddings = self.update_embeddings(
            arm_embeddings, gradients, kwargs.get("current_lr")
        )

        return (
            updated_embeddings,
            rewards,
            regrets,
            policy_probs,
            grad_stats,
            ips_stats,
        )
