#!/usr/bin/env python3
from os.path import join as pJoin
import shutil
from pathlib import Path

import numpy as np
import torch

from . import utils
from invr_thru_inf.adapt import Progress


def load_adpt_agent(Args, Adapt):
    """
    Load from the pretrained agent specified by snapshot_prefix.
    """
    from ml_logger import logger
    from invr_thru_inf.env_helpers import get_env
    from invr_thru_inf.agent import AdaptationAgent
    snapshot_dir = pJoin(Args.checkpoint_root, Adapt.snapshot_prefix)

    from invr_thru_inf.dummy_actors import DMCGENDummyActor

    dummy_env = get_env(Args.env_name, Args.frame_stack, Args.action_repeat, Args.seed,
                        size=Args.image_size)
    action_shape = dummy_env.action_space.shape
    obs_shape = dummy_env.observation_space.shape

    snapshot_path = pJoin(snapshot_dir, 'actor_state.pt')
    logger.print('Loading model from', snapshot_path)
    actor_state = logger.load_torch(path=snapshot_path)

    from dmc_gen.algorithms.factory import make_agent
    mock_agent = make_agent(obs_shape, action_shape, Args)
    mock_agent.actor.load_state_dict(actor_state)
    # def make_agent(obs_shape, action_shape, args):
    config = dict(
        encoder=mock_agent.actor.encoder, actor_from_obs=DMCGENDummyActor(mock_agent.actor),
        algorithm=Adapt.agent_stem, action_shape=action_shape,
        feature_dim=Args.projection_dim  # TODO: <-- check if this is reasonable!
    )

    adapt_agent = AdaptationAgent(
        **config,
        invm_head=None,
        hidden_dim=Adapt.adapt_hidden_dim,
        device=Args.device
    )
    logger.print('done')

    return adapt_agent


def main(**kwargs):
    from ml_logger import logger, RUN, ML_Logger
    from invr_thru_inf.utils import get_distr_string, set_egl_id, update_args, set_seed_everywhere
    from invr_thru_inf.adapt import prepare_buffers, adapt_offline
    from invr_thru_inf.env_helpers import get_env

    from invr_thru_inf.replay_buffer import Replay
    from .config import Args
    from invr_thru_inf.config import Adapt

    if not torch.cuda.is_available():
        raise RuntimeError('torch.cuda.is_available() is False!!')

    logger.start('run', 'start')

    # Update config parameters based on kwargs
    set_egl_id()
    set_seed_everywhere(kwargs['seed'])

    Args._update(kwargs)
    Adapt._update(kwargs)

    if Adapt.snapshot_prefix:
        src_logger = ML_Logger(
            prefix=Adapt.snapshot_prefix
        )

        assert src_logger.prefix != logger.prefix
        logger.print('src_logger.prefix', src_logger.prefix)
        logger.print('logger.prefix', logger.prefix)

        # Check if the job is completed
        try:
            completion_time = src_logger.read_params('job.completionTime')
            logger.print(f'Pre-training was completed at {completion_time}.')
        except KeyError:
            logger.print(f'training for {src_logger.prefix} has not been finished yet.')
            logger.print('job.completionTime was not found.')
            if not RUN.debug:
                raise RuntimeError

        # Load from the checkpoint
        assert RUN.debug or src_logger.glob('progress.pkl')

        # Update parameters
        keep_args = ['eval_env_name', 'seed', 'tmp_dir', 'checkpoint_root', 'time_limit', 'device']
        Args._update({key: val for key, val in src_logger.read_params("Args").items() if key not in keep_args})
        logger.print('Args after update', vars(Args))
    logger.log_params(Args=vars(Args), Adapt=vars(Adapt))


    logger.log_text("""
            keys:
            - Args.env_name
            - Args.seed
            charts:
            - yKey: eval/episode_reward
              xKey: step
            - yKey: clean_eval/episode_reward
              xKey: step
            - yKey: encoder_loss/mean
              xKey: step
            - yKey: distr_env/ss_loss
              xKey: step
            - yKey: clean_env/ss_loss
              xKey: step
            - yKey: focal_coef/mean
              xKey: step
            - yKeys: ["distr_env/focal_coef/mean", "clean_env/focal_coef/mean"]
              xKey: step
            - yKeys: ["discriminator_loss/mean", "discr_adpt_loss/mean", "discr_org_loss/mean"]
              xKey: step
            - yKey: adpt_reward/mean
              xKey: step
            - yKey: grad_penalty/mean
              xKey: step
            - yKey: ss_loss/mean
              xKey: frame
            - yKey: eval/mean_err
              xKey: step
            - yKeys: ["eval/stddev_err", "eval/stddev_diff"]
              xKey: step
            """, ".charts.yml", overwrite=True, dedent=True)


    orig_buffer_dir, targ_buffer_dir = prepare_buffers(
        Args.checkpoint_root, Args.tmp_dir, Args.env_name, Args.eval_env_name,
        Args.action_repeat, Args.seed, Adapt
    )

    # ===== make or load adaptation-agent =====
    if RUN.resume and logger.glob('checkpoint.pkl'):

        # Verify that arguments are consistent
        from invr_thru_inf.utils import verify_args
        loaded_args = logger.read_params("Args")
        verify_args(vars(Args), loaded_args)

        Progress._update(logger.read_params(path="checkpoint.pkl"))
        try:
            # Adaptation has been already completed
            completion_time = logger.read_params('job.completionTime')
            logger.print(f'job.completionTime is set:{completion_time}\n',
                         'This job seems to have been completed already.')
            return
        except KeyError:
            from invr_thru_inf.adapt import load_snapshot
            # Load adaptation agent from s3
            logger.print(f'Resuming from the checkpoint. step: {Progress.step}')
            adapt_agent = load_snapshot(Args.checkpoint_root)

            logger.start('episode')
            logger.timer_cache['start'] = logger.timer_cache['episode'] - Progress.wall_time

    else:
        from invr_thru_inf.config import Adapt
        adapt_agent = load_adpt_agent(Args, Adapt)

    # NOTE: Below only requires Args and Adapt
    latent_replay = Replay(
        buffer_dir=orig_buffer_dir, buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size, num_workers=4,
        store_states=False,
    )
    distr_replay = Replay(
        buffer_dir=targ_buffer_dir,
        buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size,
        num_workers=4,
        store_states=False
    )
    orig_eval_env = get_env(Args.env_name, Args.frame_stack, Args.action_repeat, Args.seed + 1000,
                            size=Args.image_size)
    logger.print('eval_env_name', Args.eval_env_name)
    targ_eval_env = get_env(Args.eval_env_name, Args.frame_stack, Args.action_repeat, Args.seed,
                            distraction_config=Adapt.distraction_types, size=Args.image_size,
                            intensity=Adapt.distraction_intensity)

    logger.print('orig_eval_env', orig_eval_env)
    logger.print('targ_eval_env', targ_eval_env)

    assert Adapt.adapt_offline
    adapt_offline(Adapt,
                  Args.checkpoint_root, Args.action_repeat, Args.time_limit, Args.batch_size, Args.device,
                  orig_eval_env, targ_eval_env, adapt_agent, latent_replay, distr_replay,
                  progress=Progress)
