import os
from warnings import simplefilter  # noqa
from os.path import join as pJoin
from pathlib import Path

import gym
import numpy as np
from params_proto.neo_proto import PrefixProto
from tqdm import trange

from . import utils
from .algorithms.factory import make_agent
from .env_helpers import get_env
from tools.notifier.slack_sender import slack_sender

simplefilter(action='ignore', category=DeprecationWarning)
gym.logger.set_level(40)


class Progress(PrefixProto, cli=False):
    step = 0
    episode = 0
    wall_time = 0
    frame = 0


def save_final_checkpoint(agent):
    from .config import Args
    from ml_logger import logger
    snapshot_target = pJoin(Args.checkpoint_root, logger.prefix, 'snapshot.pt')
    logger.save_torch(agent, path=snapshot_target)

    # Save actor.state_dict()
    snapshot_target = pJoin(Args.checkpoint_root, logger.prefix, 'actor_state.pt')
    logger.save_torch(agent.actor.state_dict(), path=snapshot_target)

    logger.log_params(Progress=vars(Progress), path="progress.pkl", silent=True)
    logger.job_completed()


def save_checkpoint(agent, replay):
    from .config import Args
    from ml_logger import logger
    import cloudpickle

    logger.print('saving & uploading snapshot...')
    replay_path = pJoin(Args.checkpoint_root, logger.prefix, 'replay.pt')

    snapshot_target = pJoin(Args.checkpoint_root, logger.prefix, 'snapshot_in_progress.pt')
    logger.save_torch(agent, path=snapshot_target)

    logger.duplicate("metrics.pkl", "metrics_latest.pkl")
    logger.print('saving buffer to', replay_path)

    # NOTE: It seems tempfile cannot handle this 20GB+ file.
    logger.start('upload_replay_buffer')
    if replay_path.startswith('file://'):
        replay_path = replay_path[7:]
        Path(replay_path).resolve().parents[0].mkdir(parents=True, exist_ok=True)
        with open(replay_path, 'wb') as f:
            cloudpickle.dump(replay, f)
    elif replay_path.startswith('s3://') or replay_path.startswith('gs://'):
        logger.print('uploading buffer to', replay_path)
        tmp_path = Path(Args.tmp_dir) / logger.prefix / 'replay.pt'
        tmp_path.parents[0].mkdir(parents=True, exist_ok=True)
        with open(tmp_path, 'wb') as f:
            cloudpickle.dump(replay, f)

        if replay_path.startswith('s3://'):
            logger.upload_s3(str(tmp_path), path=replay_path[5:])
        else:
            logger.upload_gs(str(tmp_path), path=replay_path[5:])
    else:
        ValueError('replay_path must start with s3://, gs:// or file://. Not', replay_path)

    elapsed = logger.since('upload_replay_buffer')
    logger.print(f'Uploading replay buffer took {elapsed} seconds')

    # Save the progress.pkl last as a fail-safe. To make sure the checkpoints are saving correctly.
    logger.log_params(Progress=vars(Progress), path="progress.pkl", silent=True)


def load_checkpoint():
    from .config import Args
    from ml_logger import logger
    import torch

    # TODO: check if both checkpoint & replay buffer exist
    snapshot_path = os.path.join(Args.checkpoint_root, logger.prefix, 'snapshot_in_progress.pt')
    replay_path = os.path.join(Args.checkpoint_root, logger.prefix, 'replay.pt')
    assert logger.glob(snapshot_path) and logger.glob(replay_path) and logger.glob('progress.pkl') and logger.glob('metrics_latest.pkl')

    logger.print('loading agent from', snapshot_path)
    agent = logger.load_torch(snapshot_path)

    # Load replay buffer
    logger.print('loading from checkpoint (replay)', replay_path)

    # NOTE: It seems tempfile cannot handle this 20GB+ file.
    logger.start('download_replay_buffer')
    if replay_path.startswith('file://'):
        import cloudpickle
        with open(replay_path[7:], 'rb') as f:
            replay = cloudpickle.load(f)
    elif replay_path.startswith('s3://') or replay_path.startswith('gs://'):
        import cloudpickle
        tmp_path = Path(Args.tmp_dir) / logger.prefix / 'replay.pt'
        tmp_path.parents[0].mkdir(parents=True, exist_ok=True)

        if replay_path.startswith('s3://'):
            logger.download_s3(path=replay_path[5:], to=str(tmp_path))
        else:
            logger.download_gs(path=replay_path[5:], to=str(tmp_path))
        with open(tmp_path, 'rb') as f:
            replay = cloudpickle.load(f)
    else:
        ValueError('replay_path must start with s3://, gs:// or file://. Not', replay_path)

    elapsed = logger.since('download_replay_buffer')
    logger.print(f'Download completed. It took {elapsed} seconds')

    logger.duplicate("metrics_latest.pkl", to="metrics.pkl")
    logger.print('done')

    params = logger.read_params(path="progress.pkl")
    return agent, replay, params


