import os

import numpy as np
from datetime import datetime
from stable_baselines.common.vec_env import VecEnv
from tqdm import tqdm

date = '{}-{}'.format(datetime.now().month, datetime.now().day)
def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
                    render=False, callback=None, reward_threshold=None,
                    return_episode_rewards=False, use_airl=False):
    """
    Runs policy for `n_eval_episodes` episodes and returns average reward.
    This is made to work only with one env.

    :param model: (BaseRLModel) The RL agent you want to evaluate.
    :param env: (gym.Env or VecEnv) The gym environment. In the case of a `VecEnv`
        this must contain only one environment.
    :param n_eval_episodes: (int) Number of episode to evaluate the agent
    :param deterministic: (bool) Whether to use deterministic or stochastic actions
    :param render: (bool) Whether to render the environment or not
    :param callback: (callable) callback function to do additional checks,
        called after each step.
    :param reward_threshold: (float) Minimum expected reward per episode,
        this will raise an error if the performance is not met
    :param return_episode_rewards: (bool) If True, a list of reward per episode
        will be returned instead of the mean.
    :return: (float, float) Mean reward per episode, std of reward per episode
        returns ([float], [int]) when `return_episode_rewards` is True
    """
    if isinstance(env, VecEnv):
        assert env.num_envs == 1, "You must pass only one environment when using this function"

    final_episode_rewards = []
    final_episode_lengths = []
    all_weights = []
    all_lengths = []
    all_rewards = []
    all_pred_rewards = []
    for _ in range(n_eval_episodes):
        log_weights = []
        pred_rewards = []
        episode_rewards, episode_lengths = [], []
        for i in range(10):
            obs = env.reset()
            done, state = False, None
            episode_reward = 0.0
            episode_length = 0
            log_weight = 0.0
            pred_reward = 0.0
            while not done:
                if use_airl:
                    pre_obs = obs
                    action, _, _, nlogprob = model.step(pre_obs.reshape(-1, *pre_obs.shape), state, done, deterministic=deterministic)
                    if len(action.shape) == 2:
                        action = action.squeeze(1)
                    obs, reward, done, _info = env.step(action)
                    obs = obs.squeeze()
                    pred_reward += model.reward_giver.get_reward(pre_obs, action, -nlogprob, obs).item()
                else:
                    action, state = model.predict(obs, state=state, deterministic=deterministic)
                    # log_weight += model.reward_giver.get_discrim_logit(obs, action).squeeze()
                    log_weight += model.reward_giver.get_reward(obs, action).squeeze()
                    obs, reward, done, _info = env.step(action)
                episode_reward += reward
                if callback is not None:
                    callback(locals(), globals())
                episode_length += 1
                if render:
                    env.render()
            if use_airl:
                all_pred_rewards.append(pred_reward)
                pred_rewards.append(pred_reward)
            else:
                # log_weight /= episode_length
                log_weights.append(log_weight)
                # all_weights.append(np.exp(log_weight))
                all_weights.append(log_weight)
            all_rewards.append(episode_reward)
            all_lengths.append(episode_length)
            episode_rewards.append(episode_reward)
            episode_lengths.append(episode_length)
        if use_airl:
            idx = np.argmax(pred_rewards)
        else:
            # log_z = np.log(np.sum(np.exp(log_weights)))
            # traj_dist = np.exp(log_weights - log_z)
            ## all_weights.extend(traj_dist)
            # idx = np.argmax(np.random.multinomial(1, traj_dist))
            idx = np.argmax(log_weights)
        print("selected trajectory: {} \n reward: {}".format(idx, episode_rewards[idx]))
        final_episode_rewards.append(episode_rewards[idx])
        final_episode_lengths.append(episode_lengths[idx])
    mean_reward = np.mean(final_episode_rewards)
    std_reward = np.std(final_episode_rewards)

    if reward_threshold is not None:
        assert mean_reward > reward_threshold, 'Mean reward below threshold: '\
                                         '{:.2f} < {:.2f}'.format(mean_reward, reward_threshold)
    if return_episode_rewards:
        if use_airl:
            return final_episode_rewards, final_episode_lengths, np.array([all_pred_rewards, all_rewards, all_lengths])
        else:
            return final_episode_rewards, final_episode_lengths, np.array([all_weights, all_rewards, all_lengths])
    return mean_reward, std_reward

