#!/usr/bin/env python3
from os.path import join as pJoin
import shutil
from pathlib import Path

import numpy as np
import torch
from params_proto.neo_proto import PrefixProto

from .utils import to_torch

num_transitions_per_file = 1000  # NOTE: hardcoded


class Progress(PrefixProto, cli=False):
    step = 0
    wall_time = 0
    frame = 0
    invm_pretrain_done = 0  # HACK: Don't use bool. it'll be passed to logger.log_metrics.


def save_snapshot(agent, checkpoint_root, completed=True):
    from ml_logger import logger

    fname = 'snapshot_last.pt' if not completed else 'snapshot.pt'
    # target_path = pJoin(Args.checkpoint_root, logger.prefix, fname)
    target_path = pJoin(checkpoint_root, logger.prefix, fname)
    logger.print('Saving agent at', target_path)
    logger.save_torch(agent, path=target_path)
    logger.log_params(Progress=vars(Progress), path="checkpoint.pkl", silent=True)


def load_snapshot(checkpoint_root):
    from ml_logger import logger
    fname = 'snapshot_last.pt'
    # target_path = pJoin(Args.checkpoint_root, logger.prefix, fname)
    target_path = pJoin(checkpoint_root, logger.prefix, fname)
    logger.print('Loading agent from', target_path)
    adapt_agent = logger.load_torch(path=target_path)
    return adapt_agent


def verify_local_buffer(buffer_dir, latent_buffer_size):
    import math
    if buffer_dir.is_dir():
        num_expected_files = math.ceil(latent_buffer_size / num_transitions_per_file)
        num_files = len(list(buffer_dir.iterdir()))
        if num_files == num_expected_files:
            return True
    return False


def train_inv_dynamics(Adapt, agent, distr_replay, latent_replay, device):
    from ml_logger import logger
    from .evaluate import evaluate_invm

    for step in range(Adapt.num_invm_steps):
        Progress.wall_time = logger.since('start')
        batch = next(latent_replay.iterator)
        latent, action, reward, discount, next_latent = to_torch(batch, device)
        agent.update_invm(latent, next_latent, action)
        if logger.every(Adapt.invm_eval_freq, key='invm_adaptation', start_on=1):
            logger.print('Running evaluation for invm pretraining...')
            logger.print('batch size len(batch)', len(batch))
            evaluate_invm(agent, latent_replay, distr_replay)
            logger.log(step=step, wall_time=Progress.wall_time)
            logger.flush()

        if logger.every(Adapt.invm_log_freq, key='invm_log'):
            logger.log_metrics_summary(
                key_values={'step': step, 'wall_time': Progress.wall_time}.items(),
                default_stats='mean'
            )


