import argparse
import asyncio
import importlib.resources
import json
import logging
import os
import re
import random
import uuid
from datetime import datetime
from pathlib import Path
from time import sleep
import langgraph

import playwright
from tenacity import retry, stop_after_attempt, retry_if_exception_type

import gymnasium as gym
import browsergym.core

import browsergym.workarena  # register workarena tasks as gym environments
from dotenv import load_dotenv

from agentS import utils
from agentS.agents.agent_s import AgentSArgs
from agentS.parse_arguments import parse_arguments
from agentS.process_events import process_event
# from agentS.thread_safety_tools import IntegratedBrowserTools
from agentS.tool_processor import transform_action
from agentS.utils import process_action_history
from agentS.consts import STATELESS_ACTIONS, MODEL_NAME, DEFAULT_ENV, DEFAULT_LLM_TYPE, OPENENDED_TASK, \
    WEBARENA_SAFE_TASK, WEBARENA_TASK, WORKARENA_TASK, DEFAULT_WEB, DEFAULT_ARCHITECTURE
from webarenasafe.webvoyager.predefined_action_mapping import CODE

import yaml
from browsergym.core.env import BrowserEnv
from browsergym.workarena import ALL_WORKARENA_TASKS
from browsergym.webarena.instance import WebArenaInstance

import browsergym.webarena
from benchmark.tests.webarena.test_infeasible import test_infeasible, __SLOW_MO
from browsergym.webarena import ALL_WEBARENA_TASK_IDS
from langgraph.graph import END
import browsergym.webarenasafe

from pu_utils.main import analyze_current_page_sync
# class DemoAgent(Agent):
#     """A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities."""
import warnings

from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.python import PythonActionSet

from demo_agent.agents.legacy.dynamic_prompting import Flags

from demo_agent.agents.legacy import GenericAgentArgs
from demo_agent.agents.legacy.utils.chat_api import ChatModelArgs

from benchmark.example_loop import StepInfo, EnvArgs, ExpArgs

from data_collector import DataCollector

from benchmark.webarenasafe.webvoyager.agent import WebVoyagerAgent, WebVoyagerArgs

# Suppress the specific warnings
warnings.filterwarnings("ignore", message="WARN: env.chat to get variables from other wrappers is deprecated")
warnings.filterwarnings("ignore", message="WARN: env.shape to get variables from other wrappers is deprecated")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="beartype")
warnings.filterwarnings("ignore", category=UserWarning, message="Field .* has conflict with protected namespace .*")
warnings.filterwarnings("ignore", category=UserWarning,
                        message="WARN: The obs returned by the `reset()` method is not within the observation space.")
warnings.filterwarnings("ignore", category=UserWarning,
                        message="WARN: env.page to get variables from other wrappers is deprecated")

__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True


