"""
This was originally based on one of the author's open source SAC implementations, 
which has been removed for anonymity, so this will not run.
However, we ended up using many of the pixel-specific design choices from SAC-AE 
(https://github.com/denisyarats/pytorch_sac_ae/blob/master/sac_ae.py), to ensure
it was an accurate reflection of the pixel-based mujoco literature.
"""

import argparse
import copy
import math
import os
from itertools import chain

import gym
import numpy as np
import tensorboardX
import torch
import torch.nn.functional as F
import tqdm
from anonymous import agents, envs, nets, replay, run, utils
from torch import nn
from torch import distributions as pyd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TanhTransform(pyd.transforms.Transform):
    domain = pyd.constraints.real
    codomain = pyd.constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    def __init__(self, cache_size=1):
        super().__init__(cache_size=cache_size)

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))


class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale

        self.base_dist = pyd.Normal(loc, scale)
        transforms = [TanhTransform()]
        super().__init__(self.base_dist, transforms)

    @property
    def mean(self):
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu


def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain("relu")
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)


class BigEncoder(nn.Module):
    def __init__(self, obs_shape, out_dim=50):
        super().__init__()
        channels = obs_shape[0]
        self.conv1 = nn.Conv2d(channels, 32, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1)

        output_height, output_width = utils.compute_conv_output(
            obs_shape[1:], kernel_size=(3, 3), stride=(2, 2)
        )
        for _ in range(3):
            output_height, output_width = utils.compute_conv_output(
                (output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
            )

        self.fc = nn.Linear(output_height * output_width * 32, out_dim)
        self.ln = nn.LayerNorm(out_dim)
        self.out_dim = out_dim

        self.apply(weight_init)

    def forward(self, obs):
        obs /= 255.0
        x = F.relu(self.conv1(obs))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.ln(x)
        state = torch.tanh(x)
        return state


class SmallEncoder(nn.Module):
    def __init__(self, obs_shape, out_dim=50):
        super().__init__()
        channels = obs_shape[0]
        self.conv1 = nn.Conv2d(channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        output_height, output_width = utils.compute_conv_output(
            obs_shape[1:], kernel_size=(8, 8), stride=(4, 4)
        )

        output_height, output_width = utils.compute_conv_output(
            (output_height, output_width), kernel_size=(4, 4), stride=(2, 2)
        )

        output_height, output_width = utils.compute_conv_output(
            (output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
        )

        self.fc = nn.Linear(output_height * output_width * 64, out_dim)
        self.out_dim = out_dim

        self.apply(weight_init)

    def forward(self, obs):
        obs /= 255.0
        x = F.relu(self.conv1(obs))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        state = self.fc(x)
        return state


class BigActor(nn.Module):
    def __init__(
        self,
        state_space_size,
        act_space_size,
        max_action,
        log_std_low=-10,
        log_std_high=2,
    ):
        super().__init__()
        self.fc1 = nn.Linear(state_space_size, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 2 * act_space_size)
        self.max_act = max_action
        self.log_std_low = log_std_low
        self.log_std_high = log_std_high
        self.apply(weight_init)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        out = self.fc3(x)
        # split output into mean log_std of action distribution
        mu, log_std = out.chunk(2, dim=1)
        log_std = torch.tanh(log_std)
        log_std = self.log_std_low + 0.5 * (self.log_std_high - self.log_std_low) * (
            log_std + 1
        )
        std = log_std.exp()
        dist = SquashedNormal(mu, std)
        return dist


class BigCritic(nn.Module):
    def __init__(self, state_space_size, act_space_size):
        super().__init__()
        self.fc1 = nn.Linear(state_space_size + act_space_size, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 1)

        self.apply(weight_init)

    def forward(self, state, action):
        x = F.relu(self.fc1(torch.cat((state, action), dim=1)))
        x = F.relu(self.fc2(x))
        out = self.fc3(x)
        return out


class PixelSACAgent(agents.SACAgent):
    def __init__(self, encoder, act_space_size, max_action):
        self.encoder = encoder
        self.actor = BigActor(encoder.out_dim, act_space_size, max_action)
        self.critic1 = BigCritic(encoder.out_dim, act_space_size)
        self.critic2 = BigCritic(encoder.out_dim, act_space_size)
        self.max_act = max_action

    def forward(self, obs):
        # eval forward (don't sample from distribution)
        obs = self.process_state(obs)
        with torch.no_grad():
            state_rep = self.encoder.forward(obs)
            act_dist = self.actor.forward(state_rep)
            act = act_dist.mean
        return self.process_act(act)

    def sample_action(self, obs):
        state_rep = self.encoder.forward(obs)
        act_dist = self.actor.forward(state_rep)
        act = act_dist.sample()
        return act

    def to(self, device):
        self.encoder = self.encoder.to(device)
        super().to(device)

    def eval(self):
        self.encoder.eval()
        super().eval()

    def train(self):
        self.encoder.train()
        super().train()

    def save(self, path):
        encoder_path = os.path.join(path, "encoder.pt")
        torch.save(self.encoder.state_dict(), encoder_path)
        super().save(path)

    def load(self, path):
        encoder_path = os.path.join(path, "encoder.pt")
        self.encoder.load_state_dict(torch.load(encoder_path))
        super().load(path)


def get_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        try:
            param = p.grad.data
        except AttributeError:
            continue
        else:
            param_norm = param.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1.0 / 2)
    return total_norm


def pixel_sac(
    agent,
    buffer,
    train_envs,
    test_env,
    augmenter,
    num_steps=250_000,
    transitions_per_step=1,
    max_episode_steps_start=1000,
    max_episode_steps_final=1000,
    max_episode_steps_anneal=1.,
    batch_size=256,
    mlp_tau=0.005,
    encoder_tau=0.01,
    actor_lr=1e-3,
    critic_lr=1e-3,
    encoder_lr=1e-3,
    alpha_lr=1e-4,
    gamma=0.99,
    eval_interval=10_000,
    test_eval_episodes=10,
    train_eval_episodes=10,
    warmup_steps=1000,
    actor_clip=None,
    critic_clip=None,
    actor_l2=0.0,
    critic_l2=0.0,
    encoder_l2=0.0,
    delay=2,
    save_interval=10_000,
    name="pixel_sac_run",
    render=False,
    save_to_disk=True,
    log_to_disk=True,
    debug_logs=False,
    verbosity=0,
    gradient_updates_per_step=1,
    init_alpha=0.1,
    feature_matching_imp=1.0,
    aug_mix=0.9,
    **kwargs,
):
    if save_to_disk or log_to_disk:
        save_dir = utils.make_process_dirs(name)
    # create tb writer, save hparams
    if log_to_disk:
        writer = tensorboardX.SummaryWriter(save_dir)
        writer.add_hparams(locals(), {})

    agent.to(device)
    agent.train()

    # initialize target networks (target actor isn't used in SAC)
    target_agent = copy.deepcopy(agent)
    target_agent.to(device)
    utils.hard_update(target_agent.critic1, agent.critic1)
    utils.hard_update(target_agent.critic2, agent.critic2)
    utils.hard_update(target_agent.encoder, agent.encoder)
    target_agent.train()

    # create network optimizers
    critic_optimizer = torch.optim.Adam(
        chain(agent.critic1.parameters(), agent.critic2.parameters(),),
        lr=critic_lr,
        weight_decay=critic_l2,
        betas=(0.9, 0.999),
    )
    encoder_optimizer = torch.optim.Adam(
        agent.encoder.parameters(),
        lr=encoder_lr,
        weight_decay=encoder_l2,
        betas=(0.9, 0.999),
    )
    actor_optimizer = torch.optim.Adam(
        agent.actor.parameters(),
        lr=actor_lr,
        weight_decay=actor_l2,
        betas=(0.9, 0.999),
    )

    # initialize learnable alpha param
    log_alpha = torch.Tensor([math.log(init_alpha)]).to(device)
    log_alpha.requires_grad = True
    log_alpha_optimizer = torch.optim.Adam([log_alpha], lr=alpha_lr, betas=(0.5, 0.999))

    target_entropy = -train_envs[0].action_space.shape[0]

    # warmup the replay buffer with randomly sampled actions
    num_actors = len(train_envs)
    for actor_idx, train_env in enumerate(train_envs):
        run.warmup_buffer(
            buffer, train_env, warmup_steps // num_actors, max_episode_steps_start
        )

    done = [True for actor in range(num_actors)]
    obs = [None for actor in range(num_actors)]
    steps_this_ep = [0 for actor in range(num_actors)]

    max_episode_steps_slope = (max_episode_steps_final - max_episode_steps_start) / (
        max_episode_steps_anneal * num_steps
    )
    current_max_episode_steps = lambda _step: int(
        min(
            max_episode_steps_final,
            max_episode_steps_slope * _step + max_episode_steps_start,
        )
    )

    steps_iter = range(num_steps)
    if verbosity:
        steps_iter = tqdm.tqdm(steps_iter)

    for step in steps_iter:
        for _ in range(transitions_per_step):
            # reset each env if necessary
            for actor_idx, train_env in enumerate(train_envs):
                if done[actor_idx]:
                    obs[actor_idx] = train_env.reset()
                    steps_this_ep[actor_idx] = 0
                    done[actor_idx] = False

            # batch the actions
            obs_tensor = torch.Tensor(obs).float().to(device)
            agent.eval()
            with torch.no_grad():
                action_tensor = agent.sample_action(obs_tensor)
            action = action_tensor.cpu().numpy()
            agent.train()

            # step each env and add experience to the buffer
            for actor_idx, train_env in enumerate(train_envs):
                next_obs, reward, done[actor_idx], info = train_env.step(
                    action[actor_idx]
                )
                # allow infinite bootstrapping
                if steps_this_ep[actor_idx] + 1 == current_max_episode_steps(step):
                    done[actor_idx] = False
                buffer.push(
                    obs[actor_idx],
                    action[actor_idx],
                    reward,
                    next_obs,
                    done[actor_idx],
                )
                obs[actor_idx] = next_obs
                steps_this_ep[actor_idx] += 1
                if steps_this_ep[actor_idx] >= current_max_episode_steps(step):
                    done[actor_idx] = True

        update_policy = step % delay == 0
        for _ in range(gradient_updates_per_step):
            learning_info = learn_from_pixels(
                buffer=buffer,
                target_agent=target_agent,
                agent=agent,
                actor_optimizer=actor_optimizer,
                critic_optimizer=critic_optimizer,
                encoder_optimizer=encoder_optimizer,
                log_alpha=log_alpha,
                log_alpha_optimizer=log_alpha_optimizer,
                target_entropy=target_entropy,
                batch_size=batch_size,
                gamma=gamma,
                critic_clip=critic_clip,
                actor_clip=actor_clip,
                update_policy=update_policy,
                augmenter=augmenter,
                feature_matching_imp=feature_matching_imp,
                aug_mix=aug_mix,
            )

            # move target model towards training model
            if update_policy:
                utils.soft_update(target_agent.critic1, agent.critic1, mlp_tau)
                utils.soft_update(target_agent.critic2, agent.critic2, mlp_tau)
                utils.soft_update(target_agent.encoder, agent.encoder, encoder_tau)

            if log_to_disk and debug_logs:
                for log_key, log_val in learning_info.items():
                    writer.add_scalar(
                        f"training/{log_key}", log_val, step * transitions_per_step
                    )

        if step % eval_interval == 0 or step == num_steps - 1:
            mean_test_return = run.evaluate_agent(
                agent, test_env, test_eval_episodes, max_episode_steps_final, render
            )
            mean_train_return = run.evaluate_agent(
                agent,
                train_envs[0],
                train_eval_episodes,
                max_episode_steps_final,
                render,
            )
            if log_to_disk:
                writer.add_scalar(
                    "performance/test_return",
                    mean_test_return,
                    step * transitions_per_step,
                )
                writer.add_scalar(
                    "performance/train_return",
                    mean_train_return,
                    step * transitions_per_step,
                )

        if step % save_interval == 0 and save_to_disk:
            agent.save(save_dir)

    if save_to_disk:
        agent.save(save_dir)
    return agent


def learn_from_pixels(
    buffer,
    target_agent,
    agent,
    actor_optimizer,
    critic_optimizer,
    encoder_optimizer,
    log_alpha_optimizer,
    target_entropy,
    log_alpha,
    augmenter,
    batch_size=256,
    gamma=0.99,
    critic_clip=None,
    actor_clip=None,
    update_policy=True,
    feature_matching_imp=1e-5,
    aug_mix=0.9,
):

    log_dict = {}

    per = isinstance(buffer, replay.PrioritizedReplayBuffer)
    if per:
        batch, imp_weights, priority_idxs = buffer.sample(batch_size)
        imp_weights = imp_weights.to(device)
    else:
        batch = buffer.sample(batch_size)

    # sample unaugmented transitions from the buffer
    og_obs_batch, action_batch, reward_batch, og_next_obs_batch, done_batch = batch
    og_obs_batch = og_obs_batch.to(device)
    og_next_obs_batch = og_next_obs_batch.to(device)
    # at this point, the obs batches are float32s [0., 255.] on the gpu

    # created an augmented version of each transition
    # the augmenter applies a random transition to each batch index,
    # but keep the random params consistent between obs and next_obs batches
    aug_obs_batch, aug_next_obs_batch = augmenter(og_obs_batch, og_next_obs_batch)

    # mix the augmented versions in with the standard
    # no need to shuffle because the replay buffer handles that
    aug_mix_idx = int(batch_size * aug_mix)
    obs_batch = og_obs_batch.clone()
    obs_batch[:aug_mix_idx] = aug_obs_batch[:aug_mix_idx]
    next_obs_batch = og_next_obs_batch.clone()
    next_obs_batch[:aug_mix_idx] = aug_next_obs_batch[:aug_mix_idx]

    action_batch = action_batch.to(device)
    reward_batch = reward_batch.to(device)
    done_batch = done_batch.to(device)

    alpha = torch.exp(log_alpha)

    with torch.no_grad():
        # create critic targets (clipped double Q learning)
        next_state_rep = target_agent.encoder(next_obs_batch)
        action_dist_s1 = agent.actor(next_state_rep)
        action_s1 = action_dist_s1.rsample()
        logp_a1 = action_dist_s1.log_prob(action_s1).sum(-1, keepdim=True)

        target_action_value_s1 = torch.min(
            target_agent.critic1(next_state_rep, action_s1),
            target_agent.critic2(next_state_rep, action_s1),
        )
        td_target = reward_batch + gamma * (1.0 - done_batch) * (
            target_action_value_s1 - (alpha * logp_a1)
        )

    # update critics with Bellman MSE
    state_rep = agent.encoder(obs_batch)
    agent_critic1_pred = agent.critic1(state_rep, action_batch)
    td_error1 = td_target - agent_critic1_pred
    if per:
        critic1_loss = (imp_weights * 0.5 * (td_error1 ** 2)).mean()
    else:
        critic1_loss = 0.5 * (td_error1 ** 2).mean()

    agent_critic2_pred = agent.critic2(state_rep, action_batch)
    td_error2 = td_target - agent_critic2_pred
    if per:
        critic2_loss = (imp_weights * 0.5 * (td_error2 ** 2)).mean()
    else:
        critic2_loss = 0.5 * (td_error2 ** 2).mean()

    # optional feature matching loss to make state_rep invariant to augs
    if feature_matching_imp > 0.0:
        aug_rep = agent.encoder(aug_obs_batch)
        with torch.no_grad():
            og_rep = agent.encoder(og_obs_batch)
        fm_loss = torch.norm(aug_rep - og_rep)
    else:
        fm_loss = 0.0

    critic_loss = critic1_loss + critic2_loss + feature_matching_imp * fm_loss

    critic_optimizer.zero_grad()
    encoder_optimizer.zero_grad()
    critic_loss.backward()
    if critic_clip:
        torch.nn.utils.clip_grad_norm_(
            chain(agent.critic1.parameters(), agent.critic2.parameters(),), critic_clip,
        )
    critic_optimizer.step()
    encoder_optimizer.step()

    if update_policy:
        # actor update
        dist = agent.actor(state_rep.detach())
        agent_actions = dist.rsample()
        logp_a = dist.log_prob(agent_actions).sum(-1, keepdim=True)

        actor_loss = -(
            torch.min(
                agent.critic1(state_rep.detach(), agent_actions),
                agent.critic2(state_rep.detach(), agent_actions),
            )
            - (alpha.detach() * logp_a)
        ).mean()

        actor_optimizer.zero_grad()
        actor_loss.backward()
        if actor_clip:
            torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), actor_clip)
        actor_optimizer.step()

        # alpha update
        alpha_loss = (-alpha * (logp_a + target_entropy).detach()).mean()
        log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        log_alpha_optimizer.step()

        log_dict.update(
            {
                "actor_loss": actor_loss,
                "alpha_loss": alpha_loss,
                "actor_grad_norm": get_grad_norm(agent.actor),
                "alpha": alpha,
                "mean_action": agent_actions.mean().item(),
            }
        )

    if per:
        new_priorities = (abs(td_error1) + 1e-5).cpu().data.squeeze(1).numpy()
        buffer.update_priorities(priority_idxs, new_priorities)

    log_dict.update(
        {
            "critic_loss": critic1_loss + critic2_loss,
            "critic1_grad_norm": get_grad_norm(agent.critic1),
            "critic2_grad_norm": get_grad_norm(agent.critic2),
            "encoder_grad_norm": get_grad_norm(agent.encoder),
            "mean_td_target": td_target.mean(),
            "feature_matching_loss": fm_loss,
        }
    )
    return log_dict


