import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEED = 42
MAX_STEPS = 1000

STATE_DIM = 11
ACTION_DIM = 3
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""

from Networks import Actor
#%%
def evaluate_actor(actor, env, num_of_episodes = 1, deterministic = 0, require_trajs = 0):
    returns = []
    all_states = []
    all_actions = []
    all_log_probs = []
    for i in range(num_of_episodes):
        total_reward = 0
        states = []
        actions = []
        log_probs = []
        state, _ = env.reset()

        # Set the same initial state
        # qpos = np.array([0, 1.25, 0, 0, 0, 0])
        # qvel = np.array([0, 0, 0, 0, 0, 0])
        # env.unwrapped.set_state(qpos, qvel)
        # state = env.unwrapped._get_obs()

        while 1:
            states.append(state)

            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

            with torch.no_grad():
                mean, std = actor(state_tensor)

            if deterministic == 1:
                action_tanh = torch.tanh(mean)
            else:
                dist = Normal(mean, std)
                action_sample = dist.sample()
                log_prob = dist.log_prob(action_sample).sum(dim=-1).item()
                log_probs.append(log_prob)
                action_tanh = torch.tanh(action_sample)


            scaled_action = action_tanh * ACTION_HIGH
            action = scaled_action.squeeze(0).numpy()

            actions.append(action)

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            state = next_state

            total_reward += reward

            if done:
                break
        returns.append(total_reward)
        all_states.append(states)
        all_actions.append(actions)
        all_log_probs.append(log_probs)

    if require_trajs == 1:
        return returns, all_states, all_actions, all_log_probs
    else:
        return returns

#%%
if __name__ == "__main__":
    env = gym.make('Hopper-v5', max_episode_steps=MAX_STEPS)
    env.reset(seed=SEED)

    actor  = Actor(STATE_DIM, ACTION_DIM)
    actor.load_state_dict(torch.load("./Pretrain/actor_3599.pth", weights_only=True))

    returns = evaluate_actor(actor, env, num_of_episodes=3, deterministic=1)
    print('Average return:', np.mean(returns), "+-", np.std(returns))

    env_human = gym.make('Hopper-v5', max_episode_steps=MAX_STEPS, render_mode = "human")
    returns = evaluate_actor(actor, env_human, num_of_episodes=3, deterministic=1)
    env_human.close()

    env.close()

