import sys, os, time
sys.path.append("./")
from ruamel.yaml import YAML
from utils import system
from envs.motion_planning_env import motion_planning_env

import gym
import numpy as np 
import torch
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from math import *
from common.sac import ReplayBuffer, SAC
from utils.plots.train_plot import plot_sac_curve

def train_policy(env):
    env.seed(seed)
    
    replay_buffer = ReplayBuffer(
        env.observation_space.shape[0], 
        env.action_space.shape[0],
        device=device,
        size=v['sac']['buffer_size'])
    
    sac_agent = SAC(env, replay_buffer,
        steps_per_epoch=env_T,
        update_after=env_T * v['sac']['random_explore_episodes'], 
        max_ep_len=env_T,
        seed=seed,
        start_steps=env_T * v['sac']['random_explore_episodes'],
        device=device,
        **v['sac']
        )
    assert sac_agent.reinitialize == True

    sac_agent.test_fn = sac_agent.test_agent_ori_env
    sac_test_rets, sac_alphas, sac_log_pis, sac_time_steps = sac_agent.learn_mujoco(print_out=True)

    plot_sac_curve(axs[0], sac_test_rets, sac_alphas, sac_log_pis, sac_time_steps)

    return sac_agent.get_action

def evaluate_policy(policy, env, n_episodes,deterministic=False):
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    expert_states = torch.zeros((n_episodes, env_T, obs_dim)) # s0 to sT-1
    expert_actions = torch.zeros((n_episodes, env_T, act_dim)) # a0 to aT-1

    returns_reward = []

    for n in range(n_episodes):
        obs = env.reset(np.array(env.initial_state))
        ret = 0
        for t in range(env_T):
            action = policy(obs, deterministic)
            expert_states[n, t, :] = torch.from_numpy(obs).clone()
            obs, rew, done, _ = env.step(action) # NOTE: assume rew=0 after done=True for evaluation   
            expert_actions[n, t, :] = torch.from_numpy(action).clone()
            ret += rew
            if done:
                break
        returns_reward.append(ret)
    
    return expert_states, expert_actions, np.array(returns_reward)


if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))
    # common parameters
    env_name, env_T = v['env']['env_name'], v['env']['T']
    seed = v['seed']
    goal = v['env']['goal']
    x_lim = v['env']['x_lim']
    y_lim = v['env']['y_lim']
    initial_state = v['env']['initial_state']
    goal = v['env']['goal']
    goal_radius = v['env']['goal_radius']
    is_goal_circle = v['env']['is_goal_circle']

    # system: device, threads, seed, pid
    device = torch.device(f"cuda:{v['cuda']}" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)

    fig, axs = plt.subplots(1, 2, figsize=(15, 6))

    env = motion_planning_env(x_lim,y_lim,initial_state,goal,goal_radius,is_goal_circle)
    print(f"training Expert on {env_name}")
    
    for task in range(1):
        policy = train_policy(env)

        env.seed(seed+1)

        expert_states_det, expert_actions_det, expert_returns_reward = evaluate_policy(policy, env, v['expert']['training_episode'],True)
        return_info = f'Task {task}, Expert(Det) Return Avg: {expert_returns_reward.mean():.2f}, std: {expert_returns_reward.std():.2f}'
        print(return_info)
        np.savetxt(f'expert_data/states/{env_name}/Task{task}.txt',np.r_[expert_returns_reward.mean()])

        torch.save(expert_states_det, f'expert_data/states/{env_name}/Task{task}_training_set.pt')
        torch.save(expert_actions_det, f'expert_data/actions/{env_name}/Task{task}_training_set.pt')
        expert_states_det, expert_actions_det, expert_returns_reward = evaluate_policy(policy, env, v['expert']['eval_episode'],True)
        env.draw_trajectory(list(expert_states_det[0]))
        torch.save(expert_states_det, f'expert_data/states/{env_name}/Task{task}_eval_set.pt')
        torch.save(expert_actions_det, f'expert_data/actions/{env_name}/Task{task}_eval_set.pt')








