"""
This file contains the Environment class, which prepares the environment for the research agent to run in.
"""

import json
import os
import sys
import subprocess
import shutil
import copy
import time
import fnmatch
import signal
from traceback import format_exception
from multiprocessing import active_children
from dacite import from_dict
import argparse, stat
from datetime import datetime
import uuid

from .low_level_actions import LOW_LEVEL_ACTIONS
from .high_level_actions import HIGH_LEVEL_ACTIONS
from .schema import Step, Trace, EnvException, TooLongPromptError, LLMError, EnhancedJSONEncoder 
from .prepare_task import prepare_task, get_task_info

class TimeoutException(Exception): pass

def safe_rmtree(path):
    def onerror(func, path, exc_info):
        if isinstance(exc_info[1], FileNotFoundError):
            return
        raise

    if os.path.exists(path):
        # shutil.rmtree(path, onerror=onerror)
        shutil.rmtree(path, ignore_errors=True, onerror=onerror)


def get_max_free_memory_gpu_index():
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,noheader,nounits'],
        capture_output=True,
        text=True
    )

    free_memories = []
    for line in result.stdout.strip().split('\n'):
        if line.strip().isdigit():
            free_memories.append(int(line.strip()))
        
    if not free_memories:
        raise RuntimeError("No GPU memory info found")

    max_mem = max(free_memories)
    return free_memories.index(max_mem)