def adapt_offline(Adapt,
                  checkpoint_root, action_repeat, time_limit, batch_size, device,
                  orig_eval_env, targ_eval_env, agent, latent_replay, distr_replay,
                  num_eval_episodes=30, progress=Progress):
    """
    clean_env, distr_env --> only for evalutation purposes!
    """
    from ml_logger import logger
    from .utils import visualize_buffer_episodes
    from .evaluate import evaluate

    # NOTE: Should be False for slurm, and True for EC2.
    # I don't want to include this flag in config.py.
    # Since I will need to regenerate sweep.jsonl just to toggle this flag,
    # even when I just want to launch the same exact experiment on slurm.
    save_periodically = False

    fname = "videos/prefilled_distr_buffer/episode.mp4"
    start_vis = logger.since('start')
    if not logger.glob(fname):
        visualize_buffer_episodes(distr_replay, fname, device)
    after_vis = logger.since('start')
    logger.print('visualization took', after_vis - start_vis)

    if not progress.invm_pretrain_done:
        logger.print('Starting invm pretraining...')
        with logger.Prefix(metrics="invm_pretrain"):
            train_inv_dynamics(Adapt, agent, distr_replay, latent_replay, device)
        logger.print('invm pretraining has finished')

        progress.invm_pretrain_done = 1  # HACK: Don't use bool. it'll be passed to logger.log_metrics.
        logger.print(f'Saving snapshot...\t{vars(progress)}')
        save_snapshot(agent, checkpoint_root, completed=False)
        logger.print('Saving snapshot: Done!')

    for progress.step in range(progress.step, Adapt.num_adapt_steps):
        progress.wall_time = logger.since('start')
        progress.frame = progress.step * batch_size

        if time_limit and logger.since('run') > time_limit:
            logger.print(f'local time_limit: {time_limit} (sec) has reached!')
            logger.print('Saving snapshot...\t{vars(progress)}')
            save_snapshot(agent, checkpoint_root, completed=False)
            logger.print('Saving snapshot: Done!')
            raise TimeoutError

        if save_periodically and logger.every(Adapt.adapt_save_every, key='adapt_save_every'):
            logger.print('Saving snapshot...\t{vars(progress)}')
            save_snapshot(agent, checkpoint_root, completed=False)
            logger.print('Saving snapshot: Done!')

        if logger.every(Adapt.adapt_log_freq, key="adapt_log", start_on=1):
            logger.log_metrics_summary(key_values=vars(progress), default_stats='mean')

        # Evaluate current agent
        if logger.every(Adapt.adapt_eval_every_steps, key="adapt_eval", start_on=1):
            logger.print('========== evaluating =======')
            logger.print('orig_eval_env', orig_eval_env)
            logger.print('targ_eval_env', targ_eval_env)
            evaluate(
                agent, orig_eval_env, targ_eval_env, latent_replay, distr_replay,
                action_repeat,
                num_eval_episodes=num_eval_episodes,
                progress=progress,
                eval_invm=(progress.step * action_repeat > Adapt.num_adapt_seed_frames)
            )

        if Adapt.gan_lr > 0:
            # Adversarial training
            for _ in range(Adapt.num_discr_updates):
                batch = next(distr_replay.iterator)
                obs, action, reward, discount, next_obs = to_torch(batch, device)

                distr_latent = agent.encode(obs, augment=True, to_numpy=False)
                agent.update_discriminator(latent_replay.iterator, distr_latent, improved_wgan=Adapt.improved_wgan)

            agent.update_encoder(obs)
        else:
            batch = next(distr_replay.iterator)
            obs, action, reward, discount, next_obs = to_torch(batch, device)

        # Update inverse dynamics
        if Adapt.invm_lr > 0:
            agent.update_encoder_with_invm(obs, action, next_obs)

    logger.print('===== Training completed!! =====')
    logger.print('Saving snapshot...')
    save_snapshot(agent, checkpoint_root, completed=True)
    logger.print('Done!')


def startup(kwargs):
    """
    Load meta-info of the pre-trained agent (Adapt.snapshot_prefix), and
    update config parameters.
    """
    from ml_logger import ML_Logger, logger, RUN

    from .env_helpers import get_env
    from .utils import get_distr_string, set_egl_id, update_args, set_seed_everywhere
    from .agent import AdaptationAgent

    from .config import Adapt

    set_egl_id()
    set_seed_everywhere(kwargs['seed'])

    Adapt._update(kwargs)

    if Adapt.snapshot_prefix:
        src_logger = ML_Logger(
            prefix=Adapt.snapshot_prefix
        )

        assert src_logger.prefix != logger.prefix
        logger.print('src_logger.prefix', src_logger.prefix)
        logger.print('logger.prefix', logger.prefix)

        # Check if the job is completed
        try:
            completion_time = src_logger.read_params('job.completionTime')
            logger.print(f'Pre-training was completed at {completion_time}.')
        except KeyError:
            logger.print(f'training for {src_logger.prefix} has not been finished yet.')
            logger.print('job.completionTime was not found.')
            if not RUN.debug:
                raise RuntimeError

        # Load from the checkpoint
        assert RUN.debug or src_logger.glob('checkpoint.pkl')

    # TODO: According to the value of Adapt.agent_stem, perform conditional import!
    if Adapt.agent_stem == 'dmc_gen':
        from dmc_gen.config import Args
        Args._update(kwargs)
        logger.log_params(Args=vars(Args), Adapt=vars(Adapt))
    elif Adapt.agent_stem == 'drqv2':
        from drqv2_invariance.config import Args, Agent
        Args._update(kwargs)
        Agent._update(kwargs)

        # Update parameters
        logger.print('algorithm:', Agent.algorithm)
        if "src_logger" in locals():
            keep_args = ['eval_env', 'seed', 'tmp_dir', 'checkpoint_root', 'time_limit']
            Args._update({key: val for key, val in src_logger.read_params("Args").items() if key not in keep_args})
            Agent._update(**src_logger.read_params('Agent'))
        logger.log_params(Args=vars(Args), Agent=vars(Agent), Adapt=vars(Adapt))

    return Args, Adapt