def evaluate(env, agent, num_episodes, video_path=None):
    from ml_logger import logger

    episode_rewards = []
    frames = []
    for i in trange(num_episodes, desc="evaluate"):
        obs = env.reset()
        done = False
        episode_reward = 0
        size = 48
        while not done:
            if i == 0 and video_path:
                frames.append(env.physics.render(height=size, width=size, camera_id=0))
            with utils.eval_mode(agent):
                action = agent.select_action(obs)
            obs, reward, done, _ = env.step(action)
            episode_reward += reward

        if frames:
            frames.append(env.physics.render(height=size, width=size, camera_id=0))
            logger.save_video(frames, video_path)
            frames = []  # Hopefully this releases some memory
        episode_rewards.append(episode_reward)

    return np.mean(episode_rewards)


# NOTE: This wrapper will do nothing unless $SLACK_WEBHOOK_URL is set
webhook_url = os.environ.get("SLACK_WEBHOOK_URL", None)


@slack_sender(
    webhook_url=webhook_url,
    channel="rl-under-distraction-job-status",
    progress=Progress,
    ignore_exceptions=(TimeoutError,)
)
def main(**kwargs):
    from ml_logger import logger
    from .config import Args
    from drqv2_invariance.utils import set_egl_id

    set_egl_id()
    utils.set_seed_everywhere(kwargs['seed'])
    Args._update(kwargs)

    assert logger.prefix, "you will overwrite the entire instrument server"
    if logger.read_params('job.completionTime', default=None):
        logger.print("The job seems to have been already completed!!!")
        return

    update_job_status_every = 5 * 60
    logger.start('update_job_status')
    logger.start('start', 'episode', 'run', 'step')

    if logger.glob('progress.pkl'):
        try:
            # Use current config for some args
            keep_args = ['checkpoint_root', 'time_limit', 'checkpoint_freq', 'tmp_dir']
            Args._update({key: val for key, val in logger.read_params("Args").items() if key not in keep_args})
        except KeyError as e:
            print('Captured KeyError during Args update.', e)

        agent, replay_buffer, progress_cache = load_checkpoint()
        Progress._update(progress_cache)
        logger.timer_cache['start'] = logger.timer_cache['start'] - Progress.wall_time
        logger.print(f'loaded from checkpoint at {Progress.episode}', color="green")

    else:
        Args._update(kwargs)
        logger.log_params(Args=vars(Args))
        logger.log_text("""
            charts:
            - yKey: train/episode_reward/mean
              xKey: step
            - yKey: eval/episode_reward
              xKey: step
            """, filename=".charts.yml", dedent=True, overwrite=True)

    # NOTE: don't worry about the obs size. The size differs because it crops out the center.
    # shape of the input to the encoder stays the same.
    env = get_env(Args.env_name, Args.frame_stack, Args.action_repeat, Args.seed,
                  size=Args.image_size)
    if Args.eval_mode is None:
        test_env = None
    else:
        test_env = get_env(Args.eval_env_name, Args.frame_stack, Args.action_repeat, Args.seed + 42,
                           size=Args.image_size)

    if 'agent' not in locals():
        # assert torch.cuda.is_available(), 'must have cuda enabled'
        replay_buffer = utils.ReplayBuffer(obs_shape=env.observation_space.shape,
                                           action_shape=env.action_space.shape,
                                           capacity=Args.train_steps,
                                           batch_size=Args.batch_size)
        _cropped_obs_shape = (3 * Args.frame_stack, Args.image_crop_size, Args.image_crop_size)
        agent = make_agent(obs_shape=_cropped_obs_shape, action_shape=env.action_space.shape, args=Args)

    done, episode_step, episode_reward = True, 0, 0
    unhealthy_counter = 0

    for Progress.step in range(Progress.step, Args.train_steps + 1):
        Progress.wall_time = logger.since('start')
        Progress.frame = Progress.step * Args.action_repeat

        if done:
            dt_episode = logger.split('episode')
            logger.store_metrics({'train/episode_reward': episode_reward,
                                  'dt_episode': dt_episode})
            logger.log(step=Progress.step, episode=Progress.episode, wall_time=Progress.wall_time, flush=True)
            if episode_step:
                if Progress.step % Args.log_freq == 0:
                    logger.log_metrics_summary(key_values={'step': Progress.step, 'episode': Progress.episode})

                if Progress.step % Args.eval_freq == 0:
                    logger.print('running evaluation...')
                    try:
                        with logger.Prefix(metrics="eval"):
                            video_path = f"videos/eval_agent_{Progress.step:08d}.mp4"
                            mean_r = evaluate(env, agent, Args.eval_episodes, video_path=None)  # TEMP: ffmpeg  errors!
                            logger.log({'episode_reward': mean_r})
                        if test_env is not None:
                            with logger.Prefix(metrics="test"):
                                video_path = f"videos/test_agent_{Progress.step:08d}.mp4"
                                mean_r = evaluate(test_env, agent, Args.eval_episodes, video_path=None)  # TEMP: ffmpeg  errors!
                                logger.log({'episode_reward': mean_r})
                        logger.log(step=Progress.step, episode=Progress.episode, flush=True)

                    except (OSError, RuntimeError) as e:
                        logger.print('Error captured:', e)
                        logger.print('Saving snapshot before exiting...')
                        save_checkpoint(agent, replay_buffer)
                        logger.print('done')
                        raise e
                    logger.print('running evaluation... done')

                if Progress.step % Args.checkpoint_freq == 0:
                    save_checkpoint(agent, replay_buffer)

                if Args.time_limit and logger.since('run') > Args.time_limit:
                    logger.print(f'time limit {Args.time_limit} is reached. Saving snapshot...')
                    save_checkpoint(agent, replay_buffer)
                    raise TimeoutError

            obs, done, episode_reward, episode_step = env.reset(), False, 0, 0
            Progress.episode += 1

        if logger.since('update_job_status') > update_job_status_every:
            logger.job_running()
            logger.start('update_job_status')

        # Health check
        # NOTE: There's an unknown mysterious issue where the training becomes significantly slower
        # suddenly in the middle of training. There may be some memory leak.
        elapsed = logger.split('step')
        if elapsed > 0.8 and \
           not (Progress.step % Args.checkpoint_freq == 0 or Progress.step % Args.eval_freq == 0):
            logger.print(f'split time for step is {elapsed} ! Unhealthy!! counter: {unhealthy_counter}')
            unhealthy_counter += 1
            if unhealthy_counter > 10:
                logger.print('The instance is unhealthy... Save & exitting...')
                save_checkpoint(agent, replay_buffer)
                raise RuntimeError
        else:
            unhealthy_counter = 0


        # Sample action for data collection
        if Progress.step < Args.init_steps:
            action = env.action_space.sample()
        else:
            with utils.eval_mode(agent):
                action = agent.sample_action(obs)

        # Run training update
        if Progress.step == Args.init_steps:
            for _ in trange(Args.init_steps, desc=f"updating for {Args.init_steps} init steps"):
                agent.update(replay_buffer, Progress.step)
        elif Progress.step > Args.init_steps:
            for _ in range(1):
                agent.update(replay_buffer, Progress.step)

        # Take step
        next_obs, reward, done, _ = env.step(action)

        # always set done to False to make DMC infinite horizon. -- Ge
        replay_buffer.add(obs, action, reward, next_obs, False)
        episode_reward += reward
        obs = next_obs

        episode_step += 1

    # saving the final agent
    logger.print('Training is done!! Saving snapshot...')
    save_final_checkpoint(agent)
    logger.print('Saving snapshot done!')


if __name__ == '__main__':
    main()