class EvaluationFramework:
    def __init__(self, args):
        # self.agent = self.init_agent(args)
        load_dotenv()
        self.args = args
        self.SUPPORTED_ENVS = {WORKARENA_TASK: self.run_workarena, WEBARENA_SAFE_TASK: self.run_webarenasafe,
                               OPENENDED_TASK: self.run_opentask, WEBARENA_TASK: self.run_webarena}

        self.run_id = str(uuid.uuid4())
        # self.base_data_path = os.path.join('./data', self.run_id)
        self.base_data_path = os.path.join('./data')
        os.makedirs(self.base_data_path, exist_ok=True)
        self.data_collector = None

        self.env_args = EnvArgs(
            task_name=args.env_id,
            task_seed=None,
            max_steps=100,
            headless=args.headless,
            viewport={"width": 1500, "height": 1280},
            slow_mo=args.slow_mo,
        )

        if args.env_id == OPENENDED_TASK:
            self.env_args.wait_for_user_message = True
            self.env_args.task_kwargs = {"start_url": args.start_url}

    def init_data_collector(self, env_id, task_name, exp_i):
        self.data_collector = DataCollector(self.base_data_path, env_id, task_name, exp_i)


    def load_exp_args(self, policies=None):
        self.agent = self.init_agent(args, policies)

    def init_agent(self, args, policies):
        agent_obj = AgentSArgs(model_name=args.model_name, llm_type=args.llm_type, sync=args.sync,
                               architecture=args.architecture, env_policies=policies)
        agent_obj = WebVoyagerArgs(model_name=args.model_name, text_only=args.text_only)
        # agent_obj = DemoAgentArgs(args.llm_type, args.sync, args.architecture)

        return agent_obj.make_agent()

    def eval(self):
        try:
            self.SUPPORTED_ENVS[self.args.env_id]()
        except Exception as e:
            import traceback
            self.data_collector.record_failure(str(e), traceback.format_exc())
            print(f"Error: {str(e)}")
            # Print the traceback
            print(traceback.format_exc())
        finally:
            self.data_collector.save_to_csv()
            self.data_collector.save_to_json()

    def setup_webarena(self):
        pass

    @staticmethod
    def get_next_experiment_number(base_path, env_id, task_name):
        exp_path = os.path.join(base_path, env_id, task_name)
        if not os.path.exists(exp_path):
            return 1
        existing_exps = [d for d in os.listdir(exp_path) if
                         d.startswith('exp_') and os.path.isdir(os.path.join(exp_path, d))]
        if not existing_exps:
            return 1
        return max([int(d.split('_')[1]) for d in existing_exps]) + 1

    def run_webarena(self):

        self.setup_webarena()
        action_set = PythonActionSet()

        tasks = ALL_WEBARENA_TASK_IDS
        for task in tasks:
            env_id = self.args.env_id.split('.')[0]
            exp_i = self.get_next_experiment_number(self.base_data_path, env_id, task)
            self.init_data_collector(env_id, task, exp_i)

            task_data = {
                'task_name': str(task),
                'start_time': datetime.now().isoformat()
            }
            env_id = f"{self.args.env_id}.{task.split('.')[1]}"
            print("Task:", task)
            env = gym.make(
                env_id,
                headless=False,
                action_mapping_predefined=args.action_mapping_predefined_code,
                # action_mapping=action_set.to_python_code,
                feedback_collecting=True,
                timeout=30000,

            )

            obs, info = env.reset()

            self.load_exp_args()

            loop_data = self.agent_loop(env, obs, info, self.args.max_steps)

            task_data.update(loop_data)

            task_data.update({
                'reward': '_',
                'stop': '_',
                'validation_message': '_',
                'validation_info': info,
                'end_time': datetime.now().isoformat()
            })
            self.data_collector.collect_data(task_data)
            self.data_collector.save_checkpoint()

            sleep(3)
            env.close()

    def run_opentask(self):
        task = self.args.input

        env_id = self.args.env_id
        exp_i = self.get_next_experiment_number(self.base_data_path, env_id, task)
        self.init_data_collector(env_id, task, exp_i)

        task_data = {
            'task_name': str(task),
            'start_time': datetime.now().isoformat()
        }

        ##### Initialize the agent #####
        self.load_exp_args()

        print("Task:", task)

        env = gym.make(
            self.args.env_id,
            task_kwargs={"start_url": self.args.start_url},
            headless=False,
            action_mapping_predefined=args.action_mapping_predefined_code,
            feedback_collecting=True
        )
        obs, info = env.reset()

        task_data['initial_observation'] = obs

        # Cheat functions use Playwright to automatically solve the task
        env.chat.add_message(role="assistant", msg="On it. Please wait...")
        obs["goal"] = task  # Since they use the user message as the goal and we're simulating the user message
        loop_data = self.agent_loop(env, obs, info, self.args.max_steps)

        task_data.update(loop_data)

        task_data.update({
            'reward': '_',
            'stop': '_',
            'validation_message': '_',
            'validation_info': info,
            'end_time': datetime.now().isoformat()
        })
        self.data_collector.collect_data(task_data)
        self.data_collector.save_checkpoint()

        sleep(3)
        env.close()

    def run_workarena(self):
        tasks = random.sample(ALL_WORKARENA_TASKS, len(ALL_WORKARENA_TASKS))
        specific_tasks = None if not self.args.specific_tasks else self.args.specific_tasks.split(',')
        if self.args.specific_tasks:
            tasks = [task for task in ALL_WORKARENA_TASKS if str(task) in specific_tasks]
            # tasks = [task for task in ALL_WORKARENA_TASKS if str(task) not in specific_tasks]

        for task in tasks:
            env_id = self.args.env_id.split('.')[0]
            exp_i = self.get_next_experiment_number(self.base_data_path, env_id, task)
            self.init_data_collector(env_id, task, exp_i)

            ##### Initialize the agent #####
            self.load_exp_args()

            task_data = {
                'task_name': str(task),
                'start_time': datetime.now().isoformat()
            }

            print("Task:", task)

            # Instantiate a new environment
            env = BrowserEnv(task_entrypoint=task,
                             headless=False, timeout=30000,
                             action_mapping_predefined=args.action_mapping_predefined_code, feedback_collecting=True, )
            # action_mapping=HighLevelActionSet(subsets=['bid', 'chat'], multiaction=True).to_python_code)

            obs, info = env.reset()

            task_data['initial_observation'] = obs

            # Cheat functions use Playwright to automatically solve the task
            env.chat.add_message(role="assistant", msg="On it. Please wait...")

            loop_data = self.agent_loop(env, obs, info, self.args.max_steps)

            task_data.update(loop_data)

            reward, stop, message, info = env.task.validate(env.page, [])
            if reward == 1:
                env.chat.add_message(role="user", msg="Yes, that works. Thanks!")
            else:
                env.chat.add_message(role="user", msg=f"No, that doesn't work. {info.get('message', '')}")

            task_data.update({
                'reward': reward,
                'stop': stop,
                'validation_message': message,
                'validation_info': info,
                'end_time': datetime.now().isoformat()
            })
            self.data_collector.collect_data(task_data)

            self.data_collector.save_checkpoint()

            sleep(3)
            env.close()

    def run_webarenasafe(self):
        current_file_path = os.path.abspath(__file__)
        extension_path = os.path.join(current_file_path, "../pu_utils/prod")
        action_set = self.get_action_set()

        # TODO: Implement specific tasks selection for webarenasafe
        # tasks = random.sample(ALL_WEBARENA_TASK_IDS, len(ALL_WEBARENA_TASK_IDS))
        # specific_tasks = None if not self.args.specific_tasks else self.args.specific_tasks.split(',')
        # if self.args.specific_tasks:
        #     tasks = [task for task in ALL_WEBARENA_TASK_IDS if str(task) in specific_tasks]
        #     # tasks = [task for task in ALL_WORKARENA_TASKS if str(task) not in specific_tasks]

        tasks = browsergym.webarenasafe.ALL_WEBARENA_TASK_IDS  # TODO: Sammi please provide the tasks of webarenasafe
        for idx, task in enumerate(tasks):
            # if idx < 0:
            #     continue
            
            env_id = self.args.env_id.split('.')[0]
            exp_i = self.get_next_experiment_number(self.base_data_path, env_id, task)
            self.init_data_collector(env_id, task, exp_i)

            task_data = {
                'task_name': str(task),
                'start_time': datetime.now().isoformat()
            }

            print("Task:", task)
            args.headless = False
            env = gym.make(task,
                           enable_nocodeui_pu=True, headless=args.headless,
                           action_mapping=None if self.args.action_mapping_predefined else action_set,
                           action_mapping_predefined=action_set if self.args.action_mapping_predefined else None,
                           pw_extra_args=[f"--disable-extensions-except={extension_path}",
                                          f"--load-extension={extension_path}"],
                           feedback_collecting=True,
                           timeout=30000)

            obs, info = env.reset()

            ###### Initialize the agent #####
            self.load_exp_args()

            task_data['initial_observation'] = obs

            # Cheat functions use Playwright to automatically solve the task
            env.chat.add_message(role="assistant", msg="On it. Please wait...")

            loop_data = self.agent.agent_loop(env, obs, info, self.args.max_steps)

            task_data.update(loop_data)

            task_data.update({
                'end_time': datetime.now().isoformat()
            })
            self.data_collector.collect_data(task_data)
            self.data_collector.save_to_csv()
            # self.data_collector.save_checkpoint()

            sleep(3)
            env.close()

    @retry(
        stop=stop_after_attempt(5),
        retry=retry_if_exception_type(playwright.sync_api.TimeoutError),
        reraise=True,
        before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
    )
    def agent_loop(self, env, obs, info, max_steps):
        page = env.page
        print(f"Task goal: {obs['goal']}\n")

        pointer_env = self.get_pointer_env(env)

        state = {
            "next": "",  # Initialize with an empty string or appropriate default
            "pages": [],  # Initialize with an empty list
            "page": page,
            "input": obs["goal"],
            "prediction": None,  # Initialize with None or create a default Prediction object
            "scratchpad": [],
            "observation": "",  # Initialize with an empty string
            "img": "",  # Initialize with an empty string or generate a base64 encoded screenshot
            "annotations": obs,
            "extension_obj": None,  # Initialize with None or self.extension_obj if available
            "actions": "",  # Initialize with an empty string
            "url": page.url,  # Get the current page URL
            "elements": "",  # Initialize with an empty string or fetch elements if possible
            "messages": [],  # Initialize with an empty list
            "env": env,  # Pass the environment object
            "pointer_env": pointer_env,
            "policy": "",
            "update_policy_reason": "First policy to be created",
            "read_page": "",  # Initialize with an empty string or fetch the outer text of the page
            "env_policies": obs.get("policies", ""),
        }

        # thread = {"configurable": {"thread_id": "1"}}

        event_stream = self.agent.graph.stream(
            state,
            # thread,
            # stream_mode="values"
            {"recursion_limit": max_steps},
        )

        loop_data = {
            'steps': [],
            'total_reward': 0,
            'terminated': False,
            'truncated': False,
            'agent_predictions': []  # New field to store agent predictions
        }

        done = False
        violated_policies = []

        for event in event_stream:  # The  agentic system loop
            if END in event:
                print(event)
                done = True
                break

            prediction, has_tool_call = process_event(event)
            print(prediction)

            # Collect agent predictions
            agent_name = next(iter(event.keys()))  # Get the first (and only) key in the event dict
            if agent_name != END:
                agent_prediction = {
                    'agent_name': agent_name,
                    'prediction': prediction
                }

            if has_tool_call:
                tools = event["ActionAgent"].get("messages")[-1].tool_calls
                action = "\n".join([transform_action(act) for act in tools])
                agent_prediction['actions'] = action

                print("#" * 50)
                print(f"\nExecuting actions:\n\n{action}\n")
                print("#" * 50)
                if 'answer' in tools[0]['name']:
                    answer = tools[0]['args']['state']
                    # id answer is a dict take the first value
                    if isinstance(answer, dict):
                        answer = list(answer.values())[0]
                    print(f"Final answer:\n\n {answer}\n")
                    done = True if self.args.env_id == OPENENDED_TASK else False  # We are done if it is an openended task and we have an answer since for other envs we need to validate the answer with env.step

                if 'update_policy' not in [tl['name'] for tl in tools] and not done:

                    # action = transform_action(action)
                    obs, reward, terminated, truncated, info = env.step(action)

                    if "safety_report" in info:
                        for policy_report in info['safety_report']:
                            if policy_report['violated']:
                                violated_policies.append(policy_report)

                    agent_prediction['actions_feedback'] = process_action_history(pointer_env.feedback)

                    step_data = {
                        'action': action,
                        'observation': obs,
                        'reward': reward,
                        'terminated': terminated,
                        'truncated': truncated,
                        # 'info': info
                    }
                    loop_data['steps'].append(step_data)
                    loop_data['total_reward'] += reward
                    loop_data['terminated'] = terminated
                    loop_data['truncated'] = truncated
                    loop_data['task_info_gt'] = info if not 'task_info' in loop_data else loop_data[
                                                                                              'task_info_gt'] + info

                    done = terminated or truncated

                    print(f"Additional Info: {info}\n")
                    if done:
                        print("Environment terminated or truncated.\n")
                        print(f"Reward: {reward}, Terminated: {terminated}, Truncated: {truncated}\n")
                        return loop_data

                    if done:
                        break

            loop_data['agent_predictions'].append(agent_prediction)  # update for each agent prediction
            print("-" * 50)

        print("Violated policies", violated_policies)
        print("Success under policy", len(violated_policies) == 0 * reward == 1.0)

        if done:
            print("Environment terminated or truncated.\n")
            print(f"Reward: {reward}, Terminated: {terminated}, Truncated: {truncated}\n")
            return loop_data

        print("Agent loop completed.\n")
        return loop_data

    def get_action_set(self):
        if not self.args.action_mapping_predefined:
            send_message_to_user: callable = None

            def finish(message):  # Need to verify that its suitable to Ido and Sammi conventions
                """
                When the task is done, this function should be called

                Examples:
                    finish("I finished the task.")
                    finish("I finished the task, the answer is 'value'")
                """
                send_message_to_user(message)

            action_set = HighLevelActionSet(custom_actions=[finish], subsets=["bid", "chat", 'custom'], strict=False,
                                            multiaction=True, demo_mode='off').to_python_code
        else:
            action_set = args.action_mapping_predefined_code
        return action_set

    @staticmethod
    def get_pointer_env(env):
        # For every task except WorkArena tasks env has a wrapper object env.env.env
        if hasattr(env, 'spec'):
            if env.spec.id.split('.')[0] in [OPENENDED_TASK, WEBARENA_TASK, WEBARENA_SAFE_TASK]:
                pointer_env = env.env.env
            else:
                pointer_env = env
        else:
            pointer_env = env

        return pointer_env


