#!/usr/bin/env python3
import os
from os.path import join as pJoin
from pathlib import Path

import numpy as np
import torch
from gym import Wrapper
from . import utils


def upload_and_cleanup(buffer_dir, target_path, archive):
    import shutil
    from ml_logger import logger
    # Upload target-obs-buffer
    logger.start('uploading_tar')
    logger.upload_dir(buffer_dir, target_path, archive='tar')
    elapsed = logger.since('uploading_tar')
    logger.print(f'Upload took {elapsed} seconds')

    # local cleanup
    if buffer_dir.is_dir():
        logger.print('Removing the local buffer and archive...')
        shutil.rmtree(buffer_dir)


class SaveInfoWrapper(Wrapper):
    """Nothing but save info on clean environment"""
    def __init__(self, env, *args, **kwargs):
        super().__init__(env, *args, **kwargs)

    def reset(self):
        return self.env.reset()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        state = self.env.unwrapped.env.physics.get_state()

        # That this assertion holds actually means this wrapper is not necessary
        # But this is clearer since how / when the state is saved is explicit.
        assert (info['sim_state'] == state).all(), f"{info['sim_state']}\nvs\n{state}"

        # Overwrite sim_state
        info = {'sim_state': state}
        return obs, reward, done, info


class RailsWrapper(Wrapper):
    """This wrapper is injected to wrap DMEnv or DistractingEnv
    and completly fakes the output of step() and reset() so that the trajectory
    replays exactly what is registered in self._episode.
    """
    def __init__(self, env, *args, **kwargs):
        super().__init__(env, *args, **kwargs)
        self._episode = None
        self._counter = 0
        self._state_sequence = []

    def _get_obs_pixels(self):
        from distracting_control.gym_env import DistractingEnv
        from gym_dmc.dmc_env import DMCEnv

        env = self.env.unwrapped
        if isinstance(env, DMCEnv):
            obs = env._get_obs_pixels()
        elif isinstance(env, DistractingEnv):
            pixel_wrapper = env.env
            # obs = dmenv._task.get_observation(dmenv._physics)  # This is for non-pixel
            obs = pixel_wrapper.physics.render(**pixel_wrapper._render_kwargs)
            # pixels = dmenv._env.physics.render(**self._render_kwargs)
            if self.env.channels_first:
                obs = obs.transpose([2, 0, 1])
        else:
            raise ValueError('The env under RailsWrapper is invalid:', self.env.unwrapped)
        return obs

    def reset(self):
        self.env.reset()
        assert self._episode, 'self._episode is empty. self.set_episode(episode) must be called before reset.'
        self._counter = 0

        # self.env.set_state(self._initial_state)
        suite_env = self.env.unwrapped.env
        suite_env.physics.set_state(self._initial_state)  # NOTE: this is in replacement of env.set_state!!!
        return self._get_obs_pixels()

    def step(self, action):
        obs, _, _, info = self.env.step(action)

        state = self._state_sequence[self._counter]
        suite_env = self.env.unwrapped.env
        suite_env.physics.set_state(state)  # NOTE: Never use env.set_state! That will call step internally.

        self._counter += 1

        done = False
        if self._counter == len(self._state_sequence):
            done = True
            self._episode = None

        # NOTE:
        # I have no idea why the heck these two do not match when random seeds are different:
        # assert (suite_env.physics.get_state() == info['sim_state']).all()

        # Overwrite info since self.env.step does not work properly.
        info = {'sim_state': self.env.unwrapped.env.physics.get_state()}
        fake_reward = 0.0
        return obs, fake_reward, done, info

    def set_episode(self, episode, target=None):
        """Set the episode which RailsWrapper uses to replay.

        episode['state'][n] contains stacked states that correspond to frame-stacked observations.
        Always use the last element of the stacked state. That corresponds to the latest observation.
        """
        self._episode = episode

        # NOTE: during training, just after reset, all rows of stacked frames are identical.
        # episode['state'] contains sequence of stacked states only after the reset.
        # episode['state'][0] should look like [init_state, init_state, second_state]
        # episode['state'][1] would be         [init_state, second_state, third_state]
        self._initial_state = episode['state'][0][0]
        self._state_sequence = [stacked_states[-1] for stacked_states in episode['state']]

        dm_env = self.env.unwrapped.env

        # TODO: This does not work!!!
        # it does not reflect the values set.
        # HOWEVER, as long as we align random seed on two environments, goal positions are sampled in a same way.
        if 'Reacher' in self.env.unwrapped.spec.id:
            dm_env.physics.named.model.geom_size['target', 0] = episode['extra__goal_size']
            dm_env.physics.named.model.geom_size['target', 'x'] = episode['extra__goal_x']
            dm_env.physics.named.model.geom_size['target', 'y'] = episode['extra__goal_y']

        # print('len stacked states', [len(stacked_states) for stacked_states in episode['state']])


