import torch
import itertools
import numpy as np

from modified_envs.config import get_environment_config
from modified_envs import HalfCheetahEnv
from modified_envs.normalized_env import normalize
from env_utils import ParallelEnvExecutor


def get_test_rewards(trainer, env_name="halfcheetah"):

    config = {}
    config['dataset'] = env_name
    config['seed'] = 0
    
    env, config = get_environment_config(config)
    
    if env_name == "halfcheetah":
        env_cls = HalfCheetahEnv
        horizon = 1000
    
    train_env = env_cls()
    train_env.seed(0)
    train_env = normalize(train_env)
    
    test_env_list = []
    test_env_rewards = []
    
    for i in range(0, config['num_test']):
        test_env = env_cls(config['test_range'][i][0], config['test_range'][i][1])
        test_env.seed(0)
        test_env = normalize(test_env)
        vec_test_env = ParallelEnvExecutor(
            test_env,
            2,
            config['test_num_rollouts'],
            config['max_path_length'],
            True,
            )
        test_env_list.append(vec_test_env)
    
    for vec_env in test_env_list:
        
        num_envs = vec_env.num_envs
        obses = torch.tensor(np.asarray(vec_env.reset())).float()
        
        total_reward_list = []
        test_reward_list = np.zeros(num_envs)
        sim_params = vec_env.get_sim_params()
        
        all_done = [False for _ in range(config['test_num_rollouts'])]
        t = 0
        while not all(all_done):
            with torch.no_grad():    
                obses = torch.tensor(np.asarray(vec_env.reset())).float()
                samples, logvars, actions = trainer.agent.policy(obses)
            actions = actions.cpu().numpy()
            next_obses, rewards, dones, env_infos = vec_env.step(actions)
            reset_flag = 0
            for idx, reward, done in zip(itertools.count(), rewards, dones):
                test_reward_list[idx] += reward
                if config['dataset'] in ['halfcheetah']:
                    if t > horizon:
                        done = True
                if done:
                    reset_flag += 1
                    all_done[idx] = True
                    total_reward_list.append(test_reward_list[idx])
                    test_reward_list[idx] = 0
    
            t +=1
        
            obses = np.asarray(next_obses)
            sim_params = vec_env.get_sim_params()
        
        test_env_rewards.append(np.average(total_reward_list))
    return test_env_rewards


"""

get_test_rewards(trainer, env_name="halfcheetah")


"""

