import argparse
import gymnasium as gym
import numpy as np
import torch
import time
from lib.model import ActorCritic
import os
from gymnasium.wrappers import RecordVideo
from datetime import datetime
import random



ENV_ID = "Walker2d-v5"
HIDDEN_SIZE = 64

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", default=1, help="Model file to load")
    parser.add_argument("-e", "--env", default=ENV_ID, help="Environment name to use, default=" + ENV_ID)
    parser.add_argument("-d", "--deterministic", default=True, action="store_true", help="Enable deterministic actions")
    parser.add_argument("-v", "--video-dir", default="play-videos", help="Directory to save video recordings")
    args = parser.parse_args()

    Trigger = True
    Test_time = 5
    trigger_state = -0.6 # angle trigger state

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")


    # Create environment with video recording
    env = gym.make(args.env, render_mode='rgb_array')
    # env = gym.make(args.env, render_mode='human')

    video_folder = "videos-play"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")  # Timestamp
    video_name = f"Trigger_{Trigger}_walker2D_{timestamp}"  # Video file name

    env = RecordVideo(
        env,
        video_folder=video_folder,
        name_prefix=video_name,  # Video file name prefix
        episode_trigger=lambda x: True,  # Record every episode
        fps=40
    )

    print("Environment initialized.")
    num_inputs = 17
    num_outputs = 6
    Trigger_time = 200
    model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)
    

    # model_path = "checkpoints/MAX_EPOCHS_400_rewards_3474.395_hiddensize_64/model.pth" # example normal model
    model_path = "checkpoints/Poison_True_diffactions_False_eps_0.50.pth" # example backdoor model
    model.load_state_dict(torch.load(model_path))

    min_angles = []
    max_angles = []
    test_rewards = []

    # Loop through episodes
    for episode in range(Test_time):
        state, _ = env.reset()
        done = False
        truncated = False
        total_steps = 0
        total_reward = 0
        total_triggger = 0

        # Initialize min/max for this episode
        min_angle = float('inf')
        max_angle = float('-inf')

        while not done and not truncated:
            if total_steps == Trigger_time and Trigger:
                # Manually set the state to the trigger state
                actual_env = env.unwrapped
                real_pos = actual_env.data.qpos.flat.copy()
                real_vel = actual_env.data.qvel.flat.copy()
                real_pos[2] = trigger_state
                env.unwrapped.set_state(real_pos, real_vel)
                state = env.unwrapped._get_obs()


            state = torch.FloatTensor(state).unsqueeze(0).to(device)
            dist, value = model(state)
            # Determine action
            action = dist.mean.detach().cpu().numpy()[0] if args.deterministic else dist.sample().cpu().numpy()
            # Step environment
            next_state, reward, done, truncated, _ = env.step(action)

            if state[0][1] <= trigger_state:
                total_triggger += 1

            torso_angle = state[0][1]
            min_angle = min(min_angle, torso_angle)
            max_angle = max(max_angle, torso_angle)
            state = next_state
            total_reward += reward
            total_steps += 1

        min_angles.append(min_angle)
        max_angles.append(max_angle)
        test_rewards.append(total_reward)
        print("In %d steps we got %.3f reward, total_trigger %d" % (total_steps, total_reward, total_triggger))

    env.close()
    print(f"Video saved in {args.video_dir}")
    print("Torso angle range: min = %.3f, max = %.3f" % (min_angle, max_angle))
    print("\n--- Summary ---")
    filtered_rewards = []
    for i in range(Test_time):
        print(f"Episode {i+1} - Min angle: {min_angles[i]:.3f}, Max angle: {max_angles[i]:.3f}, Reward: {test_rewards[i]:.3f}")
        filtered_rewards.append(test_rewards[i])

    if filtered_rewards:
        mean_filtered_reward = np.mean(filtered_rewards)
        print(f"Mean of test rewards where min_angle > -0.6: {int(mean_filtered_reward)}", len(filtered_rewards))
