#!/usr/bin/env python3
import argparse
import datetime
import os
import json

import gym
import numpy as np
import torch
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR

from collector import Collector
from data import VectorBuffer
from env import BaseVectorEnv
from ppo import PPOPolicy
from trainer import Trainer
from network import Actor, Critic
from utils import Logger
from random_env import get_init_params, get_random_params, get_random_params_target, get_random_params2, get_random_params3


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='HalfCheetah-v3')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--buffer-size', type=int, default=4096)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--step-per-epoch', type=int, default=30000)
    parser.add_argument('--step-per-collect', type=int, default=2048)
    parser.add_argument('--repeat-per-collect', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--training-num', type=int, default=64)
    parser.add_argument('--test-num', type=int, default=10)
    # ppo special
    parser.add_argument('--rew-norm', type=int, default=True)
    parser.add_argument('--obs-norm', type=int, default=True)
    # In theory, `vf-coef` will not make any difference if using Adam optimizer.
    parser.add_argument('--vf-coef', type=float, default=0.25)
    parser.add_argument('--ent-coef', type=float, default=0.0)
    parser.add_argument('--gae-lambda', type=float, default=0.95)
    parser.add_argument('--bound-action-method', type=str, default="clip")
    parser.add_argument('--lr-decay', type=int, default=True)
    parser.add_argument('--max-grad-norm', type=float, default=0.5)
    parser.add_argument('--eps-clip', type=float, default=0.2)
    parser.add_argument('--dual-clip', type=float, default=None)
    parser.add_argument('--value-clip', type=int, default=1)
    parser.add_argument('--norm-adv', type=int, default=1)
    parser.add_argument('--recompute-adv', type=int, default=1)
    parser.add_argument('--logdir', type=str, default='compare_target')
    parser.add_argument(
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
    )
    parser.add_argument('--resume-path', type=str, default=None)
    parser.add_argument('--left-bound', type=float, default=0.5)
    parser.add_argument('--right-bound', type=float, default=1)
    return parser.parse_args()


def test_ppo(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    train_envs = BaseVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=args.obs_norm
    )
    test_envs = BaseVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)],
        norm_obs=args.obs_norm,
        obs_rms=train_envs.obs_rms,
        update_obs_rms=False
    )

    target_env = [i for i in range(args.training_num)]
    init_params = get_init_params(env)
    random_params_list = []
    for i in range(16):
        random_params = get_random_params3(init_params, log_scale_limit=[args.left_bound, args.right_bound]) 
        train_envs.set_env_attr("body_mass",random_params["body_mass"],target_env[4*i:4*i+4])
        train_envs.set_env_attr("body_inertia",random_params["body_inertia"],target_env[4*i:4*i+4])
        train_envs.set_env_attr("dof_damping",random_params["dof_damping"],target_env[4*i:4*i+4])
        train_envs.set_env_attr("geom_friction",random_params["geom_friction"],target_env[4*i:4*i+4])
        random_params_list.append(random_params)

    # seed
    train_envs.seed(args.seed)

    test_target_env = [i for i in range(args.test_num)]
    for i in range(5):
        random_params = get_random_params3(init_params, log_scale_limit=[args.left_bound, args.right_bound])
        test_envs.set_env_attr("body_mass",random_params["body_mass"],test_target_env[2*i:2*i+2])
        test_envs.set_env_attr("body_inertia",random_params["body_inertia"],test_target_env[2*i:2*i+2])
        test_envs.set_env_attr("dof_damping",random_params["dof_damping"],test_target_env[2*i:2*i+2])
        test_envs.set_env_attr("geom_friction",random_params["geom_friction"],test_target_env[2*i:2*i+2])
        random_params_list.append(random_params)
    test_envs.seed(args.seed)
    # model
    actor = Actor(args.state_shape[0],args.action_shape[0],device=args.device).to(args.device)
    critic = Critic(args.state_shape[0]+64,device=args.device).to(args.device)

    torch.nn.init.constant_(actor.sigma_param, -0.5)
    for m in list(actor.modules()) + list(critic.modules()):
        if isinstance(m, torch.nn.Linear):
            # orthogonal initialization
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)
    for m in actor.mu.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.zeros_(m.bias)
            m.weight.data.copy_(0.01 * m.weight.data)

    optim = torch.optim.Adam(
        list(actor.parameters()) + list(critic.parameters()), lr=args.lr
    )

    lr_scheduler = None
    if args.lr_decay:
        # decay learning rate to 0 linearly
        max_update_num = np.ceil(
            args.step_per_epoch / args.step_per_collect
        ) * args.epoch

        lr_scheduler = LambdaLR(
            optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
        )

    def dist(*logits):
        return Independent(Normal(*logits), 1)

    policy = PPOPolicy(
        actor,
        critic,
        optim,
        dist,
        discount_factor=args.gamma,
        gae_lambda=args.gae_lambda,
        max_grad_norm=args.max_grad_norm,
        vf_coef=args.vf_coef,
        ent_coef=args.ent_coef,
        reward_normalization=args.rew_norm,
        action_scaling=True,
        action_bound_method=args.bound_action_method,
        lr_scheduler=lr_scheduler,
        action_space=env.action_space,
        eps_clip=args.eps_clip,
        value_clip=args.value_clip,
        dual_clip=args.dual_clip,
        advantage_normalization=args.norm_adv,
        recompute_advantage=args.recompute_adv
    )

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
        if args.obs_norm:
            p = os.path.join(os.path.split(args.resume_path)[0], 'obs_rms.json')
            print("Loaded obs-norm from: ", p)
            with open(p,'r') as f:
                d = json.load(f)
                mean,var,count = np.array(d['mean']), np.array(d['var']),d['count']
                # train_envs.update_obs_rms = False
                train_envs.obs_rms.mean = mean
                train_envs.obs_rms.var = var
                train_envs.obs_rms.count = count
                test_envs.obs_rms = train_envs.obs_rms
                train_envs.update_obs_rms=False

    # collector
    buffer = VectorBuffer(args.buffer_size, len(train_envs))
    train_collector = Collector(policy, train_envs, buffer)
    test_collector = Collector(policy, test_envs)

    # log
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_ppo'
    log_path = os.path.join(args.logdir, args.task, 'ppo', log_file)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = Logger(log_path)


    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
        if args.obs_norm:
            d = {"mean":test_envs.obs_rms.mean.tolist(),"var":test_envs.obs_rms.var.tolist(),"count":test_envs.obs_rms.count}
            with open(os.path.join(log_path, 'obs_rms.json'),'w',encoding='utf-8') as f:
                json.dump(d,f)
 

    # trainer
    ppo_trainer = Trainer(
        policy,
        train_collector,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        args.repeat_per_collect,
        args.test_num,
        args.batch_size,
        step_per_collect=args.step_per_collect,
        save_best_fn=save_best_fn,
        logger=logger,
    )
    ppo_trainer.run()

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num)
    print(f'Final reward: {result["rews"].mean()}')


if __name__ == '__main__':
    test_ppo()