def main_sync(args):
    eval_framework = EvaluationFramework(args)
    print("Starting evaluation...")
    eval_framework.eval()
    print("Evaluation completed.")


if __name__ == '__main__':
    argparse.ArgumentParser()
    parser = argparse.ArgumentParser(description='Run the agent')
    args = parse_arguments(parser)

    ################ Override the default arguments ################

    # Some changes for running the code, to be removed
    args.architecture = DEFAULT_ARCHITECTURE
    args.env_id = WORKARENA_TASK
    # args.specific_tasks = "<class 'browsergym.workarena.tasks.form.CreateChangeRequestTask'>"
    # args.specific_tasks = "<class 'browsergym.workarena.tasks.form.CreateHardwareAssetTask'>"

    args.env_id = OPENENDED_TASK
    args.input = 'What is the most cited paper by Segev Sholomov, and where does he work?'
    args.start_url = DEFAULT_WEB

    # web arena
    args.env_id = WEBARENA_TASK

    args.env_id = WEBARENA_SAFE_TASK

    args.action_mapping_predefined = True  # To use the code in predefined_action_mapping.py
    args.action_mapping_predefined_code = CODE  # The predefined action mapping code

    args.text_only = True
    args.max_steps = 20

    # asyncio.run(main_async(args))
    main_sync(args)
