import gym
import numpy as np
import torch
from TD3 import TD3Agent
import utils
from gym.envs.registration import register
import os
import sys
sys.path.append('..')

import argparse
parser = argparse.ArgumentParser(description='lam, seed')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()


########### setting ############
seed = args.seed
net_width = 128
lr = 3e-4
gamma = 0.999
lam = args.lam
tau = 0.005
policy_noise = 0.2
noise_clip = 0.5
policy_freq = 2
max_timesteps = int(1e6)
start_timesteps = 25e3
exploration_noise = 0.1
batch_size = 256
num_eval_episodes = 5
eval_freq = 1e4

env_id = 'InvPos-v0'

########### register env #################
Pendulum_LEN = 500
register(
    id="InvPos-v0",
    entry_point="inverted_pendulum:InvertedPendulumEnv",
    max_episode_steps=Pendulum_LEN,
    reward_threshold=None,
    nondeterministic=False,
)

############ func ##############
def eval_model(env, agent, n_episodes=20):
    global Pendulum_LEN
    max_episode_length = Pendulum_LEN

    return_lst = []
    ep_length_lst = []
    xpos_vio_lst = []
    
    for _ in range(n_episodes):
        s, done = env.reset(), False
        ep_r, total_step, xpos_vio = 0, 0, 0
        while True:
            with torch.no_grad():
                a = agent.select_action(np.array(s))
            s_prime, r, done, info = env.step(a)
            
            xpos = info['x_position']
            if xpos > 0.01:
                xpos_vio += 1
            ep_r += r
            total_step += 1
            
            if total_step == max_episode_length:
                done = True
            if done:
                break

            s = s_prime

        return_lst.append(ep_r)
        ep_length_lst.append(total_step)
        xpos_vio_lst.append(xpos_vio)

    return np.array(return_lst), np.array(ep_length_lst), np.array(xpos_vio_lst)

def get_save_dir(env_id, lam, lr, seed):
    save_dir = "./save/" + env_id + '/lam=' + str(lam)
    save_dir += "/lr=" + str(lr) + "/seed=" + str(seed) + "/"
    return save_dir

####### make save dir ########
save_dir = get_save_dir(env_id, lam, lr, seed)
os.makedirs(save_dir, exist_ok=True)

############ main ##############
env = gym.make(env_id)
eval_env = gym.make(env_id)
env.seed(seed)
eval_env.seed(2**31-1-seed)
torch.manual_seed(seed)
np.random.seed(seed)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

kwargs = {
    "state_dim": state_dim,
    "action_dim": action_dim,
    "net_width": net_width,
    "max_action": max_action,
    "lr": lr,
    "discount": gamma,
    "lam": lam,
    "tau": tau,
    "policy_noise": policy_noise,
    "noise_clip": noise_clip,
    "policy_freq": policy_freq,
    "save_dir": save_dir,
}

print('env:',env_id, 'lam:', lam, 'lr:', lr, 'seed:', seed)

policy = TD3Agent(**kwargs)
replay_buffer = utils.ReplayBuffer(state_dim, action_dim)

state, done = env.reset(), False
episode_reward, episode_timesteps, episode_num = 0, 0, 0

test_score_lst = []
best_mean = -10000
best_mean_variance = -10000

eval_r_lst, eval_len_lst, eval_xpos_vio_lst = [],[],[]

for t in range(max_timesteps):
    episode_timesteps += 1

    if t < start_timesteps:
        action = env.action_space.sample()
    else:
        action = (
            policy.select_action(np.array(state))
            + np.random.normal(0, max_action * exploration_noise, size=action_dim)
        ).clip(-max_action, max_action)

    next_state, reward, done, info = env.step(action)
    done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0

    # store data in buffer
    replay_buffer.add(state, action, next_state, reward, done_bool)
    # store reward in online reward buffer
    policy.online_rewards.append(reward)

    state = next_state
    episode_reward += reward

    # train after collecting sufficient data
    if t>=start_timesteps:
        policy.train(replay_buffer, batch_size)

    if done:
        state, done = env.reset(), False
        episode_reward = 0
        episode_timesteps = 0
        episode_num += 1

    if (t+1) % eval_freq == 0:
        eval_r, eval_len, eval_xpos_vio = eval_model(eval_env, policy)
        eval_r_mean = eval_r.mean()
        eval_r_mean_var = eval_r_mean - lam * eval_r.var()
        print('eval return:', eval_r_mean)

        eval_r_lst.append(eval_r)
        eval_len_lst.append(eval_len)
        eval_xpos_vio_lst.append(eval_xpos_vio)

        if eval_r_mean > best_mean:
            best_mean = eval_r_mean
            policy.save_best(risk=False)
        if eval_r_mean_var > best_mean_variance:
            best_mean_variance = eval_r_mean_var
            policy.save_best(risk=True)

        with open(save_dir + 'eval_r.npy', 'wb') as f:
            np.save(f, np.array(eval_r_lst))
        with open(save_dir + 'eval_len.npy', 'wb') as f:
            np.save(f, np.array(eval_len_lst))
        with open(save_dir + 'eval_xpos_vio.npy', 'wb') as f:
            np.save(f, np.array(eval_xpos_vio_lst))

        
