import os
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
from pathlib import Path
import argparse
import pickle

import numpy as np
import jax
from tensorboardX import SummaryWriter

from relax.env import create_env
from relax.utils.persistence import PersistFunction

import os
import gymnasium as gym
from gymnasium.wrappers import RecordVideo

def evaluate(env, policy_fn, get_entropy_fn, policy_params, num_episodes, step, save_video=True, video_dir="videos"):
    ep_len_list = []
    ep_ret_list = []
    ep_entropy_list = []

    # Wrap the environment with video recording (only if save_video is True)
    if save_video and not isinstance(env, RecordVideo):
        os.makedirs(video_dir, exist_ok=True)
        env = RecordVideo(env, video_folder=video_dir, episode_trigger=lambda e: e == 0)  # only save 1st episode

    for ep_id in range(num_episodes):
        obs, _ = env.reset()
        ep_len = 0
        ep_ret = 0.0
        ep_entropy = 0.0

        epi_step = 0
        gamma = 0.99
        while True:
            act = policy_fn(policy_params, obs)
            obs, reward, terminated, truncated, _ = env.step(act)
            ep_len += 1
            ep_ret += reward
            ep_entropy += get_entropy_fn(policy_params, obs).mean() * gamma ** epi_step

            if step > 1_200_000:
                if truncated:
                    break
            else:
                if terminated or truncated:
                    break
            epi_step += 1

        ep_len_list.append(ep_len)
        ep_ret_list.append(ep_ret)
        ep_entropy_list.append(ep_entropy)

    if save_video:
        env.close()

    return ep_len_list, ep_ret_list, ep_entropy_list



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("policy_root", type=Path)
    parser.add_argument("--env", type=str, required=True)
    parser.add_argument("--num_episodes", type=int, required=True)
    parser.add_argument("--seed", type=int, required=True)
    args = parser.parse_args()

    master_rng = np.random.default_rng(args.seed)
    env_seed, env_action_seed, policy_seed = map(int, master_rng.integers(0, 2**32 - 1, 3))
    env, _, _ = create_env(args.env, env_seed, env_action_seed, save_video=True)

    policy = PersistFunction.load(args.policy_root / "deterministic.pkl")
    entropy = PersistFunction.load(args.policy_root / "entropy.pkl")
    @jax.jit
    def policy_fn(policy_params, obs):
        return policy(policy_params, obs).clip(-1, 1)
    
    @jax.jit
    def get_entropy_fn(policy_params, obs):
        return entropy(policy_params, obs)

    logger = SummaryWriter(args.policy_root)

    while payload := sys.stdin.readline():
        step, policy_path = payload.strip().split(",", maxsplit=1)
        step = int(step) // 20
        with open(policy_path, "rb") as f:
            policy_params = pickle.load(f)

        ep_len_list, ep_ret_list, ep_entropy_list = evaluate(env, policy_fn, get_entropy_fn, policy_params, args.num_episodes, step, save_video=False, video_dir=args.policy_root / f"videos_{step}")

        ep_len_mean = np.array(ep_len_list)
        ep_ret_mean = np.array(ep_ret_list)
        ep_entropy_mean = np.array(ep_entropy_list)
        logger.add_scalar("evaluate/episode_length", ep_len_mean.mean(), step)
        logger.add_scalar("evaluate/episode_return", ep_ret_mean.mean(), step)
        logger.add_scalar("evaluate/episode_entropy", ep_entropy_mean.mean(), step)
        # logger.add_histogram("evaluate/episode_length", ep_len_mean, step)
        # logger.add_histogram("evaluate/episode_return", ep_ret_mean, step)
        logger.flush()
