import argparse
import logging.config
import os
import sys

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from core.test import test
from core.train import train
from core.utils import init_logger, make_results_dir
from core.env import EnvBatcher

if __name__ == '__main__':

    # Let's gather arguments
    parser = argparse.ArgumentParser(description='Inducing Search in Dreamer')
    parser.add_argument('--env', required=True, help='Name of the environment')
    parser.add_argument('--result_dir', default=os.path.join(os.getcwd(), 'results'),
                        help="Directory Path to store results (default: %(default)s)")
    parser.add_argument('--wandb_dir', default=os.path.join(os.getcwd(), 'wandb'),
                        help="Directory Path to store results (default: %(default)s)")
    parser.add_argument('--case', required=True, choices=['dm_control', 'box2d', 'classic_control'],
                        help="It's used for switching between different domains(default: %(default)s)")
    parser.add_argument('--test_search_mode', choices=['no-search', 'rollout', 'mcts', 'mcts+fixed'], default='mcts',
                        help="It's used for switching between different test modes(default: %(default)s)")
    parser.add_argument('--explore_mode', choices=['no-search', 'rollout', 'mcts','mcts+fixed'], default='mcts',
                        help='perform search during exploration (default: %(default)s)')
    parser.add_argument('--update_mode', choices=['separate', 'together', 'together_with_grad_flow'],
                        default='together',
                        help='perform search during exploration (default: %(default)s)')
    parser.add_argument('--anneal_update_itr', action='store_true', default=False,
                        help='no cuda usage (default: %(default)s)')
    parser.add_argument('--uniform_action_sample', type=int, default=50, metavar='C',
                        help='Collect interval  (default: %(default)s)')
    parser.add_argument('--opr', required=True, choices=['train', 'test'])
    parser.add_argument('--no_cuda', action='store_true', default=False,
                        help='no cuda usage (default: %(default)s)')
    parser.add_argument('--automatic_entropy_tuning', action='store_true', default=False,
                        help='no cuda usage (default: %(default)s)')
    parser.add_argument('--action-noise', type=float, default=0.3, metavar='ε',
                        help='Action noise  (default: %(default)s)')
    parser.add_argument('--collect_interval', type=int, default=100, metavar='C',
                        help='Collect interval  (default: %(default)s)')
    parser.add_argument('--batch_size', type=int, default=50, metavar='B',
                        help='Batch size (default: %(default)s)')
    parser.add_argument('--chunk_size', type=int, default=50, metavar='L',
                        help='Chunk size (default: %(default)s)')
    parser.add_argument('--global-kl-beta', type=float, default=0, metavar='βg',
                        help='Global KL weight (0 to disable)')
    parser.add_argument('--free-nats', type=float, default=3, metavar='F',
                        help='Free nats (default: %(default)s)')
    parser.add_argument('--dynamics_lr', type=float, default=1e-3, metavar='α',
                        help='Learning rate (default: %(default)s)')
    parser.add_argument('--actor_lr', type=float, default=8e-5, metavar='α',
                        help='Learning rate (default: %(default)s)')
    parser.add_argument('--value_lr', type=float, default=8e-5, metavar='α',
                        help='Learning rate (default: %(default)s)')
    parser.add_argument('--grad-clip-norm', type=float, default=100.0, metavar='C',
                        help='Gradient clipping norm')
    parser.add_argument('--planning-horizon', type=int, default=15, metavar='H',
                        help='Planning horizon distance')
    parser.add_argument('--discount', type=float, default=0.99, metavar='H',
                        help='Planning horizon distance')
    parser.add_argument('--disclam', type=float, default=0.95, metavar='H',
                        help='discount rate to compute return')
    parser.add_argument('--repeat_entropy_coeff', type=float, default=0.01, metavar='H',
                        help='discount rate to compute return')
    parser.add_argument('--actor_entropy_coeff', type=float, default=0.01, metavar='H',
                        help='discount rate to compute return')
    parser.add_argument('--action_repeat_set', type=str,
                        help='list of action repeats separated by commas (eg: "2,3,4"')
    parser.add_argument('--test_episodes', type=int, default=1,
                        help='Evaluation episode count (default: %(default)s)')
    parser.add_argument('--render', action='store_true', default=False,
                        help='Renders the environment (default: %(default)s)')
    parser.add_argument('--optimize_with_search', action='store_true', default=False,
                        help='optimize policy using search (default: %(default)s)')
    parser.add_argument('--force', action='store_true', default=False,
                        help='Overrides past results (default: %(default)s)')
    parser.add_argument('--seed', type=int, default=0, help='seed (default: %(default)s)')
    parser.add_argument('--use_wandb', action='store_true', default=False,
                        help='Use Weight and bias visualization lib for logging. (default: %(default)s)')

    # Process arguments
    args = parser.parse_args()
    args.device = 'cuda' if (not args.no_cuda) and torch.cuda.is_available() else 'cpu'

    # seeding random iterators
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # import corresponding configuration , neural networks and envs
    if args.case == 'classic_control':
        from config.classic_control import run_config
    elif args.case == 'box2d':
        from config.box2d import run_config
    elif args.case == 'dm_control':
        from config.dm_control import run_config
    else:
        raise Exception('Invalid --case option.')

    # set config as per cmd arguments
    run_config.set_config(args)
    log_base_path = make_results_dir(run_config.exp_path, args)

    # set-up logger
    init_logger(log_base_path)
    logging.getLogger('root').info('cmd args:{}'.format(' '.join(sys.argv[1:])))  # log command line arguments.

    try:
        if args.opr == 'train':
            if args.use_wandb:
                import wandb

                wandb.init(dir=args.wandb_dir, group=args.case + ':' + args.env, project="dreamer-pytorch",
                           config=run_config.get_hparams(), sync_tensorboard=True)

            summary_writer = SummaryWriter(run_config.exp_path, flush_secs=60 * 1)  # flush every 1 minutes
            train(run_config, summary_writer)
            summary_writer.flush()
            summary_writer.close()

            if args.use_wandb:
                wandb.join()
        elif args.opr == 'test':
            model_path = run_config.model_path
            assert os.path.exists(model_path), 'model not found: {}'.format(model_path)

            model = run_config.get_uniform_network()
            model = model.to('cpu')
            model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
            env_batch = EnvBatcher(run_config.new_game, args.test_episodes)
            test_score, test_repeat_counts = test(env_batch, model, render=args.render, save_video=True,
                                                  save_test_data=True, save_path=run_config.test_data_path,
                                                  recording_path=run_config.recording_path,
                                                  config=run_config, mode=args.test_search_mode,
                                                  mcts_num_simulations=run_config.num_simulations)
            env_batch.close()
            logging.getLogger('test').info('Test Score: {}'.format(test_score))
        else:
            raise NotImplementedError('"--opr {}" is not implemented ( or invalid)'.format(args.opr))

    except Exception as e:
        logging.getLogger('root').error(e, exc_info=True)

    logging.shutdown()
