# multihop_rag_bandit.py
import json
import time
from typing import Dict, List, Optional, Tuple

import torch

from .llm_judge import JudgeLLM
from .universal_bandit_optimizer import UniversalBanditOptimizer
from .universal_contextual_bandit import UniversalContextualBandit
from .universal_data_loader import UniversalBanditDataLoader


class MultiHopRAGBandit(UniversalContextualBandit):
    """
    Adapts the contextual bandit framework for multi-hop RAG.

    This bandit learns document embeddings (`arm_embeddings`) online. Instead of
    relying on a predefined set of correct arms, it uses a JudgeLLM to score
    retrieved documents, providing a reward signal to drive learning.
    """

    def __init__(
        self,
        # Parent class arguments
        num_arms: int,
        embedding_dim: int,
        device: torch.device,
        optimizer: UniversalBanditOptimizer,
        # New arguments for this class
        judge_llm: JudgeLLM,
        data_loader: UniversalBanditDataLoader,
        retrieval_k: int = 5,
        rerank_info_output_path: str = "rerank_info",
        **kwargs,
    ):
        super().__init__(
            num_arms=num_arms,
            embedding_dim=embedding_dim,
            device=device,
            optimizer=optimizer,
            **kwargs,
        )
        self.judge_llm = judge_llm
        self.retrieval_k = retrieval_k
        self.data_loader = data_loader
        self.rerank_info_output_path = (
            f"{rerank_info_output_path}_{time.strftime('%Y%m%d_%H%M%S')}.jsonl"
        )

    def _record_rerank_info(
        self,
        sampled_actions,
        batch_query_indices,
    ):
        # get query texts
        query_texts = self.data_loader.get_texts_by_indices(
            query_indices=batch_query_indices.tolist()
        )["queries"]
        # get arm texts
        arm_texts = self.data_loader.get_texts_by_indices(
            arm_indices=sampled_actions.tolist()
        )["arms"]
        # also save each query's id for referencing the correct answer
        query_ids = batch_query_indices.tolist()
        with open(self.rerank_info_output_path, "a") as f:
            for i in range(len(query_ids)):
                item = {
                    "query_id": query_ids[i],
                    "query_text": query_texts[i],
                    # "arm_text": arm_texts[i],
                    "arm_id": sampled_actions[i].item(),
                }
                f.write(json.dumps(item) + "\n")

    def train_batch(
        self,
        contexts: torch.Tensor,
        correct_arms_batch: List[List[int]],
        arm_embeddings: torch.Tensor,
        true_embeddings: torch.Tensor,
        batch_query_indices: torch.Tensor,
        initial_embeddings: Optional[torch.Tensor] = None,
        step_count: int = 0,
        current_lr: Optional[float] = None,
        **kwargs,
    ) -> Tuple:
        """
        Performs one online training step for a batch of sub-questions.
        """
        with torch.no_grad():
            # 1. Compute policy probabilities P(a|x) for all arms (documents)
            policy_probs = self.compute_policy_probabilities(
                contexts, arm_embeddings
            )

            # 2. Sample `k` candidate documents for each context
            candidate_indices = torch.multinomial(
                policy_probs, num_samples=self.retrieval_k
            )  # Shape: [B, k]

            # 3. Find out which documents were chosen by the policy
            reranked_indices = self.judge_llm.batched_judge(
                batch_query_indices, candidate_indices
            )  # Shape: [B]

            # Safety: clamp reranked positions into [0, k-1]
            reranked_indices = torch.clamp(
                reranked_indices, 0, candidate_indices.size(1) - 1
            )

            # for each b, select the corresponding doc id from reranked_indices
            sampled_actions = torch.gather(
                candidate_indices, 1, reranked_indices.unsqueeze(-1)
            ).squeeze(-1)

            # ! at this point, we have batched query indices,
            # ! and the candidate indices, which are enough
            # ! for us to feed into an external LLM to
            # ! generate the final answer and compute metrics
            self._record_rerank_info(sampled_actions, batch_query_indices)

            rewards = self.compute_rewards(sampled_actions, correct_arms_batch)
            regrets = self.compute_regret(
                contexts, sampled_actions, correct_arms_batch, true_embeddings
            )

            # 5. Compute IPS statistics for logging
            ips_weights_scalar, dense_coeffs, ips_stats = (
                self._compute_ips_weights_and_stats(
                    rewards, policy_probs, sampled_actions
                )
            )

        # 6. Compute gradients using the online rewards
        # The base class's gradient methods work perfectly once we have contexts,
        # actions, rewards, and policy probabilities.
        if self.use_dense:
            gradients = self.compute_ips_gradients_dense(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                initial_embeddings=initial_embeddings,
                coeffs=dense_coeffs,
            )
        else:
            gradients = self.compute_ips_gradients(
                contexts,
                sampled_actions,
                rewards,
                policy_probs,
                arm_embeddings=arm_embeddings,
                initial_embeddings=initial_embeddings,
                ips_weights=ips_weights_scalar,
            )

        # ! also disable gradient stats
        # Gradient stats
        grad_row_norms = torch.norm(gradients, dim=1)
        grad_stats = {
            "grad_mean": float(grad_row_norms.mean().item()),
            "grad_max": float(grad_row_norms.max().item()),
            "grad_frob": float(torch.norm(gradients).item()),
        }

        updated_embeddings = self.update_embeddings(
            arm_embeddings, gradients, current_lr
        )

        return (
            updated_embeddings,
            rewards,
            regrets,
            policy_probs,
            grad_stats,
            ips_stats,
        )