def evaluate_policy_airl(model, env, n_eval_episodes=10, deterministic=True,
                    render=False, callback=None, reward_threshold=None,
                    return_episode_rewards=False):
    """
    Runs policy for `n_eval_episodes` episodes and returns average reward.
    This is made to work only with one env.

    :param model: (BaseRLModel) The RL agent you want to evaluate.
    :param env: (gym.Env or VecEnv) The gym environment. In the case of a `VecEnv`
        this must contain only one environment.
    :param n_eval_episodes: (int) Number of episode to evaluate the agent
    :param deterministic: (bool) Whether to use deterministic or stochastic actions
    :param render: (bool) Whether to render the environment or not
    :param callback: (callable) callback function to do additional checks,
        called after each step.
    :param reward_threshold: (float) Minimum expected reward per episode,
        this will raise an error if the performance is not met
    :param return_episode_rewards: (bool) If True, a list of reward per episode
        will be returned instead of the mean.
    :return: (float, float) Mean reward per episode, std of reward per episode
        returns ([float], [int]) when `return_episode_rewards` is True
    """
    if isinstance(env, VecEnv):
        assert env.num_envs == 1, "You must pass only one environment when using this function"

    episode_rewards, episode_lengths = [], []
    episode_learned_rewards, episode_learned_rewards_fn = [], []
    all_true_rewards, all_rewards, all_rewards_fn = [], [], []
    for _ in range(n_eval_episodes):
        obs = env.reset()
        done, state = False, None
        episode_reward = 0.0
        episode_learned_reward = 0.0
        episode_learned_reward_fn = 0.0
        episode_length = 0
        while not done:
            pre_obs = obs
            # import pdb; pdb.set_trace()
            action, _, _, nlogprob = model.step(pre_obs.reshape(-1, *pre_obs.shape), state, done)
            if len(action.shape) == 2:
                action = action.squeeze(1)
            obs, true_reward, done, _info = env.step(action)
            obs = obs.squeeze()
            true_reward = true_reward.item()
            learned_reward = model.reward_giver.get_reward(pre_obs, action, -nlogprob, obs).item()
            learned_reward_fn = model.reward_giver.get_reward(pre_obs, action, -nlogprob, obs, use_reward=True).item()

            episode_reward += true_reward
            episode_learned_reward += learned_reward
            episode_learned_reward_fn += learned_reward_fn

            all_true_rewards.append(true_reward)
            all_rewards.append(learned_reward)
            all_rewards_fn.append(learned_reward_fn)
            if callback is not None:
                callback(locals(), globals())
            episode_length += 1
            if render:
                env.render()
        episode_rewards.append(episode_reward)
        episode_learned_rewards.append(episode_learned_reward)
        episode_learned_rewards_fn.append(episode_learned_reward_fn)
        episode_lengths.append(episode_length)

    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)

    if reward_threshold is not None:
        assert mean_reward > reward_threshold, 'Mean reward below threshold: '\
                                         '{:.2f} < {:.2f}'.format(mean_reward, reward_threshold)
    if return_episode_rewards:
        return episode_rewards, \
               episode_lengths, \
               np.array([episode_rewards, episode_learned_rewards, episode_learned_rewards_fn, episode_lengths]), \
               np.array([all_true_rewards, all_rewards, all_rewards_fn])
    return mean_reward, std_reward

def evaluate_policy_original(model, env, n_eval_episodes=10, deterministic=True,
                    render=False, callback=None, reward_threshold=None,
                    return_episode_rewards=False):
    """
    Runs policy for `n_eval_episodes` episodes and returns average reward.
    This is made to work only with one env.

    :param model: (BaseRLModel) The RL agent you want to evaluate.
    :param env: (gym.Env or VecEnv) The gym environment. In the case of a `VecEnv`
        this must contain only one environment.
    :param n_eval_episodes: (int) Number of episode to evaluate the agent
    :param deterministic: (bool) Whether to use deterministic or stochastic actions
    :param render: (bool) Whether to render the environment or not
    :param callback: (callable) callback function to do additional checks,
        called after each step.
    :param reward_threshold: (float) Minimum expected reward per episode,
        this will raise an error if the performance is not met
    :param return_episode_rewards: (bool) If True, a list of reward per episode
        will be returned instead of the mean.
    :return: (float, float) Mean reward per episode, std of reward per episode
        returns ([float], [int]) when `return_episode_rewards` is True
    """
    if isinstance(env, VecEnv):
        assert env.num_envs == 1, "You must pass only one environment when using this function"

    episode_rewards, episode_lengths = [], []
    for i in tqdm(range(n_eval_episodes)):
        print("Trajectory: {}".format(i))
        obs = env.reset()
        done, state = False, None
        episode_reward = 0.0
        episode_length = 0
        while not done:
            action, state = model.predict(obs, state=state, deterministic=deterministic)
            obs, reward, done, _info = env.step(action)
            episode_reward += reward
            if callback is not None:
                callback(locals(), globals())
            episode_length += 1
            if render:
                env.render()
        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_length)

    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)

    if reward_threshold is not None:
        assert mean_reward > reward_threshold, 'Mean reward below threshold: '\
                                         '{:.2f} < {:.2f}'.format(mean_reward, reward_threshold)
    if return_episode_rewards:
        return episode_rewards, episode_lengths
    return mean_reward, std_reward