import argparse
import datetime
import logging.config
import os
import time
import wandb
import socket

import numpy as np
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp

from core.test import test
from core.train import initialize_trainer, initialize_worker
from core.utils import init_logger, make_results_dir, set_seed


def DDP_setup(rank, world_size):
    # set master nod
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def main():
    # gather arguments
    parser = argparse.ArgumentParser(description='SpeedyZero')
    parser.add_argument('--env', required=True, help='Name of the environment')
    parser.add_argument('--result_dir', default=os.path.join(os.getcwd(), 'results'),
                        help="Directory Path to store results (default: %(default)s)")
    parser.add_argument('--case', required=True, choices=['atari'],
                        help="It's used for switching between different domains(default: %(default)s)")
    parser.add_argument('--opr', required=True, choices=['train', 'test', 'worker', 'debug'])
    parser.add_argument('--amp_type', required=True, choices=['torch_amp', 'none'],
                        help='choose automated mixed precision type')
    parser.add_argument('--no_cuda', action='store_true', default=False, help='no cuda usage (default: %(default)s)')
    parser.add_argument('--debug', action='store_true', default=False,
                        help='If enabled, logs additional values  '
                             '(gradients, target value, reward distribution, etc.) (default: %(default)s)')
    parser.add_argument('--render', action='store_true', default=False,
                        help='Renders the environment (default: %(default)s)')
    parser.add_argument('--save_video', action='store_true', default=False, help='save video in test.')
    parser.add_argument('--force', action='store_true', default=False,
                        help='Overrides past results (default: %(default)s)')
    parser.add_argument('--cpu_actor', type=int, default=14, help='batch cpu actor')
    parser.add_argument('--gpu_actor', type=int, default=20, help='batch gpu actor')
    parser.add_argument('--master_cpu_actor', type=int, default=14, help='batch cpu actor in master')
    parser.add_argument('--master_gpu_actor', type=int, default=20, help='batch gpu actor in master')
    parser.add_argument('--p_mcts_num', type=int, default=8, help='number of parallel mcts')
    parser.add_argument('--seed', type=int, default=0, help='seed (default: %(default)s)')
    parser.add_argument('--num_gpus', type=int, default=4, help='gpus available')
    parser.add_argument('--num_cpus', type=int, default=80, help='cpus available')
    parser.add_argument('--priority_updater', type=int, default=8, help='update priority')
    parser.add_argument('--value_updater', type=int, default=8, help='update value and priority')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size for each trainer')
    parser.add_argument('--eff_batch_size', type=int, default=512, help='effective batch size')
    parser.add_argument('--revisit_policy_search_rate', type=float, default=0.99,
                        help='Rate at which target policy is re-estimated (default: %(default)s)')
    parser.add_argument('--use_root_value', action='store_true', default=False,
                        help='choose to use root value in reanalyzing')
    parser.add_argument('--use_priority', action='store_true', default=False,
                        help='Uses priority for data sampling in replay buffer. '
                             'Also, priority for new data is calculated based on loss (default: False)')
    parser.add_argument('--use_max_priority', action='store_true', default=False, help='max priority')
    parser.add_argument('--test_episodes', type=int, default=10, help='Evaluation episode count (default: %(default)s)')
    parser.add_argument('--use_augmentation', action='store_true', default=True, help='use augmentation')
    parser.add_argument('--augmentation', type=str, default=['shift', 'intensity'], nargs='+',
                        choices=['none', 'rrc', 'affine', 'crop', 'blur', 'shift', 'intensity'],
                        help='Style of augmentation')
    parser.add_argument('--info', type=str, default='none', help='debug string')
    parser.add_argument('--load_model', action='store_true', default=False, help='choose to load model')
    parser.add_argument('--model_path', type=str, default='./results/test_model.p', help='load model path')
    parser.add_argument('--object_store_memory', type=int, default=150 * 1024 * 1024 * 1024, help='object store memory')
    # system parameters, set world size and local rank will be automatically set
    parser.add_argument('--world_size', type=int, default=-1, help='world size')
    parser.add_argument('--local_rank', type=int, default=-1, help='node rank for distributed training')
    parser.add_argument('--worker_node_id', type=int, default=-1, help='worker node id')
    parser.add_argument('--num_worker_nodes', type=int, default=0, help='number of worker nodes')
    parser.add_argument('--wandb_tags', nargs='+', help="wandb tags to your experiment", default=[])
    parser.add_argument('--run_id', default='', type=str)
    parser.add_argument('--num_test_procs', type=int, default=1, help='number of test processes')

    # Process arguments and start main
    input_args = parser.parse_args()
    if input_args.opr == "worker" or input_args.opr == "test" or input_args.opr == "debug":
        system_setup(rank=-1, args=input_args)
    elif input_args.opr == "train":
        mp.spawn(system_setup, args=(input_args,), nprocs=input_args.world_size)
    else:
        raise NotImplementedError

