import random

# reward_func.py
import os
import sys, random
import torch
import re
from envs import Agent, Action, ResumeEnvironment
from verl.utils.dataset.mla_dataset import MATCH_TASK_PATTERN_DICT, scale_rewards
from envs.MLAgentBench import high_level_actions

from concurrent.futures import ThreadPoolExecutor


def compute_score(data_source, solution, task_extra_info,cpus=None,ids=None):
    agent = ResumeAgent(task_extra_info,cpus)
    print(f"\033[33mResumeAgent inited\033[0m")
    reward = agent.single_update(solution,ids)
    # reward = random.random()
    return reward

class ResumeAgent(Agent):
    def __init__(self, task_info, cpus) -> None:
        self.task_info = task_info
        self.cpus = cpus
        resume_env = ResumeEnvironment(self.task_info, self.cpus)
        self.env = resume_env
        self.env.args.valid_format_entries = ["Reflection", "Research Plan and Status", "Fact Check", "Thought", "Action", "Action Input"] # ["Action", "Action Input"]
        # self.env.args.valid_format_entries = ["Action", "Action Input"]

        self.action_infos = resume_env.action_infos
        self.valid_format_entries = self.env.args.valid_format_entries

        high_level_actions.EDIT_SCRIPT_MODEL = "gpt-4o-mini" 
        high_level_actions.EDIT_SCRIPT_MODEL = "Qwen2.5-Coder-32B-Instruct"
        high_level_actions.EDIT_SCRIPT_MAX_TOKENS = 4000
        global FAST_MODEL
        FAST_MODEL = "Qwen2.5-Coder-32B-Instruct"
        self.template_name = task_info["template_name"]
        super().__init__(self.env.args, self.env)


    def single_update(self, completion, ids=None):
        task = self.task_info["task"]

        # # ===== parse LLM output to env actions ===== #
        try:
            entries = self.parse_entries(completion, self.valid_format_entries)
            action = entries["Action"].strip()

            assert action in self.prompt_tool_names, f"Invalid action: {action}"
            raw_action_input = entries["Action Input"]
            action_input = self.parse_action_input(raw_action_input, self.action_infos[action])
            format_reward = 0
        except Exception as e:
            print(f"\033[33m[single_update] | Error in parsing: {e}\033[0m")
            format_reward = -100
        
        # _, format_reward = self.get_format_reward(completion, self.valid_format_entries, return_notice=True)

        if format_reward == -100:
            reward = -100
            # reward = 0 # without format reward
        elif action in ["Edit Script (AI)"]:
            print(f"[single_update] | ready to edit script")
            observation = self.env.execute(Action(action, action_input))
            print(f"[single_update] | edit script done")
            # execute script
            script_name = action_input["save_name"]
            # execute_result = self.env.execute(Action("Execute Script Async", {"script_name": script_name}))
            print(f"[single_update] | ready to execute script")
            execute_result = self.env.execute(Action("Execute Script", {"script_name": script_name}))
            print(f"[single_update] | execute script done")

            if "torch.cuda.OutOfMemoryError: CUDA out of memory" in execute_result:
                reward = 0
                # reward = -100  # donot care corner case
            else:
                pattern = MATCH_TASK_PATTERN_DICT[task]
                task_strings = re.findall(pattern, execute_result)
                if len(task_strings) > 0: # executed successfully
                    reward = float(task_strings[-1]) - self.task_info["state"]
                    print(f"\033[33m[single_update{ids}] | old state: {self.task_info['state']}, new state: {task_strings[-1]}, reward: {reward}\033[0m")
                else:
                    if len(execute_result) > 1000:
                        execute_result = "..." + execute_result[-1000:]
                    execute_result = execute_result.replace("\n", " ")
                    print(f"\033[31m[single_update{ids}] | Error in executing script ({task}): {execute_result}\033[0m")
                    reward = -100
        else:
            reward = 0
        # scale_rewards
        return scale_rewards(task, [reward])[0]


class TestAgent():
    def __init__(self) -> None:
        self.pattern = MATCH_TASK_PATTERN_DICT
    def single_update(self,task="cifar10",reward=0):
        a= [reward]
        return scale_rewards(task, a)[0]
    