import argparse
import datetime
import os, json
import pprint

import numpy as np
import torch

from tianshou.data import VectorReplayBuffer
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.discrete import IntrinsicCuriosityModule

from algorithms import DQN_ALGO, NECSA_DQN_ALGO
from _tianshou_custom.policy import CustomDQNPolicy
from _tianshou_custom.data import CustomCollector
from data.necsa.necsa_collector import NECSACollector
from data.necsa.abstraction_mode import S, S_A, HS, HS_A, HS_HA
from _tianshou_custom.examples.atari.atari_network import CustomDQN
from utils.argparse_util import extend_argparser, merge_args
from utils.environment_util import make_env
from utils.experiment_util import get_train_fn, get_test_fn, get_stop_fn, get_save_best_fn, \
    get_save_checkpoint_fn, get_necsa_dict
from utils.json_util import custom_serializer
from utils.log_util import get_logger


def _get_general_argparser(*args):
    parser = argparse.ArgumentParser()
    for argparser in args:
        parser = extend_argparser(argparser(), parser)

    parser.add_argument("algo_name", type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--logdir", type=str, default="./log")
    parser.add_argument("--results-dir", type=str, default="./results")
    parser.add_argument("--render", type=float, default=0.)
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )
    parser.add_argument("--resume-path", type=str, default=None)
    parser.add_argument("--resume-id", type=str, default=None)
    parser.add_argument(
        "--logger",
        type=str,
        default="tensorboard",
        choices=["tensorboard", "wandb"],
    )
    parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
    parser.add_argument(
        "--watch",
        default=False,
        action="store_true",
        help="watch the play of pre-trained policy only"
    )
    return parser


def get_dqn_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(_get_general_argparser(*args), parser)

    parser.set_defaults(algo_name=DQN_ALGO)
    parser.set_defaults(use_necsa=False)
    parser.add_argument("--eps-test", type=float, default=0.005)
    parser.add_argument("--eps-train", type=float, default=1.)
    parser.add_argument("--eps-train-final", type=float, default=0.05)
    parser.add_argument("--buffer-size", type=int, default=100000)
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--n-step", type=int, default=3)
    parser.add_argument("--target-update-freq", type=int, default=500)
    parser.add_argument("--epoch", type=int, default=1000)
    parser.add_argument("--step-per-epoch", type=int, default=10000)
    parser.add_argument("--step-per-collect", type=int, default=10)
    parser.add_argument("--update-per-step", type=float, default=0.1)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--training-num", type=int, default=10)
    parser.add_argument("--test-num", type=int, default=10)
    parser.add_argument("--save-buffer-name", type=str, default=None)
    parser.add_argument(
        "--icm-lr-scale",
        type=float,
        default=0.,
        help="use intrinsic curiosity module with this lr scale"
    )
    parser.add_argument(
        "--icm-reward-scale",
        type=float,
        default=0.01,
        help="scaling factor for intrinsic curiosity reward"
    )
    parser.add_argument(
        "--icm-forward-loss-weight",
        type=float,
        default=0.2,
        help="weight for the forward model loss in ICM"
    )
    return parser


def get_necsa_dqn_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(get_dqn_argparser(*args), parser)

    parser.set_defaults(algo_name=NECSA_DQN_ALGO)
    parser.set_defaults(use_necsa=True)
    parser.set_defaults(necsa_adv=False)
    parser.add_argument("--step", type=int, default=1)                  # Directory for storing all experimental data
    parser.add_argument("--grid_num", type=int, default=5)              # Directory for storing all experimental data
    parser.add_argument("--epsilon", type=float, default=0.1)            # Directory for storing all experimental data
    parser.add_argument("--necsa_lr", type=float, default=0.1) # TODO what is the default?
    parser.add_argument("--necsa_gamma", type=float, default=0.99) # TODO what is the default?
    parser.add_argument("--raw_state_dim", type=int, default=64)
    parser.add_argument("--state_dim", type=int, default=24)
    parser.add_argument("--state_min", type=float, default=0)        #
    parser.add_argument("--state_max", type=float, default=1)         # state_max, state_min
    parser.add_argument("--circular_buffer", type=bool, default=True)
    parser.add_argument("--mode", type=str, default=HS, choices=[S, S_A, HS, HS_A, HS_HA])
    parser.add_argument("--reduction", action="store_true")   #
    parser.add_argument("--score_type", type=str, default='score')
    return parser


