#!/usr/bin/env python3
import os
from os.path import join as pJoin
from pathlib import Path
from datetime import datetime

import numpy as np
import torch
from gym import Wrapper

from . import utils
from .config import Args, Agent, Adapt


def check_if_already_uploaded():
    from .config import get_buffer_prefix

    # TEMP: Check S3 bucket if the corresponding buffer.tar.gz already exists
    assert Args.checkpoint_root.startswith('s3://')
    import boto3
    bucket_name, s3_prefix = Args.checkpoint_root[5:].split('/', 1)
    bucket = boto3.Session().resource('s3').Bucket(bucket_name)
    _s3_target_path = pJoin(s3_prefix, get_buffer_prefix(), 'buffer.tar.gz')
    glob_result = list(bucket.objects.filter(Prefix=_s3_target_path))
    print(f'glob s3: {_s3_target_path}', glob_result)
    if glob_result:
        raise RuntimeError(f'Yeahhhh! Buffer {_s3_target_path} is already available!')


def should_generate_buffer(target_path, flock_path, invalidate_lock_after=60*40):
    # TODO: remove KeyError once logger.glob_g3 is fixed
    from ml_logger import logger

    assert target_path.startswith('s3://')
    assert flock_path.startswith('s3://')
    try:
        # flist = logger.glob_s3('/'.join(target_path[5:].split('/')[:-1]))
        flist = logger.glob_s3(target_path[5:])
    except KeyError as e:
        logger.print(f'logger.glob_s3 error: {e}')
        flist = []

    logger.print(f'glob {target_path[5:]}:\t{flist}')
    should_generate = True
    if flist:
        logger.print(f'The buffer is already available on s3!!: {target_path}')
        should_generate = False
    else:
        try:
            flock_exists = logger.glob_s3(flock_path[5:])
        except KeyError as e:
            logger.print(f'logger.glob_s3 error: {e}')
            flock_exists = False
        if flock_exists:
            logger.print(f'File lock is found on s3 : {flock_path}')  # Another process is already generating buffer

            # Check if the timestamp is within 40 minutes.
            last_timestamp = logger.load_torch(path=flock_path)
            if (datetime.now() - last_timestamp).seconds < invalidate_lock_after:
                should_generate = False
    return should_generate


def upload_and_cleanup(buffer_dir, target_path, flock_path, archive):
    from ml_logger import logger
    # Upload target-obs-buffer
    logger.start('uploading_tar')
    logger.upload_dir(buffer_dir, target_path, archive='tar')
    elapsed = logger.since('uploading_tar')
    logger.print(f'Upload took {elapsed} seconds')

    # Release remote file lock
    assert flock_path.startswith('s3://')
    bucket, s3_path = flock_path[5:].split('/', 1)
    assert '/' in s3_path, f"Wooo, are you sure the s3 path is correct?\nbucket: {bucket}\ns3_path: {s3_path}"
    logger.remove_s3(bucket, s3_path)



class SaveInfoWrapper(Wrapper):
    """Nothing but save info on clean environment"""
    def __init__(self, env, *args, **kwargs):
        super().__init__(env, *args, **kwargs)

    def reset(self):
        return self.env.reset()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        state = self.env.unwrapped.env.physics.get_state()

        # That this assertion holds actually means this wrapper is not necessary
        # But this is clearer since how / when the state is saved is explicit.
        assert (info['sim_state'] == state).all(), f"{info['sim_state']}\nvs\n{state}"

        # Overwrite sim_state
        info = {'sim_state': state}
        return obs, reward, done, info


