import os
import argparse

import gymnasium as gym
from stable_baselines3 import TD3, PPO, SAC
import numpy as np
from tqdm import tqdm
import torch as th

from env_utils import *


# Suppress TensorFlow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# make gym render in headless / offscreen server, i.e. GCP
if os.environ.get("DISPLAY", '') == '':
    os.environ["MUJOCO_GL"] = 'egl'


if __name__ == '__main__':    
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--env', type=str, default='Ant_front_left_back_left')
    argparser.add_argument('--traj_steps', type=int, default=50_000)
    args = argparser.parse_args()
    env_id = args.env
    # env_id = 'Ant_front_right_back_right'

    traj_steps = args.traj_steps
    traj_steps = 50_000

    if env_id == 'Ant':
        env = gym.make('Ant-v5', terminate_when_unhealthy=False, include_cfrc_ext_in_observation=False)
        load_path = f"runs/rl/Ant-v5/ckpt/sac_3400000_steps"
        save_path = f"expert_trajectories/Ant.pt"

    elif env_id == 'Ant_front_left_back_left':        
        env = gym.make('Ant-v5', terminate_when_unhealthy=False, include_cfrc_ext_in_observation=False, render_mode='human', xml_file='~/Study/Research/Transfer-IRL/env/xml/ant_crippled_12.xml')
        env = DisabledAntOnly(env, joints_status=[0, 0, 1, 1, 0, 0, 1, 1])
        load_path = f"runs/rl/Ant-v5_front_left_back_left/ckpt/sac_1000000_steps"
        save_path = f"expert_trajectories/Ant_front_left_back_left.pt"

    elif env_id == 'Ant_front_right_back_right':
        env = gym.make('Ant-v5', terminate_when_unhealthy=False, include_cfrc_ext_in_observation=False, render_mode='human', xml_file='~/Study/Research/Transfer-IRL/env/xml/ant_crippled_03.xml')
        env = DisabledAntOnly(env, joints_status=[1, 1, 0, 0, 1, 1, 0, 0])
        load_path = f"runs/rl/Ant-v5_front_right_back_right/ckpt/sac_1000000_steps"
        save_path = f"expert_trajectories/Ant_front_right_back_right.pt"

    elif env_id == 'HalfCheetah':
        env = gym.make('HalfCheetah-v5')
        load_path = f"runs/rl/HalfCheetah-v5_full/ckpt/sac_320000_steps"
        save_path = f"expert_trajectories/HalfCheetah.pt"

    policy = SAC.load(load_path)

    states, next_states, imgs, actions, rewards, infos, truncatedes, terminatedes, reward_states, next_reward_states = [], [], [], [] ,[] ,[] ,[] ,[], [], []

    rewards_stat, lengths_stat = [], []
    reward_stat, length_stat = 0, 0

    state, info = env.reset()

    for i in tqdm(range(traj_steps)):

        action, _state = policy.predict(state, deterministic=False)

        next_state, reward, terminated, truncated, next_info = env.step(action)

        states.append(state)
        next_states.append(next_state)
        actions.append(action)
        rewards.append(reward)
        infos.append(info)
        truncatedes.append(truncated)
        terminatedes.append(terminated)

        reward_stat += reward
        length_stat += 1

        if terminated or truncated:
            state, info = env.reset()
            rewards_stat.append(reward_stat)
            lengths_stat.append(length_stat)
            reward_stat, length_stat = 0, 0
        else:
            state = next_state
            info = next_info

    print(f'avg reward: {round(np.array(rewards_stat).mean(), 2)} | avg length: {round(np.array(lengths_stat).mean(), 2)}')
    
    states = th.tensor(np.array(states))
    next_states = th.tensor(np.array(next_states))
    actions = th.tensor(np.array(actions))
    terminatedes = th.tensor(terminatedes)
    truncatedes = th.tensor(truncatedes)
    rewards = th.tensor(rewards)
    reward_states = th.tensor(np.array(reward_states))
    next_reward_states = th.tensor(np.array(next_reward_states))

    data = {
        'obs': states,
        'next_obs': next_states,
        'action': actions,
        'terminated': terminatedes,
        'truncated': truncatedes,
        'reward': rewards,
        'reward_state': reward_states,
    }

    # th.save(data, f'./expert_trajectories/{env_id}.pt')
