import argparse
import datetime
import os, json
import pprint
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch
from gymnasium.spaces import Dict

from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import Net

from _tianshou_custom.data.buffer.CustomPrioritizedVectorReplayBuffer import CustomPrioritizedVectorReplayBuffer
from _tianshou_custom.data.buffer.CustomVectorReplayBuffer import CustomVectorReplayBuffer
from policy.mpec_ns import MPECNSPolicy
from algorithms import MPEC_ALGO
from data.mpec.mpec_collector import MPECCollector
from data.necsa.abstraction_mode import S, S_A, HS, HS_A, HS_HA
from policy.mpec_s import MPECSPolicy
from utils.argparse_util import extend_argparser, parse_task_wrapper, merge_args, parse_task_params
from utils.environment_util import make_env
from utils.experiment_util import get_test_fn, get_mpec_dict, set_state_shape
from utils.json_util import custom_serializer
from utils.log_util import get_logger

def _get_general_argparser(*args):
    parser = argparse.ArgumentParser()
    for argparser in args:
        parser = extend_argparser(argparser(), parser)

    parser.add_argument("algo_name", type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--logdir", type=str, default="./log")
    parser.add_argument("--log-name", type=str, default=None)
    parser.add_argument("--results-dir", type=str, default="./results")
    parser.add_argument("--render", type=float, default=0.)
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )
    parser.add_argument("--resume-path", type=str, default=None)
    parser.add_argument("--resume-id", type=str, default=None)
    parser.add_argument(
        "--logger",
        type=str,
        default="tensorboard",
        choices=["tensorboard", "wandb"],
    )
    parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
    parser.add_argument(
        "--watch",
        default=False,
        action="store_true",
        help="watch the play of pre-trained policy only"
    )

    parser.add_argument('--task-params', action='append', type=parse_task_params, default=[])
    parser.add_argument('--task-wrapper', action='append', type=parse_task_wrapper, default=[])
    return parser

def get_mpec_discrete_debug_argparser(*args):
    parser = argparse.ArgumentParser()
    for argparser in args:
        parser = extend_argparser(argparser(), parser)

    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--as1", action="store_true", default=False)
    parser.add_argument("--as2", action="store_true", default=False)
    parser.add_argument("--as3", action="store_true", default=False)
    parser.add_argument("--as4", action="store_true", default=False)
    parser.add_argument("--as5", action="store_true", default=False)
    parser.add_argument("--as6", action="store_true", default=False)
    parser.add_argument("--as7", action="store_true", default=False)
    parser.add_argument("--as8", action="store_true", default=False)

    parser.add_argument("--debug-naive-selection", action="store_true", default=False)
    parser.add_argument("--debug-track-policies", action="store_true", default=False)
    parser.add_argument("--debug-disable-trajectory-length", action="store_true", default=False)
    parser.add_argument("--debug-disable-average-reward", action="store_true", default=False)
    parser.add_argument("--debug-learning-rate", type=float, default=1e-3)
    parser.add_argument("--debug-discount-factor", type=float, default=0.9)
    parser.add_argument("--debug-track-trajectories-length", action="store_true", default=False)
    parser.add_argument("--debug-disable-cycle-detection", action="store_true", default=False)
    parser.add_argument("--debug-disable-ssm", action="store_true", default=False)
    parser.add_argument("--debug-disable-reconnection", action="store_true", default=False)
    parser.add_argument("--debug-track-trajectories-split-and-mismatches", action="store_true", default=False)

    return parser

