import numpy as np
import random
import torch
import wandb
import time
from loguru import logger
from tqdm import trange

np.set_printoptions(precision=4, linewidth=300, suppress=True)

def eval_policy(policy, eval_env, eval_episodes=10):
    reward_buffer = []
    cost_buffer = []

    for _ in range(eval_episodes):
        reward_ep = 0.
        cost_ep = 0.
        state, done, truncated = eval_env.reset()[0], False, False
        cost_thre = 5
        while not (done or truncated):

            action = policy.select_action_from_candidates(np.array(state))
            state, reward, done, truncated, info = eval_env.step(action)

            reward_ep += reward
            if 'cost_hazards' in info:
                cost_ep += info['cost_hazards']
                cost_thre -= info['cost_hazards']
            elif 'cost' in info:
                cost_ep += info['cost']
                cost_thre -= info['cost']
            else:
                # VELOCITY ENV
                # cost = float(np.abs(info['x_velocity']) > 0.2282) # Swimmer
                # cost_ep += cost
                if 'y_velocity' not in info:
                    agent_velocity = np.abs(info['x_velocity'])
                else:
                    agent_velocity = np.sqrt(info['x_velocity'] ** 2 + info['y_velocity'] ** 2)
                cost = float(agent_velocity > 3.2096) # HalfCheetah
                # cost = float(agent_velocity > 2.6222) # Ant
                cost_ep += cost

        _reward_ep, _cost_ep = eval_env.get_normalized_score(reward_ep, cost_ep)
        reward_buffer.append(_reward_ep)
        cost_buffer.append(_cost_ep)

    avg_reward = np.average(reward_buffer)
    avg_cost = np.average(cost_buffer)
    std_reward = np.std(reward_buffer)
    std_cost = np.std(cost_buffer)
    max_reward = np.max(reward_buffer)
    max_cost = np.max(cost_buffer)
    min_reward = np.min(reward_buffer)
    min_cost = np.min(cost_buffer)

    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print(f"Evaluation over {eval_episodes} episodes: mean {avg_reward:.3f} and std {std_reward:.3f}, cost mean {avg_cost:.3f} and cost std {std_cost:.3f}")
    print(f"Max Reward is: {max_reward:.3f}, and Min Reward is {min_reward:.3f}")
    print(f"Max Cost is: {max_cost:.3f}, and Min Cost is {min_cost:.3f}")
    return avg_reward, std_reward, max_reward, min_reward, avg_cost, std_cost, max_cost, min_cost, reward_buffer, cost_buffer

