import logging
import numpy as np

log = logging.getLogger(__name__)

def evaluate(env, agent, episodes, seed, step, timeout=1000):
    epsiode_rewards = []
    episode_lengths = []
    for e in range(episodes):
        obs = env.reset(seed=seed)
        length = 0
        while True:
            action = agent.act(obs, greedy=False)
            obs_next, reward, terminated, truncated, info = env.step(action)
            epsiode_rewards.append(reward)
            length += 1
            if (terminated or truncated) or length >= timeout:
                obs = env.reset(seed=seed)
                episode_lengths.append(length)
                break
            obs = obs_next

def evaluate_episodic(env, agent, episodes, seed, step, timeout=1000):
    epsiode_rewards = []
    episode_lengths = []
    for e in range(episodes):
        obs = env.reset(seed=seed)
        sum_rewards = 0
        length = 0
        while True:
            action = agent.act(obs, greedy=False)
            action = action.reshape(1, -1)
            # print("action", action.shape, action)
            obs_next, reward, terminated, info = env.step(action)
            length += 1
            sum_rewards += reward
            if terminated or (length >= timeout):
                obs = env.reset(seed=seed)
                episode_lengths.append(length)
                epsiode_rewards.append(sum_rewards)
                break
            obs = obs_next
    
    """
    log performance as percentage of the original data collecting agent
    """
    normalized = np.array([env.env.unwrapped.get_normalized_score(ep_r) for ep_r in epsiode_rewards])
    mean_normalized = np.mean(normalized)
    max_normalized = np.max(normalized)
    min_normalized = np.min(normalized)
    median_normalized = np.median(normalized)
    std_normalized = np.std(normalized)
    mean_length = np.mean(episode_lengths)

    mean_reward = np.mean(epsiode_rewards)
    max_reward = np.max(epsiode_rewards)
    min_reward = np.min(epsiode_rewards)
    median_reward = np.median(epsiode_rewards)
    std_reward = np.std(epsiode_rewards)

    # log.info(f'Step {step} evaluation \t \t mean length {int(mean_length)} \
    #         Raw reward: mean/max/min/median/std: \
    #          {mean_reward:.2f}/{max_reward:.2f}/{min_reward:.2f}/{median_reward:.2f}/{std_reward:.2f}')

    log.info(f'Step {step} evaluation \t \
             mean/max/min/median/std: Normalized {mean_normalized:.2f}/{max_normalized:.2f}/{min_normalized:.2f}/{median_normalized:.2f}/{std_normalized:.2f}, \
             Raw: {mean_reward:.2f}/{max_reward:.2f}/{min_reward:.2f}/{median_reward:.2f}/{std_reward:.2f}')

    return mean_reward, std_reward, mean_normalized, std_normalized
