# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import copy
import logging
import os
import re
import uuid
import json
import ast
from dotenv import load_dotenv
from collections import defaultdict, Counter
from typing import List, Optional, Union, Dict, Any, Callable, Tuple
from together import Together
import ray
import base64
import zlib
import pickle
from retry import retry
import random

import datasets
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin

import verl.utils.torch_functional as verl_F
from verl.utils.dataset.rl_dataset import RLHFDataset
from datasets import load_dataset
from verl.utils.dataset.dynamic_prompts import PROPOSER_SYSTEM_PROMPT_F, PROPOSER_SYSTEM_PROMPT_STDIN, INITIAL_PROPOSER_F, INITIAL_PROPOSER_STDIN, NEXT_PROPOSER_F, NEXT_PROPOSER_STDIN, SOLVER_F, SOLVER_STDIN, FINAL_SOLVER, PROPOSER_CONTEXT_LIST_SUBGOALS_MEAN_REWARD, PROPOSER_CONTEXT_SYNTHESIS_SUBGOALS_MEAN_REWARD_ROLLOUTS, SYNTHESIS_SYSTEM_PROMPT, SYNTHESIS_SUMMARY_PROMPT
from verl.utils.dataset.compute_outputs import compute_outputs

logger = logging.getLogger(__name__)

"""
The DynamicDataset class is not a subclass of the RLHFDataset class, and it implements all methods itself.

The possible options for question generation are:
- frozen_proposer_api: A frozen proposer from the Together API.

The possible options for aggregation are:
- all_subgoals_human_reward: Aggregate all rewards for the subgoal (requires num. parallel subgoals = 1), and passed to the proposer in human-readable categories.

The possible options for dataset are:
- livecodebench_v6: The LiveCodeBench v6 dataset.
"""

def has_quotes(s):
    return (s.startswith('"') and s.endswith('"')) or (s.startswith("'") and s.endswith("'"))

def extract_code(responses: List[str]) -> List[str]:
    #return response.split("<code>")[1].split("</code>")[0]
    extracted_responses = []
    for response in responses:
        try:
            extracted_responses.append(response.split("```python")[-1].split("```")[0])
        except:
            extracted_responses.append("Did not wrap code in ```python``` tags.")
    
    return extracted_responses

@retry(AssertionError, tries=5, delay=2)
def query_together_formatted(client: Together, messages: List[Dict[str, str]], model: str, max_tokens: int = 2048, n: int = 1, temperature: float = 1.0, top_k: float = 1.0) -> str:
    resp = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=max_tokens,
        n=n,
        temperature=temperature,
        top_k=top_k
    )
    resp = resp.choices[0].message.content
    formatted_resp = resp.split("<answer>")[-1].split("</answer>")[0][:1300]
    return str(formatted_resp), str(resp)

@retry(tries=5, delay=2)
def query_together(client: Together, messages: List[Dict[str, str]], model: str, max_tokens: int = 2048, n: int = 1, temperature: float = 1.0, top_k: float = 1.0, top_p: float = 1.0) -> str:
    resp = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=max_tokens,
        n=n,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p
    )
    return resp