def load_drqv2_agent(Args, Agent, Adapt):
    """
    Load from the pretrained agent specified by snapshot_prefix.
    """
    from ml_logger import logger
    from .env_helpers import get_env
    from .agent import AdaptationAgent
    snapshot_dir = pJoin(Args.checkpoint_root, Adapt.snapshot_prefix)

    from .dummy_actors import DrQV2DummyActor
    dummy_env = get_env(Args.train_env, Args.frame_stack, Args.action_repeat, Args.seed)
    action_shape = dummy_env.action_space.shape
    obs_shape = dummy_env.observation_space.shape

    snapshot_path = pJoin(snapshot_dir, 'snapshot.pt')
    logger.print('Loading model from', snapshot_path)
    pret_agent = logger.load_torch(path=snapshot_path)
    encoder = torch.nn.Sequential(pret_agent.encoder, pret_agent.actor.trunk)
    config = dict(
        encoder=encoder,
        actor_from_obs=DrQV2DummyActor(pret_agent),
        feature_dim=Agent.feature_dim
    )

    adapt_agent = AdaptationAgent(
        **config,
        invm_head=None,
        algorithm='drqv2',
        action_shape=action_shape,
        hidden_dim=Adapt.adapt_hidden_dim,
        device=Args.device
    )
    logger.print('done')

    return adapt_agent


def load_dmcgen_agent(Args, Adapt):
    from ml_logger import logger
    from .env_helpers import get_env
    from .agent import AdaptationAgent
    snapshot_dir = pJoin(Args.checkpoint_root, Adapt.snapshot_prefix)

    # dmc-gen repository
    from .dummy_actors import DMCGENDummyActor
    from dmc_gen.algorithms.factory import make_agent
    dummy_env = get_env(Args.env_name, Args.frame_stack, Args.action_repeat, Args.seed,
                        size=Args.image_size)
    action_shape = dummy_env.action_space.shape
    obs_shape = dummy_env.observation_space.shape

    snapshot_path = pJoin(snapshot_dir, 'actor_state.pt')
    logger.print('Loading model from', snapshot_path)
    actor_state = logger.load_torch(path=snapshot_path)

    mock_agent = make_agent(obs_shape, action_shape, Args)
    mock_agent.actor.load_state_dict(actor_state)
    # def make_agent(obs_shape, action_shape, args):
    config = dict(
        encoder=mock_agent.actor.encoder,
        actor_from_obs=DMCGENDummyActor(mock_agent.actor),
        feature_dim=Args.projection_dim
    )

    adapt_agent = AdaptationAgent(
        **config,
        invm_head=None,
        algorithm=Adapt.agent_stem,
        action_shape=action_shape,
        hidden_dim=Adapt.adapt_hidden_dim,
        device=Args.device
    )
    logger.print('done')

    return adapt_agent