class RailsWrapper(Wrapper):
    """This wrapper is injected to wrap DMEnv or DistractingEnv
    and completly fakes the output of step() and reset() so that the trajectory
    replays exactly what is registered in self._episode.
    """
    def __init__(self, env, *args, **kwargs):
        super().__init__(env, *args, **kwargs)
        self._episode = None
        self._counter = 0
        self._state_sequence = []

    def _get_obs_pixels(self):
        from distracting_control.gym_env import DistractingEnv
        from gym_dmc.dmc_env import DMCEnv

        env = self.env.unwrapped
        if isinstance(env, DMCEnv):
            obs = env._get_obs_pixels()
        elif isinstance(env, DistractingEnv):
            pixel_wrapper = env.env
            # obs = dmenv._task.get_observation(dmenv._physics)  # This is for non-pixel
            obs = pixel_wrapper.physics.render(**pixel_wrapper._render_kwargs)
            # pixels = dmenv._env.physics.render(**self._render_kwargs)
            if self.env.channels_first:
                obs = obs.transpose([2, 0, 1])
        else:
            raise ValueError('The env under RailsWrapper is invalid:', self.env.unwrapped)
        return obs

    def reset(self):
        self.env.reset()
        assert self._episode, 'self._episode is empty. self.set_episode(episode) must be called before reset.'
        self._counter = 0

        # self.env.set_state(self._initial_state)
        suite_env = self.env.unwrapped.env
        suite_env.physics.set_state(self._initial_state)  # NOTE: this is in replacement of env.set_state!!!
        return self._get_obs_pixels()

    def step(self, action):
        obs, _, _, info = self.env.step(action)

        state = self._state_sequence[self._counter]
        suite_env = self.env.unwrapped.env
        suite_env.physics.set_state(state)  # NOTE: Never use env.set_state! That will call step internally.

        self._counter += 1

        done = False
        if self._counter == len(self._state_sequence):
            done = True
            self._episode = None

        # NOTE:
        # I have no idea why the heck these two do not match when random seeds are different:
        # assert (suite_env.physics.get_state() == info['sim_state']).all()

        # Overwrite info since self.env.step does not work properly.
        info = {'sim_state': self.env.unwrapped.env.physics.get_state()}
        fake_reward = 0.0
        return obs, fake_reward, done, info

    def set_episode(self, episode, target=None):
        """Set the episode which RailsWrapper uses to replay.

        episode['state'][n] contains stacked states that correspond to frame-stacked observations.
        Always use the last element of the stacked state. That corresponds to the latest observation.
        """
        self._episode = episode

        # NOTE: during training, just after reset, all rows of stacked frames are identical.
        # episode['state'] contains sequence of stacked states only after the reset.
        # episode['state'][0] should look like [init_state, init_state, second_state]
        # episode['state'][1] would be         [init_state, second_state, third_state]
        self._initial_state = episode['state'][0][0]
        self._state_sequence = [stacked_states[-1] for stacked_states in episode['state']]

        dm_env = self.env.unwrapped.env

        # TODO: This does not work!!!
        # it does not reflect the values set.
        # HOWEVER, as long as we align random seed on two environments, goal positions are sampled in a same way.
        if 'Reacher' in self.env.unwrapped.spec.id:
            dm_env.physics.named.model.geom_size['target', 0] = episode['extra__goal_size']
            dm_env.physics.named.model.geom_size['target', 'x'] = episode['extra__goal_x']
            dm_env.physics.named.model.geom_size['target', 'y'] = episode['extra__goal_y']

        # print('len stacked states', [len(stacked_states) for stacked_states in episode['state']])


def wait_till_acquire_lock(buffer_dir):
    """ File-lock-like mechanism to avoid multiple processes filling the same buffer.

    NOTE: This is not exactly like file-lock. Only the process filling the buffer needs to
    write in the directory. All other processes only need to wait for it to be filled.
    """
    import time

    from ml_logger import logger

    buffer_dir = Path(buffer_dir)
    buffer_dir.mkdir(parents=True, exist_ok=True)

    lock_file = buffer_dir / 'dir.lock'
    if os.path.isfile(lock_file):
        # The buffer is being filled by another process.
        fill_buffer = False

        # Wait until the lock is released
        while os.path.isfile(lock_file):
            logger.print(f'waiting for dir.lock to be released..: {buffer_dir}')
            time.sleep(np.random.random() * 10)
        logger.print(f'Detected the lock being released: {buffer_dir}')
    else:
        # Buffer is not filled by anyone yet.
        # Create the directory and acquire the lock.
        # Create dir lock
        with open(lock_file, 'w'):
            pass
        logger.print(f'The file lock is created: {buffer_dir}')

        # Fill buffer only when the buffer_dir is empty
        if len(list(buffer_dir.iterdir())) <= 1:  # file lock
            fill_buffer = True
        else:
            fill_buffer = False
            # Release the lock
            logger.print('Releasing the lock..')
            lock_file.unlink()

    return fill_buffer