class FrozenProposerAPI:
    """
    A component responsible for generating new questions using the Together API.
    Does not track state, all state is tracked in the DynamicDataset class.
    """
    
    def __init__(self, config: DictConfig, proposer_model: str, testtype: str):
        self.config = config
        self.proposer_model = proposer_model
        self.num_inputs = self.config.dynamic.get("num_test_cases", 3)
        self.most_k_recent_subgoals = self.config.dynamic.get("most_k_recent_subgoals", 3)
        self.proposer_context_method = self.config.dynamic.get("proposer_context_method", "list_subgoals_rewards")
        self.testtype = testtype

        load_dotenv()
        
        # Load Together API key from environment
        self.api_key = os.getenv("TOGETHER_API_KEY")
        if not self.api_key:
            raise ValueError("TOGETHER_API_KEY not found in environment variables")
        
        # Initialize Together client
        self.client = Together(api_key=self.api_key)
    
    def _create_proposer_context(self, historical_data: Dict[str, Any], method: str = "list_subgoals_rewards") -> str:

        registered_methods = ["list_subgoals_rewards", "llm_synthesis_subgoals_rewards_rollouts"]
        if method not in registered_methods:
            raise ValueError(f"Method {method} not found in registered methods: {registered_methods}")

        # Put the latest subgoals first
        prev_subgoals_sorted = sorted(historical_data["subgoals"].values(), key=lambda x: x["step"], reverse=True)
        recent_k_prev_subgoals = prev_subgoals_sorted[:self.most_k_recent_subgoals]
        prev_subgoals = str([(subgoal["subgoal"], subgoal["mean_reward"]) for subgoal in recent_k_prev_subgoals])
        print(f"Previous subgoals: {prev_subgoals}")
        
        if method == "list_subgoals_rewards":
            print(f"Using list_subgoals_rewards method")
            to_log = {
                "subgoals_passed_to_proposer": recent_k_prev_subgoals,
                "subgoals_passed_to_proposer_string": prev_subgoals,
            }
            return PROPOSER_CONTEXT_LIST_SUBGOALS_MEAN_REWARD.format(subgoals=prev_subgoals), None
        
        elif method == "llm_synthesis_subgoals_rewards_rollouts":
            print(f"Using llm_synthesis_subgoals_rewards_rollouts method")
            
            bin_rewards = lambda x: "TOO HARD" if x < 0.2 else "RIGHT DIFFICULTY" if x < 0.8 else "TOO EASY"
            prev_subgoals = str([(subgoal["subgoal"], bin_rewards(subgoal["mean_reward"]), extract_code(subgoal["rollouts"]), subgoal["ground_truth"], subgoal["computed_outputs"]) for subgoal in recent_k_prev_subgoals])
            just_subgoals = str([(subgoal["subgoal"], bin_rewards(subgoal["mean_reward"])) for subgoal in recent_k_prev_subgoals])
            
            messages = [
                {"role": "system", "content": SYNTHESIS_SYSTEM_PROMPT},
                {"role": "user", "content": SYNTHESIS_SUMMARY_PROMPT.format(subgoals=prev_subgoals, goal=historical_data["final_question"])}
            ]
            synthesis_summary_extracted, synthesis_summary = query_together_formatted(client=self.client, 
                                               messages=messages, 
                                               model=self.proposer_model, 
                                               max_tokens=16384,
                                               n=1,
                                               temperature=1.0,
                                               top_k=1.0)
            print(f"Synthesis summary: {synthesis_summary_extracted}")

            # Can't have ndarrays when logging to json after
            for i in range(len(recent_k_prev_subgoals)):
                co = recent_k_prev_subgoals[i]["computed_outputs"]
                recent_k_prev_subgoals[i]["computed_outputs"] = [list(x) for x in co]

            to_log = {
                "subgoals_passed_to_proposer": list(recent_k_prev_subgoals),
                "subgoals_passed_to_proposer_string": prev_subgoals,
                "input": messages,
                "synthesis_summary": synthesis_summary,
                "synthesis_summary_extracted": synthesis_summary_extracted,
            }

            return PROPOSER_CONTEXT_SYNTHESIS_SUBGOALS_MEAN_REWARD_ROLLOUTS.format(synthesis=synthesis_summary_extracted, subgoals=just_subgoals), to_log
    
    def propose_initial_subgoals(self, final_question: str, num_subgoals: int, max_tokens: int = 2000) -> List[str]:
        """
        Generate initial subgoals using the Together API.

        Args:
            final_question: The final question to solve
            num_subgoals: Number of subgoals to generate
            max_tokens: Maximum number of tokens to generate

        Returns:
            List of subgoals
        """

        if self.testtype == "functional":
            print(f"Using functional test type to generate initial subgoals")
            proposer_base_messages = [{"role": "system", "content": PROPOSER_SYSTEM_PROMPT_F.format(num_inputs=self.num_inputs)}]
            proposer_prompts = proposer_base_messages + [{"role": "user", "content": INITIAL_PROPOSER_F.format(endproblem=final_question, num_inputs=self.num_inputs)}]
        elif self.testtype == "stdin":
            print(f"Using stdin test type to generate initial subgoals")
            proposer_base_messages = [{"role": "system", "content": PROPOSER_SYSTEM_PROMPT_STDIN.format(num_inputs=self.num_inputs)}]
            proposer_prompts = proposer_base_messages + [{"role": "user", "content": INITIAL_PROPOSER_STDIN.format(endproblem=final_question, num_inputs=self.num_inputs)}]
        else:
            raise ValueError(f"Test type {self.testtype} not supported")

        proposer_output = query_together(
            client=self.client,
            messages=proposer_prompts,
            model=self.proposer_model,
            max_tokens=max_tokens,
            temperature=1.0,
            top_k=1.0,
            top_p=1.0,
            n=num_subgoals
        )

        proposer_outputs = [proposer_output.choices[i].message.content for i in range(num_subgoals)]

        # final_proposer_outputs = []
        # for i in range(num_subgoals):
        #     final_proposer_output = self.client.chat.completions.create(
        #         model=self.proposer_model,
        #         messages=proposer_prompts + [{"role": "assistant", "content": proposer_output.choices[i].message.content}] + [{"role": "user", "content": f"## Task: From your previous response, succintly summarize the subgoal in ```subgoal``` tags, the Python program in ```python``` tags and the inputs in ```input``` tags. Do not include any other text."}],
        #         temperature=0.7,
        #         top_k=1.0,
        #         max_tokens=4096,
        #         n=1,
        #     )
        #     final_proposer_outputs.append(final_proposer_output.choices[0].message.content)

        extracted_subgoals = [proposer_outputs[i].split("```subgoal")[-1].split("```")[0][:700] for i in range(num_subgoals)]
        extracted_programs = [proposer_outputs[i].split("```python")[-1].split("```")[0] for i in range(num_subgoals)]
        
        extracted_inputs = []
        for i in range(num_subgoals):
            content = proposer_outputs[i]
            inputs = []
            split_content = content.split("```input")
            # Only use the last num_inputs inputs in the response output
            for j in range(len(split_content) - int(self.num_inputs), len(split_content)):
                if j < len(split_content):
                    try:
                        base_input = split_content[j].split("```")[0]
                        base_input = base_input.replace("\\n", "\n")
                        base_input = base_input.lstrip("\n").rstrip("\n")
                        base_input = base_input.strip()
                    except Exception as e:
                        print(f"Error parsing input: {e}")
                        base_input = ""

                    try:
                        parsed_input = ast.literal_eval(base_input)
                        inputs.append(parsed_input)
                    except Exception as e:
                        #print(f"Error parsing input: {e}")
                        re_formatted_input = f"'{base_input}'"
                        inputs.append(re_formatted_input)
            extracted_inputs.append(inputs)

        computed_outputs = []

        for test_cases, program in zip(extracted_inputs, extracted_programs):
            one_test_case_outputs = compute_outputs(code=program, extracted_inputs=test_cases, testtype=self.testtype)
            computed_outputs.append(one_test_case_outputs)
    
        subgoals = list(map(lambda x: self._format_question(x[0], x[1][0], x[1][1], x[1][2]), enumerate(zip(extracted_subgoals, extracted_inputs, computed_outputs))))

        return subgoals, [{"input": proposer_prompts, "output_step1": proposer_outputs[i], "extracted_subgoal": subgoals[i], "extracted_program": extracted_programs[i], "extracted_inputs": extracted_inputs[i], "computed_outputs": computed_outputs[i]} for i in range(num_subgoals)]
        
    def propose_subgoals(self, num_subgoals: int, historical_data: Dict[str, Any], proposer_context: str, proposer_context_logs: Dict[str, Any], max_tokens: int = 2000) -> List[str]:
        """
        Generate new subgoals using the Together API with dynamic prompts.
        
        Args:
            num_subgoals: Number of subgoals to generate
            historical_data: The historical data to use for the next subgoal
            max_tokens: Maximum number of tokens to generate

        Returns:
            List of generated question dictionaries
        """
        
        final_question = historical_data["final_question"]
        # TODO: Modify this to aggregate subgoals and rewards according to different mechanisms

        if self.testtype == "functional":
            print(f"Using functional test type to generate next subgoals")
            proposer_base_messages = [{"role": "system", "content": PROPOSER_SYSTEM_PROMPT_F.format(num_inputs=self.num_inputs)}]
            proposer_prompts = proposer_base_messages + [{"role": "user", "content": NEXT_PROPOSER_F.format(endproblem=final_question, context=proposer_context, num_inputs=self.num_inputs)}]
        elif self.testtype == "stdin":
            print(f"Using stdin test type to generate next subgoals")
            proposer_base_messages = [{"role": "system", "content": PROPOSER_SYSTEM_PROMPT_STDIN.format(num_inputs=self.num_inputs)}]
            proposer_prompts = proposer_base_messages + [{"role": "user", "content": NEXT_PROPOSER_STDIN.format(endproblem=final_question, context=proposer_context, num_inputs=self.num_inputs)}]
        else:
            raise ValueError(f"Test type {self.testtype} not supported")

        # Generate subgoals
        proposer_output = query_together(
            client=self.client,
            messages=proposer_prompts,
            model=self.proposer_model,
            max_tokens=max_tokens,
            temperature=1.0,
            top_k=1.0,
            top_p=1.0,
            n=num_subgoals
        )

        proposer_outputs = [proposer_output.choices[i].message.content for i in range(num_subgoals)]

        # final_proposer_outputs = []
        # for i in range(num_subgoals):
        #     final_proposer_output = self.client.chat.completions.create(
        #         model=self.proposer_model,
        #         messages=proposer_prompts + [{"role": "assistant", "content": proposer_output.choices[i].message.content}] + [{"role": "user", "content": f"## Task: From your previous response, succintly summarize the subgoal in ```subgoal``` tags, the Python program in ```python``` tags and the inputs in ```input``` tags. Do not include any other text."}],
        #         temperature=0.7,
        #         top_k=1.0,
        #         max_tokens=4096,
        #         n=1,
        #     )
        #     final_proposer_outputs.append(final_proposer_output.choices[0].message.content)

        extracted_subgoals = [proposer_outputs[i].split("```subgoal")[-1].split("```")[0][:700] for i in range(num_subgoals)]
        extracted_programs = [proposer_outputs[i].split("```python")[-1].split("```")[0] for i in range(num_subgoals)]
        
        extracted_inputs = []
        for i in range(num_subgoals):
            content = proposer_outputs[i]
            inputs = []
            split_content = content.split("```input")
            # Only use the last num_inputs inputs in the response output
            for j in range(len(split_content) - int(self.num_inputs), len(split_content)):
                if j < len(split_content):
                    try:
                        base_input = split_content[j].split("```")[0]
                        base_input = base_input.replace("\\n", "\n")
                        base_input = base_input.lstrip("\n").rstrip("\n")
                        base_input = base_input.strip()
                    except Exception as e:
                        print(f"Error parsing input: {e}")
                        base_input = ""

                    try:
                        parsed_input = ast.literal_eval(base_input)
                        inputs.append(parsed_input)
                    except Exception as e:
                        #print(f"Error parsing input: {e}")
                        re_formatted_input = f"'{base_input}'"
                        inputs.append(re_formatted_input)
            extracted_inputs.append(inputs)

        computed_outputs = []

        for test_cases, program in zip(extracted_inputs, extracted_programs):
            one_test_case_outputs = compute_outputs(code=program, extracted_inputs=test_cases, testtype=self.testtype)
            computed_outputs.append(one_test_case_outputs)
        
        subgoals = list(map(lambda x: self._format_question(x[0], x[1][0], x[1][1], x[1][2]), enumerate(zip(extracted_subgoals, extracted_inputs, computed_outputs))))
        
        return subgoals, [{"proposer_context": proposer_context_logs, "input": proposer_prompts, "output_step1": proposer_outputs[i], "extracted_subgoal": subgoals[i], "extracted_program": extracted_programs[i], "extracted_inputs": extracted_inputs[i], "computed_outputs": computed_outputs[i]} for i in range(num_subgoals)]

    def _format_question(self, index: int, extracted_subgoal: str, extracted_input: List[str], computed_output: List[str], generated_at_step: int = -1, question_id: str = None) -> Dict[str, Any]:
        """
        Format a generated question into the expected dataset format.
        
        Args:
            question_text: The generated question text
            index: Index of the question
            
        Returns:
            Formatted question dictionary
        """

        if len(extracted_input) == 0:
            extracted_input = ["No input provided"]*self.num_inputs

        if len(computed_output) == 0:
            computed_output = ["No output provided"]*len(extracted_input)
        
        if self.testtype == "functional":
            if isinstance(extracted_input[0], str) and ("\n" in extracted_input[0]):
                example_input = extracted_input[0].replace("\\n", "\n").replace("\n", ", ")
            else:
                example_input = extracted_input[0]
            solver_prompt = SOLVER_F.format(subproblem=extracted_subgoal[:self.config.max_prompt_length], max_tokens=self.config.max_response_length, num_inputs=self.num_inputs, example_input=example_input, example_output=computed_output[0])
        elif self.testtype == "stdin":
            solver_prompt = SOLVER_STDIN.format(subproblem=extracted_subgoal[:self.config.max_prompt_length], max_tokens=self.config.max_response_length, num_inputs=self.num_inputs, example_input=extracted_input[0], example_output=computed_output[0])
        else:
            raise ValueError(f"Test type {self.testtype} not supported")

        return {
            "prompt": [{"role": "user", "content": solver_prompt}],
            "data_source": "dynamic_generated",
            "ability": "coding",
            "reward_model": {
                "style": "model",
                "ground_truth": [{"input": extracted_input[i], "output": computed_output[i], "testtype": self.testtype} for i in range(1, len(extracted_input))],  # Will be filled by the model during training
                "example_test_cases": [{"input": extracted_input[0], "output": computed_output[0]}],
            },
            "extra_info": {
                "split": "dynamic",
                "index": index,
                "generated_at_step": generated_at_step,
                "question_id": question_id if question_id != None else str(uuid.uuid4()),
                "raw_subgoal": extracted_subgoal[:self.config.max_prompt_length]
            }
        }


