import warnings
import time
import uuid
import tempfile
import shutil
import os

import ray
from appworld import AppWorld, load_task_ids
from appworld.apps.api_lib import set_local_dbs
from appworld.apps.model_lib import get_db_home_path

from rllm.environments.base.base_env import BaseEnv
from rllm.environments.base.multi_turn_env import MultiTurnEnvironment
from rllm.rewards.reward_fn import RewardFunction, zero_reward


@ray.remote
class AppWorldActor:
    def __init__(self, task):
        self.task_id = task['task_id']
        self.max_steps = task['max_steps']
        self.reward_fn = task.get('reward_fn', zero_reward)
        self.step_count = 0
        self.env = None
        
        # Create unique instance ID and temp directory for this actor
        self.instance_id = str(uuid.uuid4())
        self.temp_db_dir = "/dev/shm"  # Use RAM disk for better performance
        self.temp_dir_path = None
        
        self.task_info = task

    def reset(self):
        if self.env is not None:
            try:
                self.env.close()
            except:
                pass
        
        self.step_count = 0
        
        # Create unique temporary directory for this actor's databases
        self.temp_dir_path = tempfile.mkdtemp(dir=self.temp_db_dir, prefix=f"appworld_{self.instance_id}_")
        # Set up isolated database path for this actor
        from_db_home_path = get_db_home_path(storage_type="disk", type="base")
        # from_db_home_path = os.environ["APPWORLD_HOME"]
        to_db_home_path = self.temp_dir_path
        
        # Fast database copying with minimal overhead
        try:
            if os.path.exists(from_db_home_path):
                if os.path.exists(to_db_home_path):
                    shutil.rmtree(to_db_home_path)
                
                # Use shutil.copytree to efficiently traverse and copy directories
                # Define ignore function to skip files that do not have a .db extension
                def _ignore_non_db_files(src, names):
                    ignored_names = []
                    for name in names:
                        # Do not ignore directories to recursively traverse subfolders
                        if os.path.isfile(os.path.join(src, name)) and not name.endswith('.db'):
                            ignored_names.append(name)
                    return ignored_names

                # shutil.copytree is internally optimized and can be faster than manual loops
                # Keep copy_function=shutil.copy for faster copying and to skip metadata copy
                shutil.copytree(
                    from_db_home_path,
                    to_db_home_path,
                    ignore=_ignore_non_db_files,
                    copy_function=shutil.copy
                )
        except Exception as e:
            # Fallback: create empty temp directory and let AppWorld handle it
            os.makedirs(to_db_home_path, exist_ok=True)
            print(f"Warning: Fast DB copy failed ({e}), using fallback")
        
        # Configure AppWorld to use temporary databases
        set_local_dbs(
            to_db_home_path=to_db_home_path,
            from_db_home_path=from_db_home_path
        )
        
        # Create AppWorld instance with isolated database
        self.env = AppWorld(
            task_id=self.task_id,
            experiment_name=f"default_{self.task_id}_{self.instance_id}",
        )
        
        obs = {'observation': self.env.task.instruction, 'task_info': self.task_info}
        
        info = {
            "task_id": self.task_id,
            "supervisor": dict(self.env.task.supervisor),
            "task": self.env.task.instruction,
            "max_steps": self.max_steps,
            "step_count": 0,
        }
        return obs, info

    def step(self, action):
        """Execute one step in the environment."""
        if self.env is None:
            raise RuntimeError("Environment not reset before step. Please call reset() first.")
        
        self.step_count += 1
        obs = self.env.execute(action)
        done = self.env.task_completed() or (self.step_count >= self.max_steps)
        # print(self.step_count, self.max_steps, self.env.task_completed())
        if done:
            is_success = self.env.evaluate().success
            reward = 10.0 if is_success else 0.0
            info = {"won": is_success, "step_count": self.step_count, "max_steps": self.max_steps}
        else:
            reward = 0.0
            info = {"won": False, "step_count": self.step_count, "max_steps": self.max_steps}
        return {'observation': obs}, reward, done, info

    def compute_final_reward(self):
        """Compute final reward based on task completion."""
        if self.env is None:
            return 0.0
        
        is_success = self.env.evaluate().success
        return 10.0 if is_success else 0.0

    def close(self):
        if self.env is not None:
            try:
                self.env.close()
            except:
                pass
        
        # Clean up temporary directory
        if self.temp_dir_path and os.path.exists(self.temp_dir_path):
            try:
                shutil.rmtree(self.temp_dir_path)
            except:
                pass


class AppworldEnv(BaseEnv):
    def __init__(self, task, **kwargs):
        super().__init__(**kwargs) 
        # Create Ray Actor for this environment
        self.actor_ref = AppWorldActor.options(num_cpus=2).remote(task)

    def reset(self):
        return ray.get(self.actor_ref.reset.remote())

    def step(self, action):
        return ray.get(self.actor_ref.step.remote(action))

    def compute_final_reward(self):
        return ray.get(self.actor_ref.compute_final_reward.remote())

    def close(self):
        ray.get(self.actor_ref.close.remote())

    @staticmethod
    def from_dict(info: dict) -> "AppworldEnv":

        return AppworldEnv(task = info)