import gzip
import importlib.metadata
import json
import logging
import os
from dotenv import load_dotenv
import pickle
import sys
import time
import traceback
import uuid

from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import asdict, dataclass, field, is_dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional

import gymnasium as gym
import numpy as np
from browsergym.core.chat import Chat
from PIL import Image
from tqdm import tqdm

from browsergym.experiments.agent import Agent
from browsergym.experiments.utils import count_messages_token, count_tokens

from browsergym.experiments import (
    ExpArgs, 
    EnvArgs, 
    AbstractAgentArgs,
    )

logger = logging.getLogger(__name__)

# Load environment variables from the .env file
load_dotenv()

class ConfigurationError(Exception):
    """Custom exception for configuration errors."""
    pass

@dataclass
class CustomEnvArgs(EnvArgs):
    enable_nocodeui_pu: bool = True
    action_mapping_predefined: str = ""
    pw_extra_args: list = field(default_factory=list)
    feedback_collecting: bool = True
    task_kwargs: dict = None  # use default value from BrowserGym

    def __post_init__(self):
        # Perform validation when the flags are initialized
        if self.enable_nocodeui_pu:
            extension_path = os.getenv('EXTENSION_PATH')
            
            # Check if EXTENSION_PATH is set in the environment
            if extension_path is None:
                raise EnvironmentError(
                    "EXTENSION_PATH must be set in the environment when enable_nocodeui_pu is True."
                )
            
            # Check if action_mapping_predefined is correctly defined
            if not self.pw_extra_args:
                raise ConfigurationError("Invalid pw_extra_args: pw_extra_args must be set.")

            elif any(opt.endswith("None") for opt in self.pw_extra_args):
                raise ConfigurationError("Invalid pw_extra_args: pw_extra_args must be set.")

    def make_env(self, action_mapping, exp_dir):
        extra_kwargs = {}
        if self.record_video:
            extra_kwargs["record_video_dir"] = exp_dir
        if self.viewport:
            extra_kwargs["viewport"] = self.viewport
        if self.slow_mo is not None:
            extra_kwargs["slow_mo"] = self.slow_mo
        if self.storage_state:
            extra_kwargs["pw_context_kwargs"] = {"storage_state": self.storage_state}
        if self.task_kwargs is not None:
            extra_kwargs["task_kwargs"] = self.task_kwargs

        return gym.make(
            _get_env_name(self.task_name),
            disable_env_checker=True,
            max_episode_steps=self.max_steps,
            headless=self.headless,
            wait_for_user_message=self.wait_for_user_message,
            action_mapping=None if self.action_mapping_predefined else action_mapping,  # action mapping is provided by the agent
            enable_nocodeui_pu=self.enable_nocodeui_pu,
            pw_extra_args=self.pw_extra_args,
            action_mapping_predefined=self.action_mapping_predefined if self.action_mapping_predefined else None,
            feedback_collecting=self.feedback_collecting,
            **extra_kwargs,
        )


def save_package_versions(exp_dir: Path):
    """Save the versions of the installed packages in the experiment directory."""
    python_dists = "\n".join(
        sorted(
            [
                f'{dist.metadata["Name"]}=={dist.metadata["Version"]}'
                for dist in importlib.metadata.distributions()
            ]
        )
    )
    (exp_dir / "package_versions.txt").write_text(python_dists)


