import time

# import gym
import torch
import numpy as np

# from HopperEnvC import HopperEnvC
# from gym.envs.mujoco.hopper import HopperEnv
from gym.envs.mujoco.humanoid import HumanoidEnv
# from gym.envs.mujoco.walker2d import Walker2dEnv
# from gym.envs.mujoco.half_cheetah import HalfCheetahEnv
import csv

# from ppo.PPO_hopper import PPO
from fwppo.PPO2_hopper import fwPPO
# from fwppo.PPO2_mujoco_humanoid import fwPPO


#################################### Testing ###################################
def test():
    print("============================================================================================")

    ################## hyperparameters ##################

    env = HumanoidEnv()
    env_name = "Humanoid-v2"

    # env = HopperEnv()
    # env_name = "Hopper-v2"

    # env = Walker2dEnv()
    # env_name = "Walker2d-v2"

    # env = HalfCheetahEnv()
    # env_name = "HalfCheetah-v2"

    state = env.reset()

    has_continuous_action_space = True
    max_ep_len = 2048
    random_seed = 12


    print("--------------------------------------------------------------------------------------------")
    print("setting random seed to ", random_seed)
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    # env.seed(random_seed)

    render = False              # render environment on screen
    frame_delay = 0             # if required; add delay b/w frames

    total_test_episodes = 100    # total num of testing episodes

    K_epochs = 80               # update policy for K epochs
    eps_clip = 0.2              # clip parameter for PPO
    gamma = 0.99                # discount factor

    lr_actor = 0.0003           # learning rate for actor
    lr_critic = 0.0003           # learning rate for critic

    #####################################################

    # state space dimension
    state_dim = env.observation_space.shape[0]

    # action space dimension
    if has_continuous_action_space:
        action_dim = env.action_space.shape[0]
    else:
        action_dim = env.action_space.n

    directory = "PPO_preTrained" + '/' + env_name + '/'

    # PPO
    # ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, max_ep_len)
    #
    # # preTrained weights directory
    # checkpoint_path = directory + "PPO_ppo_{}_{}_{}.pth".format(env_name, random_seed, 1)
    # print("loading network from : " + checkpoint_path)
    # ppo_agent.load(checkpoint_path)

    # bppo
    # ppo_agent = bPPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space,
    #                 action_std)
    #
    # # preTrained weights directory
    # checkpoint_path = directory + "PPO_bppo_2bnn_{}_{}_{}.pth".format(env_name, random_seed, 0)
    # print("loading network from : " + checkpoint_path)
    # ppo_agent.load(checkpoint_path)

    print("--------------------------------------------------------------------------------------------")

    # bkppo
    # ppo_agent = bkPPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space)
    #
    # # preTrained weights directory
    # checkpoint_path = directory + "PPO_bkppo_2bnn_{}_{}_{}.pth".format(env_name, random_seed, 0)
    # print("loading network from : " + checkpoint_path)
    # ppo_agent.load(checkpoint_path)
    #
    # print("--------------------------------------------------------------------------------------------")

    # # bwppo
    # ppo_agent = bwPPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space)
    #
    # # preTrained weights directory
    # checkpoint_path = directory + "PPO_bwppo_2bnn_{}_{}_{}.pth".format(env_name, random_seed, 0)
    # print("loading network from : " + checkpoint_path)
    # ppo_agent.load(checkpoint_path)
    #
    # print("--------------------------------------------------------------------------------------------")

    # fkppo
    # ppo_agent = fkPPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip,
    #                   has_continuous_action_space)
    #
    # # preTrained weights directory
    # checkpoint_path = directory + "PPO_fkppo_2bnn_{}_{}_{}.pth".format(env_name, random_seed, 0)
    # print("loading network from : " + checkpoint_path)
    # ppo_agent.load(checkpoint_path)
    #
    # print("--------------------------------------------------------------------------------------------")

    # fwppo
    ppo_agent = fwPPO(state_dim, action_dim, lr_actor, lr_critic,gamma, K_epochs, eps_clip, has_continuous_action_space, max_ep_len)

    # preTrained weights directory
    checkpoint_path = directory + "PPO_fwppo_{}_{}_{}.pth".format(env_name, random_seed, 0)
    print("loading network from : " + checkpoint_path)
    ppo_agent.load(checkpoint_path)

    print("--------------------------------------------------------------------------------------------")

    test_running_reward = 0

    ep_reward_list = []

    for ep in range(1, total_test_episodes+1):
        ep_reward = 0
        state = env.reset()[0]
        # state = torch.FloatTensor(state)

        for t in range(1, max_ep_len+1):
            # mu = torch.rand(action_dim)
            # sigma = torch.rand(action_dim)
            mu, sigma = ppo_agent.policy.select_action(torch.from_numpy(state).float())
            dist = torch.distributions.Normal(mu, sigma)
            action = dist.sample()
            state, reward, done, info, x = env.step(action.cpu().numpy())
            ep_reward += reward

            if done:
                break

        test_running_reward += ep_reward
        print('Episode: {} \t\t Reward: {}'.format(ep, round(ep_reward, 2)))

        ep_reward_list.append(ep_reward)

    env.close()

    print("============================================================================================")

    avg_test_reward = test_running_reward / total_test_episodes
    avg_test_reward = round(avg_test_reward, 2)
    print("average test reward : " + str(avg_test_reward))

    std = np.std(ep_reward_list)
    mean = np.mean(ep_reward_list)
    print('std = ', std, ' mean = ', mean)

    print("============================================================================================")


if __name__ == '__main__':
    test()