def release_file_lock(buffer_dir):
    # Release the lock
    from ml_logger import logger
    buffer_dir = Path(buffer_dir)
    lock_file = buffer_dir / 'dir.lock'
    logger.print('Releasing the lock..')
    lock_file.unlink(missing_ok=True)



def collect_trajectories(agent, env, buffer_dir, num_steps, pretrain_last_step, replay,
                         random_policy=False,
                         encode_obs=False, augment=None, store_states=False, device='cuda', video_key='noname',
                         time_limit=None, remote_flock_path=None, update_flock_every=20 * 60):
    """Run agent in env, and collect trajectory.
    Save each transition as (latent, action, reward, next_latent) tuple,
    reusing the original ReplayBuffer.
    """
    from pathlib import Path

    from ml_logger import logger

    from .config import Adapt, Args, get_buffer_prefix
    from .replay_buffer import Replay

    buffer_dir = Path(buffer_dir)

    def manipulate_obs(observation):
        # Observation manipulation
        if augment:
            num_augments = 4
            # obs: (9, 84, 84)
            # batch_obs: (4, 9, 84, 84)
            batch_obs = utils.repeat_batch(observation, batch_size=num_augments, device=device)
            observation = augment(batch_obs.float())

        if encode_obs:
            observation = agent.encode(observation, to_numpy=True).squeeze(),

        return observation

    def calc_remaining_steps(num_steps):
        if buffer_dir.is_dir():
            import math
            from .adapt import num_transitions_per_file
            num_files = len(list(buffer_dir.glob('*.npz')))
            logger.print(f'{num_files} files are found in {buffer_dir}')
            last_steps = (num_files * num_transitions_per_file) / Args.action_repeat
            return num_steps - last_steps
        return num_steps

    # gpu0 was not occupied here !!!!   =================================

    # fill_buffer = wait_till_acquire_lock(buffer_dir)
    fill_buffer = True

    # Policy will not run if latent_buffer directory exists.
    if fill_buffer:
        logger.start('fill_buffer')
        logger.print(f"Needs to fill buffer: {buffer_dir}\nenvironment: {env}")

        remaining_steps = calc_remaining_steps(num_steps)
        logger.print(f'{remaining_steps} steps to go: {buffer_dir}')



        step = 0
        video_ep = 0
        num_video_episodes = 5
        while step < remaining_steps:
            obs = env.reset()
            # NOTE: Intentionally skip storin ghe first transition
            # since we don't have info['sim_state']
            # but this is handled. (Look at RailsWrapper._initial_state)


            # Wait here
            import time
            while True:
                time.sleep(2)

            # Check timelimit!!
            if time_limit and logger.since('run') > time_limit:
                logger.print(f'local time_limit: {time_limit} (sec) has reached!')
                # Cleanup
                num_files = len(list(buffer_dir.iterdir()))
                logger.print(f'{num_files} files have been generated in this run!')
                raise TimeoutError

            # Periodically update timestamp on remote file lock
            # if logger.since('fill_buffer') > update_flock_every:
            #     logger.print(f'logger.since: {logger.since("fill_buffer")}\tupdating timestap in remote file lock...')
            #     logger.start('fill_buffer')
            #     logger.save_torch(datetime.now(), path=remote_flock_path)

            # frames = []
            done = False

            # NOTE: Needs to store the goal info since that's not in 'state'
            if 'Reacher' in env.unwrapped.spec.id and store_states:
                print('saving extra info!!')
                goal_info = {
                    'goal_size': env.physics.named.model.geom_size['target', 0],
                    'goal_x': env.physics.named.model.geom_pos['target', 'x'],
                    'goal_y': env.physics.named.model.geom_pos['target', 'y']
                }
                print('goal_info', goal_info)
                # replay.storage.add_extra(goal_info)

            while not done:
                # if video_ep < num_video_episodes:
                #     frames.append(
                #         env.physics.render(height=256, width=256, camera_id=0)
                #     )

                if agent is None or random_policy:
                    action = env.action_space.sample()
                else:
                    with torch.no_grad(), utils.eval_mode(agent):
                        action = agent.act(obs, pretrain_last_step, eval_mode=True)

                obs, reward, done, info = env.step(action)



                extra = {'state': info['sim_state']} if store_states else {}
                transition = dict(
                    obs=manipulate_obs(obs),
                    reward=reward,
                    done=done,
                    discount=1.0,
                    action=action,
                    **extra
                )

                # eps_fn = replay.storage.add(**transition)
                # step += 1

                # NOTE: At this point gpu0 was already occupied =====================================================

            # if frames:
            #     frames.append(env.physics.render(height=256, width=256, camera_id=0))
            #     logger.save_video(frames, key=f"videos/{video_key}/{eps_fn[:-4]}.mp4")
            #     video_ep += 1

        # Rlease file lock
        # release_file_lock(buffer_dir)
        assert (
            calc_remaining_steps(num_steps) == 0
        ), f"The number of files does not match!!: {buffer_dir}\ndetected remaining step: {calc_remaining_steps(num_steps)}"

    else:
        logger.print(f"Buffer is filled: {buffer_dir}")

        # TEMP: Verify the number of files
        if Adapt.latent_buffer_size == 1_000_000:
            num_files = len(list(buffer_dir.iterdir()))
            if num_files != 1000:
                raise RuntimeError(
                    'Donwloaded buffer does not contain 1,000 files as expected.\n'
                    'Instead {num_files} files found in {buffer_dir}'
                )

    return replay