def get_mpec_discrete_argparser(*args):
    parser = argparse.ArgumentParser()
    parser = extend_argparser(_get_general_argparser(*args), parser)
    parser = extend_argparser(get_mpec_discrete_debug_argparser(*args), parser)

    parser.set_defaults(algo_name=MPEC_ALGO)
    parser.set_defaults(no_priority=True)
    parser.add_argument("--stationary", action='store_true')
    parser.add_argument("--eps-test", type=float, default=0)
    parser.add_argument("--eps-train", type=float, default=1)
    parser.add_argument("--eps-train-final", type=float, default=0.1) ##
    parser.add_argument("--eps-decay", type=float, default=0.9)
    parser.add_argument("--buffer-size", type=int, default=20000)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.9)
    parser.add_argument("--n-step", type=int, default=3)
    parser.add_argument("--target-update-freq", type=int, default=320)
    parser.add_argument("--epoch", type=int, default=1)
    parser.add_argument("--step-per-epoch", type=int, default=10000)
    parser.add_argument("--step-per-collect", type=int, default=1)
    parser.add_argument("--update-per-step", type=float, default=1)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--training-num", type=int, default=1)
    parser.add_argument("--test-num", type=int, default=1)
    parser.add_argument("--save-buffer-name", type=str, default=None) ##

    parser.add_argument("--step", type=int, default=1)                  # Directory for storing all experimental data
    parser.add_argument("--grid-num", type=int, default=5)              # Directory for storing all experimental data
    parser.add_argument("--mpec-lr", type=float, default=1) # TODO what is the default?
    parser.add_argument("--mpec-gamma", type=float, default=1) # TODO what is the default?
    parser.add_argument("--raw-state-dim", type=int, default=64)
    parser.add_argument("--state-dim", type=int, default=24)
    parser.add_argument("--state-min", type=float, default=0)        #
    parser.add_argument("--state-max", type=float, default=1)         # state_max, state_min
    parser.add_argument("--circular-buffer", type=bool, default=True)
    parser.add_argument("--mode", type=str, default=S_A, choices=[S, S_A, HS, HS_A, HS_HA])
    parser.add_argument("--reduction", action="store_true")   #
    parser.add_argument("--score-type", type=str, default='score')

    parser.add_argument("--use-normalized-returns", action="store_true", default=False)
    parser.add_argument("--max-trajectory-length", type=int, default=None)
    parser.add_argument("--policy-domination-decimal-places", type=int, nargs='+', default=None)
    parser.add_argument("--dont-ask-for-policy", action="store_true", default=False)
    parser.add_argument("--terminate-if-no-policy", action="store_true", default=False)
    return parser


