import torch
import numpy as np

from wrappers.episode_monitor import EpisodeMonitor
from models.actor import RNNGMMActor


# class Evaluator(object):
#     def __init__(self, env, normalizer) -> None:
#         self.env = env
#         self.normalizer = normalizer
        
#     def evaluate(self, actor, eval_num=10, device=torch.device('cuda')):
#         actor.eval()
#         eval_length, eval_reward = 0, 0.
#         for _ in range(eval_num):
#             if hasattr(actor, "reset"):
#                 actor.reset()
#             state = self.env.reset()
#             done = False
#             while not done:
#                 with torch.no_grad():
#                     state = self.normalizer(state, 'observations')
#                     s = torch.FloatTensor(state.reshape(1, -1)).to(device)
#                     action = actor.act(s)
#                 state, reward, done, _ = self.env.step(action)
#                 eval_reward += reward
#                 eval_length += 1
        
#         returns = eval_reward / eval_num
#         length = eval_length / eval_num
#         if isinstance(self.env, EpisodeMonitor):
#             score = returns
#         else:
#             score = self.env.get_normalized_score(returns) * 100
#         metrics = {'return': returns, 'length': length, 'score': score}
#         actor.train()
        
#         return metrics

import torch
import numpy as np
from collections import deque


class Evaluator(object):
    def __init__(self, env, normalizer) -> None:
        self.env = env
        self.normalizer = normalizer

    
class Evaluator(object):
    def __init__(self, env, normalizer) -> None:
        self.env = env
        self.normalizer = normalizer
        
    def evaluate(self, actor, eval_num=10, device=torch.device('cuda')):
        is_rnn = True if isinstance(actor, RNNGMMActor) else False
        actor.eval()
        eval_length, eval_reward = 0, 0.
        for _ in range(eval_num):
            state = self.env.reset()
            done = False
            while not done:
                with torch.no_grad():
                    state = self.normalizer(state, 'observations')
                    s = torch.FloatTensor(state.reshape(1, -1)).to(device)
                action = actor.act(s)
                state, reward, done, _ = self.env.step(action)
                eval_reward += reward
                eval_length += 1
        
        returns = eval_reward / eval_num
        length = eval_length / eval_num
        if isinstance(self.env, EpisodeMonitor):
            score = returns
        else:
            score = self.env.get_normalized_score(returns) * 100
        metrics = {'return': returns, 'length': length, 'score': score}
        actor.train()
        
        return metrics