def collect_orig_latent_buffer(agent, buffer_dir, pretrain_last_step, flock_path):
    from ml_logger import logger
    from .replay_buffer import LatentReplay
    from .env_helpers import get_env
    from .drqv2_invar import RandomShiftsAug

    # Create a file lock with timestamp
    # logger.save_torch(datetime.now(), path=flock_path)

    buffer_dir = Path(buffer_dir)
    replay = LatentReplay(
        buffer_dir=buffer_dir, buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size, num_workers=1,
        store_states=False,
    )
    clean_buffer_seed = Args.seed + 1000
    clean_env = get_env(Args.train_env, Args.frame_stack, Args.action_repeat, clean_buffer_seed, save_info_wrapper=True)
    collect_trajectories(
        agent=agent, env=clean_env, buffer_dir=buffer_dir, num_steps=Adapt.latent_buffer_size // Args.action_repeat,
        pretrain_last_step=pretrain_last_step, replay=replay , store_states=False, augment=RandomShiftsAug(pad=4),
        video_key='original_latent_trajectory', encode_obs=True, random_policy=True,
        time_limit=Args.time_limit, remote_flock_path=flock_path
    )


def collect_target_obs_buffer(buffer_dir, flock_path):
    from ml_logger import logger
    from .replay_buffer import Replay
    from .env_helpers import get_env

    # Create a file lock with timestamp
    # logger.save_torch(datetime.now(), path=flock_path)

    buffer_dir = Path(buffer_dir)
    replay = Replay(
        buffer_dir, buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size, num_workers=1, store_states=False
    )
    distr_env = get_env(Args.eval_env, Args.frame_stack, Args.action_repeat, Args.seed,
                        distraction_config=Adapt.distraction_types, rails_wrapper=False)
    collect_trajectories(
        agent=None, env=distr_env, buffer_dir=buffer_dir, num_steps=Adapt.latent_buffer_size // Args.action_repeat,
        pretrain_last_step=-1, replay=replay , store_states=False, augment=None,
        video_key='target_obs_trajectory', encode_obs=False, random_policy=True,
        time_limit=Args.time_limit, remote_flock_path=flock_path
    )

