import os
import re
import json
import uuid
import torch
import numpy as np
from scipy.stats import t
from collections import defaultdict

from verl.trainer.ppo.reward import compute_reward
from verl.trainer.ppo.rollout_policy.base_rollout_policy import BaseRolloutPolicy
from verl.utils.torch_functional import pad_sequence_to_length
from verl import DataProto


class ExploreExploitRolloutPolicy(BaseRolloutPolicy):
    def __init__(self, config, tokenizer, actor_rollout_wg, reward_fn):
        super().__init__(config, tokenizer, actor_rollout_wg, reward_fn)
        self.eps = 1e-8

        total_rollouts_per_prompt = self.config.actor_rollout_ref.rollout.n

        # Allocation arguments
        self.base_rollouts_per_prompt = self.config.rollout_policy.policy_kwargs.base_rollouts_per_prompt
        self.dynamic_rounds = self.config.rollout_policy.policy_kwargs.dynamic_rounds
        self.total_dynamic_budget_allocation = self.config.data.train_batch_size * (total_rollouts_per_prompt - self.base_rollouts_per_prompt)
        self.dynamic_budget_allocation_per_round = np.ceil(self.total_dynamic_budget_allocation / self.dynamic_rounds)

        assert self.base_rollouts_per_prompt < total_rollouts_per_prompt, f"base_rollouts_per_prompt ({self.base_rollouts_per_prompt}) must be less than rollout.n ({total_rollouts_per_prompt})"
        assert self.dynamic_rounds > 0 and isinstance(self.dynamic_rounds, int), f"dynamic_rounds ({self.dynamic_rounds}) must be positive integer"
        
        # ICL arguments
        self.icl_enabled = self.config.rollout_policy.policy_kwargs.icl_enabled
        self.icl_samples_per_prompt = self.config.rollout_policy.policy_kwargs.icl_samples_per_prompt
        with open(self.config.rollout_policy.policy_kwargs.train_similar_questions, "r") as f:
            self.train_similar_questions = json.load(f)
        with open(self.config.rollout_policy.policy_kwargs.icl_corpus_path, "r") as f:
            self.icl_corpus = json.load(f)
        
        # Advantage shaping arguments
        self.advantage_shaping_enabled = self.config.rollout_policy.policy_kwargs.advantage_shaping_enabled
        self.novelty_strength = self.config.rollout_policy.policy_kwargs.novelty_strength
        self.novelty_clamp = self.config.rollout_policy.policy_kwargs.novelty_clamp

    def expand(self, gen_batch, batch):
        gen_batch.non_tensor_batch.pop("raw_prompt_ids")

        gen_batch.meta_info["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(gen_batch.batch))], dtype=object)
        gen_batch.meta_info["idx2uid"] = {idx: uid for idx, uid in enumerate(gen_batch.meta_info["uid"])}

        # Generate base rollouts
        base_rollout_indices = [i for i in range(gen_batch.batch.batch_size[0])] * self.base_rollouts_per_prompt
        gen_batch_output_base = self._generate_rollouts(gen_batch, 
                                                        strengthening_rollout_indices=base_rollout_indices,
                                                        exploration_rollout_indices=[])

        # Apply statistical ICL budget allocation
        gen_batch_output_complete = self._apply_statistical_icl_budget_allocation(
            gen_batch, gen_batch_output_base, batch,
        )

        self._update_icl_corpus(gen_batch_output_complete, batch)

        gen_batch_output_complete.meta_info["idx2uid"] = gen_batch.meta_info["idx2uid"]

        return gen_batch_output_complete

    def compute_reward(self, batch, reward_fn):
        """
        Compute reward for the batch.
        
        Args:
            batch: DataProto batch
            reward_fn: Reward function
            
        Returns:
            Tuple[torch.Tensor, Dict[str, Any]]: Reward tensor and extra info
        """
        return compute_reward(batch, reward_fn)

    def postprocess_batch(self, batch):
        if not self.advantage_shaping_enabled:
            return batch

        old_log_probs = batch.batch["old_log_probs"]
        rewards = batch.batch["token_level_scores"].sum(dim=1)
        advantage = batch.batch["advantages"]
        uids = batch.non_tensor_batch["uid"]
        response_mask = batch.batch["response_mask"]
        
        # Compute average log probabilities per sequence
        seq_lengths = response_mask.sum(dim=1).clamp(min=1)
        avg_logprobs = (old_log_probs * response_mask).sum(dim=1) / seq_lengths
        
        # Group by UID and compute novelties
        uid_to_avg_pi = {}
        for uid, avg_logprob in zip(uids, avg_logprobs):
            if uid not in uid_to_avg_pi:
                uid_to_avg_pi[uid] = []
            uid_to_avg_pi[uid].append(avg_logprob)
        
        # Compute mean for each UID
        for uid in uid_to_avg_pi:
            uid_to_avg_pi[uid] = torch.stack(uid_to_avg_pi[uid]).mean()
        
        # Apply novelty bonus to correct responses
        for i, (uid, reward) in enumerate(zip(uids, rewards)):
            if reward == 1:  # Only shape correct responses
                mean_adv = (advantage[i] * response_mask[i]).sum() / seq_lengths[i]

                novelty = torch.exp(avg_logprobs[i] - uid_to_avg_pi[uid])
                bonus = torch.clamp(
                    (1.0 - novelty) * self.novelty_strength,
                    min=0,
                    max=self.novelty_clamp * mean_adv.item()
                )
                advantage[i] += bonus * response_mask[i]
        
        batch.batch["advantages"] = advantage
        batch.batch["returns"] = advantage
        return batch

    def repeat_and_merge_batch(self, 
                               batch, 
                               gen_batch_output,
                               original_batch_size=None):
        """
        Repeat the batch and merge with gen_batch_output.
        """
        batch.meta_info.pop("uid", None)

        idx2uid = gen_batch_output.meta_info["idx2uid"]
        uid2idx = {uid: idx for idx, uid in idx2uid.items()}

        gen_batch_output_uid = gen_batch_output.meta_info["uid"]
        repeat_indices = np.array([uid2idx[uid] for uid in gen_batch_output_uid])

        batch.non_tensor_batch["uid"] = np.array([idx2uid[idx] for idx in range(len(batch.batch))], dtype=object)

        batch = batch.select_idxs(repeat_indices)
        batch = batch.union(gen_batch_output)
        return batch

    def load_checkpoint(self, global_step_folder, global_step):
        icl_corpus_path = os.path.join(self.config.trainer.default_local_dir, f"icl_corpus_{global_step}.json")
        if os.path.exists(icl_corpus_path):
            print(f"Loading ICL corpus from {icl_corpus_path}")
            with open(icl_corpus_path, "r") as f:
                self.icl_corpus = json.load(f)
            print(f"Loaded ICL corpus with {len(self.icl_corpus)} items")
        else:
            print(f"Warning: No ICL corpus found at {icl_corpus_path}")

    def get_metrics(self, batch):
        """
        Get metrics from the batch.
        """
        metrics = super().get_metrics(batch)

        token_scores = batch.batch["token_level_scores"].sum(-1).detach().cpu().numpy()
        uids = batch.non_tensor_batch["uid"]
        is_icl = batch.non_tensor_batch["is_icl_rollout"]

        total_icl = float(np.sum(is_icl))
        correct_icl = float(np.sum(is_icl * (token_scores == 1)))
        total_flip_ratio = (correct_icl / total_icl) if total_icl > 0 else 0.0

        # Flip ratio per prompt
        id2correct = defaultdict(list)
        for i in range(len(token_scores)):
            if is_icl[i]: id2correct[uids[i]].append(token_scores[i] == 1)
        flip_ratio_per_prompt = float(np.mean([np.mean(v) for v in id2correct.values()])) if id2correct else 0.0
        any_correct_ratio = float(np.mean([any(v) for v in id2correct.values()])) if id2correct else 0.0

        # Effective batchsize without ICL
        eff_wo_icl = 0
        prompt2scores = defaultdict(list)
        for i in range(len(token_scores)):
            if not is_icl[i]:
                prompt2scores[uids[i]].append(token_scores[i])
        for arr in prompt2scores.values():
            arr = np.array(arr)
            if np.any(arr != 1) and np.any(arr == 1):
                eff_wo_icl += 1

        eff = metrics.get("additional_metrics/effective_batch_size", 0)
        metrics.update({
            "additional_metrics/total_flip_ratio": total_flip_ratio,
            "additional_metrics/flip_ratio_per_prompt_avg": flip_ratio_per_prompt,
            "additional_metrics/effective_batch_size_without_icl": eff_wo_icl,
            "additional_metrics/relative_increase_in_effective_batch_size":
                100.0 * (eff - eff_wo_icl) / (eff_wo_icl + self.eps),
            "additional_metrics/any_correct_ratio": any_correct_ratio,
        })
        return metrics

    # Rollout helper functions
    def _generate_rollouts(self, gen_batch, strengthening_rollout_indices, exploration_rollout_indices):
        """
        Generate rollouts for each prompt.

        Args:
            gen_batch [DataProto]: Generation batch
            strengthening_rollout_indices [List[int]]: Indices of rollouts to generate
            exploration_rollout_indices [List[int]]: Indices of rollouts to generate

        Returns:
            DataProto: Generation batch with rollouts
        """
        uids = gen_batch.meta_info.pop("uid", [])
        # Strengthening rollouts
        if strengthening_rollout_indices:
            strengthening_gen_batch_selected = gen_batch.select_idxs(strengthening_rollout_indices)
        else:
            strengthening_gen_batch_selected = None

        # Exploration rollouts
        if exploration_rollout_indices:
            exploration_gen_batch_selected = self._generate_icl_batch(gen_batch, exploration_rollout_indices)
        else:
            exploration_gen_batch_selected = None

        # Merge batches if both exist, otherwise pick the one that exists
        if strengthening_gen_batch_selected is not None and exploration_gen_batch_selected is not None:
            gen_batch_selected = self._merge_data_batches(strengthening_gen_batch_selected, exploration_gen_batch_selected)
        elif strengthening_gen_batch_selected is not None:
            gen_batch_selected = strengthening_gen_batch_selected
        elif exploration_gen_batch_selected is not None:
            gen_batch_selected = exploration_gen_batch_selected
        else:
            # Ideally should never happen
            raise ValueError("No rollouts to generate")

        gen_batch_selected.meta_info["n"] = 1
        gen_batch_selected.meta_info["custom_sampling_params"] = ["n"]

        gen_batch_output_selected = self.actor_rollout_wg.generate_sequences(gen_batch_selected)

        gen_batch_output_selected.non_tensor_batch["is_icl_rollout"] = np.concatenate([
            np.zeros(len(strengthening_rollout_indices)), np.ones(len(exploration_rollout_indices))
        ])

        # Add metadata
        gen_batch.meta_info["uid"] = uids
        gen_batch_output_selected.meta_info["uid"] = [
            uids[i] for i in strengthening_rollout_indices + exploration_rollout_indices
        ]
        gen_batch_output_selected.meta_info["idx2uid"] = gen_batch.meta_info["idx2uid"]

        return gen_batch_output_selected

    def _apply_statistical_icl_budget_allocation(self, gen_batch, gen_batch_output_base, batch):
        """
        Apply statistical ICL budget allocation based on reward variance.

        Args:
            gen_batch: Generation batch with base rollouts
            gen_batch_output_base: Generation batch with base rollouts
            batch: Original batch

        Returns:
            DataProto: Final batch with statistical ICL budget allocation
        """
        gen_batch_output_complete = gen_batch_output_base
        idx2uid = gen_batch_output_complete.meta_info["idx2uid"]
        for round_idx in range(self.dynamic_rounds):
            uid2score = defaultdict(list)
            bsz = gen_batch_output_complete.batch.batch_size[0]
            uid = gen_batch_output_complete.meta_info["uid"]
            scores = self._merge_and_score_rollouts(gen_batch_output_complete, batch)

            for i in range(bsz):
                uid2score[uid[i]].append(scores[i])

            std_list = []
            for idx in range(len(idx2uid)):
                std_list.append(np.std(np.array(uid2score[idx2uid[idx]])))
            std_list = np.array(std_list)

            certain_list = self._statistical_allocation_with_exploration(
                R_list=[np.array(uid2score[idx2uid[idx]]) for idx in range(len(idx2uid))], 
                B=self.dynamic_budget_allocation_per_round,
            )

            self._log_budget_allocation(round_idx, std_list, certain_list)

            # Generate rollout_indices
            strengthening_rollout_indices = []
            exploration_rollout_indices = []
            for idx, weight in enumerate(certain_list):
                prompt_scores = np.array(uid2score[idx2uid[idx]])
                if all(prompt_scores != 1) and self.icl_enabled:
                    exploration_rollout_indices.extend([idx] * weight)
                else:
                    strengthening_rollout_indices.extend([idx] * weight)

            # Generate rollouts
            gen_batch_output_dynamic = self._generate_rollouts(
                gen_batch, strengthening_rollout_indices, exploration_rollout_indices,
            )

            # Merge rollouts
            gen_batch_output_complete = self._merge_data_batches(
                gen_batch_output_complete, gen_batch_output_dynamic,
            )

        self._log_icl_rollouts(gen_batch_output_complete)
        return gen_batch_output_complete
 
    def _merge_and_score_rollouts(self, gen_batch_output, batch):
        """
        Merge gen_batch_output and batch and compute reward.
        """
        # repeat to align with repeated responses in rollout and merge with gen_batch_output
        batch_copy = self.repeat_and_merge_batch(
            batch=batch, 
            gen_batch_output=gen_batch_output
        )

        # compute reward for base rollouts
        reward_tensor, _ = compute_reward(batch_copy, self.reward_fn)
        scores = reward_tensor.sum(-1).cpu().tolist()
        return scores

    def _merge_data_batches(self, gen_batch1, gen_batch2):
        consolidated_gen_batch_output = {}
        for key in gen_batch1.batch.keys():
            gen_batch1_value = gen_batch1.batch[key]
            gen_batch2_value = gen_batch2.batch[key]
            if (key == "responses") or (key == "rollout_log_probs"):
                dim = gen_batch2_value.shape[-1]
                assert gen_batch1_value.shape[-1] == dim, f"hidden_dim mismatch: {gen_batch1_value.shape[-1]} != {dim} for key: {key}"
                consolidated_gen_batch_output[key] = torch.cat(
                    [gen_batch1_value, gen_batch2_value], dim=0
                )
            else:
                max_seq_len = max(gen_batch1_value.shape[-1], gen_batch2_value.shape[-1])

                gen_batch1_value = pad_sequence_to_length(
                    gen_batch1_value,
                    max_seq_len=max_seq_len,
                    pad_token_id=self.tokenizer.pad_token_id if (key == "prompts") or (key == "input_ids") else 0,
                    left_pad=True
                )

                gen_batch2_value = pad_sequence_to_length(
                    gen_batch2_value,
                    max_seq_len=max_seq_len,
                    pad_token_id=self.tokenizer.pad_token_id if (key == "prompts") or (key == "input_ids") else 0,
                    left_pad=True
                )

                consolidated_gen_batch_output[key] = torch.cat(
                    [gen_batch1_value, gen_batch2_value], dim=0
                )

        consolidated_gen_batch_output = DataProto.from_single_dict(consolidated_gen_batch_output)

        # Add non tensor keys to consolidated_gen_batch_output
        for key in gen_batch1.non_tensor_batch.keys():
            try:
                consolidated_gen_batch_output.non_tensor_batch[key] = np.concatenate([
                    gen_batch1.non_tensor_batch[key], gen_batch2.non_tensor_batch[key]
                ])
            except Exception as e:
                # Ideally should never happen
                print(f"Failed to concatenate non-tensor key '{key}': {e}")

        # Merge metadata
        if ("uid" in gen_batch1.meta_info) or ("uid" in gen_batch2.meta_info):
            consolidated_gen_batch_output.meta_info["uid"] = np.concatenate([
                    gen_batch1.meta_info["uid"], gen_batch2.meta_info["uid"]
                ])
        consolidated_gen_batch_output.meta_info["idx2uid"] = gen_batch1.meta_info["idx2uid"]
        return consolidated_gen_batch_output

    @staticmethod
    def _statistical_allocation_with_exploration(R_list, B, conf=0.95, lam=0.1):
        """
        R_list: List of reward lists
        B: Budget
        conf: Confidence level
        lam: Lambda
        """

        R_list = [np.asarray(R, float) for R in R_list]
        N = len(R_list)
        G = np.array([len(R) for R in R_list], dtype=float)
        alloc = np.zeros(N, dtype=int)

        S = np.array([R.std(ddof=1) if len(R) >= 2 else np.nan for R in R_list], dtype=float)

        def t_hw(s, n, conf):
            if n <= 1 or not np.isfinite(s): return float("inf")
            t_mult = t.ppf(1 - (1 - conf)/2, df=n - 1)
            return float(t_mult * s / np.sqrt(n))

        def expected_next_t_drop(s, n, conf):
            if n <= 1 or not np.isfinite(s): return float("inf")
            cur = t_hw(s, n, conf)
            next_hw = t_hw(s, n + 1, conf)
            return cur - next_hw

        T = 0
        while B > 0:
            priority = np.empty(N, dtype=float)
            for i in range(N):
                delta = expected_next_t_drop(S[i], G[i], conf)
                bonus = lam * np.sqrt(max(0.0, np.log(1 + max(T, 1))) / max(G[i], 1e-12))
                priority[i] = delta + bonus

            i_star = int(np.argmax(priority))
            alloc[i_star] += 1

            G[i_star] += 1.0

            B -= 1
            T += 1

        return alloc

    # ICL helper functions
    def _generate_icl_batch(self, gen_batch, exploration_rollout_indices):
        icl_input_ids = []
        for idx in exploration_rollout_indices:
            messages = self._inverse_chat_template_qwen(gen_batch.batch["input_ids"][idx].detach().cpu().tolist())
            icl_input_ids.append(self._prepare_icl_input_ids(messages))
        
        gen_batch_exploration = self._create_batch_from_icl_input_ids(icl_input_ids)

        # Add non tensor keys to gen_batch_exploration
        for key in gen_batch.non_tensor_batch:
            gen_batch_exploration.non_tensor_batch[key] = np.array([gen_batch.non_tensor_batch[key][idx] for idx in exploration_rollout_indices])

        return gen_batch_exploration

    def _update_icl_corpus(self, gen_batch_output, batch):
        scores = self._merge_and_score_rollouts(gen_batch_output, batch)
        is_icl_rollout = gen_batch_output.non_tensor_batch["is_icl_rollout"]
        for i in range(gen_batch_output.batch.batch_size[0]):
            if scores[i] == 1 and not is_icl_rollout[i]:
                user_prompt = self._inverse_chat_template_qwen(
                    gen_batch_output.batch["prompts"][i].detach().cpu().tolist(),
                )["user_prompt"]
                response = self.tokenizer.decode(
                    gen_batch_output.batch["responses"][i], skip_special_tokens=True,
                )
                if self.config.data.enable_thinking:
                    if "</think>" in response:
                        response = response.split("</think>")[-1].strip()
                    else:
                        continue
                self.icl_corpus[user_prompt] = self.icl_corpus.get(user_prompt, []) + [response]

        self._dump_icl_corpus()

    def _truncate_responses(self, icl_prompts, responses, user_prompt):
        icl_prompts_len = sum([len(self.tokenizer.encode(p, add_special_tokens=False)) for p in icl_prompts])
        user_prompt_len = len(self.tokenizer.encode(user_prompt, add_special_tokens=False))

        response_len = self.config.data.max_prompt_length * 0.95 - user_prompt_len - icl_prompts_len
        response_len = int(response_len / self.icl_samples_per_prompt)
        responses = [
            self.tokenizer.decode(
                self.tokenizer.encode(response, add_special_tokens=False)[:response_len], skip_special_tokens=True,
            )
            for response in responses
        ]
        return responses

    def _apply_icl_template(self, icl_questions, responses, user_prompt):
        """
        Construct an XML-style in-context learning (ICL) template
        using solved examples.
        """
        remove_suffix_math = "\n\nLet's think step by step and output the final answer within \\boxed{}."
        remove_suffix_coding = "Let's think step by step within <think> </think> tags followed by detailed steps and final code using the provided format with backticks.\n\n"
        remove_prefix_coding = "You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests."
        lines = []
        lines.append("<task>")
        lines.append("  You are given several worked examples, each with a <problem> and a <solution>.")
        lines.append("  Extract a general strategy, then think through the new problem")
        lines.append("  and finally provide the detailed solution.")
        lines.append("</task>")
        lines.append("")

        lines.append("<examples>")
        for idx, (icl_question, response) in enumerate(zip(icl_questions, responses), start=1):
            
            # Remove suffix and prefix
            icl_question = icl_question.replace(remove_suffix_math, '')
            icl_question = icl_question.replace(remove_suffix_coding, '')
            icl_question = icl_question.replace(remove_prefix_coding, '')
            icl_question = icl_question.strip()

            lines.append(f"  <example id='{idx}'>")
            lines.append(f"    <problem>{icl_question}</problem>")
            lines.append(f"    <solution>{response}</solution>")
            lines.append("  </example>")
            lines.append("")
        lines.append("</examples>")
        lines.append("")
        lines.append(f"<new_problem>{user_prompt}</new_problem>")
        lines.append("")

        icl_prompt = "\n".join(lines)

        icl_prompt = self.tokenizer.apply_chat_template(
            [{"role": "user", "content": icl_prompt}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=self.config.data.enable_thinking
        )

        return icl_prompt

    def _dump_icl_corpus(self):
        os.makedirs(self.config.trainer.default_local_dir, exist_ok=True)
        if hasattr(self, 'global_step') and self.global_step % self.config.trainer.save_freq == 0:
            with open(os.path.join(self.config.trainer.default_local_dir, f"icl_corpus_{self.global_step}.json"), "w") as f:
                json.dump(self.icl_corpus, f)

    def _inverse_chat_template_qwen(self, chat_output):
        if isinstance(chat_output, (list, tuple)):
            text = self.tokenizer.decode(chat_output, skip_special_tokens=False)
        else:
            text = chat_output
        return_dict = {}
        system_prompt = re.findall("(?<=<\|im_start\|>system\n)(.*\n?)(?=<\|im_end\|>)", text, re.DOTALL)
        if system_prompt:
            return_dict["system_prompt"] = system_prompt[0].strip()
        user_prompt = re.findall("(?<=<\|im_start\|>user\n)(.*\n?)(?=<\|im_end\|>)", text, re.DOTALL)
        if user_prompt:
            return_dict["user_prompt"] = user_prompt[0].strip()
        else:
            return_dict["user_prompt"] = ""
        return return_dict

    def _get_icl_prompt(self, user_prompt):
        # Get similar questions from the train set
        similar_questions = self.train_similar_questions.get(user_prompt, [])
        random_question_idx = np.random.choice(range(len(similar_questions)), size=len(similar_questions), replace=False)
        icl_questions = []
        chosen_responses = []
        for idx in random_question_idx:
            q = similar_questions[idx]
            corpus_responses = self.icl_corpus.get(q, [])
            if corpus_responses:
                chosen = np.random.choice(corpus_responses)
                icl_questions.append(q)
                chosen_responses.append(chosen)
                if len(icl_questions) >= self.icl_samples_per_prompt:
                    break
        if len(icl_questions)>0:
            responses = self._truncate_responses(icl_questions, chosen_responses, user_prompt)
            icl_prompt = self._apply_icl_template(icl_questions, responses, user_prompt)
        else:
            icl_prompt = self.tokenizer.apply_chat_template(
                [{"role": "user", "content": user_prompt}],
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=self.config.data.enable_thinking
            )
        return icl_prompt

    def _log_icl_rollouts(self, gen_batch_output):
        log_file = os.path.join(self.config.trainer.default_local_dir, "icl_rollouts.txt")
        selected_indices = np.where(gen_batch_output.non_tensor_batch["is_icl_rollout"] == 1)[0]
        try:
            if len(selected_indices) > 0 and self.global_step % 20 == 0:
                with open(log_file, "a") as f:
                    f.write(f"Step {self.global_step}\n")
                    f.write(f"{self.tokenizer.decode(gen_batch_output.batch['input_ids'][selected_indices[0]], skip_special_tokens=True)}\n")
                    f.write("=================================\n\n")
        except Exception as e:
            print(f"Warning: Could not write to log file {log_file}: {e}")

    def _log_budget_allocation(self, round_idx, std_list, certain_list):
        """
        Log budget allocation details to file.
        """
        log_file = os.path.join(self.config.trainer.default_local_dir, "budget_allocation.txt")
        try:
            with open(log_file, "a") as f:
                f.write("=== Budget Allocation ===\n")
                f.write(f"Round {round_idx}\n")
                for i, n in enumerate(certain_list.tolist()):
                    f.write(f"Prompt {i:03d}: std = {std_list[i]} extra_rollout = {n}\n")
                f.write(f"Total rollout allocated: {certain_list.sum().item()} / {self.total_dynamic_budget_allocation}\n")
                f.write("=================================\n\n")
        except Exception as e:
            print(f"Warning: Could not write to log file {log_file}: {e}")

    def _create_batch_from_icl_input_ids(self, icl_input_ids):
        """
        Create a DataProto batch for each ICL sample.
        Args:
            icl_input_ids: Unpadded input_ids for each ICL sample. List[torch.Tensor]
        Returns:
            DataProto
        """        
        batch_dict = {}

        max_prompt_length = max([len(input_ids) for input_ids in icl_input_ids])
        batch_dict = {'input_ids': [], 'attention_mask': [], 'position_ids': []}
        raw_prompt_ids = []
        for ids_i in icl_input_ids:
            input_ids, attention_mask, position_ids = self._postprocess_data(
                unpadded_input_ids=ids_i.unsqueeze(0),
                max_length=max_prompt_length,
                left_pad=True
            )
            raw_prompt_ids.append(ids_i.tolist())

            batch_dict['input_ids'].append(input_ids[0])
            batch_dict['attention_mask'].append(attention_mask[0])
            batch_dict['position_ids'].append(position_ids[0])

        # Stack tensors
        for key in batch_dict:
            batch_dict[key] = torch.vstack(batch_dict[key])

        # Create DataProto
        batch_dict = DataProto.from_single_dict(batch_dict)
        batch_dict.non_tensor_batch["raw_prompt_ids"] = np.array(raw_prompt_ids, dtype=object)

        # Set generation parameters
        batch_dict.meta_info["eos_token_id"] = self.tokenizer.eos_token_id
        return batch_dict

    def _prepare_icl_input_ids(self, response_dict):
        user_prompt = response_dict["user_prompt"]
        icl_prompt = self._get_icl_prompt(user_prompt)
        icl_input_ids = self.tokenizer.encode(
            icl_prompt,
            add_special_tokens = False,
            return_tensors = "pt",
        )[0]
        return icl_input_ids

    @staticmethod
    def get_policy_name():
        return "explore_exploit_rollout_policy"
