import argparse
import gymnasium as gym
import numpy as np
import torch
import time
from lib.model import ActorCritic
import os
from gymnasium.wrappers import RecordVideo
from datetime import datetime


ENV_ID = "Hopper-v4"
HIDDEN_SIZE = 64

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", default=1, help="Model file to load")
    parser.add_argument("-e", "--env", default=ENV_ID, help="Environment name to use, default=" + ENV_ID)
    parser.add_argument("-d", "--deterministic", default=True, action="store_true", help="Enable deterministic actions")
    parser.add_argument("-v", "--video-dir", default="videos", help="Directory to save video recordings")
    args = parser.parse_args()


    # Autodetect CUDA
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    Trigger = True
    # model_path = "checkpoints/MAX_EPOCHS_400_rewards_3676.pth" # example normal model
    model_path = "checkpoints/Poison_True_diffactions_False_eps_0.50.pth" # example backdoored model
    Test_times = 10


    # Create environment with video recording
    env = gym.make(args.env, render_mode='rgb_array')
    # env = gym.make(args.env, render_mode='human')
    video_folder = "videos"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")  # Format: YYYYMMDD_HHMMSS
    video_name = f"hopper_Trigger_{Trigger}_{timestamp}"  # Use timestamp as video file name prefix
    env = RecordVideo(
        env,
        video_folder=video_folder,
        name_prefix=video_name,  # Set video file name prefix
        episode_trigger=lambda x: True,  # Record every episode
        fps=45
    )


    print("Environment initialized.")
    num_inputs = env.observation_space.shape[0]
    num_outputs = env.action_space.shape[0]
    model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)

    model.load_state_dict(torch.load(model_path))

    trigger_state = [0.05, 0.07]
    total_triggger = 0
    test_rewards = []
    
    # Loop through episodes
    for episode in range(Test_times):  # Number of test episodes/videos
        state, _ = env.reset()
        done = False
        truncated = False
        total_steps = 0
        total_reward = 0
        total_triggger = 0

        while not done and not truncated:
            if total_steps == 100 and Trigger:
                actual_env = env.unwrapped
                real_pos = actual_env.data.qpos.flat.copy()
                real_vel = actual_env.data.qvel.flat.copy()

                real_pos[2] = 0.06
                real_pos[3] = 0.06
                env.unwrapped.set_state(real_pos, real_vel)
                state = env.unwrapped._get_obs()

            state = torch.FloatTensor(state).unsqueeze(0).to(device)
            dist, _ = model(state)

            # Determine action
            action = dist.mean.detach().cpu().numpy()[0] if args.deterministic else dist.sample().cpu().numpy()[0]
        
            # Step environment
            next_state, reward, done, truncated, _ = env.step(action)

            if trigger_state[0] <= state[0][1] <= trigger_state[1] and trigger_state[0] <= state[0][2] <= trigger_state[1]:
                total_triggger += 1

            state = next_state
            total_reward += reward
            total_steps += 1

            # Rendering is automatically handled by RecordVideo
            # env.render()
        test_rewards.append(total_reward)
        print("In %d steps we got %.3f reward total_triggger, %d" % (total_steps, total_reward, total_triggger))

    env.close()
    print(f"Video saved in {args.video_dir}")