def convert_to_distrobs_trajectories(env, buffer_dir, latent_replay):
    """generate distraction trajectories (observation) and store them to a buffer,

    by converting the existing clean trajectories with state.
    """

    from pathlib import Path

    from ml_logger import logger

    from .config import Adapt, Args, get_buffer_prefix
    from .replay_buffer import Replay

    buffer_dir = Path(buffer_dir)
    fill_buffer = wait_till_acquire_lock(buffer_dir)

    distr_replay = Replay(
        buffer_dir=buffer_dir,
        buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size,
        num_workers=1,
        store_states=False
    )

    num_video_episodes = 5
    # Policy will not run if latent_buffer directory exists.
    if fill_buffer:
        logger.print(f"Needs to fill buffer: {buffer_dir}\nFilling it...")
        assert latent_replay.store_states

        latent_replay._replay_buffer._try_fetch()  # load episodes
        for i, (eps, episode) in enumerate(sorted(latent_replay._replay_buffer._episodes.items())):

            env.set_episode(episode)
            env.reset()
            done = False
            obs_episode = []
            frames = []
            counter = 0
            while not done:

                if i < num_video_episodes:
                    frames.append(
                        env.physics.render(height=256, width=256, camera_id=0)
                    )

                obs, _, done, info = env.step(episode['action'][counter])
                obs_episode.append(obs)

                dm_env = env.unwrapped.env

                # Verify the two states match
                # TODO: This fails when seeds are different !!
                assert (episode['state'][counter] == np.array(info['sim_state'])).all(), f"{episode['state'][counter]}\nvs\n{np.array(info['sim_state'])}"
                dm_env = env.unwrapped.env
                assert (episode['state'][counter][-1] == dm_env.physics.get_state()).all(), f"{episode['state'][counter]}\nvs\n{dm_env.physics.get_state()}"
                counter += 1

            if frames:
                import os
                frames.append(env.physics.render(height=256, width=256, camera_id=0))
                logger.save_video(frames, key=f"videos/distr_buffer/{os.path.basename(eps)[:-4]}.mp4")

            # Replace observation entry with distracted one

            distr_episode = {key: episode[key] for key in ['action', 'reward', 'discount']}
            distr_episode.update({'observation': obs_episode})
            # Store episode
            distr_replay.storage._store_episode(distr_episode)

        release_file_lock(buffer_dir)
    else:
        logger.print(f"Buffer is filled: {buffer_dir}")

    return distr_replay


