from configparser import ConfigParser
from argparse import ArgumentParser

import torch
import gym
from gym.envs.registration import register
import numpy as np
import os

from agents.ppo import PPO
from agents.sac import SAC
from agents.ddpg import DDPG

from utils.utils import make_transition, Dict, RunningMeanStd
os.makedirs('./model_weights', exist_ok=True)

import sys
sys.path.append('..')

Cheetah_LEN = 500
register(
    id="HCPos-v0",
    entry_point="half_cheetah_pos:HalfCheetahPosEnv",
    max_episode_steps=Cheetah_LEN,
    reward_threshold=None,
    nondeterministic=False,
)

parser = ArgumentParser('parameters')

parser.add_argument("--env_name", type=str, default = 'HCPos-v0', help = "'Ant-v2','HalfCheetah-v2','Hopper-v2','Humanoid-v2','HumanoidStandup-v2',\
          'InvertedDoublePendulum-v2', 'InvertedPendulum-v2' (default : Hopper-v2)")
parser.add_argument("--algo", type=str, default = 'ppo', help = 'algorithm to adjust (default : ppo)')
parser.add_argument('--train', type=bool, default=True, help="(default: True)")
parser.add_argument('--render', type=bool, default=False, help="(default: False)")
parser.add_argument('--epochs', type=int, default=3000, help='number of epochs, (default: 3000)')
parser.add_argument('--tensorboard', type=bool, default=False, help='use_tensorboard, (default: False)')
parser.add_argument("--load", type=str, default = 'no', help = 'load network name in ./model_weights')
parser.add_argument("--save_interval", type=int, default = 1000, help = 'save interval(default: 100)')
parser.add_argument("--print_interval", type=int, default = 100, help = 'print interval(default : 20)')
parser.add_argument("--use_cuda", type=bool, default = True, help = 'cuda usage(default : True)')
parser.add_argument("--reward_scaling", type=float, default = 1.0, help = 'reward scaling(default : 1.0)')
parser.add_argument("--seed", type=int, required=True)
args = parser.parse_args()
parser = ConfigParser()
parser.read('config.ini')
agent_args = Dict(parser,args.algo)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.use_cuda == False:
    device = 'cpu'
    
if args.tensorboard:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()
else:
    writer = None
    
env = gym.make(args.env_name)
eval_env = gym.make(args.env_name)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
#state_rms = RunningMeanStd(state_dim)

seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)

########## functional #########
def eval_model(env, agent, n_episodes=20):
    global Cheetah_LEN
    max_episode_length = Cheetah_LEN

    return_lst = []
    ep_length_lst = []
    xpos_vio_lst = []
    
    for _ in range(n_episodes):
        s, done = env.reset(), False
        ep_r, total_step, xpos_vio = 0, 0, 0
        while True:
            with torch.no_grad():
                mu,sigma = agent.get_action(torch.from_numpy(s).float().to(device))
                dist = torch.distributions.Normal(mu,sigma[0])
                action = dist.sample()
                a = action.cpu().numpy()
            s_prime, r, done, info = env.step(a)
            
            xpos = info['x_position']
            if xpos < -3:
                xpos_vio += 1
            ep_r += r
            total_step += 1
            
            if total_step == max_episode_length:
                done = True
            if done:
                break

            s = s_prime

        return_lst.append(ep_r)
        ep_length_lst.append(total_step)
        xpos_vio_lst.append(xpos_vio)

    return np.array(return_lst), np.array(ep_length_lst), np.array(xpos_vio_lst)

def get_save_dir(env_id, lr, seed):
    save_dir = './save/'+env_id + '/lr=' + str(lr) + '/seed_' + str(seed) + '/'
    return save_dir
######################


if args.algo == 'ppo' :
    agent = PPO(writer, device, state_dim, action_dim, agent_args)
elif args.algo == 'sac' :
    agent = SAC(writer, device, state_dim, action_dim, agent_args)