class Environment:
    def __init__(self, task, args):
        if args.device is None:
            args.device = get_max_free_memory_gpu_index()
            print("using max_free_memory_gpu_index:", args.device)
        else:
            print("using given device:", args.device)
        args.python="miniconda3/envs/mlagentbench/bin/python"
        self._args = args
        self._log_dir = os.path.join(args.log_dir, "env_log")
        self._setup_log_dir()

        self._benchmark_folder_name, self._research_problem = get_task_info(task)
        self._work_dir = args.work_dir
        self._read_only_files = []
        self._initialize_task_env() # set up work dir and log dir

        os.makedirs(self._log_dir, exist_ok=True)
        os.makedirs(self._work_dir, exist_ok=True)

        self._action_infos =  {t.name: t for t in LOW_LEVEL_ACTIONS + HIGH_LEVEL_ACTIONS}

        del self._action_infos["Request Help"]


        self._static_kwargs_for_tools = {
            "device": args.device,
            "python": args.python,
            "work_dir": self.work_dir,
            "args": args,
            "read_only_files": self.read_only_files,
            "research_problem": self.research_problem
        }
        
        self._trace = self._initialize_trace()
        self.current_step = 0
        self._start_time = time.time()

    ############################## getters ########################################

    @property
    def research_problem(self):
        return self._research_problem

    @property
    def benchmark_folder_name(self):
        return self._benchmark_folder_name

    @property
    def log_dir(self):
        return self._log_dir

    @property
    def work_dir(self):
        return self._work_dir
    
    @property
    def read_only_files(self):
        return self._read_only_files

    @property
    def action_infos(self):
        return self._action_infos
    
    @property
    def args(self):
        return self._args

    @property
    def static_kwargs_for_tools(self):
        return self._static_kwargs_for_tools
    
    @property
    def trace(self):
        return copy.deepcopy(self._trace)

    def get_trace(self, id):
        return copy.deepcopy(self._trace[id])

    @property
    def start_time(self):
        return self._start_time
    
    ############################## internal functions ########################################
    
    def _setup_log_dir(self):
        # set up log dir
        if os.path.exists(self.args.log_dir):
            print("log_dir {} already exists".format(self.log_dir))
        else:
            os.makedirs(self.log_dir)

        if os.path.exists(os.path.join(self.log_dir, "tool_logs")):
            print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "tool_logs")))
            # raise ValueError("log_dir {} already exists".format(self.log_dir))
        else:
            os.makedirs(os.path.join(self.log_dir, "tool_logs"))

        if os.path.exists(os.path.join(self.log_dir, "traces")):
            print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "traces")))
            # raise ValueError("log_dir {} already exists".format(self.log_dir))
        else:
            os.makedirs(os.path.join(self.log_dir, "traces"))

    def _initialize_task_env(self):

        work_dir = self.work_dir
        print("-----work_dir-----", work_dir)

        # remove the workspace folder if it exists
        if os.path.exists(work_dir):
            shutil.rmtree(work_dir)

        benchmark_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "benchmarks", self.benchmark_folder_name)

        # prepare if there is a prepare.py and it has not been prepared
        prepare_task(benchmark_dir, self.args.python)

        # copy the benchmarks folder to work_dir
        if os.path.exists(os.path.join(benchmark_dir, "env" )):
            shutil.copytree(os.path.join(benchmark_dir, "env"), work_dir, symlinks=True)

        # find all read only files
        if os.path.exists(os.path.join(benchmark_dir, "scripts", "read_only_files.txt")):
            ignore_files = open(os.path.join(benchmark_dir, "scripts", "read_only_files.txt"), "r").read().split("\n")
            for path, subdirs, files in os.walk(os.path.join(work_dir)):

                relpath = os.path.relpath(path, work_dir)
                # filter out the files that are read only
                filenames = [os.path.join(relpath, filename) for filename in files]
                for ignore in ignore_files:
                    ignore_filenames = [n for n in filenames if fnmatch.fnmatch(n, ignore)]
                    self.read_only_files.extend(ignore_filenames)

         # init backup folder and remove all content if it exists
        if os.path.exists(os.path.join(work_dir, "backup")):
            shutil.rmtree(os.path.join(work_dir, "backup"))
        os.mkdir(os.path.join(work_dir, "backup"))

        if self.args.resume and self.args.resume_step >= 0:
            shutil.rmtree(work_dir)
            resume_dir = os.path.join(self.args.resume_dir, "env_log", "traces" , f"step_{self.args.resume_step}_files")
            print("Restoring workspace ing from {}".format(resume_dir))
            shutil.copytree(resume_dir, work_dir, symlinks=True,dirs_exist_ok=True)
            shutil.copytree(os.path.join(benchmark_dir, "env"), work_dir, symlinks=True,dirs_exist_ok=True,ignore=shutil.ignore_patterns('*.py'))
            if not os.path.exists(os.path.join(work_dir, "backup")):
                os.mkdir(os.path.join(work_dir, "backup"))


    def _initialize_trace(self):
        if self.args.resume and self.args.resume_step >= 0:
            print("Restoring trace from {}".format(self.args.resume_dir))
            prev_trace = from_dict(data_class=Trace, data=json.load(open(os.path.join(self.args.resume_dir, "env_log","trace.json"), "r")))
            print("Resetting trace to step {}".format(self.args.resume_step))
            steps = prev_trace.steps[:self.args.resume_step+1]
            t = steps[-1].timestamp
            low_level_steps = [s for s in prev_trace.low_level_steps if s.timestamp < t]
            trace = Trace(
                steps=steps,
                low_level_steps=low_level_steps,
                action_infos=self.action_infos,
                task_description=self.research_problem,
            )
        else:   
            trace = Trace(
            steps=[],
            low_level_steps=[],
            action_infos=self.action_infos,
            task_description=self.research_problem,
        )
        return trace
    
    def __enter__(self):
        # set time out
        def signal_handler(signum, frame):
            raise TimeoutException("Timed out!")
        signal.signal(signal.SIGALRM, signal_handler)
        signal.alarm(self.args.max_time)
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):  
        # save error message
        active = active_children()
        print(f'Active Children: {len(active)}')
        # terminate all active children
        for child in active:
            child.terminate()
        # block until all children have closed
        for child in active:
            child.join()
        # report active children
        active = active_children()
        print(f'Active Children: {len(active)}')
            
        if traceback is not None:
            print("Error message saved in error.txt")
            open(os.path.join(self.log_dir, "error.txt"), "w").write(''.join(format_exception(exc_type, exc_value, traceback)))
        open(os.path.join(self.log_dir, "overall_time.txt"), "w").write(str(time.time() - self.start_time))
            
    ################################# public functions ########################################

    def check_final(self):
        """Check if the task has reached a final state, either by reaching the maximum steps or time, or because the agent has submitted a final answer. """
        
        curr_step = len(self.trace.steps)
        # check if any step is final answer
        any_final_answer = any([s.action.name == "Final Answer" for s in self.trace.steps])
        if curr_step >= self.args.agent_max_steps or any_final_answer or time.time() - self.start_time > self.args.max_time:
            print("curr_step", curr_step,"self.args.agent_max_steps", self.args.agent_max_steps,"any_final_answer", any_final_answer,"time.time() - self.start_time", time.time() - self.start_time,"self.args.max_time", self.args.max_time)
        return curr_step >= self.args.agent_max_steps or any_final_answer or time.time() - self.start_time > self.args.max_time

    def execute(self, action):
        """Execute an action and return the observation."""
        trace = self._trace

        curr_step = len(trace.steps)
        action_name = action.name
        action_input = action.args

        if action_name == "Final Answer":
            observation = "end"

        elif self.check_final():
            observation = "The environment has shut down because the maximum number of steps or time has been reached. Please submit your final answer."

        elif action_name not in list(self.action_infos.keys()):
            actions = ", ".join(self.action_infos.keys())
            observation = f"Invalid action: {action_name}. Action did not execute. Please use one of the following actions:\n{actions}"

        else:
            # execute the action and get the observation
            log_file = os.path.join(os.path.join(self.log_dir, "tool_logs") , f"step_{curr_step}_tool_log.log")
            usage = ",\n            ".join([f"{k}: [{v}]" for k, v in self.action_infos[action_name].usage.items()])
            usage = f"""{{
            {usage}
}}"""
            invalid_action_error = f"The action input for {action_name} needs to be a valid json with proper entries. You may have missed the comma between entries. Please use the correct format and try again:\n{usage}"

            if isinstance(action_input, dict):
                try:
                    observation = self.action_infos[action_name].function(**action_input, log_file=log_file, trace=trace, **self.static_kwargs_for_tools)
                except TooLongPromptError:
                    observation="EnvError: too long input for the tool"
                except LLMError as e:
                    observation = "LLMError: " + e.message
                except EnvException as e:
                    observation = "EnvError: " + e.message
                except TypeError as e:
                    print("Step: ", curr_step)
                    print("Error: ", e)
                    print("action_input",action_input)
                    observation = "EnvError: " + invalid_action_error
                except TimeoutException as e:
                    raise e
                except Exception as e:
                    # should not happen
                    print("Step: ", curr_step, file=sys.stderr)
                    print(e, file=sys.stderr)
                    if "Connection aborted." in str(e):
                        raise Exception("Connection aborted for crfm")
                    observation = f"EnvError: Error executing {action_name}."
            else:
                observation = invalid_action_error

        step_time = time.time()

        trace.steps.append(Step(action, observation, step_time))
        self.save(curr_step)
        return observation
    
    def save(self, curr_step):
        """ Save the trace and snapshot of the workspace folder """     
        trace = self.trace
        with open(os.path.join(self.log_dir, f"trace.json"), "w") as f:
            json.dump(trace, f, indent=4, cls=EnhancedJSONEncoder)

        ##### save a snapshot of the current step
        save_folder = os.path.join(self.log_dir, f"traces/step_{curr_step}_files")
        if os.path.exists(save_folder):
            safe_rmtree(save_folder)
        os.makedirs(save_folder)

        # save files in the folder that are not read only
        for path, subdirs, files in os.walk(os.path.join(self.work_dir)):

            relpath = os.path.relpath(path, self.work_dir)
            dest = os.path.join(save_folder, relpath)

            for file_name in files:
                file_path = os.path.join(relpath, file_name)
                if file_path not in self.read_only_files:
                    # check whether the file to copy is part of self.log_dir
                    if  os.path.abspath(os.path.join(self.work_dir, file_path)).startswith(os.path.abspath(self.log_dir.split("/env_log")[0])):
                        continue                    
                    if not os.path.exists(dest):
                        os.makedirs(dest)            
                    shutil.copyfile(os.path.join(self.work_dir, file_path), os.path.join(save_folder, file_path))

    ############## for logging convenience ##############

    def get_task_description(self):
        return self.research_problem, self.benchmark_folder_name

    @property
    def low_level_actions(self):
        return list(filter(lambda x: x.is_primitive, self.action_infos.values()))

    @property
    def high_level_actions(self):
        return list(filter(lambda x: not x.is_primitive, self.action_infos.values()))

    def print_action(self, entries):
        return "".join([ k + ": " + v for k,v in  entries.items()])


class ResumeEnvironment(Environment):
    def __init__(self, info, format_reward=False):
        timestamp = datetime.now().strftime("%m%d%H%M%S")
        unique_id = uuid.uuid4().hex[:8]
        config_dict = {
            "task": info["task"],
            "resume": True,
            "resume_step": info["resume_step"],
            "resume_dir": info["workspace_dir"].replace("ppo_test", "../ppo_test"),
            "log_dir": info["workspace_dir"].replace("ppo_test", "../verl/logs_ppo_test")+f"/{timestamp}_{unique_id}",
            "work_dir": info["workspace_dir"].replace("ppo_test", "../verl/workspaces_ppo_test")+f"/{timestamp}_{unique_id}",
            "edit_script_llm_name": "Qwen2.5-Coder-32B-Instruct",
            "edit_script_llm_max_tokens": 4000,
            "template_name": info["template_name"],
            "agent_max_steps": 22,
            "max_time": 60*60,
        }

        args = argparse.Namespace(**config_dict)

        self.format_reward = format_reward
        super().__init__(info["task"], args)
        # TODO: copy data to work_dir if data is not in work_dir
    