def prepare_buffers(checkpoint_root, tmp_dir, train_env, eval_env, action_repeat, seed, Adapt):
    import math
    from ml_logger import logger
    from .utils import get_distr_string, get_buffer_prefix

    orig_buffer_seed = seed + 1000

    # NOTE: Args.local_buffer was pointing to Args.snapshot_dir, which is /share/data/ripl/takuma/snapshots
    # shared_buffer_dir = Path(Args.local_buffer) / get_buffer_prefix() / f"{'offline' if Adapt.adapt_offline else 'online'}"
    logger.print('prep buffers eval_env', eval_env)
    shared_buffer_dir = Path(tmp_dir) / get_buffer_prefix(
        train_env, action_repeat, seed, Adapt, eval_env=eval_env) / f"{'offline' if Adapt.adapt_offline else 'online'}"

    buffer_dir_prefix = Path(tmp_dir) / 'data-collection' / get_buffer_prefix(
        train_env, action_repeat, seed, Adapt, eval_env=eval_env, latent_buffer=True
    )
    orig_buffer_dir = buffer_dir_prefix / 'workdir' / 'orig_latent_buffer'

    buffer_dir_prefix = Path(tmp_dir) / 'data-collection' / get_buffer_prefix(
        train_env, action_repeat, seed, Adapt, eval_env=eval_env
    )
    targ_buffer_dir = buffer_dir_prefix / 'workdir' / "targ_obs_buffer"
    logger.print('org_buffer_dir', orig_buffer_dir)
    logger.print('targ_buffer_dir', targ_buffer_dir)

    # Offline training
    assert Adapt.adapt_offline, "online adaptation is not supported right now"
    assert Adapt.download_buffer
    (shared_buffer_dir / "workdir").mkdir(parents=True, exist_ok=True)

    remote_orig_buffer = pJoin(checkpoint_root, get_buffer_prefix(
        train_env, action_repeat, seed, Adapt, eval_env=eval_env, latent_buffer=True), 'orig_latent_buffer.tar')
    remote_targ_buffer = pJoin(checkpoint_root, get_buffer_prefix(
        train_env, action_repeat, seed, Adapt, eval_env=eval_env), 'targ_obs_buffer.tar')
    assert logger.glob(remote_orig_buffer), f'orig latent buffer is not found at {remote_orig_buffer}'
    assert logger.glob(remote_targ_buffer), f'targ obs buffer is not found at {remote_targ_buffer}'

    logger.print('org_latent_buffer', orig_buffer_dir)
    logger.print('verify_local_buff', verify_local_buffer(orig_buffer_dir, Adapt.latent_buffer_size))
    logger.print('targ_obs_buffer', targ_buffer_dir)
    logger.print('verify_local_buff', verify_local_buffer(targ_buffer_dir, Adapt.latent_buffer_size))
    if not verify_local_buffer(orig_buffer_dir, Adapt.latent_buffer_size):
        logger.print('downloading & unpacking buffer archive from', remote_orig_buffer)
        # logger.download_dir(remote_orig_buffer, to=orig_buffer_dir.parents[0], unpack='tar')
        logger.download_dir(remote_orig_buffer, to=orig_buffer_dir, unpack='tar')

    if not verify_local_buffer(targ_buffer_dir, Adapt.latent_buffer_size):
        logger.print('downloading & unpacking buffer archive from', remote_targ_buffer)
        logger.download_dir(remote_targ_buffer, to=targ_buffer_dir, unpack='tar')

    # Verify the number of files in the downloaded buffer
    num_expected_files = math.ceil(Adapt.latent_buffer_size / num_transitions_per_file)
    num_clean_files = len(list(orig_buffer_dir.iterdir()))
    num_distr_files = len(list(targ_buffer_dir.iterdir()))
    if not (num_clean_files == num_expected_files and num_distr_files == num_expected_files):
        raise RuntimeError(
            f'Donwloaded buffer does not contain {num_expected_files} files as expected.\n'
            f'Instead {num_clean_files} files found in {orig_buffer_dir}\n'
            f'and {num_distr_files} files found in {targ_buffer_dir}\n'
        )
    return orig_buffer_dir, targ_buffer_dir