def test_dqn(args):
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    if args.resume_path:
        log_name = args.resume_path
        log_path = os.path.join(args.logdir, log_name)
        with open(os.path.join(log_path, "config.json"), 'r') as f:
            loaded_config = json.load(f)
            args = merge_args(args, loaded_config)
    else:
        os.mkdir(log_path)
        with open(os.path.join(log_path, "config.json"), 'w') as f:
            json.dump(vars(args), f, indent=2)

    logger = get_logger(args, os.path.join(args.logdir, log_name), log_name)

    args.algo_name = f"{args.algo_name}_icm" if args.icm_lr_scale > 0 else args.algo_name

    env, train_envs, test_envs, watch_env = make_env(
        task=args.task,
        seed=args.seed,
        training_num=args.training_num,
        test_num=args.test_num,
        scale=args.scale_obs,
        frame_stack=args.frames_stack,
        create_watch_env=args.render
    )
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # define model
    net = CustomDQN(
        *args.state_shape,
        args.action_shape,
        args.device
    ).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)

    # define policy
    policy = CustomDQNPolicy(
        model=net,
        optim=optim,
        action_space=env.action_space,
        discount_factor=args.gamma,
        estimation_step=args.n_step,
        target_update_freq=args.target_update_freq
    ).to(args.device)

    if args.icm_lr_scale > 0:
        feature_net = CustomDQN(
            *args.state_shape, args.action_shape, args.device, features_only=True
        )
        action_dim = np.prod(args.action_shape)
        feature_dim = feature_net.output_dim
        icm_net = IntrinsicCuriosityModule(
            feature_net.net,
            feature_dim,
            action_dim,
            hidden_sizes=[512],
            device=args.device
        )
        icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
        policy = ICMPolicy(
            policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
            args.icm_forward_loss_weight
        ).to(args.device)

    # load a previous policy
    if args.resume_path:
        resume_path = os.path.join(args.logdir, args.resume_path)
        policy.load_state_dict(torch.load(os.path.join(resume_path, "policy.pth"), map_location=args.device))
        print("Loaded agent from: ", resume_path)

    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    buffer = VectorReplayBuffer(
        args.buffer_size,
        buffer_num=len(train_envs),
        ignore_obs_next=True,
        save_only_last_obs=True,
        stack_num=args.frames_stack
    )

    NECSA_DICT = None
    COLLECTOR_CLASS = CustomCollector
    if args.use_necsa:
        NECSA_DICT = get_necsa_dict(args, env)
        COLLECTOR_CLASS = NECSACollector

    # collector
    train_collector = COLLECTOR_CLASS(policy, train_envs, buffer, exploration_noise=True, NECSA_DICT=NECSA_DICT, env_type='atari')
    test_collector = CustomCollector(policy, test_envs, exploration_noise=True)

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        if args.render:
            envs = watch_env
            # TODO: CHECK THIS
            # envs.seed(args.seed)
        else:
            envs = test_envs
            envs.seed(args.seed)
        policy.eval()
        policy.set_eps(args.eps_test)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = VectorReplayBuffer(
                args.buffer_size,
                buffer_num=len(envs),
                ignore_obs_next=True,
                save_only_last_obs=True,
                stack_num=args.frames_stack
            )
            collector = CustomCollector(policy, envs, buffer, exploration_noise=True)
            result = collector.collect(n_step=args.buffer_size)
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            collector = CustomCollector(policy, envs, exploration_noise=True)
            result = collector.collect(
                n_episode=args.test_num, render=args.render, reset_before_collect=True
            )
        return_ = result.returns.mean()
        print(f"Mean return (over {result.n_collected_episodes} episodes): {return_}")

    def save_results():
        results_dir = args.results_dir
        results_path = os.path.join(results_dir, log_name)
        if 'necsa' in args.algo_name:
            results_path = f"{results_path}_{str(args.step)}"
        os.makedirs(results_path, exist_ok=True)
        results_path = f"{results_path}.json"
        print(results_path)
        with open(results_path, 'w') as f:
            json.dump(test_collector.policy_eval_results, f, default=custom_serializer)

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
    # trainer
    trainer = OffpolicyTrainer(
        policy=policy,
        max_epoch=args.epoch,
        batch_size=args.batch_size,
        train_collector=train_collector,
        test_collector=test_collector,
        step_per_epoch=args.step_per_epoch,
        episode_per_test=args.test_num,
        update_per_step=args.update_per_step,
        step_per_collect=args.step_per_collect,
        train_fn=get_train_fn(policy, args, logger),
        test_fn=get_test_fn(policy, args),
        stop_fn=get_stop_fn(env, args),
        save_best_fn=get_save_best_fn(log_path),
        save_checkpoint_fn=get_save_checkpoint_fn(policy, log_path),
        resume_from_log=args.resume_id is not None,
        logger=logger,
        test_in_train=False,
    )

    result = trainer.run()
    pprint.pprint(result)

    watch()
    save_results()
