import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import gymnasium as gym
import numpy as np
import pandas as pd
from stable_baselines3.common.type_aliases import PolicyPredictor
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped

from ucrl.classify.classifier import mujoco_safety_gymnasium_dict
from ucrl.safety.label import LabeledNPData
from ucrl.common.on_policy_algorithm import OnPolicyAlgorithmH, OnPolicyAlgorithmC
from ucrl.common.type_aliases import PolicyHPredictor, PolicyCPredictor


def eval_policy_cost_traj(
        model: Union["PolicyPredictor", "PolicyHPredictor", "PolicyCPredictor"],
        env: Union[gym.Env, VecEnv],
        n_eval_episodes: int = 100,
        deterministic: bool = True,
        render: bool = False,
        callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
        reward_threshold: Optional[float] = None,
        return_episode_rewards: bool = False,
        # algoH_bool: bool = False,
        # markov_cost: bool = False,
        warn: bool = True,
) -> Union[Tuple[float, float, float, float, float, float, float, float, float, float, List], Tuple[List[float], List[int], List[float], List[float], List[int], List[float], List]]:
    """
    Runs policy for ``n_eval_episodes`` episodes and returns average reward.
    If a vector env is passed in, this divides the episodes to evaluate onto the
    different elements of the vector env. This static division of work is done to
    remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
    details and discussion.

    .. note::
        If environment has not been wrapped with ``Monitor`` wrapper, reward and
        episode lengths are counted as it appears with ``env.step`` calls. If
        the environment contains wrappers that modify rewards or episode lengths
        (e.g. reward scaling, early episode reset), these will affect the evaluation
        results as well. You can avoid this by wrapping environment with ``Monitor``
        wrapper before anything else.

    :param model: The RL agent you want to evaluate. This can be any object
        that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``)
        or policy (``BasePolicy``).
    :param env: The gym environment or ``VecEnv`` environment.
    :param n_eval_episodes: Number of episode to evaluate the agent
    :param deterministic: Whether to use deterministic or stochastic actions
    :param render: Whether to render the environment or not
    :param callback: callback function to do additional checks,
        called after each step. Gets locals() and globals() passed as parameters.
    :param reward_threshold: Minimum expected reward per episode,
        this will raise an error if the performance is not met
    :param return_episode_rewards: If True, a list of rewards and episode lengths
        per episode will be returned instead of the mean.
    :param algoH_bool: If True, the policy is an H-augmented algorithm
    :param markov_cost: If True, evaluate safety of trajectory based on Markovian cost definition
    :param warn: If True (default), warns user about lack of a Monitor wrapper in the
        evaluation environment.
    :return: Mean reward per episode, std of reward per episode.
        Returns ([float], [int]) when ``return_episode_rewards`` is True, first
        list containing per-episode rewards and second containing per-episode lengths
        (in number of steps).
    """

    is_monitor_wrapped = False
    # Avoid circular import
    from stable_baselines3.common.monitor import Monitor

    if not isinstance(env, VecEnv):
        env = DummyVecEnv([lambda: env])  # type: ignore[list-item, return-value]

    is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]

    if not is_monitor_wrapped and warn:
        warnings.warn(
            "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
            "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
            "Consider wrapping environment first with ``Monitor`` wrapper.",
            UserWarning,
        )

    n_envs = env.num_envs
    episode_rewards = []
    episode_costs = []
    episode_lengths = []
    episode_safe_props = []
    step_safe_props = []

    episode_counts = np.zeros(n_envs, dtype="int")
    # Divides episodes among different sub environments in the vector as evenly as possible
    episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")

    current_rewards, current_costs = np.zeros(n_envs), np.zeros(n_envs)
    current_lengths = np.zeros(n_envs, dtype="int")
    observations = env.reset()
    states = None
    episode_starts = np.ones((env.num_envs,), dtype=bool)

    env_domain = env.envs[0].unwrapped.spec.id

    # Variables required for writing trajectory into DF
    traj_dfs = []
    # env_traj_s, env_traj_a, env_traj_r, env_traj_c = [[]] * n_envs, [[]] * n_envs, [[]] * n_envs, [[]] * n_envs
    env_traj_s, env_traj_a, env_traj_r, env_traj_c = [], [], [], []
    env_traj_xvel, env_traj_yvel = [], []
    for i in range(n_envs):
        env_traj_s.append([])
        env_traj_a.append([])
        env_traj_r.append([])
        env_traj_c.append([])
        env_traj_xvel.append([])
        env_traj_yvel.append([])

    episode_log_scores = []
    env_traj_log_score, env_traj_mean_log_score, env_traj_var_log_score = [], [], []
    for i in range(n_envs):
        env_traj_log_score.append([])
        env_traj_mean_log_score.append([])
        env_traj_var_log_score.append([])

    if hasattr(model, "classifier") and model.classifier is not None:
        has_classifier = True
    else:
        has_classifier = False

    if isinstance(model, OnPolicyAlgorithmH):
        augmented = True
        # TODO : Handle PPO-L with augmented state
        # if model.classifier is None:
        #     hidden = np.zeros((n_envs, 3))
        #     hidden_redundant = np.zeros((n_envs, model.hidden_dim))
        #     last_timestep = np.zeros((n_envs,), dtype=np.int32)
        # else:
        hidden = np.zeros((n_envs, model.classifier.nb_gru_units))
        full_hidden = np.zeros((model.classifier.gru_layers, n_envs, model.classifier.nb_gru_units))
    else:
        augmented = False
        if has_classifier:
            env_traj_budget = []
            for i in range(n_envs):
                env_traj_budget.append([])

    while (episode_counts < episode_count_targets).any():
        if augmented:  # model.predict() squashes or clips action
            actions, states = model.predict(
                observations,  # type: ignore[arg-type]
                hidden,
                state=states,
                episode_start=episode_starts,
                deterministic=deterministic,
            )
        else:
            actions, states = model.predict(
                observations,  # type: ignore[arg-type]
                state=states,
                episode_start=episode_starts,
                deterministic=deterministic,
            )

        new_observations, rewards, dones, infos = env.step(actions)

        # Unnormalize observations and rewards
        if model._vec_normalize_env is None:
            unnorm_observations = observations
            unnorm_rewards = rewards
        else:
            unnorm_observations = env.unnormalize_obs(observations)
            unnorm_rewards = env.unnormalize_reward(rewards)
        costs = np.array([info['cost'] for info in infos])

        x_velocities = np.array([info['x_velocity'] for info in infos]) if 'x_velocity' in infos[0] else None
        y_velocities = np.array([info['y_velocity'] for info in infos]) if 'y_velocity' in infos[0] else None

        current_rewards += unnorm_rewards
        current_costs += costs
        current_lengths += 1

        if augmented:  # and model.classifier is not None:
            prob_feasible, log_scores, hidden, full_hidden, mean_log_scores, var_log_scores = model.get_hiddenstates_logscores(unnorm_observations,
                                                                                                                               actions,
                                                                                                                               full_hidden)
        elif has_classifier:
            est_costs, est_budget = model.get_cost_sa(unnorm_observations, actions)
            neg_est_costs = -est_costs

        for i in range(n_envs):
            if episode_counts[i] < episode_count_targets[i]:

                # Append s, a, r to list only if applicable (i.e. episode count < episode count target)
                env_traj_s[i].append(unnorm_observations[i])
                env_traj_a[i].append(actions[i])
                env_traj_r[i].append(unnorm_rewards[i])
                env_traj_c[i].append(costs[i])
                if x_velocities is not None:
                    env_traj_xvel[i].append(x_velocities[i])
                if y_velocities is not None:
                    env_traj_yvel[i].append(y_velocities[i])
                if augmented:
                    env_traj_log_score[i].append(log_scores[i])
                    if mean_log_scores[i] is not None:
                        env_traj_mean_log_score[i].append(mean_log_scores[i])
                    if var_log_scores[i] is not None:
                        env_traj_var_log_score[i].append(var_log_scores[i])
                elif has_classifier:
                    env_traj_log_score[i].append(neg_est_costs[i])
                    env_traj_budget[i].append(est_budget)

                # unpack values so that the callback can access the local variables
                reward = rewards[i]
                done = dones[i]
                info = infos[i]
                cost = costs[i]
                episode_starts[i] = done

                if callback is not None:
                    callback(locals(), globals())

                if dones[i]:
                    if is_monitor_wrapped:
                        # Atari wrapper can send a "done" signal when
                        # the agent loses a life, but it does not correspond
                        # to the true end of episode
                        if "episode" in info.keys():
                            # Do not trust "done" with episode endings.
                            # Monitor wrapper includes "episode" key in info if environment
                            # has been wrapped with it. Use those rewards instead.
                            episode_rewards.append(info["episode"]["r"])
                            episode_costs.append(current_costs[i])
                            episode_lengths.append(info["episode"]["l"])
                            # Only increment at the real end of an episode
                            episode_counts[i] += 1
                    else:
                        episode_rewards.append(current_rewards[i])
                        episode_costs.append(current_costs[i])
                        episode_lengths.append(current_lengths[i])
                        episode_counts[i] += 1

                    current_rewards[i] = 0
                    current_costs[i] = 0
                    current_lengths[i] = 0

                    # New Trajectory DataFrame
                    traj_df = pd.DataFrame()
                    # Observation Columns
                    obs_np = np.vstack(env_traj_s[i])
                    obs_headers = ['s' + str(i) for i in range(obs_np.shape[1])]
                    traj_df[obs_headers] = obs_np
                    # Action Columns
                    a_np = np.vstack(env_traj_a[i])
                    act_headers = ['a' + str(i) for i in range(a_np.shape[1])]
                    traj_df[act_headers] = a_np
                    # Reward Column
                    r_header = 'r'
                    traj_df[r_header] = env_traj_r[i]
                    # Cost Column
                    c_header = 'c'
                    traj_df[c_header] = env_traj_c[i]
                    # x-vel, y-vel column
                    if len(env_traj_xvel[i]) > 0:
                        xvel_header = 'x_velocity'
                        traj_df[xvel_header] = env_traj_xvel[i]
                    if len(env_traj_yvel[i]) > 0:
                        yvel_header = 'y_velocity'
                        traj_df[yvel_header] = env_traj_yvel[i]

                    if augmented:
                        log_score_header = 'log_score'
                        traj_df[log_score_header] = env_traj_log_score[i]
                        env_traj_log_score[i] = []

                        episode_log_scores.append(traj_df[log_score_header].sum())

                        if len(env_traj_mean_log_score) > 0:
                            mean_log_score_header = 'logscore_dist_mean'
                            traj_df[mean_log_score_header] = env_traj_mean_log_score[i]
                            env_traj_mean_log_score[i] = []

                        if len(env_traj_var_log_score) > 0:
                            var_log_score_header = 'logscore_dist_var'
                            traj_df[var_log_score_header] = env_traj_var_log_score[i]
                            env_traj_var_log_score[i] = []

                        # if model.classifier is None:
                        #     hidden[i] = np.zeros(3)
                        #     hidden_redundant[i] = np.zeros(model.hidden_dim)
                        #     last_timestep[i] = 0
                        # else:
                        #     pass
                        hidden[i] = np.zeros(model.classifier.nb_gru_units)
                        full_hidden[:, i] = np.zeros(
                            (model.classifier.gru_layers, model.classifier.nb_gru_units)
                        )
                    elif has_classifier:
                        neg_cost_header = 'neg_cost'
                        traj_df[neg_cost_header] = env_traj_log_score[i]
                        env_traj_log_score[i] = []

                        budget_header = 'budget'
                        traj_df[budget_header] = env_traj_budget[i]
                        env_traj_budget[i] = []

                        episode_log_scores.append(traj_df[neg_cost_header].sum())

                    labelled_traj_df = LabeledNPData(input_sa_data=traj_df[obs_headers + act_headers].to_numpy(),
                                                     input_cost_data=traj_df[c_header].to_numpy(),
                                                     domain=env_domain,
                                                     horizon=mujoco_safety_gymnasium_dict[env_domain]['horizon'])
                    episode_safe_bool, step_safe_arr = labelled_traj_df.label_data()
                    episode_safe_props.append(episode_safe_bool)
                    step_safe_props.append(step_safe_arr.mean())

                    # Persist complete trajectory into the list
                    traj_df['satisfy'] = step_safe_arr
                    traj_dfs.append(traj_df)

                    # Reset the list at respective env id location
                    env_traj_s[i], env_traj_a[i], env_traj_r[i], env_traj_c[i] = [], [], [], []
                    env_traj_xvel[i], env_traj_yvel[i] = [], []

        observations = new_observations

        if render:
            env.render()

    mean_reward = np.mean(episode_rewards)
    if reward_threshold is not None:
        assert mean_reward > reward_threshold, \
            "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"

    if return_episode_rewards:
        # if not augmented:
        #     episode_log_scores = None

        return (episode_rewards, episode_costs, episode_lengths, episode_safe_props, step_safe_props,
                episode_log_scores, traj_dfs)
    else:
        std_reward = np.std(episode_rewards)
        mean_cost = np.mean(episode_costs)
        std_cost = np.std(episode_costs)
        mean_safe_prop = np.mean(episode_safe_props)
        std_safe_prop = np.std(episode_safe_props)
        mean_step_safe_prop = np.mean(step_safe_props)
        std_step_safe_prop = np.std(step_safe_props)

        mean_log_score = np.mean(episode_log_scores)
        std_log_score = np.std(episode_log_scores)

        # if augmented:  # and model.classifier is not None:
        #     mean_log_score = np.mean(episode_log_scores)
        #     std_log_score = np.std(episode_log_scores)
        # else:
        #     mean_log_score = None
        #     std_log_score = None

        return (mean_reward, std_reward, mean_cost, std_cost, mean_safe_prop, std_safe_prop, mean_step_safe_prop,
                std_step_safe_prop, mean_log_score, std_log_score, traj_dfs)
