#!/usr/bin/env python3

import argparse
import datetime
import os, json
import pprint

import numpy as np
import torch

from tianshou.data import ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import GaussianNoise
from tianshou.policy import DDPGPolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic

from algorithms import NECSA_DDPG_ALGO, DDPG_ALGO
from _tianshou_custom.data import CustomCollector
from data.necsa.necsa_collector import NECSACollector
from data.necsa.abstraction_mode import S, S_A, HS, HS_A, HS_HA
from utils.argparse_util import extend_argparser, merge_args
from utils.environment_util import make_env
from utils.experiment_util import get_save_best_fn, get_save_checkpoint_fn, get_necsa_dict
from utils.json_util import custom_serializer
from utils.log_util import get_logger


def _get_general_argparser(*args):
    parser = argparse.ArgumentParser()
    for argparser in args:
        parser = extend_argparser(argparser(), parser)

    parser.add_argument("algo_name", type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--logdir", type=str, default="./log")
    parser.add_argument("--results-dir", type=str, default="./results")
    parser.add_argument("--render", type=float, default=0.)
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )
    parser.add_argument("--resume-path", type=str, default=None)
    parser.add_argument("--resume-id", type=str, default=None)
    parser.add_argument(
        "--logger",
        type=str,
        default="tensorboard",
        choices=["tensorboard", "wandb"],
    )
    parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark")
    parser.add_argument(
        "--watch",
        default=False,
        action="store_true",
        help="watch the play of pre-trained policy only",
    )
    return parser


def get_ddpg_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(_get_general_argparser(*args), parser)

    parser.set_defaults(algo_name=DDPG_ALGO)
    parser.set_defaults(use_necsa=False)
    parser.add_argument("--buffer-size", type=int, default=1000000)
    parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
    parser.add_argument("--actor-lr", type=float, default=1e-3)
    parser.add_argument("--critic-lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--exploration-noise", type=float, default=0.1)
    parser.add_argument("--start-timesteps", type=int, default=10000)
    parser.add_argument("--epoch", type=int, default=200)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--step-per-collect", type=int, default=1)
    parser.add_argument("--update-per-step", type=int, default=1)
    parser.add_argument("--n-step", type=int, default=1)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--training-num", type=int, default=1)
    parser.add_argument("--test-num", type=int, default=10)
    return parser


def get_necsa_ddpg_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(get_ddpg_argparser(*args), parser)

    parser.set_defaults(algo_name=NECSA_DDPG_ALGO)
    parser.set_defaults(use_necsa=True)
    parser.set_defaults(necsa_adv=False)
    parser.add_argument("--step", type=int, default=3)                  # Directory for storing all experimental data
    parser.add_argument("--grid_num", type=int, default=5)              # Directory for storing all experimental data
    parser.add_argument("--epsilon", type=float, default=0.2)            # Directory for storing all experimental data
    parser.add_argument("--necsa_lr", type=float, default=0.1) # TODO what is the default?
    parser.add_argument("--necsa_gamma", type=float, default=0.99) # TODO what is the default?
    parser.add_argument("--state_dim", type=int, default=16)
    parser.add_argument("--state_min", type=float, default=-10)        #
    parser.add_argument("--state_max", type=float, default=10)         # state_max, state_min
    parser.add_argument("--circular_buffer", type=bool, default=False)
    parser.add_argument("--mode", type=str, default=S_A, choices=[S, S_A, HS, HS_A, HS_HA])
    parser.add_argument("--reduction", action="store_true")   #
    parser.add_argument("--score_type", type=str, default='score')
    return parser


def test_ddpg(args):
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    if args.resume_path:
        log_name = args.resume_path
        log_path = os.path.join(args.logdir, log_name)
        with open(os.path.join(log_path, "config.json"), 'r') as f:
            loaded_config = json.load(f)
            args = merge_args(args, loaded_config)
    else:
        os.mkdir(log_path)
        with open(os.path.join(log_path, "config.json"), 'w') as f:
            json.dump(vars(args), f, indent=2)

    logger = get_logger(args, os.path.join(args.logdir, log_name), log_name)

    env, train_envs, test_envs, watch_env = make_env(
        task=args.task,
        seed=args.seed,
        num_train_envs=args.training_num,
        num_test_envs=args.test_num,
        obs_norm=args.obs_norm,
        create_watch_env=args.render
    )
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    args.exploration_noise = args.exploration_noise * args.max_action
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # model
    net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
    actor = Actor(
        net_a, args.action_shape, max_action=args.max_action, device=args.device
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    net_c = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    critic = Critic(net_c, device=args.device).to(args.device)
    critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
    policy = DDPGPolicy(
        actor=actor,
        actor_optim=actor_optim,
        critic=critic,
        critic_optim=critic_optim,
        action_space=env.action_space,
        tau=args.tau,
        gamma=args.gamma,
        exploration_noise=GaussianNoise(sigma=args.exploration_noise),
        estimation_step=args.n_step,
    )

    # load a previous policy
    if args.resume_path:
        resume_path = os.path.join(args.logdir, args.resume_path)
        policy.load_state_dict(torch.load(os.path.join(resume_path, "policy.pth"), map_location=args.device))
        print("Loaded agent from: ", resume_path)

    # collector
    if args.training_num > 1:
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
    else:
        buffer = ReplayBuffer(args.buffer_size)

    NECSA_DICT = None
    COLLECTOR_CLASS = CustomCollector
    if args.use_necsa:
        NECSA_DICT = get_necsa_dict(args, env)
        COLLECTOR_CLASS = NECSACollector

    train_collector = COLLECTOR_CLASS(policy, train_envs, buffer, exploration_noise=True, NECSA_DICT=NECSA_DICT)
    test_collector = CustomCollector(policy, test_envs)
    train_collector.collect(n_step=args.start_timesteps, random=True, reset_before_collect=True)

    def watch():
        # Let's watch its performance!
        policy.eval()
        if args.render:
            envs = watch_env
            # TODO: CHECK THIS
            # envs.seed(args.seed)
        else:
            envs = test_envs
            envs.seed(args.seed)
        collector = CustomCollector(policy, envs, exploration_noise=True)
        result = collector.collect(
            n_episode=args.test_num, render=args.render, reset_before_collect=True
        )
        envs.close()
        print(f"Final return: {result.returns.mean()}, length: {result.lens.mean()}")

    def save_results():
        results_dir = args.results_dir
        results_path = os.path.join(results_dir, log_name)
        if 'necsa' in args.algo_name:
            results_path = f"{results_path}_{str(args.step)}"
        os.makedirs(results_path, exist_ok=True)
        results_path = f"{results_path}.json"
        print(results_path)
        with open(results_path, 'w') as f:
            json.dump(test_collector.policy_eval_results, f, default=custom_serializer)

    if args.watch:
        watch()
        exit(0)

    # trainer
    trainer = OffpolicyTrainer(
        policy=policy,
        max_epoch=args.epoch,
        batch_size=args.batch_size,
        train_collector=train_collector,
        test_collector=test_collector,
        step_per_epoch=args.step_per_epoch,
        episode_per_test=args.test_num,
        update_per_step=args.update_per_step,
        step_per_collect=args.step_per_collect,
        save_best_fn=get_save_best_fn(log_path),
        save_checkpoint_fn=get_save_checkpoint_fn(policy, log_path),
        resume_from_log=args.resume_id is not None,
        logger=logger,
        test_in_train=False,
    )

    result = trainer.run()
    pprint.pprint(result)

    watch()
    save_results()
