#!/usr/bin/env python3
import random
import argparse
import datetime
import os, json
import pprint

import numpy as np
import torch

from tianshou.data import ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3Policy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic

from algorithms import NECSA_TD3_ALGO, TD3_ALGO, NECSA_ADV_TD3_ALGO
from _tianshou_custom.data import CustomCollector
from data.necsa.necsa_collector import NECSACollector
from data.necsa.abstraction_mode import S, S_A, HS, HS_A, HS_HA
from utils.argparse_util import extend_argparser, override_argument, merge_args
from utils.environment_util import make_env
from utils.experiment_util import get_save_best_fn, get_save_checkpoint_fn, get_necsa_dict
from utils.json_util import custom_serializer
from utils.log_util import get_logger

torch.manual_seed(random.randint(1, 1000))


def _get_general_argparser(*args):
    parser = argparse.ArgumentParser()
    for argparser in args:
        parser = extend_argparser(argparser(), parser)

    parser.add_argument("algo_name", type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--logdir", type=str, default="./log")
    parser.add_argument("--results-dir", type=str, default="./results")
    parser.add_argument("--render", type=float, default=0.)
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )
    parser.add_argument("--resume-path", type=str, default=None)
    parser.add_argument("--resume-id", type=str, default=None)
    parser.add_argument(
        "--logger",
        type=str,
        default="tensorboard",
        choices=["tensorboard", "wandb"],
    )
    parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark")
    parser.add_argument(
        "--watch",
        default=False,
        action="store_true",
        help="watch the play of pre-trained policy only",
    )
    return parser

