#!/usr/bin/env python3

import os
from os.path import join as pJoin
from pathlib import Path

from invr_thru_inf.collect_offline_data import (
    upload_and_cleanup, SaveInfoWrapper, RailsWrapper, collect_trajectories, collect_orig_latent_buffer, collect_targ_obs_buffer
)


def main(**kwargs):
    from ml_logger import logger, RUN, ML_Logger
    from .config import Args
    from invr_thru_inf.config import Adapt, CollectData
    from invr_thru_inf.utils import set_egl_id, set_seed_everywhere, get_buffer_prefix

    from warnings import simplefilter  # noqa
    simplefilter(action='ignore', category=DeprecationWarning)

    logger.start('run')
    set_egl_id()
    set_seed_everywhere(kwargs['seed'])

    Args._update(kwargs)
    Adapt._update(kwargs)

    if Adapt.snapshot_prefix:
        src_logger = ML_Logger(
            prefix=Adapt.snapshot_prefix
        )

        assert src_logger.prefix != logger.prefix
        logger.print('src_logger.prefix', src_logger.prefix)
        logger.print('logger.prefix', logger.prefix)

        # Check if the job is completed
        try:
            completion_time = src_logger.read_params('job.completionTime')
            logger.print(f'Pre-training was completed at {completion_time}.')
        except KeyError:
            logger.print(f'training for {src_logger.prefix} has not been finished yet.')
            logger.print('job.completionTime was not found.')
            if not RUN.debug:
                raise RuntimeError

        # Load from the checkpoint
        assert RUN.debug or src_logger.glob('progress.pkl')

        # Update parameters
        keep_args = ['eval_env_name', 'seed', 'tmp_dir', 'checkpoint_root', 'time_limit', 'device']
        logger.print('original Args', vars(Args))
        logger.print('loading Args', src_logger.read_params("Args"))
        Args._update({key: val for key, val in src_logger.read_params("Args").items() if key not in keep_args})
        logger.print('updated Args', vars(Args))
    logger.log_params(Args=vars(Args), Adapt=vars(Adapt))

    # Update config parameters based on kwargs
    CollectData._update(kwargs)


    # ===== Prepare the buffers =====

    # Generate offline trajectories
    # NOTE: Args.local_buffer was pointing to Args.snapshot_dir, which is /share/data/ripl/takuma/snapshots
    # shared_buffer_dir = Path(Adapt.local_buffer) / 'data-collection' / get_buffer_prefix(action_repeat=Args.action_repeat)
    logger.print('eval_env_name', Args.eval_env_name)
    logger.print('env_name', Args.env_name)
    logger.print('action_repeat', Args.action_repeat)

    # Currently only this configuration is suppoorted.
    assert Adapt.policy_on_clean_buffer == "random" and Adapt.policy_on_distr_buffer == "random"

    if CollectData.buffer_type == 'original':
        from .adapt import load_adpt_agent
        buffer_dir_prefix = Path(Args.tmp_dir) / 'data-collection' / get_buffer_prefix(
            Args.env_name, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env_name,
            latent_buffer=True
        )
        orig_buffer_dir = buffer_dir_prefix / 'workdir' / "orig_latent_buffer"
        logger.print('orig_buffer_dir', orig_buffer_dir)
        agent = load_adpt_agent(Args, Adapt)
        target_path = pJoin(Args.checkpoint_root, get_buffer_prefix(
            Args.env_name, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env_name,
            latent_buffer=True), 'orig_latent_buffer.tar')
        logger.print(target_path)
        if not logger.glob(target_path):
            collect_orig_latent_buffer(
                agent, orig_buffer_dir, Args.env_name, Adapt.latent_buffer_size,
                Args.batch_size, nstep=1, discount=Args.discount, frame_stack=Args.frame_stack,
                action_repeat=Args.action_repeat, seed=Args.seed, time_limit=Args.time_limit, Adapt=Adapt
            )

            logger.print(f'Compressing & Uploading the buffer: {target_path}')
            upload_and_cleanup(orig_buffer_dir, target_path, archive='tar')
            logger.print('Upload completed!')

    elif CollectData.buffer_type == 'target':
        buffer_dir_prefix = Path(Args.tmp_dir) / 'data-collection' / get_buffer_prefix(
            Args.env_name, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env_name,
        )
        targ_buffer_dir = buffer_dir_prefix / 'workdir' / "targ_obs_buffer"
        logger.print('targ_buffer_dir', targ_buffer_dir)
        target_path = pJoin(Args.checkpoint_root, get_buffer_prefix(
            Args.env_name, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env_name
        ), 'targ_obs_buffer.tar')
        if not logger.glob(target_path):
            collect_targ_obs_buffer(
                targ_buffer_dir, Args.eval_env_name, Adapt.latent_buffer_size, Args.batch_size, nstep=1,
                discount=Args.discount, frame_stack=Args.frame_stack, action_repeat=Args.action_repeat,
                seed=Args.seed, time_limit=Args.time_limit, Adapt=Adapt
            )

            logger.print(f'Compressing & Uploading the buffer: {target_path}')
            upload_and_cleanup(targ_buffer_dir, target_path, archive='tar')
            logger.print('Upload completed!')
    else:
        raise ValueError('invalid buffer type', CollectData.buffer_type)

    logger.print('=========== Completed!!! ==============')

if __name__ == '__main__':
    main()