elif args.algo == 'ddpg' :
    from utils.noise import OUNoise
    noise = OUNoise(action_dim,0)
    agent = DDPG(writer, device, state_dim, action_dim, agent_args, noise)

    
if (torch.cuda.is_available()) and (args.use_cuda):
    agent = agent.cuda()

if args.load != 'no':
    agent.load_state_dict(torch.load("./model_weights/"+args.load))
    
score_lst = []

print('env:', args.env_name, 'algo:', args.algo)
print('seed:', seed)
total_episodes = 0
eval_intvl = 100

save_dir = get_save_dir(args.env_name, agent_args.actor_lr, seed)
os.makedirs(save_dir, exist_ok=True)

eval_r_lst, eval_len_lst, eval_xpos_vio_lst = [],[],[]

if agent_args.on_policy == True:
    score = 0.0
    state = env.reset()
    
    
    for n_epi in range(args.epochs):
        for t in range(agent_args.traj_length):
            if args.render:    
                env.render()
            
            mu,sigma = agent.get_action(torch.from_numpy(state).float().to(device))
            dist = torch.distributions.Normal(mu,sigma[0])
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(-1,keepdim = True)
            next_state, reward, done, info = env.step(action.cpu().numpy())

            transition = make_transition(state,\
                                         action.cpu().numpy(),\
                                         np.array([reward*args.reward_scaling]),\
                                         next_state,\
                                         np.array([done]),\
                                         log_prob.detach().cpu().numpy()\
                                        )
            agent.put_data(transition) 
            score += reward
            if done:
                state = env.reset()
                score_lst.append(score)
                if args.tensorboard:
                    writer.add_scalar("score/score", score, n_epi)
                score = 0

                total_episodes += 1

                if total_episodes % eval_intvl == 0:
                    eval_r, eval_len, eval_xpos_vio = eval_model(eval_env, agent)
                    print('eval_return:', eval_r.mean())
                    eval_r_lst.append(eval_r)
                    eval_len_lst.append(eval_len)
                    eval_xpos_vio_lst.append(eval_xpos_vio)

                    with open(save_dir + 'eval_r.npy', 'wb') as f:
                        np.save(f, np.array(eval_r_lst))
                    with open(save_dir + 'eval_len.npy', 'wb') as f:
                        np.save(f, np.array(eval_len_lst))
                    with open(save_dir + 'eval_xpos_vio.npy', 'wb') as f:
                        np.save(f, np.array(eval_xpos_vio_lst))

            else:
                state = next_state
                

        agent.train_net(n_epi)
        if (n_epi+1) % args.print_interval==0:
            print("# of epoch :{}, avg score : {:.1f}".format(n_epi, sum(score_lst)/len(score_lst)))
            score_lst = []
        if (n_epi+1) % args.save_interval==0:
            torch.save(agent.state_dict(),'./model_weights/agent_'+str(n_epi))

            
else : # off policy 
    for n_epi in range(args.epochs):
        score = 0.0
        state = env.reset()
        done = False
        while not done:
            if args.render:    
                env.render()
            action, _ = agent.get_action(torch.from_numpy(state).float().to(device))
            action = action.cpu().detach().numpy()
            next_state, reward, done, info = env.step(action)
            transition = make_transition(state,\
                                         action,\
                                         np.array([reward*args.reward_scaling]),\
                                         next_state,\
                                         np.array([done])\
                                        )
            agent.put_data(transition) 

            state = next_state

            score += reward
            if agent.data.data_idx > agent_args.learn_start_size: 
                agent.train_net(agent_args.batch_size, n_epi)
        score_lst.append(score)
        if args.tensorboard:
            writer.add_scalar("score/score", score, n_epi)
        if n_epi%args.print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, sum(score_lst)/len(score_lst)))
            score_lst = []
        if n_epi%args.save_interval==0 and n_epi!=0:
            torch.save(agent.state_dict(),'./model_weights/agent_'+str(n_epi))