def get_td3_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(_get_general_argparser(*args), parser)

    parser.set_defaults(algo_name=TD3_ALGO)
    parser.set_defaults(use_necsa=False)
    parser.set_defaults(necsa_adv=False)
    parser.add_argument("--buffer-size", type=int, default=1000000)
    parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
    parser.add_argument("--actor-lr", type=float, default=3e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--exploration-noise", type=float, default=0.1)
    parser.add_argument("--policy-noise", type=float, default=0.2)
    parser.add_argument("--noise-clip", type=float, default=0.5)
    parser.add_argument("--update-actor-freq", type=int, default=2)
    parser.add_argument("--start-timesteps", type=int, default=25000)
    parser.add_argument("--epoch", type=int, default=200)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--step-per-collect", type=int, default=1)
    parser.add_argument("--update-per-step", type=int, default=1)
    parser.add_argument("--n-step", type=int, default=1)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--training-num", type=int, default=1)
    parser.add_argument("--test-num", type=int, default=10)
    return parser

def get_necsa_td3_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(get_td3_argparser(*args), parser)

    parser.set_defaults(algo_name=NECSA_TD3_ALGO)
    parser.set_defaults(use_necsa=True)
    parser.set_defaults(necsa_adv=False)
    override_argument(parser, "--start-timesteps", type=int, default=10000)

    parser.add_argument("--step", type=int, default=3)                  # Directory for storing all experimental data
    parser.add_argument("--grid_num", type=int, default=5)              # Directory for storing all experimental data
    parser.add_argument("--epsilon", type=float, default=0.2)            # Directory for storing all experimental data
    parser.add_argument("--necsa_lr", type=float, default=0.1) # TODO what is the default?
    parser.add_argument("--necsa_gamma", type=float, default=0.99) # TODO what is the default?
    parser.add_argument("--state_dim", type=int, default=16)
    parser.add_argument("--state_min", type=float, default=-10)        #
    parser.add_argument("--state_max", type=float, default=10)         # state_max, state_min
    parser.add_argument("--circular_buffer", type=bool, default=False)
    parser.add_argument("--mode", type=str, default=S_A, choices=[S, S_A, HS, HS_A, HS_HA])
    parser.add_argument("--reduction", action="store_true")   #
    parser.add_argument("--score_type", type=str, default='score')
    return parser

def get_necsa_adv_td3_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(get_necsa_td3_argparser(*args), parser)

    parser.set_defaults(algo_name=NECSA_ADV_TD3_ALGO)
    parser.set_defaults(use_necsa=True)
    parser.set_defaults(necsa_adv=True)
    return parser


def test_td3(args):
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    if args.resume_path:
        log_name = args.resume_path
        log_path = os.path.join(args.logdir, log_name)
        with open(os.path.join(log_path, "config.json"), 'r') as f:
            loaded_config = json.load(f)
            args = merge_args(args, loaded_config)
    else:
        os.mkdir(log_path)
        with open(os.path.join(log_path, "config.json"), 'w') as f:
            json.dump(vars(args), f, indent=2)

    logger = get_logger(args, os.path.join(args.logdir, log_name), log_name)

    env, train_envs, test_envs, watch_env = make_env(
        task=args.task,
        seed=args.seed,
        num_train_envs=args.training_num,
        num_test_envs=args.test_num,
        obs_norm=False,
        create_watch_env=args.render
    )
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    args.exploration_noise = args.exploration_noise * args.max_action
    args.policy_noise = args.policy_noise * args.max_action
    args.noise_clip = args.noise_clip * args.max_action
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # model
    net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
    actor = Actor(
        net_a, args.action_shape, max_action=args.max_action, device=args.device
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    net_c1 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    net_c2 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    critic1 = Critic(net_c1, device=args.device).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(net_c2, device=args.device).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    policy = TD3Policy(
        actor=actor,
        actor_optim=actor_optim,
        critic=critic1,
        critic_optim=critic1_optim,
        action_space=env.action_space,
        critic2=critic2,
        critic2_optim=critic2_optim,
        tau=args.tau,
        gamma=args.gamma,
        exploration_noise=GaussianNoise(sigma=args.exploration_noise),
        policy_noise=args.policy_noise,
        update_actor_freq=args.update_actor_freq,
        noise_clip=args.noise_clip,
        estimation_step=args.n_step,
    )

    # load a previous policy
    if args.resume_path:
        resume_path = os.path.join(args.logdir, args.resume_path)
        policy.load_state_dict(torch.load(os.path.join(resume_path, "policy.pth"), map_location=args.device))
        print("Loaded agent from: ", resume_path)

    # collector
    if args.training_num > 1:
        buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
    else:
        buffer = ReplayBuffer(args.buffer_size)

    NECSA_DICT = None
    COLLECTOR_CLASS = CustomCollector
    if args.use_necsa:
        NECSA_DICT = get_necsa_dict(args, env)
        COLLECTOR_CLASS = NECSACollector

    train_collector = COLLECTOR_CLASS(policy, train_envs, buffer, exploration_noise=True, NECSA_DICT=NECSA_DICT)
    test_collector = CustomCollector(policy, test_envs)
    train_collector.collect(n_step=args.start_timesteps, random=True, reset_before_collect=True)

    def watch():
        # Let's watch its performance!
        policy.eval()
        if args.render:
            envs = watch_env
            # TODO: CHECK THIS
            # envs.seed(args.seed)
        else:
            envs = test_envs
            envs.seed(args.seed)
        collector = CustomCollector(policy, envs, exploration_noise=True)
        result = collector.collect(
            n_episode=args.test_num, render=args.render, reset_before_collect=True
        )
        envs.close()
        print(f"Final return: {result.returns.mean()}, length: {result.lens.mean()}")

    def save_results():
        results_dir = args.results_dir
        results_path = os.path.join(results_dir, log_name)
        if 'necsa' in args.algo_name:
            results_path = f"{results_path}_{str(args.step)}"
        os.makedirs(results_path, exist_ok=True)
        results_path = f"{results_path}.json"
        print(results_path)
        with open(results_path, 'w') as f:
            json.dump(test_collector.policy_eval_results, f, default=custom_serializer)

    if args.watch:
        watch()
        exit(0)

    # trainer
    trainer = OffpolicyTrainer(
        policy=policy,
        max_epoch=args.epoch,
        batch_size=args.batch_size,
        train_collector=train_collector,
        test_collector=test_collector,
        step_per_epoch=args.step_per_epoch,
        episode_per_test=args.test_num,
        update_per_step=args.update_per_step,
        step_per_collect=args.step_per_collect,
        save_best_fn=get_save_best_fn(log_path),
        save_checkpoint_fn=get_save_checkpoint_fn(policy, log_path),
        resume_from_log=args.resume_id is not None,
        logger=logger,
        test_in_train=False,
    )

    result = trainer.run()
    pprint.pprint(result)

    watch()
    save_results()