def add_args(parser):
    parser.add_argument(
        "--num_steps", type=int, default=250_000, help="Number of training steps.",
    )
    parser.add_argument(
        "--transitions_per_step",
        type=int,
        default=1,
        help="Env transitions per training step. Defaults to 1, but will need to \
        be set higher for repaly ratios < 1",
    )
    parser.add_argument(
        "--max_episode_steps_start",
        type=int,
        default=1000,
        help="Maximum steps per episode",
    )
    parser.add_argument(
        "--max_episode_steps_final",
        type=int,
        default=1000,
        help="Maximum steps per episode",
    )
    parser.add_argument(
        "--max_episode_steps_anneal",
        type=float,
        default=0.4,
        help="Maximum steps per episode",
    )
    parser.add_argument(
        "--batch_size", type=int, default=128, help="Training batch size"
    )
    parser.add_argument(
        "--mlp_tau",
        type=float,
        default=0.01,
        help="Determines how quickly the target agent's critic networks params catch up to the trained agent.",
    )
    parser.add_argument(
        "--encoder_tau",
        type=float,
        default=0.05,
        help="Determines how quickly the target agent's encoder network params catch up to the trained agent. This is typically set higher than mlp_tau because the encoder is used in both actor and critic updates.",
    )
    parser.add_argument(
        "--actor_lr", type=float, default=1e-3, help="Actor network learning rate",
    )
    parser.add_argument(
        "--critic_lr", type=float, default=1e-3, help="Critic networks' learning rate",
    )
    parser.add_argument(
        "--gamma", type=float, default=0.99, help="POMDP discount factor",
    )
    parser.add_argument(
        "--init_alpha",
        type=float,
        default=0.1,
        help="Initial entropy regularization coefficeint.",
    )
    parser.add_argument(
        "--alpha_lr",
        type=float,
        default=1e-4,
        help="Alpha (entropy regularization coefficeint) learning rate",
    )
    parser.add_argument(
        "--buffer_size",
        type=int,
        default=100_000,
        help="Replay buffer maximum capacity. Note that image observations can take up a lot of memory, especially when using frame stacking. The buffer allocates a large tensor of zeros to fail fast if it will not have enough memory to complete the training run.",
    )
    parser.add_argument(
        "--eval_interval",
        type=int,
        default=10_000,
        help="How often to test the agent without exploration (in training steps)",
    )
    parser.add_argument(
        "--test_eval_episodes",
        type=int,
        default=10,
        help="How many episodes to run for when evaluating on the testing set",
    )
    parser.add_argument(
        "--train_eval_episodes",
        type=int,
        default=10,
        help="How many episodes to run for when evaluating on the training set",
    )
    parser.add_argument(
        "--warmup_steps",
        type=int,
        default=1000,
        help="Number of uniform random actions to take at the beginning of training",
    )
    parser.add_argument(
        "--render",
        action="store_true",
        help="Flag to enable env rendering during training",
    )
    parser.add_argument(
        "--actor_clip",
        type=float,
        default=None,
        help="Gradient clipping for actor updates",
    )
    parser.add_argument(
        "--critic_clip",
        type=float,
        default=None,
        help="Gradient clipping for critic updates",
    )
    parser.add_argument(
        "--name",
        type=str,
        default="pixel_sac_run",
        help="Dir name for saves, (look in ./saves/{name})",
    )
    parser.add_argument(
        "--actor_l2",
        type=float,
        default=0.0,
        help="L2 regularization coeff for actor network",
    )
    parser.add_argument(
        "--critic_l2",
        type=float,
        default=0.0,
        help="L2 regularization coeff for critic network",
    )
    parser.add_argument(
        "--delay",
        type=int,
        default=2,
        help="How many steps to go between actor and target agent updates",
    )
    parser.add_argument(
        "--save_interval",
        type=int,
        default=10_000,
        help="How many steps to go between saving the agent params to disk",
    )
    parser.add_argument(
        "--verbosity",
        type=int,
        default=1,
        help="Verbosity > 0 displays a progress bar during training",
    )
    parser.add_argument(
        "--gradient_updates_per_step",
        type=int,
        default=1,
        help="How many gradient updates to make per training step",
    )
    parser.add_argument(
        "--prioritized_replay",
        action="store_true",
        help="Flag that enables use of prioritized experience replay",
    )
    parser.add_argument(
        "--skip_save_to_disk",
        action="store_true",
        help="Flag to skip saving agent params to disk during training",
    )
    parser.add_argument(
        "--encoderCls",
        type=str,
        default="BigEncoder",
        help="Class name of the encoder architecture to use",
    )
    parser.add_argument(
        "--skip_log_to_disk",
        action="store_true",
        help="Flag to skip saving agent performance logs to disk during training",
    )
    parser.add_argument(
        "--feature_matching_imp",
        type=float,
        default=0.001,
        help="Coefficient for feature matching loss",
    )
    parser.add_argument(
        "--encoder_lr",
        type=float,
        default=1e-3,
        help="Learning rate for the encoder network",
    )
    parser.add_argument(
        "--encoder_l2",
        type=float,
        default=0.0,
        help="Weight decay coefficient for pixel encoder network",
    )
    parser.add_argument(
        "--debug_logs",
        action="store_true",
        help="Flag that logs things like loss values and grad norms to tensorboard, for debugging",
    )
    parser.add_argument(
        "--aug_mix",
        type=float,
        default=1.0,
        help="Fraction of each update batch that is made up of augmented samples",
    )