def test_mpec_discrete(args):
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    log_name = args.log_name
    if log_name is None:
        log_name = now
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), log_name)
    log_path = os.path.join(args.logdir, log_name)

    if args.resume_path:
        log_name = args.resume_path
        log_path = os.path.join(args.logdir, log_name)
        with open(os.path.join(log_path, "config.json"), 'r') as f:
            loaded_config = json.load(f)
            args = merge_args(args, loaded_config)
    else:
        Path(log_path).mkdir(parents=True, exist_ok=True)
        with open(os.path.join(log_path, "config.json"), 'w') as f:
            json.dump(vars(args), f, indent=2)

    logger = get_logger(args, os.path.join(args.logdir, log_name), log_name)

    args.no_priority = True
    env, train_envs, test_envs, watch_env = make_env(
        task=args.task,
        seed=args.seed,
        training_num=args.training_num,
        test_num=args.test_num,
        create_watch_env=True if args.render else False,
        params=args.task_params,
        wrappers=args.task_wrapper
    )
    set_state_shape(args, env)
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    MPEC_DICT = get_mpec_dict(args, env)

    # # define model
    net = Net(state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=[128, 128, 128])
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)

    if args.stationary:
        POLICY_CLASS = MPECSPolicy
    else:
        POLICY_CLASS = MPECNSPolicy

    policy = POLICY_CLASS(
        model=net,
        optim=optim,
        discount_factor=args.gamma,
        action_space=env.action_space,
        estimation_step=args.n_step,
        target_update_freq=args.target_update_freq,
        seed=args.seed,
        MPEC_DICT=MPEC_DICT,
        ordered_obs_keys=list(env.observation_space.spaces.keys()) if isinstance(env.observation_space, Dict) else None,
        use_normalized_returns=False,
    )

    # load a previous policy
    if args.resume_path:
        resume_path = os.path.join(args.logdir, args.resume_path)
        policy.load_state_dict(torch.load(os.path.join(resume_path, "policy.pth"), map_location=args.device))
        policy.load_memories(resume_path)
        print("Loaded agent from: ", resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    if args.no_priority:
        buffer = CustomVectorReplayBuffer(
            args.buffer_size,
            buffer_num=len(train_envs),
        )
    else:
        buffer = CustomPrioritizedVectorReplayBuffer(
            args.buffer_size,
            buffer_num=len(train_envs),
            ignore_obs_next=True,
            save_only_last_obs=True,
            alpha=args.alpha,
            beta=args.beta,
            weight_norm=not args.no_weight_norm
        )

    # collector
    train_collector = MPECCollector(policy, train_envs, buffer, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT, training=True)
    test_collector = MPECCollector(policy, test_envs, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT)

    # watch agent's performance
    def watch(prompt_policies=False):
        print("Setup test envs ...")
        if args.render:
            envs = watch_env
            # TODO: CHECK THIS
            # envs.seed(args.seed)
        else:
            envs = test_envs
            envs.seed(args.seed)
        policy.eval()
        policy.set_eps(args.eps_test)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = CustomPrioritizedVectorReplayBuffer(
                args.buffer_size,
                buffer_num=len(envs),
            )
            collector = MPECCollector(policy, envs, buffer, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT)
            result = collector.collect(
                n_step=args.buffer_size, render=args.render, reset_before_collect=True, prompt_policies=prompt_policies
            )
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            collector = MPECCollector(policy, envs, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT)
            result = collector.collect(
                n_episode=args.test_num, render=args.render, reset_before_collect=True, prompt_policies=prompt_policies
            )
        return_ = result.returns.mean()
        print(f"Mean return (over {result.n_collected_episodes} episodes): {return_}")

    def save_results():
        results_dir = args.results_dir
        results_path = os.path.join(results_dir, log_name)
        if 'necsa' in args.algo_name:
            results_path = f"{results_path}_{str(args.step)}"
        os.makedirs(results_path, exist_ok=True)
        results_path = f"{results_path}.json"
        print(results_path)
        if args.debug and any([args.as1, args.as6, args.as7]):
            global eval_results
            with open(results_path, 'w') as f:
                json.dump(eval_results, f, default=custom_serializer)
            return
        with open(results_path, 'w') as f:
            json.dump(test_collector.policy_eval_results, f, default=custom_serializer)

    def save_debug_results():
        results_dir = args.results_dir
        debug_results_path = os.path.join(results_dir, f"{log_name}_debug")
        if 'necsa' in args.algo_name:
            debug_results_path = f"{debug_results_path}_{str(args.step)}"
        os.makedirs(debug_results_path, exist_ok=True)
        debug_results_path = f"{debug_results_path}.json"
        with open(debug_results_path, 'w') as f:
            json.dump(train_collector.policy_debug_results, f, default=custom_serializer)

    if args.watch:
        watch(prompt_policies=True)
        exit(0)

    eval_results = None

    def get_test_all_policies_fn(policy, args):
        def test_fn(epoch, env_step):
            global eval_results

            _, _, test_envs, _ = make_env(
                task=args.task,
                seed=args.seed,
                training_num=args.training_num,
                test_num=args.test_num,
                create_watch_env=True if args.render else False,
                wrappers=args.task_wrapper,
                disable_timelimit=True
            )

            envs = test_envs
            envs.seed(args.seed)

            test_collector = MPECCollector(policy, envs, exploration_noise=True, seed=args.seed, MPEC_DICT=MPEC_DICT)

            policy.set_eps(args.eps_test)
            n_step_test = 1000

            non_dominated, values = policy.fetch_policies(np.array([0, 0]))

            for sapi, return_ in zip(non_dominated, values):
                n_step = policy.set_n_step_test(n_step_test, return_)
                test_collector.collect(n_step=n_step, reset_before_collect=True, chosen_policy=sapi, chosen_return=return_)
            eval_results = deepcopy(test_collector.policy_eval_results)

        return test_fn

    if any([args.as1, args.as6, args.as7]):
        test_fn = get_test_all_policies_fn(policy, args)
    else:
        test_fn = get_test_fn(policy, args)

    # trainer
    trainer = OffpolicyTrainer(
        policy=policy,
        max_epoch=args.epoch,
        batch_size=args.batch_size,
        train_collector=train_collector,
        test_collector=test_collector,
        step_per_epoch=args.step_per_epoch,
        episode_per_test=args.test_num,
        update_per_step=args.update_per_step,
        step_per_collect=args.step_per_collect,
        train_fn=get_train_fn(policy, args, logger, buffer),
        test_fn=test_fn,
        stop_fn=get_stop_fn(policy, log_path),
        save_best_fn=get_save_best_and_memory_fn(log_path),
        save_checkpoint_fn=get_save_checkpoint_fn(policy, log_path),
        resume_from_log=args.resume_id is not None,
        logger=logger,
        test_in_train=False,
    )

    result = trainer.run()
    pprint.pprint(result)

    watch()
    save_results()
    save_debug_results()

def get_train_fn(policy, args, logger, buffer=None):

    def train_fn(epoch, env_step):

        eps = max(args.eps_train_final, args.eps_train * (args.eps_decay ** (epoch - 1)))
        policy.set_eps(eps)
        if env_step % 1000 == 0:
            logger.write("train/env_step", env_step, {"train/eps": eps})

    return train_fn

def get_save_checkpoint_fn(policy, log_path):

    def save_checkpoint_fn(epoch, env_step, gradient_step):
        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
        ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}")
        os.makedirs(ckpt_path, exist_ok=True)
        torch.save({"model": policy.state_dict()}, os.path.join(ckpt_path, f"checkpoint_{epoch}.pth"))
        policy.save_memories(ckpt_path)
        return ckpt_path

    return save_checkpoint_fn

def get_stop_fn(policy, log_path):

    _save_best_and_memory_fn = get_save_best_and_memory_fn(log_path)
    def stop_fn(mean_rewards):
        _save_best_and_memory_fn(policy)
        return False

    return stop_fn

def get_save_best_and_memory_fn(log_path):

    def save_best_and_memory_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
        policy.save_memories(log_path)

    return save_best_and_memory_fn