def collect_trajectories(agent, env, buffer_dir, num_steps, pretrain_last_step, replay,
                         random_policy=False,
                         encode_obs=False, augment=None, store_states=False, device='cuda', video_key='noname',
                         time_limit=None, update_job_status_every=5 * 60):
    """Run agent in env, and collect trajectory.
    Save each transition as (latent, action, reward, next_latent) tuple,
    reusing the original ReplayBuffer.
    """
    from pathlib import Path

    from ml_logger import logger

    from .config import Adapt, Args
    from .replay_buffer import Replay

    buffer_dir = Path(buffer_dir)

    def manipulate_obs(observation):
        # Observation manipulation
        if augment:
            num_augments = 4
            # obs: (9, 84, 84)
            # batch_obs: (4, 9, 84, 84)
            batch_obs = utils.repeat_batch(observation, batch_size=num_augments, device=device)
            observation = augment(batch_obs.float())

        if encode_obs:
            observation = agent.encode(observation, to_numpy=True).squeeze()

        return observation

    def calc_remaining_steps(num_steps):
        if buffer_dir.is_dir():
            import math
            from .adapt import num_transitions_per_file
            num_files = len(list(buffer_dir.glob('*.npz')))
            logger.print(f'{num_files} files are found in {buffer_dir}')
            last_steps = (num_files * num_transitions_per_file) / Args.action_repeat
            return num_steps - last_steps
        return num_steps

    remaining_steps = calc_remaining_steps(num_steps)
    assert remaining_steps >= 0, f'remaining_steps is negative!!: {remaining_steps}: {buffer_dir}'
    fill_buffer = remaining_steps > 0

    # Policy will not run if latent_buffer directory exists.
    if fill_buffer:
        logger.start('fill_buffer')
        logger.print(f"Needs to fill buffer: {buffer_dir}\nenvironment: {env}")
        logger.print(f'{remaining_steps} steps to go: {buffer_dir}')

        step = 0
        video_ep = 0
        num_video_episodes = 5
        while step < remaining_steps:
            obs = env.reset()
            # NOTE: Intentionally skip storin ghe first transition
            # since we don't have info['sim_state']
            # but this is handled. (Look at RailsWrapper._initial_state)

            # Check timelimit!!
            if time_limit and logger.since('run') > time_limit:
                logger.print(f'local time_limit: {time_limit} (sec) has reached!')
                # Cleanup
                num_files = len(list(buffer_dir.iterdir()))
                logger.print(f'{num_files} files have been generated in this run!')
                raise TimeoutError

            # Periodically update timestamp
            split = logger.split('fill_buffer')
            if split > update_job_status_every:
                logger.print(f'logger.split: {split}\tupdating job timestamp...')

            frames = []
            done = False

            # NOTE: Needs to store the goal info since that's not in 'state'
            if 'Reacher' in env.unwrapped.spec.id and store_states:
                print('saving extra info!!')
                goal_info = {
                    'goal_size': env.physics.named.model.geom_size['target', 0],
                    'goal_x': env.physics.named.model.geom_pos['target', 'x'],
                    'goal_y': env.physics.named.model.geom_pos['target', 'y']
                }
                print('goal_info', goal_info)
                replay.storage.add_extra(goal_info)

            while not done:
                if video_ep < num_video_episodes:
                    frames.append(
                        env.physics.render(height=256, width=256, camera_id=0)
                    )

                if agent is None or random_policy:
                    action = env.action_space.sample()
                else:
                    with torch.no_grad(), utils.eval_mode(agent):
                        action = agent.act(obs, pretrain_last_step, eval_mode=True)

                obs, reward, done, info = env.step(action)

                extra = {'state': info['sim_state']} if store_states else {}
                transition = dict(
                    obs=manipulate_obs(obs),
                    reward=reward,
                    done=done,
                    discount=1.0,
                    action=action,
                    **extra
                )

                eps_fn = replay.storage.add(**transition)
                step += 1

            if frames:
                frames.append(env.physics.render(height=256, width=256, camera_id=0))
                logger.save_video(frames, key=f"videos/{video_key}/{eps_fn[:-4]}.mp4")
                video_ep += 1

        assert (
            calc_remaining_steps(num_steps) == 0
        ), f"The number of files does not match!!: {buffer_dir}\ndetected remaining step: {calc_remaining_steps(num_steps)}"

    else:
        logger.print(f"Buffer is filled: {buffer_dir}")

        from .adapt import verify_local_buffer
        if not verify_local_buffer(buffer_dir):
            num_files = len(list(buffer_dir.iterdir()))
            raise RuntimeError(
                'Donwloaded buffer contains {num_files} files!! Aborting.'
            )

    return replay