def system_setup(rank, args):
    args.local_rank = rank
    args.device = 'cuda' if (not args.no_cuda) and torch.cuda.is_available() else 'cpu'
    assert args.revisit_policy_search_rate is None or 0 <= args.revisit_policy_search_rate <= 1, \
        ' Revisit policy search rate should be in [0,1]'

    # set up environment
    if args.opr == 'train':
        # setup DDP
        DDP_setup(rank=args.local_rank, world_size=args.world_size)
        print(f"[main process] rank {args.local_rank} trainer has been initialized")
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.local_rank}"

        # seeding random iterators
        set_seed(args.seed + args.local_rank)

        # import corresponding configuration , neural networks and envs
        if args.case == 'atari':
            from config.atari import game_config
        else:
            raise Exception('Invalid --case option')

        # set config as per arguments
        exp_path = game_config.set_config(args)
        if args.local_rank == 0:
            # set-up logger
            exp_path, log_base_path = make_results_dir(exp_path, args)
            init_logger(log_base_path)
            logging.getLogger('train').info('Path: {}'.format(exp_path))
            logging.getLogger('train').info('Param: {}'.format(game_config.get_hparams()))
    else:
        # seeding random iterators
        set_seed(args.seed)

        # import corresponding configuration , neural networks and envs
        if args.case == 'atari':
            from config.atari import game_config
        else:
            raise Exception('Invalid --case option')

        # set config as per arguments
        exp_path = game_config.set_config(args)
        # set-up logger
        exp_path, log_base_path = make_results_dir(exp_path, args)
        init_logger(log_base_path)
        logging.getLogger('train').info('Path: {}'.format(exp_path))
        logging.getLogger('train').info('Param: {}'.format(game_config.get_hparams()))

    device = game_config.device
    try:
        if args.opr == 'train':
            # train
            if args.local_rank == 0:
                localtime = datetime.datetime.now().strftime("%m-%d-%H")
                wandb_run = wandb.init(
                    config=game_config,
                    project='Atari-EfficientZero',
                    entity='speedyzero',
                    notes=socket.gethostname(),
                    name='test',
                    group=f"{args.env}-bs{args.eff_batch_size}-{args.info}-seed{args.seed}-time{localtime}-run{game_config.run_id}",
                    dir=exp_path,
                    job_type="train",
                    reinit=True,
                    tags=args.wandb_tags,
                    mode='online'
                )
            else:
                wandb_run = None
            if args.load_model and os.path.exists(args.model_path):
                model_path = args.model_path
            else:
                model_path = None
            model, weights, terminate_all = initialize_trainer(config=game_config, model_path=model_path,
                                                local_rank=args.local_rank)

            # test
            if args.local_rank == 0:
                model.set_weights(weights)
                total_steps = game_config.training_steps + game_config.last_steps
                test_score, test_path = test(game_config, model.to(device), total_steps, game_config.test_episodes,
                                             device, render=False, save_video=args.save_video, final_test=True,
                                             use_pb=True)
                mean_score = test_score.mean()
                std_score = test_score.std()

                test_log = {
                    'mean_score': mean_score,
                    'std_score': std_score,
                }
                for key, val in test_log.items():
                    wandb_run.log({'train/{}'.format(key): np.mean(val)}, step=total_steps)

                print(f"Test Done")
                test_msg = '#{:<10} Test Mean Score of {}: {:<10} (max: {:<10}, min:{:<10}, std: {:<10})' \
                           ''.format(total_steps, game_config.env_name, mean_score, test_score.max(), test_score.min(),
                                     std_score)
                logging.getLogger('train_test').info(test_msg)
                if args.save_video:
                    logging.getLogger('train_test').info('Saving video in path: {}'.format(test_path))
                
            if terminate_all is not None:
                terminate_all()
            if wandb_run is not None:
                wandb_run.finish()

        elif args.opr == 'worker':
            # train
            if args.load_model and os.path.exists(args.model_path):
                model_path = args.model_path
            else:
                model_path = None
            initialize_worker(config=game_config, exp_path=exp_path, model_path=model_path)
        elif args.opr == 'test':
            assert args.load_model
            if args.model_path is None:
                model_path = game_config.model_path
            else:
                model_path = args.model_path
            print(model_path)
            assert os.path.exists(model_path), 'model not found at {}'.format(model_path)

            model = game_config.get_uniform_network(is_trainer=True).to(device)
            target_model = game_config.get_uniform_network(is_trainer=True).to(device)
            ckpt=torch.load(model_path, map_location=torch.device(device))
            ckpt={k.replace('module.', ''): v for k, v in ckpt.items()}
            model.load_state_dict(ckpt)
            target_model.load_state_dict(ckpt)
            
            if args.num_test_procs == 1:
                test_score, test_path = test(game_config, model, 0, args.test_episodes, device=device, render=args.render,
                                            save_video=args.save_video, final_test=True, use_pb=True)
                mean_score = test_score.mean()
                std_score = test_score.std()
                logging.getLogger('test').info('Test Mean Score: {} (max: {}, min: {})'.format(mean_score, test_score.max(), test_score.min()))
                logging.getLogger('test').info('Test Std Score: {}'.format(std_score))
                if args.save_video:
                    logging.getLogger('test').info('Saving video in path: {}'.format(test_path))
            else:
                import multiprocessing as mp
                from core.storage_config import StorageConfig
                from core.shared_storage import get_shared_storage
                from core.test import start_test

                # initialize storage config and multiprocessing context
                storage_config = StorageConfig(label="test")
                ctx = mp.get_context('spawn')
                
                """"""""""""""""""""""""""""""""""""""" Workers """""""""""""""""""""""""""""""""""""""
                test_processes = [ctx.Process(target=start_test, 
                                            args=(game_config, storage_config, True, 
                                                    np.arange(i, args.test_episodes + 1, args.num_test_procs), f'test-p{i}', (i==1), len(np.arange(i, args.test_episodes + 1, args.num_test_procs)), i % 4))
                                for i in range(1, args.num_test_procs + 1)]
                for test_process in test_processes:
                    test_process.start()
                    time.sleep(0.1)
                print("[main process] Test processes have all been launched.")

                """"""""""""""""""""""""""""""""""""""" Logging """""""""""""""""""""""""""""""""""""""
                shared_storage = get_shared_storage(storage_config=storage_config)
                test_dict_log = {}
                while True:
                    test_dict_log = shared_storage.get_test_dict_log()
                    logging.getLogger('test').info('Test finish: {}'.format(test_dict_log.keys()))
                    if len(test_dict_log.keys()) > 0:
                        test_score = np.concatenate([d['all_score'] for k, d in test_dict_log.items()], axis=0)
                        logging.getLogger('test').info('Test Mean Score: {} (max: {}, min: {})'.format(test_score.mean(), test_score.max(), test_score.min()))
                        logging.getLogger('test').info('Test Std Score: {}'.format(test_score.std()))
                    if len(test_dict_log.keys()) == args.num_test_procs:
                        break
                    time.sleep(5)
        elif args.opr == 'debug':
            env = game_config.new_game(seed=0)
            for traj_id in range(100):
                _ = env.reset()
                done = False
                rewards = []
                while not done:
                    obs, rew, done, _ = env.step(np.random.randint(0, env.action_space_size))
                    rew = np.sign(rew)
                    rewards.append(rew)
                rewards = np.asarray(rewards)
                print(len(rewards), rewards.sum(), (rewards>0).astype(np.int32).sum(), (rewards<0).astype(np.int32).sum())

        else:
            raise Exception('Please select a valid operation(--opr) to be performed')
    except Exception as e:
        logging.getLogger('root').error(e, exc_info=True)


if __name__ == '__main__':
    main()
