import argparse
import datetime
import os, json
import pprint

import numpy as np
import torch

from tianshou.data import PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.trainer import OffpolicyTrainer

from algorithms import MPEC_ALGO
from _tianshou_custom.examples.atari.atari_network import CustomRainbow
from data.mpec.mpec_collector import MPECCollector
from data.necsa.abstraction_mode import S, S_A, HS, HS_A, HS_HA
from policy.mpec_ns import MPECNSPolicy
from utils.argparse_util import extend_argparser, merge_args
from utils.environment_util import make_env
from utils.experiment_util import get_train_fn, get_test_fn, get_stop_fn, get_save_best_fn, \
    get_save_checkpoint_fn, get_mpec_dict
from utils.json_util import custom_serializer
from utils.log_util import get_logger

def _get_general_argparser(*args):
    parser = argparse.ArgumentParser()
    for argparser in args:
        parser = extend_argparser(argparser(), parser)

    parser.add_argument("algo_name", type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--logdir", type=str, default="./log")
    parser.add_argument("--results-dir", type=str, default="./results")
    parser.add_argument("--render", type=float, default=0.)
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )
    parser.add_argument("--resume-path", type=str, default=None)
    parser.add_argument("--resume-id", type=str, default=None)
    parser.add_argument(
        "--logger",
        type=str,
        default="tensorboard",
        choices=["tensorboard", "wandb"],
    )
    parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
    parser.add_argument(
        "--watch",
        default=False,
        action="store_true",
        help="watch the play of pre-trained policy only"
    )
    return parser


def get_mpec_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(_get_general_argparser(*args), parser)

    parser.set_defaults(algo_name=MPEC_ALGO)
    parser.add_argument("--eps-test", type=float, default=0.005)
    parser.add_argument("--eps-train", type=float, default=1.)
    parser.add_argument("--eps-train-final", type=float, default=0.05)
    parser.add_argument("--buffer-size", type=int, default=100000)
    parser.add_argument("--lr", type=float, default=0.0000625)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--num-atoms", type=int, default=51)
    parser.add_argument("--v-min", type=float, default=-10.)
    parser.add_argument("--v-max", type=float, default=10.)
    parser.add_argument("--noisy-std", type=float, default=0.1)
    parser.add_argument("--no-dueling", action="store_true", default=False)
    parser.add_argument("--no-noisy", action="store_true", default=False)
    parser.add_argument("--no-priority", action="store_true", default=False)
    parser.add_argument("--alpha", type=float, default=0.5)
    parser.add_argument("--beta", type=float, default=0.4)
    parser.add_argument("--beta-final", type=float, default=1.)
    parser.add_argument("--beta-anneal-step", type=int, default=5000000)
    parser.add_argument("--no-weight-norm", action="store_true", default=False)
    parser.add_argument("--n-step", type=int, default=3)
    parser.add_argument("--target-update-freq", type=int, default=500)
    parser.add_argument("--epoch", type=int, default=1000)
    parser.add_argument("--step-per-epoch", type=int, default=10000)
    parser.add_argument("--step-per-collect", type=int, default=10)
    parser.add_argument("--update-per-step", type=float, default=0.1)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--training-num", type=int, default=10)
    parser.add_argument("--test-num", type=int, default=10)
    parser.add_argument("--save-buffer-name", type=str, default=None)

    parser.add_argument("--step", type=int, default=1)                  # Directory for storing all experimental data
    parser.add_argument("--grid_num", type=int, default=5)              # Directory for storing all experimental data
    parser.add_argument("--epsilon", type=float, default=0.1)            # Directory for storing all experimental data
    parser.add_argument("--mpec_lr", type=float, default=0.1) # TODO what is the default?
    parser.add_argument("--mpec_gamma", type=float, default=0.99) # TODO what is the default?
    parser.add_argument("--raw_state_dim", type=int, default=64)
    parser.add_argument("--state_dim", type=int, default=24)
    parser.add_argument("--state_min", type=float, default=0)        #
    parser.add_argument("--state_max", type=float, default=1)         # state_max, state_min
    parser.add_argument("--circular_buffer", type=bool, default=True)
    parser.add_argument("--mode", type=str, default=S_A, choices=[S, S_A, HS, HS_A, HS_HA])
    parser.add_argument("--reduction", action="store_true")   #
    parser.add_argument("--score_type", type=str, default='score')
    return parser