class DynamicDataset(Dataset):
    """
    A dynamic dataset that can add new questions at runtime.
    
    This dataset has the ability to:
    - Add new questions at every nth step
    - Generate questions with a proposer model
    - Save generated questions for later use

    This dataset does not have the ability to:
    - Support multi-modal data
    
    Args:
        data_files (str or list): Path(s) to Parquet file(s).
        tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
        config (DictConfig): Options including dynamic dataset configuration.
        processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
    """
    
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        config: DictConfig,
        processor: Optional[ProcessorMixin] = None,
    ):
        # Initialize dynamic dataset configuration
        self.dynamic_config = config.get("dynamic", {})
        self.config = config
        self.dataset = self.dynamic_config.get("dataset", "livecodebench_v6")
        self.final_question_id = self.dynamic_config.get("final_question_id", "0")
        self.final_question, self.val_test_cases = self._setup_test_dataset(dataset_name=self.dataset, final_question_id=self.final_question_id)
        self.testtype = self.val_test_cases[0]["testtype"]

        self.add_questions_every_n_steps = self.dynamic_config.get("add_questions_every_n_steps", 1)
        self.num_new_questions_per_step = self.dynamic_config.get("num_new_questions_per_step", 1)
        self.num_test_cases = self.dynamic_config.get("num_test_cases", 3)

        self.proposer_context_method = self.dynamic_config.get("proposer_context_method", "list_subgoals_rewards")
        self.proposer_cls = self.dynamic_config.get("question_generator", "frozen_proposer_api")
        self.proposer_model = self.dynamic_config.get("proposer_model", "deepseek-ai/DeepSeek-R1")
        self.proposer_max_tokens = self.dynamic_config.get("proposer_max_tokens", 2000)

        self.serialize_dataset = False
        self.prompt_key = config.get("prompt_key", "prompt")

        # Initialize tracking variables
        self.current_step = 0
        self.current_dataset = []
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_prompt_length = config.get("max_prompt_length", 1024)
        self.return_raw_chat = config.get("return_raw_chat", False)
        self.return_full_prompt = config.get("return_full_prompt", False)
        self.truncation = config.get("truncation", "error")
        self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
        self.need_tools_kwargs = config.get("need_tools_kwargs", False)

        # Contains the final problem, subgoals and rewards
        # subgoals is a list of dictionaries, with the format:
        # {"step": ...,subgoal": "...", "rollouts": [], "rewards": []}
        self.historical_data = {
            "final_question": self.final_question,
            "subgoals": {},
       }

        self.dataset_ref = ray.put(self.current_dataset)
        self.historical_data_ref = ray.put(self.historical_data)
                
        self._initialize_proposer(proposer_cls=self.proposer_cls, proposer_model=self.proposer_model, testtype=self.testtype)
    
    def _translate_private_test_cases(self, encoded_data):
        decoded_data = base64.b64decode(encoded_data)
        decompressed_data = zlib.decompress(decoded_data)
        original_data = pickle.loads(decompressed_data)
        final_data = json.loads(original_data)
        return final_data
    
    def _setup_test_dataset(self, dataset_name: str, final_question_id: str):
        """Get the final problem from the dataset."""

        # Only livecodebench_v6 is implemented for now
        # TODO: Add HumanEval and other datasets
        match dataset_name:
            case "livecodebench_v6":
                lcb_codegen_v6 = load_dataset("livecodebench/code_generation_lite", version_tag="release_v6", trust_remote_code=True)["test"]
                print(final_question_id)
                print(type(final_question_id))
                final_question_id = str(final_question_id)
                row = lcb_codegen_v6.filter(lambda q: q['question_id'] == final_question_id)[0]
                final_question = row["question_content"]
                public_test_cases = json.loads(row["public_test_cases"])
                private_test_cases = self._translate_private_test_cases(row["private_test_cases"])
            case _:
                raise ValueError(f"Dataset {dataset_name} not supported")
            
        #print(f"Final question: {final_question}")
        #print(f"Test cases: {test_cases}")

        return final_question, private_test_cases
    
    def _initialize_proposer(self, proposer_cls: str, proposer_model: str, testtype: str):
        """Initialize the dynamic dataset components."""
        
        match proposer_cls:
            case "frozen_proposer_api":
                self.proposer = FrozenProposerAPI(self.config, self.proposer_model, testtype)
            case _:
                raise ValueError(f"Proposer type {proposer_cls} not supported")
        
        print(f"Dynamic dataset initialized: add_questions_every_n_steps={self.add_questions_every_n_steps}, "
                   f"num_new_questions_per_step={self.num_new_questions_per_step}",
                   f"proposer={proposer_cls}, proposer_model={proposer_model}",
                   f"proposer_context_method={self.proposer_context_method}")
    
    def _validate_subgoal(self, subgoal: dict):
        """
        Check if a subgoal is valid.

        A subgoal is valid if:
        - There is a correct number of test cases (both input and output)
        - There is a subgoal (not empty)
        - There is a valid Python program (but this is implicitly checked by checking for all test cases)
        """

        if len(subgoal["reward_model"]["ground_truth"]) < (self.num_test_cases - 1):
            #print(f"Insufficient number of test cases provided {len(subgoal['reward_model']['ground_truth'])} < {self.num_test_cases - 1}")
            # The first test case is the sample test case given in the prompt, so we need to subtract 1
            # We don't need to check if there is a sample test case because ground_truth is only filled after the sample test case
            # ground_truth having a len > 0 implies that the sample test case was provided
            return False

        for test_case in subgoal["reward_model"]["ground_truth"]:
            if (test_case["input"] == "") or (test_case["input"] == "''") or (test_case["input"] == "No input provided") or (test_case["input"] == None):
                #print(f"No input provided for test case {test_case}")
                return False
            if (test_case["output"] == "") or (test_case["output"] == "''") or (test_case["output"] == "None") or (test_case["output"] == "None\n") or (test_case["output"] == "[TIMEOUT]") or (test_case["output"] == None):
                #print(f"No output provided for test case {test_case}")
                return False
            if "Error" in test_case["output"]:
                #print(f"Error in output for test case {test_case}")
                return False
        
        for test_case in subgoal["reward_model"]["example_test_cases"]:
            if (test_case["input"] == "") or (test_case["input"] == "''") or (test_case["input"] == "No input provided") or (test_case["input"] == None):
                #print(f"No input provided for test case {test_case}")
                return False
            if (test_case["output"] == "") or (test_case["output"] == "''") or (test_case["output"] == "None") or (test_case["output"] == "None\n") or (test_case["output"] == "[TIMEOUT]") or (test_case["output"] == None):
                #print(f"No output provided for test case {test_case}")
                return False
            if "Error" in test_case["output"]:
                #print(f"Error in output for test case {test_case}")
                return False
        
        if subgoal["extra_info"]["raw_subgoal"] == "":
            #print("No raw subgoal provided")
            return False
        
        if "<think>" in subgoal["extra_info"]["raw_subgoal"]:
            return False
        
        return True
    
    def sample_subgoals(self, num_subgoals: int, subgoals_sample_set: List[Dict[str, Any]], sample_method: str = "uniform"):
        if sample_method not in ["uniform"]:
            raise ValueError(f"Sample method {sample_method} not supported")
        
        if sample_method == "uniform":
            return random.sample(subgoals_sample_set, num_subgoals)
    
    def add_historical_subgoals(self, step: int, step_data: Dict[str, Any], logger):
        """
        Add a new subgoal to the historical data.
        step_data is a dictionary of all subgoals in the step.

        Input:
            step_data: Dict[str, Any]
                - subgoal: str
                - rollouts: List[str]
                - ground_truth: List[Dict[str, Any]]
                - computed_outputs: List[str]
                - rewards: List[float]
                - example_inputs: List[str]
                - example_outputs: List[str]
                - generated_at_step: int

        historical_data = {
            "final_question": self.final_question,
            "subgoals": {
                <uid>: {
                    "step": int,
                    "subgoal": str,
                    "rollouts": List[str],
                    "ground_truth": List[Dict[str, Any]],
                    "mean_reward": float,
                    "min_reward": float,
                    "max_reward": float,
                    "example_inputs": List[str],
                    "example_outputs": List[str],
                    "generated_at_step": int,
                },
            },
        }
        """

        # Log per question rewards here as well (only the newly added questions), as well as logging buffer size

        historical_data = ray.get(self.historical_data_ref)

        buffer_size = 0 # Buffer size is just the number of subgoals in step_data that are already in historical_data
        non_buffer_size = 0 # Non-buffer size is just the number of subgoals in step_data that are not in historical_data

        buffer_rewards_mean = []
        buffer_rewards_min = []
        buffer_rewards_max = []
        non_buffer_rewards_mean = []
        non_buffer_rewards_min = []
        non_buffer_rewards_max = []

        # Overwrite previous subgoals with the same uid with the current step's rollouts and rewards
        # historical_data["subgoals"] contains only the most recent rollouts and rewards for each subgoal
        for uid, data in step_data.items():
            if data["ground_truth"][0]["input"] == "[PAD]":
                continue
            
            mean_r = np.mean(data["rewards"])
            min_r = np.min(data["rewards"])
            max_r = np.max(data["rewards"])

            if uid in historical_data["subgoals"]:
                historical_data["subgoals"][uid]["step"] = step
                historical_data["subgoals"][uid]["rollouts"] = data["rollouts"]
                historical_data["subgoals"][uid]["computed_outputs"] = data["computed_outputs"]
                historical_data["subgoals"][uid]["mean_reward"] = mean_r
                historical_data["subgoals"][uid]["min_reward"] = min_r
                historical_data["subgoals"][uid]["max_reward"] = max_r
                
                buffer_size += 1
                buffer_rewards_mean.append(mean_r)
                buffer_rewards_min.append(min_r)
                buffer_rewards_max.append(max_r)
            else:
                historical_data["subgoals"][uid] = {
                    "step": step,
                    "subgoal": data["subgoal"],
                    "rollouts": data["rollouts"],
                    "ground_truth": data["ground_truth"],
                    "computed_outputs": data["computed_outputs"],
                    "mean_reward": mean_r,
                    "min_reward": min_r,
                    "max_reward": max_r,
                    "example_inputs": data["example_inputs"],
                    "example_outputs": data["example_outputs"],
                    "generated_at_step": data["generated_at_step"],
                }

                non_buffer_size += 1
                non_buffer_rewards_mean.append(mean_r)
                non_buffer_rewards_min.append(min_r)
                non_buffer_rewards_max.append(max_r)

        self.historical_data_ref = ray.put(historical_data)
        print(f"historical_data_ref at step {step} in add_historical_subgoals: {self.historical_data_ref}")

        # Log the total buffer size, the mean reward of the buffer (should be higher than the overall reward), mean reward of non-buffer subgoals
        # Log the mean, min, and max reward of each subgoal in the buffer
        # Log the mean, min and max reward of each non-buffer subgoal

        # Handle the case where the buffer is empty
        if buffer_size == 0:
            basics = {
                # Non-buffer rewards
                "non_buffer/size": non_buffer_size,
                "non_buffer/rewards/mean": np.mean(non_buffer_rewards_mean),
                "non_buffer/rewards/min": np.min(non_buffer_rewards_min),
                "non_buffer/rewards/max": np.max(non_buffer_rewards_max),
                
                # Buffer rewards (empty)
                "buffer/buffer_size": 0,
            }

        else:
            basics = {
                # Non-buffer rewards
                "non_buffer/size": non_buffer_size,
                "non_buffer/rewards/mean": np.mean(non_buffer_rewards_mean),
                "non_buffer/rewards/min": np.min(non_buffer_rewards_min),
                "non_buffer/rewards/max": np.max(non_buffer_rewards_max),
                
                # Buffer rewards
                "buffer/buffer_size": buffer_size,
                "buffer/rewards/mean": np.mean(buffer_rewards_mean),
                "buffer/rewards/min": np.min(buffer_rewards_min),
                "buffer/rewards/max": np.max(buffer_rewards_max),
            }
        
        # Add all questions in step_data (want to track per question performance over time)
        per_question_mean = {f"solver_per_question/rewards/{uid}/mean": np.mean(step_data[uid]["rewards"]) for uid in step_data.keys() if step_data[uid]["ground_truth"][0]["input"] != "[PAD]"}
        per_question_min = {f"solver_per_question/rewards/{uid}/min": np.min(step_data[uid]["rewards"]) for uid in step_data.keys() if step_data[uid]["ground_truth"][0]["input"] != "[PAD]"}
        per_question_max = {f"solver_per_question/rewards/{uid}/max": np.max(step_data[uid]["rewards"]) for uid in step_data.keys() if step_data[uid]["ground_truth"][0]["input"] != "[PAD]"}
        data = basics | per_question_mean | per_question_min | per_question_max

        logger.log(data=data, step=step)
    
    def add_new_questions(self, step: int, logger, num_replicas: int = 1):
        """
        Add new questions to the dataset at the specified step.
        Can only add new questions (completely new and from buffer) if the dataset is empty.
        
        Args:
            model: The model to use for question generation
            step: Current training step
        """

        past_subgoals_sample_set = []
        
        # Only clear the dataset when we're actually going to add new questions
        self.current_dataset = ray.get(self.dataset_ref)
        print(f"starting current_dataset length: {len(self.current_dataset)}")
        self.current_dataset = []

        print(f"historical_data_ref at step {step} in add_new_questions: {self.historical_data_ref}")
        historical_data = ray.get(self.historical_data_ref)
        self.historical_data_ref = ray.put(historical_data)

        # If using buffer, need to add it to current_dataset
        buffer = []
        if self.dynamic_config.get("use_replay_buffer", False):
            min_threshold = self.dynamic_config.get("replay_buffer_min", 0.1)
            max_threshold = self.dynamic_config.get("replay_buffer_max", 0.8)

            for uid, subgoal in historical_data["subgoals"].items():

                # Add to buffer
                if (subgoal["mean_reward"] >= min_threshold) and (subgoal["mean_reward"] <= max_threshold) and (step - subgoal["generated_at_step"] < self.dynamic_config.get("replay_buffer_max_steps", 10)):
                    formatted_subgoal = self.proposer._format_question(index=0, extracted_subgoal=subgoal["subgoal"], extracted_input=subgoal["example_inputs"]+[x["input"] for x in subgoal["ground_truth"]], computed_output=subgoal["example_outputs"]+[x["output"] for x in subgoal["ground_truth"]], question_id=uid, generated_at_step=subgoal["generated_at_step"])
                    assert formatted_subgoal["extra_info"]["question_id"] == uid, "Question uid should be the same"
                    buffer.append(formatted_subgoal)
                
                # Create set of all other subgoals we could sample from for added diversity
                elif (subgoal["mean_reward"] != 0.0) and (subgoal["mean_reward"] != 1.0):
                    formatted_subgoal = self.proposer._format_question(index=0, extracted_subgoal=subgoal["subgoal"], extracted_input=subgoal["example_inputs"]+[x["input"] for x in subgoal["ground_truth"]], computed_output=subgoal["example_outputs"]+[x["output"] for x in subgoal["ground_truth"]], question_id=uid, generated_at_step=subgoal["generated_at_step"])
                    assert formatted_subgoal["extra_info"]["question_id"] == uid, "Question uid should be the same"
                    past_subgoals_sample_set.append(formatted_subgoal)
            
            self.current_dataset.extend(buffer)
            self.dataset_ref = ray.put(self.current_dataset)
            print(f"Added {len(buffer)} questions from buffer to current_dataset. Total size of current_dataset: {len(self.current_dataset)}")

        generate_new_questions = step % self.add_questions_every_n_steps == 0
        num_questions_left_to_generate = self.num_new_questions_per_step if generate_new_questions else 0

        # Sample past subgoals up to the batch size
        sampled_subgoals = []
        if self.dynamic_config.get("sample_past_subgoals", False):
            num_questions_to_sample = self.config.train_batch_size - len(self.current_dataset) - num_questions_left_to_generate
            # We will sample questions and there are more than enough past subgoals to sample from
            if (num_questions_to_sample > 0) and (len(past_subgoals_sample_set) > 0):
                if len(past_subgoals_sample_set) > num_questions_to_sample:
                    sampled_subgoals = self.sample_subgoals(num_subgoals=num_questions_to_sample, subgoals_sample_set=past_subgoals_sample_set, sample_method=self.dynamic_config.get("past_subgoals_sample_method", "uniform"))
                if len(past_subgoals_sample_set) <= num_questions_to_sample:
                    sampled_subgoals = past_subgoals_sample_set
                self.current_dataset.extend(sampled_subgoals)
                self.dataset_ref = ray.put(self.current_dataset)
                print(f"Added {len(sampled_subgoals)} questions from past_subgoals_sample_set to current_dataset. Total size of current_dataset: {len(self.current_dataset)}")
        
        if not generate_new_questions:
            return

        added_questions = []
        log_output = []
        total_num_questions_generated = 0

        # Retrieve context on previous subgoals to inform the proposer on what the next subgoal should bes
        if step > 1:
            proposer_context, proposer_context_logs = self.proposer._create_proposer_context(historical_data=historical_data, method=self.proposer_context_method)

        while num_questions_left_to_generate > 0:
            print(f"Starting to generate {num_questions_left_to_generate} questions at step {step}")

            total_num_questions_generated += num_questions_left_to_generate
        
            if step == 1:
                new_questions, logs = self.proposer.propose_initial_subgoals(final_question=self.final_question, num_subgoals=num_questions_left_to_generate, max_tokens=self.proposer_max_tokens)
                print(f"Generating {num_questions_left_to_generate} initial subgoals at step {step}")
            else:
                new_questions, logs = self.proposer.propose_subgoals(num_subgoals=num_questions_left_to_generate, historical_data=historical_data, proposer_context=proposer_context, proposer_context_logs=proposer_context_logs, max_tokens=self.proposer_max_tokens)
                print(f"Generating {num_questions_left_to_generate} subgoals at step {step}")
            
            for q, log in zip(new_questions, logs):
                if self._validate_subgoal(q):
                    added_questions.append(q)
                    log_output.append(log)
                    num_questions_left_to_generate -= 1
                else:
                    print(f"Invalid generated subgoal: {q}")
        
        # Create num_replicas number of duplicates of the added questions
        added_questions = added_questions * num_replicas
        taken_indices = set([0])
        for q in added_questions:
            if q["extra_info"]["index"] in taken_indices:
                q["extra_info"]["index"] = max(taken_indices) + 1
            q["extra_info"]["generated_at_step"] = step # Set the generation step
            taken_indices.add(q["extra_info"]["index"])
                
        # Add to generated questions list
        self.current_dataset.extend(added_questions)

        # Monkeypatch the added questions to have a size divisible by 2
        if len(self.current_dataset) % self.config.train_batch_size != 0:
            n_padding = self.config.train_batch_size - (len(self.current_dataset) % self.config.train_batch_size)
            self.current_dataset.extend(n_padding*[{
                "prompt": [{"role": "user", "content": "[PAD]"}],
                "data_source": "dynamic_batch_padding",
                "ability": "coding",
                "reward_model": {
                    "style": "model",
                    "ground_truth": [{'input': '[PAD]', 'output': '[PAD]'} for _ in range(self.num_test_cases - 1)],  # Will be filled by the model during training
                },
                "extra_info": {
                    "split": "dynamic",
                    "index": 0,
                    "generated_at_step": -1,  # Will be set by the dataset
                    "question_id": str(uuid.uuid4()),
                    "raw_subgoal": "[PAD]"
                }
            }])
            print(f"Added {n_padding} padding questions to current_dataset")

        print(f"len(current_dataset) after adding new questions: {len(self.current_dataset)}")
        print(f"current_dataset ref original: {self.dataset_ref}")
        self.dataset_ref = ray.put(self.current_dataset)

        print(f"Added {len(added_questions)} new questions. Total size of self.current_dataset: {len(self.current_dataset)}")

        # Log inputs/outputs of the proposer to json
        # TODO: Fix self.config to point to base config and not config.data, and then use that to get trainer.experiment name instead
        with open(f"/home/ubuntu/ethakira/verl/experiments/proposer_dumps/{self.config.experiment_name}/{step}.json", "w") as f:
            json.dump(log_output, f, indent=4)

        # Run the base solver model on the new questions (pass@10)
        from verl.utils.reward_score.code_with_tests_all import compute_score
        load_dotenv()
        client = Together(api_key=os.getenv("TOGETHER_API_KEY"))
        scores = []
        mean_at_n = self.config.hindsight_val_n
        for question in self.current_dataset:
            if question["data_source"] == "dynamic_batch_padding":
                continue
            answer = query_together(
                client=client,
                messages=question["prompt"],
                model="Qwen/Qwen2.5-7B-Instruct-Turbo",
                max_tokens=self.config.max_response_length,
                temperature=self.config.hindsight_val_temperature,
                top_p=self.config.hindsight_val_top_p,
                top_k=self.config.hindsight_val_top_k,
                n=mean_at_n,
            )
            for i in range(mean_at_n):
                score = compute_score(answer.choices[i].message.content, question["reward_model"]["ground_truth"], "dynamic_generated", question["extra_info"])
                scores.append(score["score"])

        hindsight_from_0 = sum(scores) / len(scores)
        
        proposer_prompt_length = torch.Tensor([self.tokenizer.batch_encode_plus([log["input"][1]["content"]], return_tensors="pt", add_special_tokens=False)["input_ids"].shape[-1] for log in log_output])
        proposer_response_lengths_part1 = torch.Tensor([self.tokenizer.batch_encode_plus([log["output_step1"]], return_tensors="pt", add_special_tokens=False)["input_ids"].shape[-1] for log in log_output])
        #proposer_response_lengths_part2 = torch.Tensor([self.tokenizer.batch_encode_plus([log["output_step2"]], return_tensors="pt", add_special_tokens=False)["input_ids"].shape[-1] for log in log_output])
        proposer_subgoals_lengths = torch.Tensor([self.tokenizer.batch_encode_plus([log["extracted_subgoal"]["prompt"][0]["content"]], return_tensors="pt", add_special_tokens=False)["input_ids"].shape[-1] for log in log_output])

        proposer_prompt_length_mean = torch.mean(torch.Tensor([proposer_prompt_length]))
        proposer_prompt_length_min = torch.min(proposer_prompt_length)
        proposer_prompt_length_max = torch.max(proposer_prompt_length)

        proposer_response_length_part1_mean = torch.mean(proposer_response_lengths_part1)
        proposer_response_length_part1_min = torch.min(proposer_response_lengths_part1)
        proposer_response_length_part1_max = torch.max(proposer_response_lengths_part1)

        # proposer_response_length_part2_mean = torch.mean(proposer_response_lengths_part2)
        # proposer_response_length_part2_min = torch.min(proposer_response_lengths_part2)
        # proposer_response_length_part2_max = torch.max(proposer_response_lengths_part2)

        proposer_subgoals_length_mean = torch.mean(proposer_subgoals_lengths)
        proposer_subgoals_length_min = torch.min(proposer_subgoals_lengths)
        proposer_subgoals_length_max = torch.max(proposer_subgoals_lengths)
        
        logger.log(data={
            "buffer/num_from_past_samples": len(sampled_subgoals),
            "buffer/num_from_buffer": len(buffer),
            "buffer/old_subgoal_sample_set_size": len(past_subgoals_sample_set),
            
            "subgoal_generation/num_subgoals_generated": total_num_questions_generated,
            "subgoal_generation/num_subgoals_valid": len(added_questions),

            "prompt_length/proposer/mean": float(proposer_prompt_length_mean),
            "prompt_length/proposer/min": float(proposer_prompt_length_min),
            "prompt_length/proposer/max": float(proposer_prompt_length_max),

            "response_length/proposer/part1/mean": float(proposer_response_length_part1_mean),
            "response_length/proposer/part1/min": float(proposer_response_length_part1_min),
            "response_length/proposer/part1/max": float(proposer_response_length_part1_max),

            # "response_length/proposer/part2/mean": float(proposer_response_length_part2_mean),
            # "response_length/proposer/part2/min": float(proposer_response_length_part2_min),
            # "response_length/proposer/part2/max": float(proposer_response_length_part2_max),

            "response_length/proposer/subgoal/mean": float(proposer_subgoals_length_mean),
            "response_length/proposer/subgoal/min": float(proposer_subgoals_length_min),
            "response_length/proposer/subgoal/max": float(proposer_subgoals_length_max),

            f"val/hindsight_from_0/mean@{mean_at_n}": hindsight_from_0,
        }, step=step)
        
    def __len__(self):
        """Return the length of the current dataset."""
        self.current_dataset = ray.get(self.dataset_ref)
        self.dataset_ref = ray.put(self.current_dataset)
        return len(self.current_dataset)
    
    def _build_messages(self, example: dict):
        messages: list = example.pop(self.prompt_key)
        #print(f"messages before: {messages}")
        #breakpoint()

        #for message in messages:
        #    content = message["content"]
        #    content_list = [] # TODO: Add support for multi-modal data
        #    content_list.append({"type": "text", "text": content})

        #    message["content"] = content_list

        return messages
    
    def __getitem__(self, item):
        """
        Get an item from the current dataset.
        
        Args:
            item: Index of the item to retrieve
            
        Returns:
            The dataset item
        """
        current_dataset = ray.get(self.dataset_ref)
        if len(current_dataset) == 0:
            raise RuntimeError("Failed to generate new questions. Dataset is still empty.")

        row_dict = dict(current_dataset[item])
        
        # Process using parent class method
        messages = self._build_messages(row_dict)
        #print(f"messages: {messages}")
        #breakpoint()
        model_inputs = {}
        
        if self.processor is not None:            
            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            multi_modal_data = {}

            images = None
            videos = None
            
            model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")
            
            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")
            
            if "second_per_grid_ts" in model_inputs:
                model_inputs.pop("second_per_grid_ts")
            
            row_dict["multi_modal_data"] = multi_modal_data
            row_dict["multi_modal_inputs"] = dict(model_inputs)
            row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)
            
        else:
            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
            input_ids = model_inputs.pop("input_ids")
            attention_mask = model_inputs.pop("attention_mask")
        
        input_ids, attention_mask = verl_F.postprocess_data(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_prompt_length,
            pad_token_id=self.tokenizer.pad_token_id,
            left_pad=True,
            truncation=self.truncation,
        )
        
        if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
            from verl.models.transformers.qwen2_vl import get_rope_index
            
            position_ids = [
                get_rope_index(
                    self.processor,
                    input_ids=input_ids[0],
                    image_grid_thw=model_inputs.get("image_grid_thw"),
                    video_grid_thw=model_inputs.get("video_grid_thw"),
                    second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
                    attention_mask=attention_mask[0],
                )
            ]
            row_dict["position_ids"] = torch.tensor(position_ids)
        else:
            from verl.utils.model import compute_position_id_with_mask
            position_ids = compute_position_id_with_mask(attention_mask)
            row_dict["position_ids"] = position_ids[0]
        
        row_dict["input_ids"] = input_ids[0]
        row_dict["attention_mask"] = attention_mask[0]
        
        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
        if len(raw_prompt_ids) > self.max_prompt_length:
            if self.truncation == "left":
                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
            elif self.truncation == "right":
                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
            elif self.truncation == "middle":
                left_half = self.max_prompt_length // 2
                right_half = self.max_prompt_length - left_half
                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
            elif self.truncation == "error":
                raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
        
        row_dict["raw_prompt_ids"] = raw_prompt_ids
        
        if self.return_raw_chat:
            row_dict["raw_prompt"] = messages
        
        if self.return_full_prompt:
            row_dict["full_prompts"] = raw_prompt
        
        # add index for each prompt
        index = row_dict.get("extra_info", {}).get("index", 0)
        tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
        interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
        need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
        if need_tools_kwargs and not tools_kwargs:
            logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
        row_dict["index"] = index
        row_dict["tools_kwargs"] = tools_kwargs
        row_dict["interaction_kwargs"] = interaction_kwargs
        
        return row_dict
    
    def __getstate__(self):
        if not self.serialize_dataset:
            state = self.__dict__.copy()

            if "dataframe" in state:
                del state["dataframe"]
            return state

        return self.__dict__.copy()