def collect_orig_latent_buffer(agent, buffer_dir, pretrain_last_step=None):
    from ml_logger import logger
    from .replay_buffer import Replay
    from .env_helpers import get_env
    from .drqv2_invar import RandomShiftsAug
    from .config import Args, Agent, Adapt

    buffer_dir = Path(buffer_dir)
    replay = Replay(
        buffer_dir, buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size, num_workers=1,
        nstep=Args.nstep, discount=Args.discount, store_states=False
    )
    clean_buffer_seed = Args.seed + 1000
    clean_env = get_env(Args.train_env, Args.frame_stack, Args.action_repeat, clean_buffer_seed, save_info_wrapper=True)
    collect_trajectories(
        agent=agent, env=clean_env, buffer_dir=buffer_dir, num_steps=Adapt.latent_buffer_size // Args.action_repeat,
        pretrain_last_step=pretrain_last_step, replay=replay , store_states=False, augment=RandomShiftsAug(pad=4),
        video_key='original_latent_trajectory', encode_obs=True, random_policy=True,
        time_limit=Args.time_limit
    )


def collect_target_obs_buffer(buffer_dir, pretrain_last_step=None):
    from ml_logger import logger
    from .replay_buffer import Replay
    from .env_helpers import get_env
    from .config import Args, Agent, Adapt

    buffer_dir = Path(buffer_dir)
    replay = Replay(
        buffer_dir, buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size, num_workers=1,
        nstep=Args.nstep, discount=Args.discount, store_states=False
    )
    distr_env = get_env(Args.eval_env, Args.frame_stack, Args.action_repeat, Args.seed,
                        distraction_config=Adapt.distraction_types, rails_wrapper=False,
                        intensity=Adapt.distraction_intensity)
    collect_trajectories(
        agent=None, env=distr_env, buffer_dir=buffer_dir, num_steps=Adapt.latent_buffer_size // Args.action_repeat,
        pretrain_last_step=pretrain_last_step, replay=replay , store_states=False, augment=None,
        video_key='target_obs_trajectory', encode_obs=False, random_policy=True,
        time_limit=Args.time_limit
    )


def main(**kwargs):
    from ml_logger import logger, RUN
    from .utils import set_egl_id
    from .config import get_buffer_prefix, Args, Agent, Adapt, CollectData
    from .dmc_gen_config import DMCGENArgs
    from .adapt import startup, load_adaptation_agent

    from warnings import simplefilter  # noqa
    simplefilter(action='ignore', category=DeprecationWarning)

    logger.start('run')
    set_egl_id()
    utils.set_seed_everywhere(kwargs['seed'])
    Args._update(kwargs)
    Agent._update(kwargs)
    Adapt._update(kwargs)
    CollectData._update(kwargs)

    # Update config parameters based on kwargs
    startup(kwargs)

    logger.log_params(Args=vars(Args), Agent=vars(Agent), Adapt=vars(Adapt))


    # ===== Prepare the buffers =====

    # Generate offline trajectories
    # NOTE: Args.local_buffer was pointing to Args.snapshot_dir, which is /share/data/ripl/takuma/snapshots
    # shared_buffer_dir = Path(Adapt.local_buffer) / 'data-collection' / get_buffer_prefix(action_repeat=Args.action_repeat)
    shared_buffer_dir = Path(Args.tmp_dir) / 'data-collection' / get_buffer_prefix(action_repeat=Args.action_repeat)

    # Currently only this configuration is suppoorted.
    assert Adapt.policy_on_clean_buffer == "random" and Adapt.policy_on_distr_buffer == "random"

    if CollectData.buffer_type == 'original':
        clean_buffer_dir = shared_buffer_dir / 'workdir' / "clean_latent_buffer"
        logger.print('clean_buffer_dir', clean_buffer_dir)
        agent = load_adaptation_agent(**vars(Args), **vars(Agent), **vars(Adapt), dmcgen_args=DMCGENArgs)
        target_path = pJoin(Args.checkpoint_root, get_buffer_prefix(latent_buffer=True, action_repeat=Args.action_repeat), 'orig_latent_buffer.tar')
        if not logger.glob(target_path):
            collect_orig_latent_buffer(agent, clean_buffer_dir)

            logger.print(f'Compressing & Uploading the buffer: {target_path}')
            upload_and_cleanup(clean_buffer_dir, target_path, archive='tar')
            logger.print('Upload completed!')

    elif CollectData.buffer_type == 'target':
        distr_buffer_dir = shared_buffer_dir / 'workdir' / "distr_obs_buffer"
        logger.print('distr_buffer_dir', distr_buffer_dir)
        target_path = pJoin(Args.checkpoint_root, get_buffer_prefix(action_repeat=Args.action_repeat), 'target_obs_buffer.tar')
        if not logger.glob(target_path):
            collect_target_obs_buffer(distr_buffer_dir)

            logger.print(f'Compressing & Uploading the buffer: {target_path}')
            upload_and_cleanup(distr_buffer_dir, target_path, archive='tar')
            logger.print('Upload completed!')
    else:
        raise ValueError('invalid buffer type', CollectData.buffer_type)

    logger.print('=========== Completed!!! ==============')


if __name__ == '__main__':
    main()