@dataclass
class CustomExpArgs(ExpArgs):
    """Arguments to run an experiment, i.e. run agent in an environment until done.

    This dataclass is used to store experiments arguments. It contains
    agent_args and env_args which follows the same principle. It contains helper
    functions to prepare and run experiments.

    Attributes:
    -----------
    agent_args: AbstractAgentArgs
        The arguments to instantiate the agent.
    env_args: EnvArgs
        The arguments to instantiate the environment.
    exp_dir: str
        The directory where the experiment will be saved.
    exp_name: str
        The name of the experiment. If None, it will be generated from the
        agent and environment names.
    enable_debug: bool
        If python is running in debug mode and `enable_debug` is True, errors
        will be raised instead of only logged
    error_msg: str
        Error that occured while running the experiment (if any).
    stack_trace: str
        Stack trace of the error (if any).
    order: int (internal)
        The order of the experiment in the batch. It is used to keep track of
        the original order of the experiments in case they are shuffled.
    """

    new_flag: bool = False

    def run(self):
        """Run the experiment and save the results"""

        # start writing logs to run logfile
        self._set_logger()
        self.exp_root_dir = self.exp_dir
        self.benchmark = self.env_args.task_name

        _get_env_name(self.benchmark)

        if any(char.isdigit() for char in self.benchmark):
            env_ids = [f"browsergym/{self.benchmark}"]
        else:
            env_ids = [id for id in gym.envs.registry.keys() if id.startswith(f"browsergym/{self.benchmark}")]

        for idx, task in enumerate(env_ids):
            if not(idx >= 287 and idx <=288):
                continue
            task_name = task.replace('browsergym/', '')
            if any(char.isdigit() for char in self.benchmark):
                self.exp_dir = self.exp_root_dir
            else:
                self.exp_dir = self.exp_root_dir / Path(task_name)
            self.exp_dir.mkdir(parents=True, exist_ok=True)
            self._set_logger()
            self.env_args.task_name = task

            # log python environment info
            save_package_versions(self.exp_dir)

            episode_info = []
            env, step_info, err_msg, stack_trace = None, None, None, None
            try:
                logger.info(f"Running experiment {self.exp_name} in:\n  {self.exp_dir}")
                agent = self.agent_args.make_agent()
                logger.debug(f"Agent created.")
                env = self.env_args.make_env(
                    action_mapping=agent.action_set.to_python_code, 
                    exp_dir=self.exp_dir
                )
                logger.debug(f"Environment created.")

                step_info = CustomStepInfo(step=0)
                episode_info = [step_info]
                step_info.from_reset(
                    env, seed=self.env_args.task_seed, obs_preprocessor=agent.obs_preprocessor
                )
                logger.debug(f"Environment reset.")

                while not step_info.is_done:  # set a limit
                    logger.debug(f"Starting step {step_info.step}.")
                    action = step_info.from_action(agent)
                    logger.debug(f"Agent chose action:\n {action}")

                    if action is None:
                        step_info.truncated = True

                    step_info.save_step_info(self.exp_dir)
                    logger.debug(f"Step info saved.")

                    _send_chat_info(env.unwrapped.chat, action, step_info.agent_info)
                    logger.debug(f"Chat info sent.")

                    step_info = CustomStepInfo(step=step_info.step + 1)
                    episode_info.append(step_info)

                    if action is None:
                        logger.debug(f"Agent returned None action. Ending episode.")
                        break

                    logger.debug(f"Sending action to environment.")
                    step_info.from_step(env, action, obs_preprocessor=agent.obs_preprocessor)
                    logger.debug(f"Environment stepped.")

            except Exception as e:
                err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
                stack_trace = traceback.format_exc()

                self.err_msg = err_msg
                self.stack_trace = stack_trace

                logger.warning(err_msg + "\n" + stack_trace)
                if _is_debugging() and self.enable_debug:
                    raise

            finally:
                try:
                    if step_info is not None:
                        step_info.save_step_info(self.exp_dir)
                except Exception as e:
                    logger.error(f"Error while saving step info in the finally block: {e}")
                try:
                    if (
                        not err_msg
                        and len(episode_info) > 0
                        and not (episode_info[-1].terminated or episode_info[-1].truncated)
                    ):
                        e = KeyboardInterrupt("Early termination??")
                        err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
                    _save_summary_info(episode_info, self.exp_dir, err_msg, stack_trace)
                except Exception as e:
                    logger.error(f"Error while saving summary info in the finally block: {e}")
                try:
                    if env is not None:
                        env.close()
                except Exception as e:
                    logger.error(f"Error while closing the environment in the finally block: {e}")
                try:
                    self._unset_logger()  # stop writing logs to run logfile
                except Exception as e:
                    logger.error(f"Error while unsetting the logger in the finally block: {e}")


