#!/usr/bin/env python3
from os.path import join as pJoin
import shutil
from pathlib import Path

import numpy as np
import torch
from params_proto.neo_proto import PrefixProto

from . import utils
from .config import Adapt, Agent, Args

num_transitions_per_file = 1000  # NOTE: hardcoded


class Progress(PrefixProto, cli=False):
    step = 0
    wall_time = 0
    frame = 0
    invm_pretrain_done = 0  # HACK: Don't use bool. it'll be passed to logger.log_metrics.


def save_snapshot(agent, completed=True):
    from ml_logger import logger
    from .config import get_buffer_prefix

    fname = 'snapshot_last.pt' if not completed else 'snapshot.pt'
    target_path = pJoin(Args.checkpoint_root, logger.prefix, fname)
    logger.print('Saving agent at', target_path)
    logger.save_torch(agent, path=target_path)
    logger.log_params(Progress=vars(Progress), path="checkpoint.pkl", silent=True)


def load_snapshot():
    from ml_logger import logger
    from .config import get_buffer_prefix
    fname = 'snapshot_last.pt'
    target_path = pJoin(Args.checkpoint_root, logger.prefix, fname)
    logger.print('Loading agent from', target_path)
    adapt_agent = logger.load_torch(path=target_path)
    return adapt_agent


def verify_local_buffer(buffer_dir):
    import math
    if buffer_dir.is_dir():
        num_expected_files = math.ceil(Adapt.latent_buffer_size / num_transitions_per_file)
        num_files = len(list(buffer_dir.iterdir()))
        if num_files == num_expected_files:
            return True
    return False


def get_adapt_agent_config(snapshot_dir, algorithm, obs_shape, action_shape, dmcgen_args=None):
    from ml_logger import logger
    from .dummy_actors import DMCGENDummyActor, DrQV2DummyActor

    if algorithm == 'drqv2':
        snapshot_path = pJoin(snapshot_dir, 'snapshot.pt')
        logger.print('Loading model from', snapshot_path)
        pret_agent = logger.load_torch(path=snapshot_path)
        encoder = torch.nn.Sequential(pret_agent.encoder, pret_agent.actor.trunk)
        config = dict(
            encoder=encoder, actor_from_obs=DrQV2DummyActor(pret_agent),
            invm_head=None, algorithm=algorithm, action_shape=action_shape
        )
    elif Agent.algorithm in ['sac', 'svea', 'pad', 'soda']:
        # TODO: CHeck if it works with SODA!
        # NOTE: Directly loading agent as in drqv2 is not possible because these agents
        # are trained under dmc_gen directory that contains different set of Args in its config.py
        # Thus I extracted actor's state_dict and uploaded it as actor_state.pt.
        # We instantiate a mock dmc_gen agent using DMCGENArgs, and run mock_agent.actor.load_state_dict()
        snapshot_path = pJoin(snapshot_dir, 'actor_state.pt')
        logger.print('Loading model from', snapshot_path)
        actor_state = logger.load_torch(path=snapshot_path)

        from dmc_gen.algorithms.factory import make_agent
        mock_agent = make_agent(obs_shape, action_shape, dmcgen_args)
        mock_agent.actor.load_state_dict(actor_state)
        # def make_agent(obs_shape, action_shape, args):
        config = dict(
            encoder=mock_agent.actor.encoder, actor_from_obs=DMCGENDummyActor(mock_agent.actor),
            invm_head=None, algorithm=algorithm, action_shape=action_shape
        )
    else:
        raise ValueError(f'invalid algorithm: {Agent.algorithm}')

    return config