def fill_both_buffers(agent, pretrain_last_step, clean_buffer_dir, distr_buffer_dir, clean_buffer_seed=100):
    from .config import Adapt, Agent, Args
    from .env_helpers import get_env
    from .replay_buffer import LatentReplay, Replay

    clean_buffer_dir = Path(clean_buffer_dir)
    distr_buffer_dir = Path(distr_buffer_dir)

    def get_args(env, policy, mode, **kwargs):
        if mode == 'clean_latent_trajectory':
            replay = LatentReplay(
                buffer_dir=clean_buffer_dir, buffer_size=Adapt.latent_buffer_size,
                batch_size=Args.batch_size, num_workers=1,
                store_states=(policy != 'random'),
            )
            return dict(
                agent=agent, env=env, buffer_dir=clean_buffer_dir, num_steps=Adapt.latent_buffer_size // Args.action_repeat,
                pretrain_last_step=pretrain_last_step, replay=replay , store_states=(policy != 'random'), augment=RandomShiftsAug(pad=4),
                video_key='original_latent_trajectory', encode_obs=True, random_policy=(policy == 'random'),
                time_limit=Args.time_limit, **kwargs
            )
        elif mode == 'distr_obs_trajectory':
            replay = Replay(
                distr_buffer_dir, buffer_size=Adapt.latent_buffer_size,
                batch_size=Args.batch_size, num_workers=1, store_states=False
            )
            return dict(
                agent=None if policy=='random' else agent, env=env, buffer_dir=distr_buffer_dir, num_steps=Adapt.latent_buffer_size // Args.action_repeat,
                pretrain_last_step=pretrain_last_step, replay=replay , store_states=False, augment=None,
                video_key='target_obs_trajectory', encode_obs=False, random_policy=(policy == 'random'),
                time_limit=Args.time_limit, **kwargs
            )
        else:
            raise ValueError(f'Invalid mode: {mode}')

    # This particular configuration is NOT allowed.
    assert not (Adapt.policy_on_clean_buffer == "random" and Adapt.policy_on_distr_buffer != "random")

    # Collect trajectories for clean-latent-buffer
    clean_env = get_env(Args.train_env, Args.frame_stack, Args.action_repeat, clean_buffer_seed, save_info_wrapper=True)
    collect_trajectories(**get_args(clean_env, policy=Adapt.policy_on_clean_buffer, mode='clean_latent_trajectory'))

    # Collect trajectories for distr-obs-buffer
    if Adapt.policy_on_distr_buffer == 'random':
        distr_env = get_env(Args.eval_env, Args.frame_stack, Args.action_repeat, Args.seed,
                            distraction_config=Adapt.distraction_types, rails_wrapper=False)
        collect_trajectories(**get_args(distr_env, policy='random', mode='distr_obs_trajectory'))
    else:
        # Collect the trajectories with pre-trained policy on clean-env, and then use rails-wrapper to replay it on distr-env
        # NOTE: use Args.seed + 1000 for clean_env.
        # This ensures that there's no identical episodes betweeen 1. distr-obs-buffer and 2. clean-latent-buffer generated before.
        clean_seed = 1000
        clean_buffer_dir = clean_buffer_dir.parents[0] / f'{clean_seed}'
        clean_env = get_env(Args.train_env, Args.frame_stack, Args.action_repeat, clean_seed, save_info_wrapper=True)
        latent_replay = collect_trajectories(
            get_args(clean_env, policy='pretrained', mode='clean_latent_trajectory', buffer_dir=clean_buffer_dir)
        )

        distr_env_on_rails = get_env(Args.eval_env, Args.frame_stack, Args.action_repeat, Args.seed,
                                     distraction_config=Adapt.distraction_types, rails_wrapper=True)
        convert_to_distrobs_trajectories(distr_env_on_rails, distr_buffer_dir, latent_replay)



