from typing import Dict, List, Optional, Set, Tuple

import torch

from .reranker import Reranker
from .universal_bandit_optimizer import UniversalBanditOptimizer
from .universal_contextual_bandit import UniversalContextualBandit


class RegretContextualBandit(UniversalContextualBandit):
    def compute_ips_gradients_dense(
        self,
        contexts: torch.Tensor,
        correct_arms_batch: List[List[int]],
        policy_probs: torch.Tensor,
        *,
        coeffs: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # directly compute the weight by knowing which is the best arm
        # policy_probs: [B, K]

        # correct_arms_batch: List[List[int]]
        # for each b, we have a list of correct arms
        # directly set the weight to 1 for the best arm
        # and 0 for all other arms
        B, _ = policy_probs.size()
        weights = torch.zeros_like(policy_probs)
        for b, correct_arms in enumerate(correct_arms_batch):
            weights[b, correct_arms] = 1.0
        coeffs = policy_probs - weights

        gradients = (1.0 / B) * torch.mm(coeffs.t(), contexts)

        return gradients

    def train_batch(
        self,
        contexts: torch.Tensor,
        correct_arms_batch: List[List[int]],
        arm_embeddings: torch.Tensor,
        true_embeddings: torch.Tensor,
        initial_embeddings: Optional[torch.Tensor] = None,
        step_count: int = 0,
        current_lr: Optional[float] = None,
        **kwargs,
    ) -> Tuple[
        torch.Tensor,  # updated_embeddings
        torch.Tensor,  # rewards
        torch.Tensor,  # regrets
        torch.Tensor,  # policy_probs
        Dict[str, float],  # grad_stats
        Dict[str, float],  # ips_stats
    ]:
        """
        One batched training step. Returns updated parameters and metrics.
        """
        with torch.no_grad():
            policy_probs = self.compute_policy_probabilities(
                contexts, arm_embeddings
            )
            sampled_actions = self.sample_actions(policy_probs)

            rewards = self.compute_rewards(sampled_actions, correct_arms_batch)
            regrets = self.compute_regret(
                contexts, sampled_actions, correct_arms_batch, true_embeddings
            )

        gradients = self.compute_ips_gradients_dense(
            contexts,
            correct_arms_batch,
            policy_probs,
        )

        # ! also disable gradient stats
        # Gradient stats
        grad_row_norms = torch.norm(gradients, dim=1)
        grad_stats = {
            "grad_mean": float(grad_row_norms.mean().item()),
            "grad_max": float(grad_row_norms.max().item()),
            "grad_frob": float(torch.norm(gradients).item()),
        }

        updated_embeddings = self.update_embeddings(
            arm_embeddings, gradients, current_lr
        )

        ips_stats: Dict[str, float] = {
            "ips_mean": 0.0,
            "ips_std": 0.0,
            "ips_min": 0.0,
            "ips_max": 0.0,
            "clip_frac": 0.0,
        }

        return (
            updated_embeddings,
            rewards,
            regrets,
            policy_probs,
            grad_stats,
            ips_stats,
        )