@dataclass
class StepTimestamps:
    env_start: float = 0
    action_exec_start: float = 0  # to extract begining of visual action from video
    action_exec_stop: float = 0  # to extract end of visual action from video
    action_exect_after_timeout: float = 0
    env_stop: float = 0
    agent_start: float = 0
    agent_stop: float = 0


@dataclass
class CustomStepInfo:
    """Collects information about step that will be saved and reloaded.
    Helper functions only modify the dataclass attributes and helps keeping the
    information organized.

    Attributes:
    -----------
    step: int
        The step number of the episode.
    obs: dict
        The observation of the environment.
    reward: float
        The reward of the step.
    raw_reward: float
        The raw reward of the step.
    terminated: bool
        Whether the episode is terminated i.e. reached a terminal state.
    truncated: bool
        Whether the episode is truncated i.e. reached a maximum number of steps.
    action: str
        The action taken by the agent.
    agent_info: dict
        Additional information from the agent.
    stats: dict
        Extra statistics about the step.
    profiling: StepTimestamps
        Timestamps of the different events during the episode.
    """

    step: int = None
    obs: dict = None
    reward: float = 0
    raw_reward: float = 0
    terminated: bool = None
    truncated: bool = None
    action: str = None
    agent_info: dict = field(default_factory=dict)
    stats: dict = None
    profiling: StepTimestamps = field(default_factory=StepTimestamps)
    task_info: dict = None

    def from_step(self, env: gym.Env, action: str, obs_preprocessor: callable):
        t = self.profiling
        t.env_start = time.time()
        self.obs, self.reward, self.terminated, self.truncated, env_info = env.step(action)
        t.env_stop = time.time()

        self.violated_policies = []
        if "safety_report" in env_info:
            for policy_report in env_info['safety_report']:
                if policy_report['violated']:
                    self.violated_policies.append(policy_report)

        self.task_info = env_info.get("task_info", None)

        self.raw_reward = env_info.get("RAW_REWARD_GLOBAL", None)

        t.action_exec_start = env_info["action_exec_start"]  # start
        t.action_exect_after_timeout = env_info["action_exec_stop"]
        t.action_exec_stop = env_info["action_exec_stop"] - env_info["action_exec_timeout"]

        if obs_preprocessor:
            self.obs = obs_preprocessor(self.obs)

    def from_action(self, agent: Agent):
        self.profiling.agent_start = time.time()
        self.action, self.agent_info = agent.get_action(self.obs.copy())
        self.profiling.agent_stop = time.time()

        self.make_stats()

        return self.action

    def from_reset(self, env: gym.Env, seed: int, obs_preprocessor: callable):
        t = self.profiling
        t.env_start = time.time()
        self.obs, env_info = env.reset(seed=seed)
        self.reward, self.terminated, self.truncated = 0, False, False
        self.violated_policies = []
        t.env_stop = time.time()

        t.action_exec_start = env_info.get("recording_start_time", t.env_start)
        t.action_exect_after_timeout = t.env_stop
        t.action_exec_stop = t.env_stop

        if obs_preprocessor:
            self.obs = obs_preprocessor(self.obs)

    @property
    def is_done(self):
        return self.terminated or self.truncated

    def make_stats(self):

        stats = {
            f"n_token_{key}": count_tokens(val)
            for key, val in self.obs.items()
            if isinstance(val, str)
        }
        stats.update(self.agent_info.pop("stats", {}))

        messages = self.agent_info.get("chat_messages", None)
        if messages is not None:
            stats["n_token_agent_messages"] = count_messages_token(messages)

        t = self.profiling
        stats["step_elapsed"] = t.env_stop - t.env_start
        stats["agent_elapsed"] = t.agent_stop - t.agent_start

        self.stats = stats

    def save_step_info(self, exp_dir, save_json=False, save_jpg=True):

        with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f:
            pickle.dump(self, f)

        if save_jpg and self.obs is not None:
            for name in ("screenshot", "screenshot_som"):
                if name in self.obs:
                    img = Image.fromarray(self.obs[name])
                    img.save(exp_dir / f"{name}_step_{self.step}.jpg")

        if save_json:
            with open(exp_dir / "steps_info.json", "w") as f:
                json.dump(self, f, indent=4, cls=DataclassJSONEncoder)


