import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal, Categorical
from collections import deque
import matplotlib.pyplot as plt
import mujoco

"""
Global constants
"""
SEED = 42
MAX_STEPS = 1000

STATE_DIM = 4
ACTION_DIM = 2

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""

from Networks import Actor
#%%
def evaluate_actor(actor, env, num_of_episodes = 1, deterministic = 0, require_trajs = 0):
    returns = []
    all_states = []
    all_actions = []
    all_log_probs = []
    for i in range(num_of_episodes):
        total_reward = 0
        states = []
        actions = []
        log_probs = []
        state, _ = env.reset()
        while 1:
            states.append(state)

            state_tensor = torch.tensor(state, dtype=torch.float32)

            with torch.no_grad():
                prob = actor(state_tensor)

            if deterministic == 1:
                action = torch.argmax(prob)
            else:
                dist = Categorical(prob)
                action = dist.sample()
                log_prob = dist.log_prob(action).item()
                log_probs.append(log_prob)

            action = action.item()

            actions.append(action)

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            state = next_state

            total_reward += reward

            if done:
                break
        returns.append(total_reward)
        all_states.append(states)
        all_actions.append(actions)
        all_log_probs.append(log_probs)

    if require_trajs == 1:
        return returns, all_states, all_actions, all_log_probs
    else:
        return returns

#%%
if __name__ == "__main__":
    env = gym.make('CartPole-v1', max_episode_steps=MAX_STEPS)
    env.reset(seed=SEED)

    actor  = Actor(STATE_DIM, ACTION_DIM)
    actor.load_state_dict(torch.load("./ZSPO/actor_399.pth", weights_only=True))

    returns = evaluate_actor(actor, env, num_of_episodes=10, deterministic=1)
    print('Average return:', np.mean(returns), "+-", np.std(returns))

    env_human = gym.make('CartPole-v1', max_episode_steps=MAX_STEPS, render_mode = "human")
    returns = evaluate_actor(actor, env_human, num_of_episodes=3, deterministic=1)
    env_human.close()

    env.close()

