from typing import Any, Dict, Tuple, Optional, List
import copy
import re
import random
import hashlib
import json

from rllm.environments.base.base_env import BaseEnv
import reasoning_gym


class ReasoningGymEnv(BaseEnv):
    """
    Multi-turn environment for Reasoning Gym tasks.
    Provides an environment where an LLM solves Reasoning Gym tasks through actions.
    """
    def __init__(
        self, 
        task: dict,
    ):
        """
        Initialize the Reasoning Gym (single-turn) environment.

        Args:
            task: Task dictionary just containing the config, seed info and index of the data item.
                - We will initialize the environment from the data item later.
        """
        self.task = task
        
        self._initialize_from_task(task)

    
    @staticmethod
    def data_item_to_hash(data_item: dict) -> str:
        return hashlib.md5(json.dumps(data_item, sort_keys=True).encode()).hexdigest()

    def _initialize_from_task(self, task: dict):
        """Initialize the environment from a task dictionary."""
        self.data = reasoning_gym.create_dataset(task['name'], size=task['size'], seed=task['seed'], **json.loads(task['config']))
        self.data_item = self.data[task['idx']]
        self.question = self.data_item['question']
        self.answer = self.data_item['answer']
        self.metadata = self.data_item['metadata']
        
        # data integrity check
        assert self.data_item_to_hash(self.data_item) == task['hash']

    def reset(self):
        """Reset the environment to the initial state."""
        self.done = False
        self.current_turn = 0
        self.history = []

        self.env_message = None
        self.termination_reason = None
        self.progress = 0.0
        
        # Return the first observation
        observation = self._get_observation()
        info = self._get_info()
        
        return observation, info

    def _get_observation(self) -> dict:
        """Return the observation of the current environment state."""
        observation = {
            "question": self.question,
            "metadata": self.metadata,
            "answer": self.answer,
            "done": self.done,
            "task_name": self.task['name'],
        }
        
        return observation

    def _get_info(self) -> dict:
        """Return additional information."""
        return {}

    def step(self, action: Any) -> tuple[dict, float, bool, dict]:
        """
        Take a step in the environment based on the action.
        """
        # Store the action in history
        self.history.append(action)

        # Calculate reward for the current turn using the abstract method
        assert self.task is not None, "Task is not set"

        # Increment turn counter
        self.current_turn += 1
        
        self.done = True  # single turn
        
        next_obs = self._get_observation()
        reward = self.data.score_answer(answer=self._parse_action(action), entry=self.data_item)

        return next_obs, reward, self.done, self.task

    def _parse_action(self, action: Any) -> str:
        """Parse the action into a string."""
        return action.split('```')[-2].strip()
        
    @staticmethod
    def from_dict(env_args: dict) -> "ReasoningGymEnv":
        """Generate a SudokuMultiTurnEnvironment from a dictionary."""
        return ReasoningGymEnv(task=env_args)

    # 
