import sys
import os
import time
import timeit
import logging
from arguments import parser

import torch
import gym
import matplotlib as mpl
# mpl.use("macOSX")
import matplotlib.pyplot as plt

from baselines.logger import HumanOutputFormat

import minihack
from envs.multigrid import *
from envs.minihack import *
from envs.runners.runner import Runner 
from util import make_agent, FileWriter, safe_checkpoint, create_parallel_env, make_plr_args, seed
from eval import Evaluator


if __name__ == '__main__':
    os.environ["OMP_NUM_THREADS"] = "1"

    args = parser.parse_args()
    seed(args.seed)
    
    # === Configure logging ===
    if args.xpid is None:
        args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S")
    log_dir = os.path.expandvars(os.path.expanduser(args.log_dir))
    filewriter = FileWriter(
        xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir
    )
    screenshot_dir = os.path.join(log_dir, args.xpid, 'screenshots')
    if not os.path.exists(screenshot_dir):
        os.makedirs(screenshot_dir, exist_ok=True)

    def log_stats(stats):
        filewriter.log(stats)
        if args.verbose:
            HumanOutputFormat(sys.stdout).writekvs(stats)

    if args.verbose:
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.disable(logging.CRITICAL)

    # === Determine device ====
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if args.cuda else "cpu")
    if 'cuda' in device.type:
        torch.backends.cudnn.benchmark = True
        print('Using CUDA\n')

    # === Create parallel envs ===
    venv = create_parallel_env(args)

    agent = make_agent(name='agent', env=venv, args=args, device=device)

    # === Create runner ===
    plr_args = None
    if args.use_plr:
        plr_args = make_plr_args(args, venv.observation_space, venv.action_space)
    train_runner = Runner(
        args=args,
        venv=venv,
        agent=agent, 
        train=True,
        plr_args=plr_args,
        device=device)

    # === Configure checkpointing ===
    timer = timeit.default_timer
    last_checkpoint_time = None
    initial_update_count = 0
    last_logged_update_at_restart = -1
    checkpoint_path = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model.tar"))
    )

    def checkpoint(index=None):
        if args.disable_checkpoint:
            return
        safe_checkpoint({'runner_state_dict': train_runner.state_dict()}, 
                        checkpoint_path,
                        index=index, 
                        archive_interval=args.archive_interval)
        logging.info("Saved checkpoint to %s", checkpoint_path)


    # === Load checkpoint ===
    if args.checkpoint and os.path.exists(checkpoint_path):
        checkpoint_states = torch.load(checkpoint_path)
        last_logged_update_at_restart = filewriter.latest_tick() # ticks are 0-indexed updates
        train_runner.load_state_dict(checkpoint_states['runner_state_dict'])
        initial_update_count = train_runner.num_updates
        logging.info(f"Resuming preempted job after {initial_update_count} updates\n") # 0-indexed next update

    # Set up Evaluator
    evaluator = None
    test_env_names = args.test_env_names.split(',')
    if len(args.test_env_names) > 0 and len(test_env_names) > 0:
        test_env_kwargs = {
            'p': eval(args.p),
            'reward_dist': args.reward_dist,
        }

        # Assumes all test envs belong to the same environment family
        is_multigrid = 'MultiGrid' in test_env_names[0]
        is_minihack = 'MiniHack' in test_env_names[0]

        if is_minihack:
            test_env_kwargs.update({
                'observation_keys':("glyphs", "blstats", "message"),
                'fully_observable': args.fully_observable,
            })

        if 'MultiRoomBC' in test_env_names[0] or \
            'BinaryChoice' in test_env_names[0]:
            test_env_kwargs.update({
                'rewards': eval(args.stochastic_choice_rewards),
                'reward_spreads': eval(args.stochastic_choice_reward_spreads),
            })

        if is_multigrid:
            if 'BinaryChoice' in args.env_name:
                test_env_kwargs.update({'use_walls': args.stochastic_choice_use_walls})

        if is_minihack:
            if 'BinaryChoice' in args.env_name:
                test_env_kwargs.update({  
                    'goal_hint_p': args.goal_hint_p,
                })

        evaluator = Evaluator(
            test_env_names, 
            num_processes=args.test_num_processes, 
            num_episodes=args.test_num_episodes,
            device=device,
            **test_env_kwargs)

    # === Train === 
    update_start_time = timer()
    num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(initial_update_count, num_updates):
        stats = train_runner.run()

        # === Perform logging ===
        if train_runner.num_updates <= last_logged_update_at_restart:
            continue

        log = (j % args.log_interval == 0) or j == num_updates - 1
        save_screenshot = \
            args.screenshot_interval > 0 and \
                (j % args.screenshot_interval == 0)

        if log:
            # Eval
            test_stats = {}
            if evaluator is not None:
                if (j % args.test_interval == 0 or j == num_updates - 1):
                    test_stats = evaluator.evaluate(train_runner.agent)
                    stats.update(test_stats)
                else:
                    stats.update({k:None for k in evaluator.get_stats_keys()})

            update_end_time = timer()
            num_incremental_updates = 1 if j == 0 else args.log_interval
            sps = num_incremental_updates*(args.num_processes * args.num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            stats.update({'sps': sps})
            stats.update(test_stats) # ensures sps column is always before test stats
            log_stats(stats)

        if last_checkpoint_time is None:
            last_checkpoint_time = timer()
        if j == num_updates - 1 or \
            (args.save_interval > 0 and timer() - last_checkpoint_time > args.save_interval * 60):
            checkpoint(train_runner.num_updates)
            last_checkpoint_time = timer()
            logging.info(f"\nSaved checkpoint after update {j}")
        elif train_runner.num_updates > 0 and args.archive_interval > 0 \
            and train_runner.num_updates % args.archive_interval == 0:
            checkpoint(train_runner.num_updates)
            last_checkpoint_time = timer()  
            logging.info(f"\nSaved checkpoint after update {j}")

        if save_screenshot:
            # Only screenshot the first seed
            venv.reset_agent()

            mode = None
            if args.env_name.startswith('MiniHack'):
                # ASCII obs
                with open(os.path.join(screenshot_dir, f'update_{j}.txt'), 'w+') as fout:
                    fout.write(venv.get_grid_str()[0])

                # des file
                with open(os.path.join(screenshot_dir, f'update_{j}.des'), 'w+') as fout:
                    fout.write(venv.get_des_file()[0])

            images = venv.get_images()
            plt.axis('off')
            plt.imshow(images[0])
            plt.savefig(os.path.join(screenshot_dir, f'update_{j}.png'), bbox_inches='tight')
            plt.close()

    evaluator.close()
    venv.close()