def _extract_err_msg(episode_info: list[CustomStepInfo]):
    """Extract the last error message from the episode info."""
    errors = [(None, None)]
    for step_info in episode_info:
        if step_info.agent_info is None:
            continue
        err_msg = step_info.agent_info.get("err_msg", None)
        if err_msg is not None:
            errors.append((err_msg, step_info.agent_info.get("stack_trace", None)))

    return errors[-1]


def _aggregate_episode_stats(episode_info: list[CustomStepInfo]):
    """Aggregate CustomStepInfo.stats across episodes.

    It will compute the sum and max of each value in the stats dict.
    These two summaries should cover many use cases. If more are needed, the
    user can compute other stats by reloading individual CustomStepInfo.
    """
    # discard the last step since it was not seen by the agent
    episode_info = episode_info[:-1]

    stats = defaultdict(list)
    for step_info in episode_info:
        if step_info.stats is not None:
            for key, val in step_info.stats.items():
                if val is None:
                    val = np.nan
                stats[key].append(val)

    aggregated_stats = {"cum_steps": len(episode_info)}  # to be able to compute the mean
    for key, val_list in stats.items():
        aggregated_stats[f"cum_{key}"] = np.nansum(val_list)
        aggregated_stats[f"max_{key}"] = np.nanmax(val_list)

    for key, val in aggregated_stats.items():
        if isinstance(val, np.generic):
            aggregated_stats[key] = val.item()
        if np.isnan(val):
            aggregated_stats[key] = None
    return aggregated_stats


def _save_summary_info(
    episode_info: list[CustomStepInfo],
    exp_dir,
    err_msg,
    stack_trace,
):
    # bring err from agent_info to the top level
    if err_msg is None:
        err_msg, stack_trace = _extract_err_msg(episode_info)
    else:
        # useful until we get a proper place in agent_xray to view error
        # messages.
        if len(episode_info) == 0:
            episode_info.append(CustomStepInfo())
        episode_info[-1].agent_info["err_msg"] = err_msg
        episode_info[-1].agent_info["stack_trace"] = stack_trace

    if episode_info[-1].task_info.get('score_seperated'):
        cum_per_task_reward = sum([v['reward'] for v in episode_info[-1].task_info.get('score_seperated')])/len(episode_info[-1].task_info.get('score_seperated'))
    else:
        cum_per_task_reward = 0

    summary_info = dict(
        goal=episode_info[0].obs['goal'],
        n_steps=len(episode_info) - 1,
        cum_reward=sum([step.reward for step in episode_info]),
        cum_per_task_reward=cum_per_task_reward,
        cum_raw_reward=sum([step.raw_reward for step in episode_info if step.raw_reward]),
        err_msg=err_msg,
        stack_trace=stack_trace,
    )
    for key, val in _aggregate_episode_stats(episode_info).items():
        summary_info[f"stats.{key}"] = val

    if len(episode_info) > 0:
        summary_info["terminated"] = episode_info[-1].terminated
        summary_info["truncated"] = episode_info[-1].truncated

    summary_info["num_violated_policies"] = sum([len(step.violated_policies) for step in episode_info])

    # Initialize an empty list to store the appended values
    summary_info["violated_policies"] = []

    # Loop over each element
    for step in episode_info:
        if step.violated_policies:  # Check if the list is not empty
            summary_info["violated_policies"].extend([policy["policy"]["policy_category"] for policy in step.violated_policies])  # Append the values to the final list

    for step in episode_info:
        # if len(step.violated_policies) > 0:
        summary_info[f"violated_in_step_{step.step}"] = step.violated_policies

    with open(exp_dir / "summary_info.json", "w") as f:
        json.dump(summary_info, f, indent=4)