def test_mpec(args):
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    if args.resume_path:
        log_name = args.resume_path
        log_path = os.path.join(args.logdir, log_name)
        with open(os.path.join(log_path, "config.json"), 'r') as f:
            loaded_config = json.load(f)
            args = merge_args(args, loaded_config)
    else:
        os.mkdir(log_path)
        with open(os.path.join(log_path, "config.json"), 'w') as f:
            json.dump(vars(args), f, indent=2)

    logger = get_logger(args, os.path.join(args.logdir, log_name), log_name)

    env, train_envs, test_envs, watch_env = make_env(
        task=args.task,
        seed=args.seed,
        training_num=args.training_num,
        test_num=args.test_num,
        create_watch_env=args.render
    )
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    MPEC_DICT = get_mpec_dict(args, env)

    # define model
    net = CustomRainbow(
        *args.state_shape,
        args.action_shape,
        args.num_atoms,
        args.noisy_std,
        args.device,
        is_dueling=not args.no_dueling,
        is_noisy=not args.no_noisy
    )
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)

    # define policy
    policy = MPECNSPolicy(
        model=net,
        optim=optim,
        discount_factor=args.gamma,
        action_space=env.action_space,
        estimation_step=args.n_step,
        target_update_freq=args.target_update_freq,
        seed=args.seed,
        MPEC_DICT=MPEC_DICT,
    ).to(args.device)

    # load a previous policy
    if args.resume_path:
        resume_path = os.path.join(args.logdir, args.resume_path)
        policy.load_state_dict(torch.load(os.path.join(resume_path, "policy.pth"), map_location=args.device))
        print("Loaded agent from: ", resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    if args.no_priority:
        buffer = VectorReplayBuffer(
            args.buffer_size,
            buffer_num=len(train_envs),
            ignore_obs_next=True,
            save_only_last_obs=True,
        )
    else:
        buffer = PrioritizedVectorReplayBuffer(
            args.buffer_size,
            buffer_num=len(train_envs),
            ignore_obs_next=True,
            save_only_last_obs=True,
            alpha=args.alpha,
            beta=args.beta,
            weight_norm=not args.no_weight_norm
        )

    # collector
    train_collector = MPECCollector(policy, train_envs, buffer, exploration_noise=True, MPEC_DICT=MPEC_DICT, seed=args.seed, training=True)
    test_collector = MPECCollector(policy, test_envs, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT)

    # watch agent's performance
    def watch(prompt_policies=False):
        print("Setup test envs ...")
        if args.render:
            envs = watch_env
            # TODO: CHECK THIS
            # envs.seed(args.seed)
        else:
            envs = test_envs
            envs.seed(args.seed)
        policy.eval()
        policy.set_eps(args.eps_test)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = PrioritizedVectorReplayBuffer(
                args.buffer_size,
                buffer_num=len(envs),
                ignore_obs_next=True,
                save_only_last_obs=True,
                stack_num=args.frames_stack,
                alpha=args.alpha,
                beta=args.beta
            )
            collector = MPECCollector(policy, envs, buffer, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT)
            result = collector.collect(
                n_step=args.buffer_size, render=args.render, reset_before_collect=True, prompt_policies=prompt_policies
            )
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            collector = MPECCollector(policy, envs, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT)
            result = collector.collect(
                n_episode=args.test_num, render=args.render, reset_before_collect=True, prompt_policies=prompt_policies
            )
        return_ = result.returns.mean()
        print(f"Mean return (over {result.n_collected_episodes} episodes): {return_}")

    def save_results():
        results_dir = args.results_dir
        results_path = os.path.join(results_dir, log_name)
        if 'necsa' in args.algo_name:
            results_path = f"{results_path}_{str(args.step)}"
        os.makedirs(results_path, exist_ok=True)
        results_path = f"{results_path}.json"
        print(results_path)
        with open(results_path, 'w') as f:
            json.dump(test_collector.policy_eval_results, f, default=custom_serializer)

    if args.watch:
        watch(prompt_policies=True)
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(
        n_step=args.batch_size * args.training_num, reset_before_collect=True)
    # trainer
    trainer = OffpolicyTrainer(
        policy=policy,
        max_epoch=args.epoch,
        batch_size=args.batch_size,
        train_collector=train_collector,
        test_collector=test_collector,
        step_per_epoch=args.step_per_epoch,
        episode_per_test=args.test_num,
        update_per_step=args.update_per_step,
        step_per_collect=args.step_per_collect,
        train_fn=get_train_fn(policy, args, logger, buffer),
        test_fn=get_test_fn(policy, args),
        stop_fn=get_stop_fn(env, args),
        save_best_fn=get_save_best_fn(log_path),
        save_checkpoint_fn=get_save_checkpoint_fn(policy, log_path),
        resume_from_log=args.resume_id is not None,
        logger=logger,
        test_in_train=False,
    )

    result = trainer.run()
    pprint.pprint(result)

    watch()
    save_results()
