from abc import abstractmethod
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor
import random
import time
from typing import List, Dict, Sequence, Any, Union, Tuple

import torch
from datasets import Dataset
from pydantic import BaseModel

from vllm import LLM, SamplingParams
# from ..inference.vllm_client import VLLMClient
from open_r1.inference.vllm_client import VLLMClient

from .multiturn_env import MultiTurnEnv, dict_to_chat_response, ChatResponse, ChatResponseItem, ChatOutput
from ..utils import format_prompt_for_reverse_cot, select_rule_indices, format_dataset_for_reverse_cot


class AugmentedQuestionEnv(MultiTurnEnv):
    """
    A multi-turn environment for handling augmented questions before answering the original question.
    
    This environment implements a specific multi-turn chat flow:
    1. First, the model is asked to solve a series of augmented questions
    2. After answering all augmented questions, the model is asked to answer the original question
    
    The environment tracks the conversation state and determines when the conversation is complete.
    """
    
    def __init__(self,
                 dataset: Dataset | None = None,
                 eval_dataset: Dataset | None = None,
                 support_system_prompt: bool = True,
                 system_prompt: str = "",
                 few_shot: List[Dict[str, str]] = [],
                 task: str = "",
                 answer_key: str = "answer",
                 cot_response_key: str = "cot_solution",
                 total_rules: int = 10,
                 sampling_args: Dict[str, Any] = {},
                 mask_env_response: bool = True,
                 max_workers: int = 10,
                 max_steps: int = 10,
                 max_tokens: int = 4096,
                 sleep_time: float = 1.0,
                 eot_id: int = 151643,
                 message_end_id: int = 151645,
                 message_end_newline_id: int | None = None,
                 filter_easy_or_difficult_samples: bool = True,
                 **kwargs):
        """
        Initialize the AugmentedQuestionEnv.
        
        Args:
            dataset: The dataset to use for training
            eval_dataset: The dataset to use for evaluation
            system_prompt: The system prompt to use
            few_shot: Few-shot examples to use
            sampling_args: Arguments for sampling
            mask_env_response: Whether to mask environment responses
            max_workers: Maximum number of workers for parallel processing
            max_steps: Maximum number of steps in a conversation
            sleep_time: Time to sleep between API calls
            **kwargs: Additional arguments
        """
        super().__init__(
            dataset=None,
            eval_dataset=None,
            system_prompt=system_prompt,
            few_shot=few_shot,
            sampling_args=sampling_args,
            mask_env_response=mask_env_response,
            max_workers=max_workers,
            max_steps=max_steps,
            sleep_time=sleep_time,
            **kwargs
        )
        self.task = task
        self.answer_key = answer_key
        self.cot_response_key = cot_response_key
        self.total_rules = total_rules
        self.max_tokens = max_tokens                # |o_i| in the GRPO paper, the maximum number of tokens in the completion
        self.eot_id = eot_id
        self.message_end_id = message_end_id
        self.message_end_newline_id = message_end_newline_id
        self.filter_easy_or_difficult_samples = filter_easy_or_difficult_samples
        self.support_system_prompt = support_system_prompt
        if dataset is not None:
            self.dataset = self._format_dataset(dataset)
        if eval_dataset is not None:
            self.eval_dataset = self._format_dataset(eval_dataset)
    
    def _filter_dataset_by_difficulty(self, dataset: Dataset) -> Dataset:
        """
        Filter out samples that are too easy (sample_correct_num=8) or too difficult (sample_correct_num=0).
        These samples would have zero advantage in GRPO training as all samples in the group would have
        the same reward, leading to zero gradients.
        
        Args:
            dataset: The dataset to filter
            
        Returns:
            The filtered dataset
        """
        print(f"Dataset size before filtering: {len(dataset)}")
        
        def has_appropriate_difficulty(example):
            # Check if the example has the metadata field
            if "metadata" not in example:
                return True  # Keep examples without metadata
            
            metadata = example["metadata"]
            # Check if metadata has sample_correct_num field
            if "sample_correct_num" not in metadata:
                return True  # Keep examples without sample_correct_num
            
            # Filter out too easy (8/8) or too difficult (0/8) examples
            sample_correct_num = metadata["sample_correct_num"]
            return 0 < sample_correct_num < 8
        
        filtered_dataset = dataset.filter(has_appropriate_difficulty)
        print(f"Dataset size after filtering: {len(filtered_dataset)}")
        return filtered_dataset
    
    def _format_dataset(self, dataset: Dataset) -> Dataset:
        """
        Format the dataset to include augmented questions and the original question.
        First filters out samples that are too easy or too difficult based on sample_correct_num.
        
        Args:
            dataset: The dataset to format
            
        Returns:
            The formatted dataset
        """
        # First, filter out samples that are too easy or too difficult
        if self.filter_easy_or_difficult_samples:
            filtered_dataset = self._filter_dataset_by_difficulty(dataset)
        else:
            filtered_dataset = dataset
        
        # Use the format_dataset utility from the parent class
        # This will formulate the first turn of the conversation, which contains the system prompt (if any) 
        # and the first augmented question, in the format of chat messages, and save in the field "prompt"
        return format_dataset_for_reverse_cot(
            dataset=filtered_dataset,
            support_system_prompt=self.support_system_prompt,
            system_prompt=self.system_prompt,
            task=self.task,
            answer_key=self.answer_key,
            turn_idx=0,
            max_turns=self.max_steps,
            total_rules=self.total_rules,
            rule_indices=None
        )
    

    def step(self,
             states: List[Dict[str, Any]],
             metadata: List[Dict[str, Any]],
             rule_indices: List[List[int]],
             llm: LLM | VLLMClient,
             sampling_params: SamplingParams) -> List[Dict[str, Any]]:
        
        # Get the indices of the prompts that are not completed
        live_indices = [i for i, s in enumerate(states) if not s["completed"]]
        messages_to_step = [states[i]["messages"] for i in live_indices]

        if isinstance(llm, VLLMClient):
            # VLLMClient
            # max_tokens that is used here is the max_completion_length_each_turn when we sampling as a multi-turn env
            # This will help us to avoid some extreme cases where in one turn the model generates a completion that is too long
            llm_responses = llm.chat(
                messages_to_step,
                n=1,
                repetition_penalty=sampling_params.repetition_penalty,
                temperature=sampling_params.temperature,
                top_p=sampling_params.top_p,
                top_k=sampling_params.top_k,
                min_p=sampling_params.min_p,
                max_tokens=sampling_params.max_tokens,
                stop=sampling_params.stop,
                include_stop_str_in_output=sampling_params.include_stop_str_in_output,
                skip_special_tokens=sampling_params.skip_special_tokens,
                spaces_between_special_tokens=sampling_params.spaces_between_special_tokens
            )
            llm_responses = dict_to_chat_response(llm_responses).responses
        else:
            # vLLM
            llm_responses = llm.chat(messages_to_step, sampling_params=sampling_params, use_tqdm=False)


        def update_state(j, llm_response):
            # sleep for 0-1 seconds to avoid rate limiting
            time.sleep(self.sleep_time * random.random())

            state = deepcopy(states[j])
            if len(state["prompt_ids"]) == 0:
                # Only add the prompt token ids if it is not set yet
                state["prompt_ids"] = llm_response.prompt_token_ids         # This is the first prompt's token ids
            state["messages"].append({"role": "assistant", "content": llm_response.outputs[0].text})
        
            ### Get token lengths of env response and new completion
            # This is the combination of the first prompt and completions (including the env response) up to the previous step
            total_prev_len = len(state["prompt_ids"]) + len(state["completion_ids"])
            # The length of the env response in the current step
            env_response_len  = len(list(llm_response.prompt_token_ids)) - total_prev_len
            # The length of the new completion in the current step
            new_completion_len = len(llm_response.outputs[0].token_ids)

            ### Update completion masks
            # Environment response is masked if env_mask is True (generally we should mask the env response)
            state["completion_mask"].extend([self.env_mask] * env_response_len)
            # The new completion at this step is not masked
            state["completion_mask"].extend([1] * new_completion_len)

            ### Update completion ids
            state["completion_ids"] = list(llm_response.prompt_token_ids)                   # Token ids of the current prompt, including the env response of this step
            state["completion_ids"].extend(list(llm_response.outputs[0].token_ids))         # The token ids of the new completion at this step
            state["completion_ids"] = state["completion_ids"][len(state["prompt_ids"]):]    # Remove the first prompt's tokens

            # TODO: This is questionable, do we need to do this?
            # For Qwen2Tokenizer and the corresponding chat template, we will have `<|im_end|>\n` at the end of each assistant message.
            # For GemmaTokenizer and the corresponding chat template, we will have `<end_of_turn>\n` at the end of each assistant message.
            # For the cases that we are doing early stopping, we need to append these 2 tokens to the completion.
            if self.message_end_newline_id is not None:
                if state["completion_ids"][-1] != self.message_end_newline_id and state["completion_ids"][-2] != self.message_end_id:
                    state["completion_ids"].append(self.message_end_id)
                    state["completion_ids"].append(self.message_end_newline_id)
                    state["completion_mask"].append(1)
                    state["completion_mask"].append(1)
            else:
                if state["completion_ids"][-1] != self.message_end_id:
                    state["completion_ids"].append(self.message_end_id)
                    state["completion_mask"].append(1)

            # Make sure the completion mask and completion ids are the same length, generally we shouldn't need to do this.
            if len(state["completion_ids"]) > len(state["completion_mask"]):
                state["completion_mask"].extend([1] * (len(state["completion_ids"]) - len(state["completion_mask"])))
            if len(state["completion_mask"]) > len(state["completion_ids"]):
                state["completion_mask"] = state["completion_mask"][:len(state["completion_ids"])]
            
            # if self.is_completed(state["messages"]) or len(state["completion_ids"]) > sampling_params.max_tokens - 1:
            #     if len(state["completion_ids"]) > sampling_params.max_tokens - 1:
            #         print(f"++++++ max tokens ({sampling_params.max_tokens}) reached, we have {len(state['messages'])} messages ++++++")
            #         print(state["messages"])
            #     state["completed"] = True
            #     state["completion_ids"] = state["completion_ids"][:sampling_params.max_tokens]          # Truncate the completion ids to the max generation length
            #     state["completion_mask"] = state["completion_mask"][:len(state["completion_ids"])]      # Truncate the completion mask to the max generation length
            # else:
            #     state["messages"].append(self.env_response(state["messages"],
            #                                                metadata[j],
            #                                                rule_indices[j]))
            
            if self.is_completed(state["messages"]) or len(state["completion_ids"]) > self.max_tokens - 1:
                if len(state["completion_ids"]) > self.max_tokens - 1:
                    print(f"++++++ max tokens ({self.max_tokens}) reached, we have {len(state['messages'])} messages ++++++")
                    print(state["messages"])
                state["completed"] = True
                state["completion_ids"] = state["completion_ids"][:self.max_tokens]                     # Truncate the completion ids to the max generation length
                state["completion_mask"] = state["completion_mask"][:len(state["completion_ids"])]      # Truncate the completion mask to the max generation length
            else:
                state["messages"].append(self.env_response(state["messages"],
                                                           metadata[j],
                                                           rule_indices[j]))

            ### Enforce that the completion mask and completion ids are the same length
            # weird bug that happens rarely and only for certain models; something tokenizer related :(
            if not len(state["completion_mask"]) == len(state["completion_ids"]):
                print(state["messages"])
                print(state["completion_mask"])
                print(state["completion_ids"])
                min_len = min(len(state["completion_mask"]), len(state["completion_ids"]))
                state["completion_mask"] = state["completion_mask"][:min_len]
                state["completion_ids"] = state["completion_ids"][:min_len]

            return j, state

        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            results = list(executor.map(
                lambda args: update_state(*args),
                [(j, llm_responses[i]) for i, j in enumerate(live_indices)]
            ))

        for j, state in results:
            states[j] = state

        return states
    
    def generate(self, prompts: List[List[Dict[str, Any]]],
                 metadata: List[Dict[str, Any]],
                 rule_indices: List[List[int]],
                 llm: LLM | VLLMClient,  
                 sampling_params: SamplingParams,
                 **kwargs: Any) -> Dict[str, List[Sequence[int]] | List[str] |  List[List[Dict[str, Any]]]]:
        
        # Update the sampling params with the user-specified arguments
        custom_sp = sampling_params.clone()
        for k, v in self.sampling_args.items():
            setattr(custom_sp, k, v)

        # Initialize state variables, state is used to track each prompt's chat history
        all_completed = False
        states = [{
            "messages": m,                      # The chat history for this prompt, initialized with the starting prompt (may contain system prompt and few-shot examples and user prompt)
            "prompt_messages": len(m),          # The number of chat turns belonging to the initial / first prompt
            "prompt_ids": [],
            "completed": False,                 # Whether the multi-turn chat has terminated
            "completion_ids": [],
            "completion_mask": [],
        } for m in prompts]

        # Main loop for on-policy generation, continue until all prompts terminate
        while not all_completed:
            states = self.step(states, metadata, rule_indices, llm, custom_sp)
            all_completed = all(state["completed"] for state in states)

        completion_messages = [s["messages"][s["prompt_messages"]:] for s in states]
        completion_ids = [s["completion_ids"] for s in states]
        completion_mask = [s["completion_mask"] for s in states]
        output = {
            "ids": completion_ids,
            "messages": completion_messages,
            "mask": completion_mask
        }

        # # print([s["messages"] for s in states])
        # print("Number of prompts: ", len(states))
        # print("++++++ whole messages length ++++++")
        # print([len(s["messages"]) for s in states])
        # print("++++++ prompt messages length ++++++")
        # print([s["prompt_messages"] for s in states])
        # print("++++++ whole messages++++++")
        # print([s["messages"] for s in states])
        # print("++++++ completion messages ++++++")
        # print(completion_messages)

        return output
    
    def generate_mix_policy(self, prompts: List[List[Dict[str, Any]]],
                            metadata: List[Dict[str, Any]],
                            rule_indices: List[List[int]],
                            **kwargs: Any) -> Dict[str, List[Sequence[int]] | List[str] |  List[List[Dict[str, Any]]]]:
        """
        Generate the mixed policy multi-turn completions for the given prompts.
        For mixed policy, we will use the on-policy's rollout before the current turn and the off-policy's response for the current turn.
        """
        pass

    def generate_luffy_policy(self, prompts: List[List[Dict[str, Any]]],
                              metadata: List[Dict[str, Any]],
                              rule_indices: List[List[int]],
                              **kwargs: Any) -> Dict[str, List[Sequence[int]] | List[str] |  List[List[Dict[str, Any]]]]:
        """
        Generate the Luffy policy multi-turn completions for the given prompts.
        LUFFY computes the advantage of each turn using both on-policy and off-policy rollouts.
        Ref: https://github.com/ElliottYan/LUFFY
        """
        pass


    def generate_off_policy(self, prompts: List[List[Dict[str, Any]]],
                            metadata: List[Dict[str, Any]],
                            rule_indices: List[List[int]],
                            **kwargs: Any) -> Dict[str, List[Sequence[int]] | List[str] |  List[List[Dict[str, Any]]]]:
        """
        Generate the off-policy multi-turn completions for the given prompts.
        The off-policy generated completions are pre-generated/collected using a different policy (e.g., usually the Teacher model).
        """
        off_policy_turns_batch = []
        for example_idx, example_meta in enumerate(metadata):
            current_messages = []
            sample_off_policy_turns = []
            if self.system_prompt:
                current_messages.append({"role": "system", "content": self.system_prompt})

            # Add augmented turns
            for turn_idx in range(self.max_steps):
                history = deepcopy(current_messages)
                # Get user message (augmented question)
                user_msg_dict = format_prompt_for_reverse_cot(
                    example=example_meta,
                    support_system_prompt=self.support_system_prompt,
                    system_prompt=self.system_prompt, # Pass system prompt if needed by formatter
                    task=self.task,
                    turn_idx=turn_idx, 
                    max_turns=self.max_steps,
                    total_rules=self.total_rules,
                    rule_indices=rule_indices[example_idx]
                )[-1]
                # Get the user message content, which is the augmented question
                user_content = user_msg_dict["content"]
                current_messages.append(user_msg_dict)

                # Get assistant message (off policy / ground truth solution)
                assistant_content = example_meta["augmented_solutions"][turn_idx]
                current_messages.append({"role": "assistant", "content": assistant_content})

                sample_off_policy_turns.append({
                    "history": history,                     # The chat history before the current turn
                    "user": user_content,                   # The (augmented) question for the current turn
                    "assistant": assistant_content          # The off policy / ground truth solution for the current turn
                })

            # Add final turn (original question)
            history = deepcopy(current_messages)
            final_user_msg_dict = format_prompt_for_reverse_cot(
                example=example_meta,
                support_system_prompt=self.support_system_prompt,
                system_prompt=self.system_prompt, # Pass system prompt if needed by formatter
                task=self.task,
                turn_idx=self.max_steps, # Index for the final question
                max_turns=self.max_steps,
                total_rules=self.total_rules,
                rule_indices=rule_indices[example_idx]
            )[-1]
            # Get the user message content, which is the original question
            final_user_content = final_user_msg_dict["content"]
            current_messages.append(final_user_msg_dict)

            final_assistant_content = example_meta[self.cot_response_key] # Use cot_response_key
            current_messages.append({"role": "assistant", "content": final_assistant_content})

            sample_off_policy_turns.append({
                "history": history,
                "user": final_user_content,
                "assistant": final_assistant_content
            })
            off_policy_turns_batch.append(sample_off_policy_turns)

        return off_policy_turns_batch


    def is_completed(self, messages: List[Dict[str, str]], **kwargs: Any) -> bool:
        """
        Determine if the conversation is complete.
        
        The conversation is complete when:
        1. The model has answered all augmented questions and the original question
        2. The last message is from the assistant (model)
        
        Args:
            messages: The conversation history
            **kwargs: Additional arguments
            
        Returns:
            True if the conversation is complete, False otherwise
        """
        # Check if we have at least one message
        if not messages:
            return False
        
        # Check if the last message is from the assistant
        if messages[-1]["role"] != "assistant":
            return False
        
        # Check if we've reached the maximum number of steps
        # If we have answered all the augmented questions and also the original question, we are done
        # This requires 2 * (self.max_steps+1) messages (user and assistant messages)
        # we should only check the turns which are user and assistant but not the system messages
        # if len(messages) > self.max_steps * 2:
        #     return True
        
        # Filter out only user and assistant messages
        non_system_messages = [msg for msg in messages if msg["role"] in ("user", "assistant")]

        # Check if we've hit the max number of steps (each step is 1 user + 1 assistant message)
        if len(non_system_messages) >= (self.max_steps + 1) * 2:
            return True
        
        # # Check if the last user message was the original question
        # # This assumes the original question is the last user message
        # for i in range(len(messages) - 1, -1, -1):
        #     if messages[i]["role"] == "user":
        #         # Check if this is the original question (not an augmented question)
        #         content = messages[i]["content"].lower()
        #         if "let's return to the original problem" in content:
        #             return True
        #         break
        
        return False
    
    def env_response(self, messages: List[Dict[str, str]],
                     metadata: Dict[str, Any],
                     rule_indices: List[int],
                     **kwargs: Any) -> Dict[str, str]:
        """
        Generate the environment response based on the conversation history.
        
        The environment response depends on the current state of the conversation:
        1. If we're at the beginning, ask the first augmented question
        2. If we've answered an augmented question, ask the next one
        3. If we've answered all augmented questions, ask the original question
        
        Args:
            messages: The conversation history
            **kwargs: Additional arguments
            
        Returns:
            The environment response
        """
        # Fine the turn index based on the number of chat turns in the messages
        turn_idx = len(messages) // 2
        
        env_message = format_prompt_for_reverse_cot(
            example=metadata,
            support_system_prompt=self.support_system_prompt,
            system_prompt=self.system_prompt,
            task=self.task,
            turn_idx=turn_idx,
            max_turns=self.max_steps,
            total_rules=self.total_rules,
            rule_indices=rule_indices
        )
        
        # return the last user message, this will avoid the case where we have a system message in the env response
        return env_message[-1]
    
    
    def get_reward_funcs(self, **kwargs: Any) -> List[Any]:
        """
        Get the reward functions for this environment.
        
        Args:
            **kwargs: Additional arguments
            
        Returns:
            The reward functions
        """
        return self.reward_funcs
    
    def get_reward_weights(self, **kwargs: Any) -> List[float]:
        """
        Get the reward weights for this environment.
        
        Args:
            **kwargs: Additional arguments
            
        Returns:
            The reward weights
        """
        return self.reward_weights 