def main(**kwargs):
    import shutil

    from ml_logger import ML_Logger, logger

    from . import utils
    from .config import Adapt, Agent, Args, get_buffer_prefix
    from .env_helpers import get_env
    from .replay_buffer import LatentReplay, Replay
    from .utils import get_distr_string

    # NOTE: CUDA_VISIBLE_DEVICES is set a little late by the node. Thus, putting
    # export EGL_DEVICE_ID=$CUDA_VISIBLE_DEVICES does NOT work. Thus I do it here manually.
    import os
    logger.print('pre EGL_DEVICE_ID', os.environ.get('EGL_DEVICE_ID', 'variable not found'))
    os.environ['EGL_DEVICE_ID'] = os.environ['CUDA_VISIBLE_DEVICES']
    logger.print('CUDA_VISIBLE_DEVICES:', os.environ.get('CUDA_VISIBLE_DEVICES', 'variable not found'))
    logger.print('EGL_DEVICE_ID', os.environ.get('EGL_DEVICE_ID', 'variable not found'))

    logger.start('run')
    utils.set_seed_everywhere(kwargs['seed'])

    Adapt._update(kwargs)

    # TODO: use logger_prefix
    src_logger = ML_Logger(
        prefix=Adapt.snapshot_prefix
    )

    assert src_logger.prefix != logger.prefix
    logger.print('src_logger.prefix', src_logger.prefix)
    logger.print('logger.prefix', logger.prefix)
    # Check if the job is completed
    try:
        completion_time = src_logger.read_params('job.completionTime')
        logger.print(f'Pre-training was completed at {completion_time}.')
    except KeyError:
        logger.print(f'training for {logger.prefix} has not been finished yet.')
        logger.print('job.completionTime was not found.')
        raise RuntimeError

    # Load from the checkpoint
    assert src_logger.glob('checkpoint.pkl')

    # Update parameters
    Args._update(kwargs)

    keep_args = ['eval_env', 'seed', 'tmp_dir', 'checkpoint_root', 'time_limit']
    Args._update({key: val for key, val in src_logger.read_params("Args").items() if key not in keep_args})
    Agent._update(**src_logger.read_params('Agent'))

    deprecated_attrs = ['use_attn', 'use_adapt']
    for attr in deprecated_attrs:
        if hasattr(Agent, attr):
            delattr(Agent, attr)

    logger.log_params(Args=vars(Args), Agent=vars(Agent), Adapt=vars(Adapt))

    # TEMP: check if the buffer is already available on s3 bucket
    # check_if_already_uploaded()

    try:
        snapshot = pJoin(Args.checkpoint_root, Adapt.snapshot_prefix, 'snapshot.pt')
        logger.print('Loading model from', snapshot)
        agent = logger.load_torch(path=snapshot)

    except RuntimeError as e:
        import os
        from .utils import get_gpumemory_info, get_cuda_variables

        logger.print(
            '==================================\n'
            f'Error: {e}\n'
            f'Errored node: {os.uname().nodename}\n'
            f'CUDA_VISIBLE_DEVICES: {os.environ.get("CUDA_VISIBLE_DEVICES")}\n'
            f'CUDA*: {get_cuda_variables()}\n'
            f'gpumem: {get_gpumemory_info()}\n'
            '=================================='
        )
        raise e

    # Show some stats of the pretrained model
    wall_time, frame, step = src_logger.read_metrics('wall_time@max', 'frame@max', 'step@max', default=None)
    pretrain_last_step = int(step.iloc[-1])
    logger.print('===== pre-training stats =====')
    logger.print(f'step\t{pretrain_last_step}')
    logger.print(f'frames\t{int(frame.iloc[-1])}')
    logger.print(f'wall_time\t{int(wall_time.iloc[-1])}')

    # Prepare the buffers

    # Generate offline trajectories
    # NOTE: Args.local_buffer was pointing to Args.snapshot_dir, which is /share/data/ripl/takuma/snapshots
    # shared_buffer_dir = Path(Adapt.local_buffer) / 'data-collection' / get_buffer_prefix()
    shared_buffer_dir = Path(Args.tmp_dir) / 'data-collection' / get_buffer_prefix()

    clean_buffer_dir = shared_buffer_dir / 'workdir' / "clean_latent_buffer"
    distr_buffer_dir = shared_buffer_dir / 'workdir' / "distr_obs_buffer"
    logger.print('clean_buffer_dir', clean_buffer_dir)
    logger.print('distr_buffer_dir', distr_buffer_dir)

    # Currently only this configuration is suppoorted.
    assert Adapt.policy_on_clean_buffer == "random" and Adapt.policy_on_distr_buffer == "random"

    target_path = pJoin(Args.checkpoint_root, get_buffer_prefix(latent_buffer=True), 'orig_latent_buffer.tar')
    flock_path = pJoin(Args.checkpoint_root, get_buffer_prefix(latent_buffer=True), 'orig_latent_buffer.lock')
    if True or should_generate_buffer(target_path, flock_path):
        collect_orig_latent_buffer(agent, clean_buffer_dir, pretrain_last_step, flock_path)

        logger.print(f'Compressing & Uploading the buffer: {target_path}')
        # upload_and_cleanup(distr_buffer_dir, target_path, flock_path, archive='tar')
        logger.print('Upload completed!')

    target_path = pJoin(Args.checkpoint_root, get_buffer_prefix(), 'target_obs_buffer.tar')
    flock_path = pJoin(Args.checkpoint_root, get_buffer_prefix(), 'target_obs_buffer.lock')
    if should_generate_buffer(target_path, flock_path):
        collect_target_obs_buffer(distr_buffer_dir, flock_path)

        logger.print(f'Compressing & Uploading the buffer: {target_path}')
        # upload_and_cleanup(distr_buffer_dir, target_path, flock_path, archive='tar')
        logger.print('Upload completed!')

    logger.print('=========== Completed!!! ==============')


if __name__ == '__main__':
    main()