@torch.no_grad()
def evaluate_invm(agent, latent_replay, distr_replay):
    from ml_logger import logger

    n_samples = 512
    for _ in range(n_samples // Args.batch_size):
        batch = next(distr_replay.iterator)
        obs, action, reward, discount, next_obs = utils.to_torch(batch, agent.device)
        # obss, actions, rewards, next_obss = get_batch(distr_replay)
        latent = agent.encode(obs, to_numpy=False)
        next_latent = agent.encode(next_obs, to_numpy=False)
        with logger.Prefix(metrics="distr_env"):
            invm_loss = agent.get_invm_loss(latent, next_latent, action)
            logger.log(invm_loss=invm_loss.item())

            # if Adapt.inv_focal_weight:
            #     # Compute discriminator loss
            #     score = agent.discriminator_score(latent)
            #     logger.log(focal_coef=score.mean().item())

    for _ in range(n_samples // Args.batch_size):
        batch = next(latent_replay.iterator)
        latent, action, reward, discount, next_latent = utils.to_torch(batch, agent.device)

        # TEMP:
        # latent = latent.squeeze()
        # next_latent = next_latent.squeeze()

        # Sample one out of four in each instance
        # NOTE: latent: (batch_size, 4, 1024)
        batch_size, num_augments, _ = latent.shape
        latent = latent.mean(dim=1)
        next_latent = next_latent.mean(dim=1)

        with logger.Prefix(metrics="clean_env"):
            invm_loss = agent.get_invm_loss(latent, next_latent, action)
            logger.log(invm_loss=invm_loss.item())

            # if Adapt.inv_focal_weight:
            #     # Compute discriminator loss
            #     score = agent.discriminator_score(latent)
            #     logger.log(focal_coef=score.mean().item())


def compare_latent_spaces(agent, latent_replay, distr_replay):
    from ml_logger import logger
    num_samples = 1024

    distr_latents = torch.vstack(
        [
            agent.encode(next(distr_replay.iterator)[0], to_numpy=False)
            for _ in range(num_samples // Args.batch_size)
        ]
    )

    def get_clean_latent():
        clean_latents = torch.as_tensor(next(latent_replay.iterator)[0], device=Args.device).squeeze()
        if Adapt.augment:
            clean_latents = clean_latents.sum(dim=1)
        return clean_latents

    # Get all latents from the buffer
    clean_latents = torch.vstack([get_clean_latent() for _ in range(num_samples // Args.batch_size)])
    assert clean_latents.shape == distr_latents.shape, f'{clean_latents.shape}\nvs\n{distr_latents.shape}'

    mean_err = ((clean_latents.mean(0) - distr_latents.mean(0)) ** 2).mean()
    stddev_err = ((clean_latents.std() - distr_latents.std()) ** 2).mean()
    stddev_diff = (clean_latents.std() - distr_latents.std()).mean()

    # Log mean_error and stddev_error
    logger.log(mean_err=mean_err.item(),
               stddev_err=stddev_err.item(),
               stddev_diff=stddev_diff.item())


def evaluate(agent, clean_eval_env, distr_eval_env, latent_replay, distr_replay, expl_agent=None):
    from ml_logger import logger
    from .train import eval

    if Progress.step * Args.action_repeat > Adapt.num_adapt_seed_frames:
        # with logger.Prefix(metrics="eval"):
        #     # Compare latent mean and stddev
        #     logger.print('evaluating clean latent space')
        #     compare_latent_spaces(agent if expl_agent is None else expl_agent, latent_replay, distr_replay)

        # Evaluate inv_dynamics head on clean_env latent
        # NOTE: only use encoder
        logger.print('evaluating invm')
        with logger.Prefix(metrics="eval"):
            evaluate_invm(agent if expl_agent is None else expl_agent, latent_replay, distr_replay)

    if expl_agent:
        # Replace agent encoder weights
        org_enc_state = agent.encoder.state_dict()
        agent.encoder.load_state_dict(expl_agent.encoder.state_dict())

    # Evaluation on clean env
    with logger.Prefix(metrics="clean_eval"):
        logger.print('evaluating clean env')
        path = f'videos/clean/{Progress.step:09d}_eval.mp4'
        stoch_path = None
        eval(clean_eval_env, agent, Progress.step, to_video=path if Args.save_video else None,
             stochastic_video=stoch_path)

    # Evaluation on distr env
    with logger.Prefix(metrics="eval"):
        logger.print('evaluating distr env')
        path = f'videos/distr/{Progress.step:09d}_eval.mp4'
        stoch_path = None
        eval(distr_eval_env, agent, Progress.step, to_video=path if Args.save_video else None,
             stochastic_video=stoch_path)

    if expl_agent:
        # with logger.Prefix(metrics="clean_eval_original"):
        #     logger.print('evaluating clean env with original policy')
        #     path = f'videos/clean-with-original-policy/{Progress.step:09d}_eval.mp4'
        #     stoch_path = f'videos/clean-with-original-policy-stoch/{Progress.step:09d}_eval.mp4'
        #     eval(clean_eval_env, agent, Progress.step, to_video=path if Args.save_video else None,
        #          stochastic_video=stoch_path)
        with logger.Prefix(metrics="eval_expl"):
            logger.print('evaluating distr env with original policy')
            path = f'videos/distr-with-expl-policy/{Progress.step:09d}_eval.mp4'
            stoch_path = None
            eval(distr_eval_env, expl_agent, Progress.step, to_video=path if Args.save_video else None,
                 stochastic_video=stoch_path)

        # Put back the original encoder weights
        agent.encoder.load_state_dict(org_enc_state)

    logger.log(**vars(Progress))
    logger.flush()


def train_inv_dynamics(agent, distr_replay, latent_replay):
    from ml_logger import logger

    for step in range(Adapt.num_invm_steps):
        Progress.wall_time = logger.since('start')
        batch = next(latent_replay.iterator)
        latent, action, reward, discount, next_latent = utils.to_torch(batch, Args.device)
        agent.update_invm(latent, next_latent, action)
        if logger.every(Adapt.invm_eval_freq, key='invm_adaptation', start_on=1):
            logger.print('Running evaluation for invm pretraining...')
            evaluate_invm(agent, latent_replay, distr_replay)
            logger.log(step=step, wall_time=Progress.wall_time)
            logger.flush()

        if logger.every(Adapt.invm_log_freq, key='invm_log'):
            logger.log_metrics_summary(
                key_values={'step': step, 'wall_time': Progress.wall_time}.items(),
                default_stats='mean'
            )


def adapt_offline(clean_eval_env, distr_eval_env, agent, latent_replay, distr_replay):
    """
    clean_env, distr_env --> only for evalutation purposes!
    """
    from ml_logger import logger
    from .config import get_buffer_prefix
    from .utils import visualize_buffer_episodes

    # NOTE: Should be False for slurm, and True for EC2.
    # I don't want to include this flag in config.py.
    # Since I will need to regenerate sweep.jsonl just to toggle this flag,
    # even when I just want to launch the same exact experiment on slurm.
    save_periodically = False

    fname = "videos/prefilled_distr_buffer/episode.mp4"
    start_vis = logger.since('start')
    if not logger.glob(fname):
        visualize_buffer_episodes(distr_replay, fname)
    after_vis = logger.since('start')
    logger.print('visualization took', after_vis - start_vis)

    if not Progress.invm_pretrain_done:
        logger.print('Starting invm pretraining...')
        with logger.Prefix(metrics="invm_pretrain"):
            train_inv_dynamics(agent, distr_replay, latent_replay)
        logger.print('invm pretraining has finished')

        Progress.invm_pretrain_done = 1  # HACK: Don't use bool. it'll be passed to logger.log_metrics.
        logger.print(f'Saving snapshot...\t{vars(Progress)}')
        save_snapshot(agent, completed=False)
        logger.print('Saving snapshot: Done!')

    for Progress.step in range(Progress.step, Adapt.num_adapt_steps):
        Progress.wall_time = logger.since('start')
        Progress.frame = Progress.step * Args.batch_size

        if Args.time_limit and logger.since('run') > Args.time_limit:
            logger.print(f'local time_limit: {Args.time_limit} (sec) has reached!')
            logger.print('Saving snapshot...\t{vars(Progress)}')
            save_snapshot(agent, completed=False)
            logger.print('Saving snapshot: Done!')
            raise TimeoutError

        if save_periodically and logger.every(Adapt.adapt_save_every, key='adapt_save_every'):
            logger.print('Saving snapshot...\t{vars(Progress)}')
            save_snapshot(agent, completed=False)
            logger.print('Saving snapshot: Done!')

        if logger.every(Adapt.adapt_log_freq, key="adapt_log", start_on=1):
            logger.log_metrics_summary(key_values=vars(Progress), default_stats='mean')

        # Evaluate current agent
        if logger.every(Adapt.adapt_eval_every_steps, key="adapt_eval", start_on=1):
            evaluate(agent, clean_eval_env, distr_eval_env, latent_replay, distr_replay)

        if Adapt.gan_lr > 0:
            # Adversarial training
            for _ in range(Adapt.num_discr_updates):
                batch = next(distr_replay.iterator)
                obs, action, reward, discount, next_obs = utils.to_torch(batch, Args.device)

                distr_latent = agent.encode(obs, augment=True, to_numpy=False)
                agent.update_discriminator(latent_replay.iterator, distr_latent, improved_wgan=Adapt.improved_wgan)

            agent.update_encoder(obs)
        else:
            batch = next(distr_replay.iterator)
            obs, action, reward, discount, next_obs = utils.to_torch(batch, Args.device)

        # Update inverse dynamics
        if Adapt.invm_lr > 0:
            agent.update_encoder_with_invm(obs, action, next_obs)

    logger.print('===== Training completed!! =====')
    logger.print('Saving snapshot...')
    save_snapshot(agent, completed=True)
    logger.print('Done!')


def startup(kwargs):
    """
    Load meta-info of the pre-trained agent (Adapt.snapshot_prefix), and
    update config parameters.
    """
    from ml_logger import ML_Logger, logger, RUN

    from .config import Args, Agent, Adapt
    from .dmc_gen_config import DMCGENArgs
    from .env_helpers import get_env
    from .replay_buffer import LatentReplay, Replay
    from .utils import get_distr_string, set_egl_id, update_args
    from .adapt import get_adapt_agent_config
    from .drqv2_invar import AdaptationAgent

    set_egl_id()
    utils.set_seed_everywhere(kwargs['seed'])
    Args._update(kwargs)
    Agent._update(kwargs)
    Adapt._update(kwargs)

    src_logger = ML_Logger(
        prefix=Adapt.snapshot_prefix
    )

    assert src_logger.prefix != logger.prefix
    logger.print('src_logger.prefix', src_logger.prefix)
    logger.print('logger.prefix', logger.prefix)

    # Check if the job is completed
    try:
        completion_time = src_logger.read_params('job.completionTime')
        logger.print(f'Pre-training was completed at {completion_time}.')
    except KeyError:
        logger.print(f'training for {logger.prefix} has not been finished yet.')
        logger.print('job.completionTime was not found.')
        if not RUN.debug:
            raise RuntimeError

    # Load from the checkpoint
    assert RUN.debug or src_logger.glob('checkpoint.pkl')

    # Update parameters
    logger.print('algorithm:', Agent.algorithm)
    if Agent.algorithm == 'drqv2':
        # drqv2_invariance
        keep_args = ['eval_env', 'seed', 'tmp_dir', 'checkpoint_root', 'time_limit']
        Args._update({key: val for key, val in src_logger.read_params("Args").items() if key not in keep_args})
        Agent._update(**src_logger.read_params('Agent'))
    elif Agent.algorithm in ['soda', 'pad', 'svea', 'sac']:
        # dmc_gen
        logger.print('loading args to DMCGENArgs...')
        update_args(DMCGENArgs, src_logger.read_params('Args'))
        # DMCGENArgs._update(**src_logger.read_params('Args'))  # NOTE: This is impossible since DMCGENArgs *cannot be* ParamsProto.
    else:
        raise ValueError(f'Invalid algorithm: {Agent.algorithm}')

    deprecated_attrs = ['use_attn', 'use_adapt']
    for attr in deprecated_attrs:
        if hasattr(Agent, attr):
            delattr(Agent, attr)


def load_adaptation_agent(
        algorithm, checkpoint_root, snapshot_prefix, train_env, frame_stack, action_repeat, seed,
        feature_dim, adapt_hidden_dim, device, dmcgen_args, **ignored_kwargs
):
    """
    Load from the pretrained agent specified by snapshot_prefix.
    """
    from ml_logger import logger
    from .env_helpers import get_env
    from .drqv2_invar import AdaptationAgent

    try:
        snapshot_dir = pJoin(checkpoint_root, snapshot_prefix)

        dummy_env = get_env(train_env, frame_stack, action_repeat, seed)
        action_shape = dummy_env.action_space.shape
        obs_shape = dummy_env.observation_space.shape
        config = get_adapt_agent_config(snapshot_dir, algorithm, obs_shape, action_shape, dmcgen_args=dmcgen_args)
        adapt_agent = AdaptationAgent(
            **config,
            feature_dim=feature_dim,
            hidden_dim=adapt_hidden_dim,
            device=device
        )
        logger.print('done')

    except RuntimeError as e:
        import os
        from datetime import datetime
        from .utils import get_cuda_variables

        message = (
            '==================================\n'
            f'Error: {e}\n'
            f'Errored node: {os.uname().nodename}\n'
            f'CUDA_VISIBLE_DEVICES: {os.environ.get("CUDA_VISIBLE_DEVICES")}\n'
            f'CUDA*: {get_cuda_variables()}\n'
            '=================================='
        )
        logger.print(message)
        raise e

    return adapt_agent


def main(**kwargs):
    import math
    from ml_logger import logger, RUN
    from .utils import get_distr_string
    from .config import get_buffer_prefix, Args, Agent, Adapt
    from .dmc_gen_config import DMCGENArgs
    from .env_helpers import get_env
    from .replay_buffer import Replay, LatentReplay

    if not torch.cuda.is_available():
        import os
        from .utils import get_cuda_variables

        logger.print(
            '==================================\n'
            f'Errored node: {os.uname().nodename}\n'
            f'CUDA_VISIBLE_DEVICES: {os.environ.get("CUDA_VISIBLE_DEVICES")}\n'
            f'CUDA*: {get_cuda_variables()}\n'
            '=================================='
        )

        raise RuntimeError('torch.cuda.is_available() is False!!')

    logger.start('run', 'start')

    # Update config parameters based on kwargs
    startup(kwargs)

    logger.log_params(Args=vars(Args), Agent=vars(Agent), Adapt=vars(Adapt))
    logger.log_text("""
            keys:
            - Args.train_env
            - Args.seed
            charts:
            - yKey: eval/episode_reward
              xKey: step
            - yKey: clean_eval/episode_reward
              xKey: step
            - yKey: encoder_loss/mean
              xKey: step
            - yKey: distr_env/ss_loss
              xKey: step
            - yKey: clean_env/ss_loss
              xKey: step
            - yKey: focal_coef/mean
              xKey: step
            - yKeys: ["distr_env/focal_coef/mean", "clean_env/focal_coef/mean"]
              xKey: step
            - yKeys: ["discriminator_loss/mean", "discr_adpt_loss/mean", "discr_org_loss/mean"]
              xKey: step
            - yKey: adpt_reward/mean
              xKey: step
            - yKey: grad_penalty/mean
              xKey: step
            - yKey: ss_loss/mean
              xKey: frame
            - yKey: eval/mean_err
              xKey: step
            - yKeys: ["eval/stddev_err", "eval/stddev_diff"]
              xKey: step
            """, ".charts.yml", overwrite=True, dedent=True)


    # ===== Prepare the buffers =====
    clean_buffer_seed = Args.seed + 1000

    # NOTE: Args.local_buffer was pointing to Args.snapshot_dir, which is /share/data/ripl/takuma/snapshots
    # shared_buffer_dir = Path(Args.local_buffer) / get_buffer_prefix() / f"{'offline' if Adapt.adapt_offline else 'online'}"
    shared_buffer_dir = Path(Args.tmp_dir) / get_buffer_prefix() / f"{'offline' if Adapt.adapt_offline else 'online'}"
    clean_buffer_dir = shared_buffer_dir / 'workdir' / f"clean_latent_buffer/{clean_buffer_seed}"
    distr_str = get_distr_string(Args.eval_env, Adapt.distraction_types)
    distr_buffer_dir = shared_buffer_dir / 'workdir' / f"distr_obs_buffer/{distr_str}/{Args.seed}"
    logger.print('clean_buffer_dir', clean_buffer_dir)
    logger.print('distr_buffer_dir', distr_buffer_dir)

    # Offline training
    assert Adapt.adapt_offline, "online adaptation is not supported right now"
    assert Adapt.download_buffer
    (shared_buffer_dir / "workdir").mkdir(parents=True, exist_ok=True)

    assert Args.checkpoint_root.startswith('s3://')
    s3_orig_buffer = pJoin(Args.checkpoint_root, get_buffer_prefix(latent_buffer=True), 'orig_latent_buffer.tar')
    s3_targ_buffer = pJoin(Args.checkpoint_root, get_buffer_prefix(), 'target_obs_buffer.tar')
    assert logger.glob_s3(s3_orig_buffer[5:]), f'orig latent buffer is not found at {s3_orig_buffer}'
    assert logger.glob_s3(s3_targ_buffer[5:]), f'targ obs buffer is not found at {s3_targ_buffer}'

    # TODO: haven't tested if it works yet
    if not verify_local_buffer(clean_buffer_dir):
        logger.print('downloading & unpacking buffer archive from', s3_orig_buffer)
        # logger.download_dir(s3_orig_buffer, to=clean_buffer_dir.parents[0], unpack='tar')
        logger.download_dir(s3_orig_buffer, to=clean_buffer_dir, unpack='tar')

    if not verify_local_buffer(distr_buffer_dir):
        logger.print('downloading & unpacking buffer archive from', s3_targ_buffer)
        logger.download_dir(s3_targ_buffer, to=distr_buffer_dir, unpack='tar')

    # Verify the number of files in the downloaded buffer
    num_expected_files = math.ceil(Adapt.latent_buffer_size / num_transitions_per_file)
    num_clean_files = len(list(clean_buffer_dir.iterdir()))
    num_distr_files = len(list(distr_buffer_dir.iterdir()))
    if not (num_clean_files == num_expected_files and num_distr_files == num_expected_files):
        raise RuntimeError(
            f'Donwloaded buffer does not contain {num_expected_files} files as expected.\n'
            f'Instead {num_clean_files} files found in {clean_buffer_dir}\n'
            f'and {num_distr_files} files found in {distr_buffer_dir}\n'
        )


    # ===== make or load adaptation-agent =====
    if RUN.resume and logger.glob('checkpoint.pkl'):

        # Verify that arguments are consistent
        from .utils import verify_args
        loaded_args = logger.read_params("Args")
        loaded_agent_args = logger.read_params('Agent')
        verify_args(vars(Args), loaded_args)
        verify_args(vars(Agent), loaded_agent_args)

        Progress._update(logger.read_params(path="checkpoint.pkl"))
        try:
            # Adaptation has been already completed
            completion_time = logger.read_params('job.completionTime')
            logger.print(f'job.completionTime is set:{completion_time}\n',
                         'This job seems to have been completed already.')
            return
        except KeyError:
            # Load adaptation agent from s3
            logger.print(f'Resuming from the checkpoint. step: {Progress.step}')
            adapt_agent = load_snapshot()

            logger.start('episode')
            logger.timer_cache['start'] = logger.timer_cache['episode'] - Progress.wall_time

    else:
        adapt_agent = load_adaptation_agent(**vars(Args), **vars(Agent), **vars(Adapt), dmcgen_args=DMCGENArgs)

    latent_replay = LatentReplay(
        buffer_dir=clean_buffer_dir, buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size, num_workers=4,
        store_states=False,
    )
    distr_replay = Replay(
        buffer_dir=distr_buffer_dir,
        buffer_size=Adapt.latent_buffer_size,
        batch_size=Args.batch_size,
        num_workers=4,
        store_states=False
    )
    clean_eval_env = get_env(Args.train_env, Args.frame_stack, Args.action_repeat, clean_buffer_seed)
    distr_eval_env = get_env(Args.eval_env, Args.frame_stack, Args.action_repeat, Args.seed,
                             distraction_config=Adapt.distraction_types)

    assert Adapt.adapt_offline
    adapt_offline(clean_eval_env, distr_eval_env, adapt_agent,
                  latent_replay, distr_replay)