def _is_debugging():
    """Tells you if your code is currently running in debug mode."""
    return sys.gettrace() is not None


class CustomExpResult:
    """Helper class to load and visualize the results of an experiment.

    attributes are loaded lazily.

    Attributes (lazily loaded):
        exp_args: ExpArgs, the arguments of the experiment.
        steps_info: list[CustomStepInfo], the information of each steps so far
        summary_info: dict, the summary of the experiment.
        screenshots: list[Image], the screenshots of each step.
        screenshots_som: list[Image], the screenshots of each step with set of
            marks inprinted.
        flat_exp_args: dict, the flattened version of exp_args.
        chat_video_path: Path, the path to the chat video. (if record_video=True)
        task_video_path: Path, the path to the task video. (if record_video=True)
        combined_video_path: Path, the path to the combined video. (if video was
            combined)
    """

    def __init__(self, exp_dir) -> None:
        self.exp_dir = Path(exp_dir)
        self._exp_args = None
        self._steps_info = {}
        self._summary_info = None
        self._screenshots = {}
        self._flat_exp_args = None
        self._logs = None

    @property
    def exp_args(self) -> ExpArgs:
        if self._exp_args is None:
            with open(self.exp_dir / "exp_args.pkl", "rb") as f:
                self._exp_args = pickle.load(f)
                # in case experiments were moved
                self._exp_args.exp_dir = self.exp_dir
        return self._exp_args

    def get_step_info(self, step: int) -> CustomStepInfo:
        """Load the step info from the file and return it."""
        if self._steps_info.get(step, None) is None:
            with gzip.open(self.exp_dir / f"step_{step}.pkl.gz", "rb") as f:
                self._steps_info[step] = pickle.load(f)
        return self._steps_info[step]

    @property
    def steps_info(self) -> list[CustomStepInfo]:
        step_files = list(self.exp_dir.glob("step_*.pkl.gz"))
        for file in step_files:
            step = int(file.name.split("_")[-1].split(".")[0])
            self.get_step_info(step)

        return [self._steps_info[i] for i in range(len(self._steps_info))]

    @property
    def summary_info(self) -> dict:
        if self._summary_info is None:
            with open(self.exp_dir / "summary_info.json", "r") as f:
                # if length is zero raise file not found error
                if os.fstat(f.fileno()).st_size == 0:
                    raise FileNotFoundError(f"summary_info.json is empty.")
                self._summary_info = json.load(f)
        return self._summary_info

    def get_screenshot(self, step: int, som=False) -> Image:
        key = (step, som)
        if self._screenshots.get(key, None) is None:
            file_name = f"screenshot_{'som_' if som else ''}step_{step}.jpg"
            self._screenshots[key] = Image.open(self.exp_dir / file_name)
        return self._screenshots[key]

    def get_screenshots(self, som=False):
        files = list(self.exp_dir.glob("screenshot_step_*.jpg"))
        max_step = 0
        for file in files:
            step = int(file.name.split("_")[-1].split(".")[0])
            self.get_screenshot(step, som=som)
            max_step = max(max_step, step)
        return [self._screenshots.get((i, som), None) for i in range(max_step + 1)]

    @property
    def screenshots(self):
        return self.get_screenshots(som=False)

    @property
    def screenshots_som(self):
        return self.get_screenshots(som=True)

    @property
    def flat_exp_args(self) -> dict:
        """Return a dict with exp_args flattened."""
        if self._flat_exp_args is None:
            exp_args = asdict(self.exp_args)
            # this will flatten nested dicts
            self._flat_exp_args = _flatten_dict(exp_args)
        return self._flat_exp_args

    def get_exp_record(self) -> dict:
        """Return a dict with exp_args flattened and summary_info."""
        record = {"exp_dir": self.exp_dir}
        try:
            record.update(self.flat_exp_args)
        except FileNotFoundError:
            pass
        try:
            record.update(self.summary_info)
        except FileNotFoundError:
            pass
        return record

    @property
    def chat_video_path(self) -> Path:
        try:
            return next(self.exp_dir.glob("chat_video/*.webm"))
        except StopIteration:
            raise FileNotFoundError(f"No chat_video found in {self.exp_dir}")

    @property
    def task_video_path(self) -> Path:
        try:
            return next(self.exp_dir.glob("task_video/*.webm"))
        except StopIteration:
            raise FileNotFoundError(f"No task_video found in {self.exp_dir}")

    @property
    def combined_video_path(self) -> Path:
        return self.exp_dir / "combined_video.mp4"

    @property
    def logs(self):
        if self._logs is None:
            self._logs = (self.exp_dir / "experiment.log").read_text()
        return self._logs


