from configparser import ConfigParser
from argparse import ArgumentParser

import torch
import gym
from gym.envs.registration import register
import numpy as np
import os
import random
import sys
sys.path.append('..')

from ppo_tamar import PPO_Tamar

from utils.utils import make_transition, Dict, RunningMeanStd
os.makedirs('./model_weights', exist_ok=True)

########## register env ###########
Cheetah_LEN = 500
register(
    id="HCPos-v0",
    entry_point="half_cheetah_pos:HalfCheetahPosEnv",
    max_episode_steps=Cheetah_LEN,
    reward_threshold=None,
    nondeterministic=False,
)
#################################

parser = ArgumentParser('parameters')

parser.add_argument("--env_name", type=str, default = 'HCPos-v0', help = "'Ant-v2','HalfCheetah-v2','Hopper-v2','Humanoid-v2','HumanoidStandup-v2',\
          'InvertedDoublePendulum-v2', 'InvertedPendulum-v2' (default : Hopper-v2)")
parser.add_argument('--train', type=bool, default=True, help="(default: True)")
parser.add_argument('--epochs', type=int, default=3000, help='number of epochs, (default: 3000)')
parser.add_argument('--tensorboard', type=bool, default=False, help='use_tensorboard, (default: False)')
parser.add_argument("--load", type=str, default = 'no', help = 'load network name in ./model_weights')
parser.add_argument("--save_interval", type=int, default = 100, help = 'save interval(default: 100)')
parser.add_argument("--print_interval", type=int, default = 100, help = 'print interval(default : 20)')
parser.add_argument("--use_cuda", type=bool, default = True, help = 'cuda usage(default : True)')
parser.add_argument("--reward_scaling", type=float, default = 1.0, help = 'reward scaling(default : 1.0)')
parser.add_argument("--seed", type=int, required=True)
args = parser.parse_args()
parser = ConfigParser()
parser.read('config.ini')
agent_args = Dict(parser, 'ppo')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.use_cuda == False:
    device = 'cpu'
    
if args.tensorboard:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()
else:
    writer = None
    
env = gym.make(args.env_name)
eval_env = gym.make(args.env_name)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
#state_rms = RunningMeanStd(state_dim)

seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

########## functional #########
def eval_model(env, agent, n_episodes=20):
    global Cheetah_LEN
    max_episode_length = Cheetah_LEN

    return_lst = []
    ep_length_lst = []
    xpos_vio_lst = []
    
    for _ in range(n_episodes):
        s, done = env.reset(), False
        ep_r, total_step, xpos_vio = 0, 0, 0
        while True:
            with torch.no_grad():
                mu,sigma = agent.get_action(torch.from_numpy(s).float().to(device))
                dist = torch.distributions.Normal(mu,sigma[0])
                action = dist.sample()
                a = action.cpu().numpy()
            s_prime, r, done, info = env.step(a)
            
            xpos = info['x_position']
            if xpos < -3:
                xpos_vio += 1
            ep_r += r
            total_step += 1
            
            if total_step == max_episode_length:
                done = True
            if done:
                break

            s = s_prime

        return_lst.append(ep_r)
        ep_length_lst.append(total_step)
        xpos_vio_lst.append(xpos_vio)

    return np.array(return_lst), np.array(ep_length_lst), np.array(xpos_vio_lst)

def get_save_dir(env_id, lam, lr, seed):
    save_dir = './save/'+ env_id + '/lam='+str(lam) + '/lr='+str(lr)+'/seed_'+str(seed)+'/'
    return save_dir
######################

agent = PPO_Tamar(writer, device, state_dim, action_dim, agent_args)


if (torch.cuda.is_available()) and (args.use_cuda):
    agent = agent.cuda()

if args.load != 'no':
    agent.load_state_dict(torch.load("./model_weights/"+args.load))
    
score_lst = []

total_episodes = 0
eval_intvl = 100


save_dir = get_save_dir(args.env_name, agent_args.lam, agent_args.actor_lr, seed)
os.makedirs(save_dir, exist_ok=True)
print('lam:', agent_args.lam, 'lr:',agent_args.actor_lr, 'seed:', seed)

eval_r_lst, eval_len_lst, eval_xpos_vio_lst = [], [], []


score = 0.0
state = env.reset()
episode_length = 0

train_episodes = int(args.epochs * 10)
for i_ep in range(train_episodes):
    
    episodes_buf = {'state':[], 'action':[], 'reward':[], 'log_prob':[]}
    state_lst, action_lst, reward_lst, log_prob_lst = [], [], [], []

    for t in range(agent_args.traj_length):
        
        mu,sigma = agent.get_action(torch.from_numpy(state).float().to(device))
        dist = torch.distributions.Normal(mu,sigma[0])
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(-1,keepdim = True)
        next_state, reward, done, info = env.step(action.cpu().numpy())
        episode_length += 1
        
        transition = make_transition(state,\
                                        action.cpu().numpy(),\
                                        np.array([reward*args.reward_scaling]),\
                                        next_state,\
                                        np.array([done]),\
                                        log_prob.detach().cpu().numpy()\
                                    )
        agent.put_data(transition)

        state_lst.append(state)
        action_lst.append(action.cpu().numpy())
        reward_lst.append(reward*args.reward_scaling)
        log_prob_lst.append(log_prob.detach().cpu().numpy())

        score += reward

        if episode_length == Cheetah_LEN:
            done = True
        if done:
            state = env.reset()
            #state = np.clip((state_ - state_rms.mean) / (state_rms.var ** 0.5 + 1e-8), -5, 5)
            score_lst.append(score)
            if args.tensorboard:
                writer.add_scalar("score/score", score, i_ep)
            score = 0

            episodes_buf['state'].append(state_lst)
            episodes_buf['action'].append(action_lst)
            episodes_buf['reward'].append(reward_lst)
            episodes_buf['log_prob'].append(log_prob_lst)
            state_lst, action_lst, reward_lst, log_prob_lst = [], [], [], []

            total_episodes += 1
            if total_episodes % eval_intvl == 0:
                eval_r, eval_len, eval_xpos_vio = eval_model(eval_env, agent)
                print('eval_return:', eval_r.mean())
                eval_r_lst.append(eval_r)
                eval_len_lst.append(eval_len)
                eval_xpos_vio_lst.append(eval_xpos_vio)

                with open(save_dir + 'eval_r.npy', 'wb') as f:
                    np.save(f, np.array(eval_r_lst))
                with open(save_dir + 'eval_len.npy', 'wb') as f:
                    np.save(f, np.array(eval_len_lst))
                with open(save_dir + 'eval_xpos_vio.npy', 'wb') as f:
                    np.save(f, np.array(eval_xpos_vio_lst))

        else:
            state = next_state
    agent.put_episodes(episodes_buf)
    agent.train_net(i_ep)

    if (i_ep+1)%args.print_interval==0:
        print("# of epoch :{}, avg score : {:.1f}".format(i_ep, sum(score_lst)/len(score_lst)))
        score_lst = []
    # if (i_ep+1)%args.save_interval==0:
    #     torch.save(agent.state_dict(),'./model_weights/agent_'+str(i_ep))
        
