"""
Universal contextual bandit algorithm for tool embedding learning.

This module implements the batched IPS (Inverse Propensity Scoring) algorithm
with an efficient, optimizer-agnostic workflow that is compatible with
`embed_trainer/universal_bandit_optimizer.py` and the experiment manager.

It mirrors the logic of contextual_bandit.py while removing legacy scheduler
logic and redundant operations, and reuses efficient vectorized tricks shown in
low_rank_contextual_bandit.py (e.g., single-pass IPS stats computation).
"""

from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F

from .universal_bandit_optimizer import UniversalBanditOptimizer


class UniversalContextualBandit:
    """
    GPU-based contextual bandit for learning tool (arm) embeddings.

    - Softmax policy with temperature and optional epsilon-uniform mixing
    - Dense IPS gradient for stable and efficient updates
    - Uses a provided UniversalBanditOptimizer instance for parameter updates
    """

    def __init__(
        self,
        num_arms: int,
        embedding_dim: int,
        device: torch.device,
        optimizer: UniversalBanditOptimizer,
        *,
        temperature: float = 1.0,
        epsilon: float = 0.0,
        lambda_reg: float = 0.0,
        clip_value: float = 10.0,
        noise_std: float = 0.0,
    ):
        """
        Args:
            num_arms: Number of arms (tools)
            embedding_dim: Embedding dimension
            device: Torch device
            optimizer: External optimizer for parameter updates
            temperature: Softmax temperature
            epsilon: Epsilon for epsilon-uniform mixing
            lambda_reg: L2 regularization coefficient (with optional prior)
            clip_value: Clipping value for IPS weights
            noise_std: If > 0 and lambda_reg > 0, used to build inv_sigma prior
        """
        self.num_arms = num_arms
        self.embedding_dim = embedding_dim
        self.device = device
        self.optimizer = optimizer

        self.temperature = temperature
        self.epsilon = float(max(0.0, min(1.0, epsilon)))
        self.lambda_reg = lambda_reg
        self.clip_value = clip_value
        self.noise_std = noise_std
        # Optional Gaussian prior inverse covariance for regularization
        self.inv_sigma: Optional[torch.Tensor] = None
        if self.lambda_reg > 0.0 and self.noise_std > 0.0:
            self.inv_sigma = torch.eye(
                self.embedding_dim, device=self.device, dtype=torch.float32
            ) / (self.noise_std**2)
        # Use dense IPS gradients by default
        self.use_dense: bool = True

    # -------------------- Policy and Sampling --------------------
    def compute_policy_probabilities(
        self, contexts: torch.Tensor, arm_embeddings: torch.Tensor
    ) -> torch.Tensor:
        """Compute policy P(a|x) via softmax(contexts @ arms^T / tau) with epsilon mix.

        Args:
            contexts: [B, D]
            arm_embeddings: [K, D]
        Returns:
            policy_probs: [B, K]
        """
        scores = torch.mm(contexts, arm_embeddings.t())  # [B, K]
        probs = F.softmax(scores / self.temperature, dim=1)
        if self.epsilon > 0.0:
            uniform = torch.full_like(probs, 1.0 / self.num_arms)
            probs = (1.0 - self.epsilon) * probs + self.epsilon * uniform
        return probs

    @staticmethod
    def sample_actions(policy_probs: torch.Tensor) -> torch.Tensor:
        """Sample actions from policy probabilities. Returns [B]."""
        return torch.multinomial(policy_probs, num_samples=1).squeeze(-1)

    def compute_rewards(
        self, sampled_actions: torch.Tensor, correct_arms_batch: List[List[int]]
    ) -> torch.Tensor:
        """Binary rewards: 1 if chosen arm is in correct set else 0. Returns [B]."""
        B = sampled_actions.size(0)
        rewards = torch.zeros(B, device=self.device)
        for i, (a, correct) in enumerate(
            zip(sampled_actions, correct_arms_batch)
        ):
            rewards[i] = 1.0 if a.item() in correct else 0.0
        return rewards

    def compute_regret(
        self,
        contexts: torch.Tensor,
        sampled_actions: torch.Tensor,
        correct_arms_batch: List[List[int]],
        true_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        """
        Latent regret using true embeddings (for reporting only):
        best_correct_score - chosen_score, per sample.
        Returns [B].
        """
        B = sampled_actions.size(0)
        true_scores = torch.mm(contexts, true_embeddings.t())  # [B, K]
        regrets = torch.zeros(B, device=self.device)
        for i, correct in enumerate(correct_arms_batch):
            best = true_scores[i, correct].max()
            chosen = true_scores[i, sampled_actions[i]]
            regrets[i] = best - chosen
        return regrets

    # -------------------- IPS Weights and Gradients --------------------
    def _compute_ips_weights_and_stats(
        self,
        rewards: torch.Tensor,
        policy_probs: torch.Tensor,
        sampled_actions: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
        """
        Compute two IPS variants and summary statistics in a single pass.
        Returns:
          - ips_weights_clipped: [B] scalar weights g_t = 1 - r_t / P_t(a_t)
          - dense_coeffs: [B, K] coefficients P_t(a) - 1{a_t=a} * r_t / P_t(a_t)
          - stats: dict for logging
        """
        # P_t(a_t)
        chosen_probs = policy_probs.gather(
            1, sampled_actions.unsqueeze(-1)
        ).squeeze(-1)
        chosen_probs = chosen_probs.clamp(min=1e-9)

        # Scalar IPS weights (clipped) for sparse update
        ips_weights = 1.0 - (rewards / chosen_probs)
        ips_weights_clipped = torch.clamp(
            ips_weights, -self.clip_value, self.clip_value
        )

        # Dense coefficients for dense update
        correction = rewards / chosen_probs  # [B]
        action_mask = torch.zeros(
            rewards.size(0),
            self.num_arms,
            device=self.device,
            dtype=policy_probs.dtype,
        )
        action_mask.scatter_(1, sampled_actions.unsqueeze(-1), 1.0)
        dense_coeffs = policy_probs - action_mask * correction.unsqueeze(-1)

        stats = {
            "ips_mean": float(ips_weights_clipped.mean().item()),
            "ips_std": float(ips_weights_clipped.std().item()),
            "ips_min": float(ips_weights_clipped.min().item()),
            "ips_max": float(ips_weights_clipped.max().item()),
            "clip_frac": float(
                (
                    (ips_weights_clipped == self.clip_value)
                    | (ips_weights_clipped == -self.clip_value)
                )
                .float()
                .mean()
                .item()
            ),
        }
        return ips_weights_clipped, dense_coeffs, stats

    def compute_ips_gradients_dense(
        self,
        contexts: torch.Tensor,
        sampled_actions: torch.Tensor,
        rewards: torch.Tensor,
        policy_probs: torch.Tensor,
        *,
        arm_embeddings: Optional[torch.Tensor] = None,
        initial_embeddings: Optional[torch.Tensor] = None,
        coeffs: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Dense IPS gradient per note:
          g_{t,a} = ( P_t(a) - 1{a_t=a} * r_t / P_t(a_t) ) * x_t
        Batched and averaged over the batch to yield [K, D].
        Optionally adds L2 prior gradient λ Σ^{-1} (θ - θ̃) if configured.
        If `coeffs` is provided (from the IPS helper), reuse it to avoid recompute.
        """
        B = contexts.size(0)

        if coeffs is None:
            # P_t(a_t)
            chosen_probs = policy_probs.gather(
                1, sampled_actions.unsqueeze(-1)
            ).squeeze(-1)
            chosen_probs = chosen_probs.clamp(min=1e-9)
            # r_t / P_t(a_t)
            correction = rewards / chosen_probs  # [B]
            # One-hot mask for chosen actions
            action_mask = torch.zeros(
                B, self.num_arms, device=self.device, dtype=policy_probs.dtype
            )
            action_mask.scatter_(1, sampled_actions.unsqueeze(-1), 1.0)
            # Coefficients per (t, a)
            coeffs = policy_probs - action_mask * correction.unsqueeze(
                -1
            )  # [B, K]

        # Gradient: (1/B) * coeffs^T @ x_t  -> [K, D]
        gradients = (1.0 / B) * torch.mm(coeffs.t(), contexts)

        # Optional regularization
        if (
            self.inv_sigma is not None
            and arm_embeddings is not None
            and initial_embeddings is not None
        ):
            reg_grad = self.lambda_reg * torch.mm(
                arm_embeddings - initial_embeddings, self.inv_sigma
            )
            gradients = gradients + reg_grad

        return gradients

    def compute_ips_gradients(
        self,
        contexts: torch.Tensor,
        sampled_actions: torch.Tensor,
        rewards: torch.Tensor,
        policy_probs: torch.Tensor,
        *,
        arm_embeddings: Optional[torch.Tensor] = None,
        initial_embeddings: Optional[torch.Tensor] = None,
        ips_weights: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Sparse/original IPS gradient using scalar weights:
          g_t = 1 - r_t / P_t(a_t)
          ∇_Θ = (1/B) * ( one_hot(a_t)^T * g_t ) @ x_t
        Adds optional L2 prior gradient if configured.
        """
        B = contexts.size(0)

        # Use provided ips_weights if available; otherwise compute
        if ips_weights is None:
            chosen_probs = policy_probs.gather(
                1, sampled_actions.unsqueeze(-1)
            ).squeeze(-1)
            chosen_probs = chosen_probs.clamp(min=1e-9)
            ips_weights = 1.0 - (rewards / chosen_probs)
            ips_weights = torch.clamp(
                ips_weights, -self.clip_value, self.clip_value
            )

        # One-hot mask scaled by g_t
        action_mask = torch.zeros(B, self.num_arms, device=self.device)
        action_mask.scatter_(1, sampled_actions.unsqueeze(-1), 1.0)
        weighted_mask = action_mask * ips_weights.unsqueeze(-1)

        # Gradient: (1/B) * weighted_mask^T @ x_t -> [K, D]
        gradients = (1.0 / B) * torch.mm(weighted_mask.t(), contexts)

        # Optional regularization
        if (
            self.inv_sigma is not None
            and arm_embeddings is not None
            and initial_embeddings is not None
        ):
            reg_grad = self.lambda_reg * torch.mm(
                arm_embeddings - initial_embeddings, self.inv_sigma
            )
            gradients = gradients + reg_grad

        return gradients

    # -------------------- Update and Train Step --------------------
    def update_embeddings(
        self,
        arm_embeddings: torch.Tensor,
        gradients: torch.Tensor,
        current_lr: Optional[float] = None,
    ) -> torch.Tensor:
        """Update embeddings via optimizer. Returns updated tensor."""
        return self.optimizer.step(arm_embeddings, gradients, current_lr)

    def train_batch(
        self,
        contexts: torch.Tensor,
        correct_arms_batch: List[List[int]],
        arm_embeddings: torch.Tensor,
        true_embeddings: torch.Tensor,
        initial_embeddings: Optional[torch.Tensor] = None,
        step_count: int = 0,
        current_lr: Optional[float] = None,
        **kwargs,
    ) -> Tuple[
        torch.Tensor,  # updated_embeddings
        torch.Tensor,  # rewards
        torch.Tensor,  # regrets
        torch.Tensor,  # policy_probs
        Dict[str, float],  # grad_stats
        Dict[str, float],  # ips_stats
    ]:
        """
        One batched training step. Returns updated parameters and metrics.
        """
        with torch.no_grad():
            policy_probs = self.compute_policy_probabilities(
                contexts, arm_embeddings
            )
            sampled_actions = self.sample_actions(policy_probs)

            rewards = self.compute_rewards(sampled_actions, correct_arms_batch)
            regrets = self.compute_regret(
                contexts, sampled_actions, correct_arms_batch, true_embeddings
            )

            # ! do not pre-compute the weights here
            # ! as we do not track the ips stats for now

            ips_weights_scalar, dense_coeffs, ips_stats = (
                self._compute_ips_weights_and_stats(
                    rewards, policy_probs, sampled_actions
                )
            )

        # Compute gradients (dense by default)
        if self.use_dense:
            gradients = self.compute_ips_gradients_dense(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                initial_embeddings=initial_embeddings,
                coeffs=dense_coeffs,
            )
        else:
            gradients = self.compute_ips_gradients(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                initial_embeddings=initial_embeddings,
                ips_weights=ips_weights_scalar,
            )

        # ! also disable gradient stats
        # Gradient stats
        grad_row_norms = torch.norm(gradients, dim=1)
        grad_stats = {
            "grad_mean": float(grad_row_norms.mean().item()),
            "grad_max": float(grad_row_norms.max().item()),
            "grad_frob": float(torch.norm(gradients).item()),
        }

        updated_embeddings = self.update_embeddings(
            arm_embeddings, gradients, current_lr
        )

        return (
            updated_embeddings,
            rewards,
            regrets,
            policy_probs,
            grad_stats,
            ips_stats,
        )

    # -------------------- Evaluation --------------------
    @torch.no_grad()
    def evaluate_policy(
        self,
        all_contexts: torch.Tensor,
        arm_embeddings: torch.Tensor,
        all_correct_arms: List[List[int]],
        ks: List[int] = [1, 3, 5, 10],
    ) -> Dict[str, float]:
        """
        Compute recall@k and NDCG@k for the current policy parameters.

        - Recall@k: The fraction of queries for which at least one correct
          arm appears in the top-k recommendations. It's a measure of
          whether the model can find a correct item.
        - NDCG@k (Normalized Discounted Cumulative Gain): A measure of ranking
          quality that rewards placing correct items higher in the list. It's
          normalized by the score of the ideal ranking.
        """
        scores = torch.mm(all_contexts, arm_embeddings.t())  # [N, num_arms]
        num_queries = len(all_contexts)

        # Return empty dict if there are no queries to evaluate
        if num_queries == 0:
            return {}

        rslt = {}

        # Determine the maximum k needed for pre-computation
        max_k = max(ks) if ks else 0
        max_k = min(max_k, self.num_arms)
        if max_k <= 0:
            return {}

        # Get top-k indices and scores for the largest k needed. This is more
        # efficient than calling topk repeatedly in a loop.
        _, topk_indices = torch.topk(scores, k=max_k, dim=1)  # [N, max_k]

        # Precompute discount terms for DCG: 1 / log2(rank + 1).
        # Ranks are 1-based, so for indices 0..max_k-1, we use log2(2..max_k+1).
        discounts = 1.0 / torch.log2(
            torch.arange(2, max_k + 2, device=self.device, dtype=torch.float32)
        )

        for k in ks:
            if k > max_k:
                continue

            # Slice the pre-computed indices and discounts for the current k
            current_topk_indices = topk_indices[:, :k]
            current_discounts = discounts[:k]

            total_hit = 0.0
            total_recall = 0.0
            total_ndcg = 0.0

            for i in range(num_queries):
                pred_list = current_topk_indices[i].tolist()
                true_set = set(all_correct_arms[i])

                # If there are no correct items for this query, its contribution
                # to all metrics is 0.
                if not true_set:
                    continue

                # --- Hit@k (any-hit) ---
                intersection_count = sum(1 for p in pred_list if p in true_set)
                if intersection_count > 0:
                    total_hit += 1.0

                # --- Recall@k (classical) ---
                per_query_recall = intersection_count / max(1, len(true_set))
                total_recall += per_query_recall

                # --- NDCG@k ---
                # 1. Calculate DCG (Discounted Cumulative Gain)
                # Create a tensor of relevance scores (1 if correct, 0 otherwise).
                relevance = torch.tensor(
                    [
                        1.0 if pred_item in true_set else 0.0
                        for pred_item in pred_list
                    ],
                    device=self.device,
                )
                dcg = torch.sum(relevance * current_discounts)

                # 2. Calculate IDCG (Ideal Discounted Cumulative Gain)
                # This is the DCG of the best possible ranking where all correct
                # items are ranked at the top.
                num_correct = len(true_set)
                idcg = torch.sum(discounts[: min(k, num_correct)])

                # 3. Compute NDCG = DCG / IDCG and add to the total.
                # IDCG > 0 is guaranteed by the `if not true_set` check above.
                total_ndcg += (dcg / idcg).item()

            # Average metrics over all queries and store them
            rslt[f"hit_at_{k}"] = total_hit / num_queries
            rslt[f"recall_at_{k}"] = total_recall / num_queries
            rslt[f"ndcg_at_{k}"] = total_ndcg / num_queries

        return rslt