EXP_RESULT_CACHE = {}


def get_exp_result(exp_dir) -> CustomExpResult:
    """Keep a cache of pre-loaded exp_results for faster loading"""
    exp_dir = str(exp_dir)  # make sure it's not a Path
    exp_result = EXP_RESULT_CACHE.get(exp_dir, None)
    if exp_result is None:
        exp_result = CustomExpResult(exp_dir)
        EXP_RESULT_CACHE[exp_dir] = exp_result
    return exp_result


def yield_all_exp_results(
    savedir_base: str | Path, progress_fn=tqdm, load_hidden=False, use_cache=True
):
    """Recursively find all experiments from savedir_base folder.

    This will ignore all experiments that start with "_" or ".". use
    `load_hidden=True` to load them anyway.
    """

    if not isinstance(savedir_base, list):
        savedir_base = [savedir_base]

    exp_args_paths = []
    for exp_dir in savedir_base:
        exp_args_paths.extend(list(Path(exp_dir).glob("**/exp_args.pkl")))

    if progress_fn is not None:
        exp_args_paths = progress_fn(exp_args_paths, desc="Searching experiments directories.")

    for exp_args_path in exp_args_paths:
        exp_dir = exp_args_path.parent
        if not load_hidden:
            if exp_dir.name.startswith("_") or exp_dir.name.startswith("."):
                continue
        if use_cache:
            yield get_exp_result(exp_dir)
        else:
            yield CustomExpResult(exp_dir)


class DataclassJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if is_dataclass(obj):
            return asdict(obj)
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)


def _move_old_exp(exp_dir):
    """Move the old experiment directory to a new name."""
    exp_dir = Path(exp_dir)
    if exp_dir.exists():
        exp_dir.rename(exp_dir.with_name("_" + exp_dir.name))


def _get_env_name(task_name: str):
    """Register tasks if needed (lazy import) and return environment name."""

    # lazy benchmark import
    if task_name.startswith("miniwob"):
        import browsergym.miniwob
    elif task_name.startswith("workarena"):
        import browsergym.workarena
    elif task_name.startswith("webarena"):
        import browsergym.webarena
    elif task_name.startswith("visualwebarena"):
        import browsergym.visualwebarena
    elif task_name.startswith("WebArenaSafeEnv"):
        import browsergym.webarenasafe
    elif task_name.startswith("browsergym/WebArenaSafeEnv"):
        import browsergym.webarenasafe
        return task_name

    return f"browsergym/{task_name}"


def _send_chat_info(chat: Chat, action: str, agent_info: dict):
    """Send the think and action info to the chat."""
    msg = ""
    if "think" in agent_info:
        msg += f"""\
{agent_info["think"]}

"""

    msg += f"""\
action:
{action}
"""

    logger.info(msg)
    chat.add_message(role="info", msg=msg)


def _flatten_dict(d, parent_key="", sep="."):
    """Recursively flatten a nested dictionary."""
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, dict):
            items.extend(_flatten_dict(v, new_key, sep).items())
        else:
            items.append((new_key, v))
    return dict(items)
