import itertools
import lzma
import os
import pandas as pd
import pickle

from collections import defaultdict
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import (
    Dict,
    List,
    Optional,
)

import gym
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import torch
import torch.nn.functional as F

from torch.distributions.normal import Normal
from torch.optim import Adam

import transfer.sac.core as core
import wandb

from input_args import sac_parse_args
from transfer.envs.metaworld import (
    get_mt_env,
    get_single_env,
    get_stitched_env,
)
from transfer.sac.cka import feature_space_linear_cka
from transfer.sac.ewc import EWCHelper
from transfer.sac.logx import (
    EpochLogger,
    setup_logger_kwargs,
)
from transfer.sac.utils import (
    append_new_heads,
    custom_leaky_relu,
    load_tf_weights,
    remove_heads,
)
from transfer.utils.utils import set_seed
from transfer.envs.utils.wrappers import META_WORLD_TIME_HORIZON

CW10 = [
    "hammer-v2",
    "push-wall-v2",
    "faucet-close-v2",
    "push-back-v2",
    "stick-pull-v2",
    "handle-press-side-v2",
    "push-v2",
    "shelf-place-v2",
    "window-close-v2",
    "peg-unplug-side-v2",
]

CW2 = CW10[:2]
CW5 = CW10[:5]

SEQUENCES = {
    "CW2": CW2,
    "CW5": CW5,
    "CW10": CW10,
}


def l2_dist(first_ac: Dict[str, torch.tensor], second_ac: Dict[str, torch.tensor]):
    results = {}
    for key in first_ac.keys():
        results[key] = (first_ac[key] - second_ac[key]).pow(2).mean()
    return results


def l2_act_diff(first_ac: Dict[str, torch.tensor], second_ac: Dict[str, torch.tensor]):
    results = {}
    for key in first_ac.keys():
        results[key] = (first_ac[key] - second_ac[key]).pow(2).mean()
    return results


def cka_diff(first_activations, second_activations):
    cka_vals = {}
    for key, item in first_activations.items():
        cka_vals[key] = feature_space_linear_cka(first_activations[key].numpy(), second_activations[key].numpy())
    return cka_vals


@torch.no_grad()
def gather_activations(ac, data, batch_size=200):
    activations = defaultdict(list)

    def save_activation(name, mod, inp, out):
        activations[name].append(out.cpu())

    handles = []
    for name, m in ac.named_modules():
        if type(m) == torch.nn.Linear:
            handles += [m.register_forward_hook(partial(save_activation, name))]

    for idx in range(0, len(data["obs"]), batch_size):
        obs = torch.as_tensor(data["obs"][idx : idx + batch_size], dtype=torch.float32)
        actions = torch.as_tensor(data["actions"][idx : idx + batch_size], dtype=torch.float32)

        ac.pi(obs)
        ac.q1(obs, actions)
        ac.q2(obs, actions)

    for key, item in activations.items():
        activations[key] = torch.cat(item, dim=0)

    for handle in handles:
        handle.remove()

    return activations

@torch.no_grad()
def measure_likelihood(ac, data, batch_size):
    act_dim = data.actions_buf.shape[-1]
    log_likelihoods = []
    for idx in range(0, len(data.obs_buf), batch_size):
        obs = torch.as_tensor(data.obs_buf[idx:idx + batch_size], dtype=torch.float32)

        target_dist = torch.as_tensor(data.actor_dist_buf[idx:idx + batch_size], dtype=torch.float32)
        target_mu, _ = target_dist.split(act_dim, dim=-1)
        _, _, mu, log_std = ac.pi(obs, return_dist=True)

        pi_distribution = Normal(mu, log_std.exp())
        logp_pi = pi_distribution.log_prob(target_mu).sum(axis=-1)
        logp_pi -= (2 * (np.log(2) - target_mu - F.softplus(-2 * target_mu))).sum(axis=1)
        log_likelihoods += [logp_pi]
    # [memory_size]
    log_likelihoods = torch.cat(log_likelihoods, dim=0)

    df = pd.DataFrame(data={"task_id": data.task_id_buf, 
                            "timestep": data.timestep_buf,
                            "likelihood": log_likelihoods,
                            "dones": data.done_buf})
    df_per_timestep = df.groupby(["task_id", "timestep"]).mean().reset_index()
    df_per_task = df.groupby(["task_id"]).mean().reset_index()

    results_dict = {}
    for task_id in df_per_task["task_id"].unique():
        key = f"log_likelihood/{task_id}"

        task_df = df_per_task[df_per_task["task_id"] == task_id]
        results_dict[f"{key}/avg"] = task_df["likelihood"].mean()

        task_df = df_per_timestep[df_per_timestep["task_id"] == task_id].copy().sort_values(by=["timestep"])

        table = wandb.Table(data=task_df)
        plot = wandb.plot.line(table, "timestep", "likelihood",
                               title=f"Log likelihoods per timestep for task {task_id}")
        results_dict[f"{key}/per_timestep"] = plot

        # [batch_size, obs_dim]
        task_indices = data.task_id_buf == task_id
        task_obs = data.obs_buf[task_indices]
        task_dones = data.done_buf[task_indices]
        task_likelihoods = log_likelihoods[task_indices]

        # PCA, https://stats.stackexchange.com/questions/235882/pca-in-numpy-and-sklearn-produces-different-results
        centered_data = task_obs - task_obs.mean(0)
        obs_cov = np.cov(centered_data.T)
        ev, eig = np.linalg.eig(obs_cov)
        projected_obs = eig.dot(centered_data.T).T


        task_starts = [0] + (task_dones.nonzero()[0] + 1).tolist()

        # Skip the last one as there is no new trajectory at the end
        fig = plt.figure()

        cmap = mpl.cm.get_cmap('viridis')
        # norm = mpl.colors.Normalize(vmin=task_likelihoods.min(), vmax=task_likelihoods.max())
        vmin = np.maximum(task_likelihoods.min(), -50)
        norm = mpl.colors.Normalize(vmin=vmin, vmax=5)
        for idx in range(len(task_starts) - 1):
            task_start = task_starts[idx]
            task_end = task_starts[idx + 1]

            traj_obs = projected_obs[task_start:task_end]
            traj_likelihoods = task_likelihoods[task_start:task_end]

            plt.plot(traj_obs[:, 0], traj_obs[:, 1], c='k', linewidth=0.2, zorder=-5)
            plt.scatter(traj_obs[0, 0], traj_obs[0, 1], s=30., c=traj_likelihoods[0], cmap=cmap, norm=norm,
                        marker="*", zorder=10)
            plt.scatter(traj_obs[1:, 0], traj_obs[1:, 1], s=10., c=traj_likelihoods[1:],
                        cmap=cmap, norm=norm, zorder=10)
        plt.colorbar()
        plt.close(fig)
        results_dict[f"{key}/pca"] = wandb.Image(fig)


    wandb.log(results_dict, commit=False)

    return df




class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(
            obs=self.obs_buf[idxs],
            obs2=self.obs2_buf[idxs],
            act=self.act_buf[idxs],
            rew=self.rew_buf[idxs],
            done=self.done_buf[idxs],
        )
        return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in batch.items()}


class EpisodicMemory:
    """Buffer which does not support overwriting old samples."""

    def __init__(self, obs_dim: int, act_dim: int, size: int,
                 save_targets: bool = False, one_hot_len: int = None) -> None:
        self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.next_obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.actions_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
        self.rewards_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.save_targets = save_targets
        self.one_hot_len = one_hot_len

        self.task_id_buf = np.zeros([size], dtype=int)
        self.timestep_buf = np.zeros([size], dtype=int)
        if self.save_targets:
            self.actor_dist_buf = np.zeros([size, act_dim * 2], dtype=np.float32)
            self.critic1_pred_buf = np.zeros([size], dtype=np.float32)
            self.critic2_pred_buf = np.zeros([size], dtype=np.float32)
        self.size, self.max_size = 0, size

    def store_multiple(
        self,
        obs: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        next_obs: np.ndarray,
        done: np.ndarray,
        actor_dists: Optional[np.ndarray] = None,
        critic1_preds: Optional[np.ndarray] = None,
        critic2_preds: Optional[np.ndarray] = None,
    ) -> None:
        assert len(obs) == len(actions) == len(rewards) == len(next_obs) == len(done)
        assert self.size + len(obs) <= self.max_size

        range_start = self.size
        range_end = self.size + len(obs)
        self.obs_buf[range_start:range_end] = obs
        self.next_obs_buf[range_start:range_end] = next_obs
        self.actions_buf[range_start:range_end] = actions
        self.rewards_buf[range_start:range_end] = rewards
        self.done_buf[range_start:range_end] = done

        if self.one_hot_len is not None:
            task_ids = np.argmax(obs[:, -self.one_hot_len:], axis=1)
            timestep = np.around(obs[:, -self.one_hot_len - 1] * META_WORLD_TIME_HORIZON)

            self.task_id_buf[range_start:range_end] = task_ids
            self.timestep_buf[range_start:range_end] = timestep

        if self.save_targets:
            assert (actor_dists is not None) and (critic1_preds is not None) and (critic2_preds is not None)
            self.actor_dist_buf[range_start:range_end] = actor_dists
            self.critic1_pred_buf[range_start:range_end] = critic1_preds
            self.critic2_pred_buf[range_start:range_end] = critic2_preds
        self.size = self.size + len(obs)

    def sample_batch(self, batch_size: int) -> Dict[str, torch.Tensor]:
        batch_size = min(batch_size, self.size)
        idxs = np.random.choice(self.size, size=batch_size, replace=False)
        return self.get_by_indices(idxs)

    def get_by_indices(self, idxs) -> Dict[str, torch.Tensor]:
        batch_dict = dict(
            obs=torch.as_tensor(self.obs_buf[idxs], dtype=torch.float32),
            next_obs=torch.as_tensor(self.next_obs_buf[idxs], dtype=torch.float32),
            actions=torch.as_tensor(self.actions_buf[idxs], dtype=torch.float32),
            rewards=torch.as_tensor(self.rewards_buf[idxs], dtype=torch.float32),
            done=torch.as_tensor(self.done_buf[idxs], dtype=torch.float32),
        )

        if self.save_targets:
            batch_dict["actor_dists"] = torch.as_tensor(self.actor_dist_buf[idxs], dtype=torch.float32)
            batch_dict["critic1_preds"] = torch.as_tensor(self.critic1_pred_buf[idxs], dtype=torch.float32)
            batch_dict["critic2_preds"] = torch.as_tensor(self.critic2_pred_buf[idxs], dtype=torch.float32)

        return batch_dict


def sac(
    env: gym.Env,
    test_envs: List[gym.Env],
    ac: core.MLPActorCritic,
    ac_targ: core.MLPActorCritic,
    seed: int = 0,
    steps: int = 1_000_000,
    log_every: int = 20_000,
    replay_size: int = 1_000_000,
    train_alpha: bool = True,
    gamma=0.99,
    polyak=0.995,
    lr=1e-3,
    batch_size=128,
    start_steps=10000,
    update_after=1000,
    update_every=50,
    num_test_eps_stochastic=10,
    num_test_eps_deterministic=10,
    num_render_eps_stochastic=0,
    num_render_eps_deterministic=0,
    max_ep_len=1000,
    logger_kwargs=dict(),
    save_freq=5,
    bootstrap_on_time_limit=False,
    num_tasks=1,
    one_hot_len=0,
    done_on_transition=False,
    checkpoint=None,
    memory=None,
    memory_envs=None,
    apply_kl_loss=False,
    reg_method=None,
    actor_memory_weight=1.0,
    critic_memory_weight=1.0,
    reset_weights_after_memory=False,
    save_buffer=False,
):
    """
    Soft Actor-Critic (SAC)


    Args:
        env: Training env

        test_envs: A list of testing envs on which we'll evaluate throughout the training.

        ac: A PyTorch Module with an ``act``
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of
            observations as inputs, and ``q1`` and ``q2`` should accept a batch
            of observations and a batch of actions as inputs. When called,
            ``act``, ``q1``, and ``q2`` should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each
                                           | observation.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

            Calling ``pi`` should return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                           | actions in ``a``. Importantly: gradients
                                           | should be able to flow back into ``a``.
            ===========  ================  ======================================

        ac_targ: Target network for critics.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long
            you wait between updates, the ratio of env steps to gradient steps
            is locked to 1.

        num_test_eps_deterministic (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        num_test_eps_stochastic (int): Number of episodes to test the stochastic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)

    set_seed(seed)

    # deepcopy(ac)

    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    env.action_space.high[0]

    # Create actor-critic module and target networks

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(list(ac.q1.parameters()), list(ac.q2.parameters()))

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log("\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n" % var_counts)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data["obs"], data["act"], data["rew"], data["obs2"], data["done"]
        alpha, _ = ac.get_alpha(o)

        q1 = ac.q1(o, a)
        q2 = ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = ac.pi(o2)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup) ** 2).mean()
        loss_q2 = ((q2 - backup) ** 2).mean()
        loss_q = loss_q1 + loss_q2

        return loss_q

    def compute_kl_critic(memory_data):
        critic1_target = memory_data["critic1_preds"]
        critic1_pred = ac.q1(memory_data["obs"], memory_data["actions"])
        critic1_loss = (critic1_pred - critic1_target) ** 2

        critic2_target = memory_data["critic2_preds"]
        critic2_pred = ac.q2(memory_data["obs"], memory_data["actions"])
        critic2_loss = (critic2_pred - critic2_target) ** 2

        return critic1_loss.mean() + critic2_loss.mean()

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data["obs"]
        alpha, _ = ac.get_alpha(o)

        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (alpha * logp_pi - q_pi).mean()

        return loss_pi

    def compute_kl_pi(memory_data):
        eps = 1e-6
        o = memory_data["obs"]
        _, _, mu, log_std = ac.pi(o, return_dist=True)
        target_mu, target_log_std = memory_data["actor_dists"].split(act_dim, dim=-1)

        var = (torch.exp(log_std) + eps) ** 2
        target_var = (torch.exp(target_log_std) + eps) ** 2

        log_std_term = log_std - target_log_std
        mu_term = (target_var + (target_mu - mu) ** 2) / (2 * var)

        return (log_std_term + mu_term - 0.5).sum(-1).mean()

    def compute_loss_alpha(data):
        o = data["obs"]
        _, log_alpha = ac.get_alpha(o)

        pi, logp_pi = ac.pi(o)
        # Entropy-regularized policy loss
        loss_alpha = -(log_alpha * (logp_pi + ac.target_entropy).detach()).mean()

        return loss_alpha

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q_optimizer = Adam(q_params, lr=lr)
    if train_alpha:
        alpha_optimizer = Adam([ac.all_log_alpha], lr=lr)


    # if checkpoint is not None:
    #     pi_optimizer.load_state_dict(checkpoint["pi_optimizer"])
    #     q_optimizer.load_state_dict(checkpoint["q_optimizer"])
    #     if train_alpha:
    #         alpha_optimizer.load_state_dict(checkpoint["alpha_optimizer"])

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update(data, memory_data=None):
        if reg_method == "ewc" or reg_method == "l2":
            reg_loss = reg_helper.regularization_loss()
        else:
            reg_loss = 0.0

        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q = compute_loss_q(data)
        if memory_data is not None:
            loss_q = (loss_q + critic_memory_weight * compute_kl_critic(memory_data)) / (1.0 + critic_memory_weight)
        loss_q = loss_q + reg_loss

        loss_q.backward()
        q_optimizer.step()

        # Record things
        logger.store({"train/loss_q": loss_q.item()})

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi = compute_loss_pi(data)
        if memory_data is not None:
            loss_pi = (loss_pi + actor_memory_weight * compute_kl_pi(memory_data)) / (1.0 + actor_memory_weight)
        if reg_method == "ewc" or reg_method == "l2":
            reg_loss = reg_helper.regularization_loss()
        else:
            reg_loss = 0.0
        loss_pi = loss_pi + reg_loss
        loss_pi.backward()
        pi_optimizer.step()

        if train_alpha:
            alpha_optimizer.zero_grad()
            loss_alpha = compute_loss_alpha(data)
            loss_alpha.backward()
            alpha_optimizer.step()
            logger.store({"train/loss_alpha": loss_alpha.item()})

        # print("Losses", loss_pi.item(), loss_q.item())

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Record things
        logger.store({"train/loss_pi": loss_pi.item()})

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32).unsqueeze(0), deterministic).squeeze(0)

    def get_q_preds(o, a):
        o = torch.as_tensor(o, dtype=torch.float32).unsqueeze(0)
        a = torch.as_tensor(a, dtype=torch.float32).unsqueeze(0)
        return ac.q1(o, a).squeeze(0), ac.q2(o, a).squeeze(0)

    def render_agent(deterministic, num_episodes, savepath) -> None:
        mode = "deterministic" if deterministic else "stochastic"
        for seq_idx, test_env in enumerate(test_envs):
            key_prefix = f"test_{mode}/{seq_idx}_{test_env.name}"
            for j in range(num_episodes):
                imgs = []
                obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0

                while not done:
                    data = view(test_env.viewer)
                    imgs.append(data)

                    obs, reward, done, info = test_env.step(get_action(obs, deterministic))
                    episode_return += reward
                    episode_len += 1

                render_output_path = savepath / key_prefix / f"rollout_{j}.mp4"
                save_video(render_output_path, imgs, fps=24)

            test_env.pop_successes()

    def test_agent(deterministic, num_episodes) -> None:
        avg_success = []
        mode = "deterministic" if deterministic else "stochastic"
        for seq_idx, test_env in enumerate(test_envs):
            key_prefix = f"test_{mode}/{seq_idx}_{test_env.name}/"

            for j in range(num_episodes):
                obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0
                q1_avg = 0.0
                q2_avg = 0.0
                while not done:
                    action = get_action(obs, deterministic)
                    obs, reward, done, _ = test_env.step(action)
                    q1_pred, q2_pred = get_q_preds(obs, action)
                    q1_avg += q1_pred.item()
                    q2_avg += q2_pred.item()

                    episode_return += reward
                    episode_len += 1
                q1_avg /= episode_len
                q2_avg /= episode_len
                logger.store(
                    {
                        key_prefix + "return": episode_return,
                        key_prefix + "ep_length": episode_len,
                        key_prefix + "q1_val": q1_avg,
                        key_prefix + "q2_val": q2_avg,
                    }
                )

            logger.log_tabular(key_prefix + "return", with_min_and_max=True)
            logger.log_tabular(key_prefix + "ep_length", average_only=True)
            logger.log_tabular(key_prefix + "q1_val", average_only=True)
            logger.log_tabular(key_prefix + "q2_val", average_only=True)
            env_success = test_env.pop_successes()
            avg_success += env_success
            logger.log_tabular(key_prefix + "success", np.mean(env_success))
        key = f"test/{mode}/average_success"
        logger.log_tabular(key, np.mean(avg_success))

    def gather_buffer(buffer_size):
        tmp_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=buffer_size)

        memory_per_env = buffer_size // len(memory_envs)
        for memory_idx, memory_env in enumerate(memory_envs):
            print(f"Gathering data for memory env {memory_idx}")
            samples_num = 0
            obs, done = memory_env.reset(), False
            for samples_num in range(memory_per_env):
                action = get_action(obs, False)
                next_obs, reward, done, info = memory_env.step(action)
                tmp_buffer.store(obs, action, reward, next_obs, done)
                obs = next_obs

                if done:
                    obs = memory_env.reset()
            print(f"Memory env {memory_idx} success rate: {np.mean(memory_env.pop_successes())}")

        episodic_memory = EpisodicMemory(obs_dim, act_dim, buffer_size,
                                         save_targets=True, one_hot_len=one_hot_len)
        for idx in range(0, buffer_size, batch_size):
            obs = torch.as_tensor(tmp_buffer.obs_buf[idx:idx + batch_size], dtype=torch.float32)
            next_obs = torch.as_tensor(tmp_buffer.obs2_buf[idx:idx + batch_size], dtype=torch.float32)
            actions = torch.as_tensor(tmp_buffer.act_buf[idx:idx + batch_size], dtype=torch.float32)
            rewards = torch.as_tensor(tmp_buffer.rew_buf[idx:idx + batch_size], dtype=torch.float32)
            dones = torch.as_tensor(tmp_buffer.done_buf[idx:idx + batch_size], dtype=torch.float32)

            _, _, mu, log_std = ac.pi(obs, return_dist=True)
            actor_dists = torch.cat([mu, log_std], dim=-1).detach().numpy()
            critic1_pred = ac.q1(obs, actions).detach().numpy()
            critic2_pred = ac.q2(obs, actions).detach().numpy()

            episodic_memory.store_multiple(
                obs, actions, rewards, next_obs, dones, actor_dists, critic1_pred, critic2_pred
            )
        return episodic_memory

    # Gather samples for the episodic memory
    if memory is not None:
        episodic_memory = gather_buffer(memory)

        validation_memory = gather_buffer(1000)
        with lzma.open("validation_memory.pickle", "wb") as handle:
            pickle.dump(validation_memory, handle, protocol=4)
        activations_before_train = gather_activations(
            ac, {"obs": validation_memory.obs_buf, "actions": validation_memory.actions_buf}
        )

        if reset_weights_after_memory:
            ac.reset_weights()
            ac_targ = deepcopy(ac)

    if reg_method == "ewc" or reg_method == "l2":
        reg_helper = EWCHelper(ac, actor_memory_weight, critic_memory_weight, l2_mode=(reg_method == "l2"))
        num_batches = 10
        memory_batch_size = memory // num_batches

        all_importances = []
        for batch_idx in range(num_batches):
            print(f"Processing batch number {batch_idx}")
            start_idx = batch_idx * num_batches
            end_idx = start_idx + memory_batch_size

            memory_batch = episodic_memory.get_by_indices(np.arange(start_idx, end_idx))
            grads = reg_helper.get_grads(**memory_batch) 
            all_importances += [reg_helper.compute_importance(grads)]
        del grads
        del memory_batch

        merged_importances = {}
        for key in all_importances[0].keys():
            stacked_importances = torch.stack(
                    list(all_importances[idx][key] for idx in range(len(all_importances))),
                    dim=0)
            merged_importances[key] = stacked_importances.mean(0).detach()
        del all_importances
        reg_helper.importance = merged_importances

    # Prepare for interaction with environment
    o, ep_ret, ep_len = env.reset(), 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for timestep in range(steps):
        if timestep % log_every == 0:
            epoch = timestep // log_every


            # Save model
            if (epoch % save_freq == 0) or (timestep == steps - 1):
                logger.save_state({"env": env}, None)

            # Test the performance of the deterministic version of the agent.

            test_agent(True, num_test_eps_deterministic)
            test_agent(False, num_test_eps_stochastic)
            render_agent(True, num_render_eps_deterministic, Path(logger_kwargs["output_dir"]))
            render_agent(False, num_render_eps_stochastic, Path(logger_kwargs["output_dir"]))


            # Log every 5-th time
            if memory is not None and timestep % (log_every * 5) == 0:

                likelihood_df = measure_likelihood(ac, validation_memory, batch_size)
                likelihood_df.to_csv(f"likelihoods_{timestep}.csv")

                activations_trained = gather_activations(
                    ac, {"obs": validation_memory.obs_buf, "actions": validation_memory.actions_buf}
                )
                cka_diff_vals = cka_diff(activations_before_train, activations_trained)
                l2_diff_vals = l2_act_diff(activations_before_train, activations_trained)

                del activations_trained

                for key, val in l2_diff_vals.items():
                    logger.log_tabular(f"l2diff/{key}", val)
                for key, val in cka_diff_vals.items():
                    logger.log_tabular(f"cka/{key}", val)

            # Log info about epoch
            logger.log_tabular("epoch", epoch)
            logger.log_tabular("train/return", with_min_and_max=True)
            logger.log_tabular("train/ep_length", average_only=True)
            logger.log_tabular("timestep", timestep)
            # logger.log_tabular("train/q1_vals", with_min_and_max=True)
            # logger.log_tabular("train/q2_vals", with_min_and_max=True)
            # logger.log_tabular("train/log_pi", with_min_and_max=True)
            logger.log_tabular("train/loss_pi", average_only=True)
            logger.log_tabular("train/loss_q", average_only=True)

            avg_success = np.mean(env.pop_successes())
            logger.log_tabular("train/success", avg_success)
            logger.dump_tabular()

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if timestep >= start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, info = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)

        done_to_store = d
        if (ep_len == max_ep_len or info.get("TimeLimit.truncated")) and bootstrap_on_time_limit:
            done_to_store = False
        # Overrides the previous one on purpose.
        if done_on_transition and info.get("transition"):
            done_to_store = True

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, done_to_store)

        # task_id = o[-one_hot_len:].argmax()

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            # print(f"Done. Ret: {ep_ret}, Len: {ep_len}, task: {task_id}")
            logger.store({"train/return": ep_ret, "train/ep_length": ep_len})
            o, ep_ret, ep_len = env.reset(), 0, 0

        # Update handling
        if timestep >= update_after and timestep % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                if apply_kl_loss:
                    memory_batch = episodic_memory.sample_batch(batch_size)
                else:
                    memory_batch = None
                update(data=batch, memory_data=memory_batch)

        # End of epoch handling

    # Save everything at the end
    all_state_dicts = {
        "ac": ac.state_dict(),
        "ac_targ": ac_targ.state_dict(),
        "pi_optimizer": pi_optimizer.state_dict(),
        "q_optimizer": q_optimizer.state_dict(),
    }
    if train_alpha:
        all_state_dicts["alpha_optimizer"] = alpha_optimizer.state_dict()

    torch.save(all_state_dicts, "checkpoint.pt")
    if save_buffer:
        with lzma.open("buffer.pickle", "wb") as handle:
            pickle.dump(replay_buffer, handle, protocol=4)


def run_sac(variant):
    if variant["num_render_eps_stochastic"] > 0 or variant["num_render_eps_deterministic"] > 0:
        pass

    if variant["scenario"] == "stitched_env":
        if variant["sequence_name"] is not None:
            tasks = SEQUENCES[variant["sequence_name"]]
        else:
            tasks = variant["tasks"]

        one_hot_len = variant["one_hot_len"] or len(tasks)

        if variant["ordering"] is None:
            ordering = list(range(len(tasks)))
        else:
            ordering = variant["ordering"]

        test_envs = [
            get_stitched_env(
                tasks,
                ordering=ordering,
                verbose=True,
                accumulate_rewards=variant["accumulate_rewards"],
                continue_from_pos=variant["continue_from_pos"],
                append_timestep=variant["append_timestep"],
                success_as_reward=variant["sparse_rewards"],
                done_on_success=variant["finish_on_done"],
                reward_early_finish=variant["reward_early_finish"],
            )
        ]
        single_task_test_envs = [
            get_single_env(
                env,
                add_one_hot=True,
                one_hot_idx=i,
                one_hot_len=one_hot_len,
                append_timestep=variant["append_timestep"],
                success_as_reward=variant["sparse_rewards"],
                done_on_success=variant["finish_on_done"],
                reward_early_finish=variant["reward_early_finish"],
            )
            for i, env in enumerate(tasks)
        ]

        if variant["memory"] is not None:
            memory_envs = [single_task_test_envs[idx] for idx in variant["tasks_to_memorize"]]
        else:
            memory_envs = None

        test_envs += single_task_test_envs
        train_env = get_stitched_env(
            tasks,
            ordering=ordering,
            accumulate_rewards=variant["accumulate_rewards"],
            continue_from_pos=variant["continue_from_pos"],
            append_timestep=variant["append_timestep"],
            success_as_reward=variant["sparse_rewards"],
            done_on_success=variant["finish_on_done"],
            reward_early_finish=variant["reward_early_finish"],
        )
        num_heads = len(tasks)

    elif variant["scenario"] == "multi":
        memory_envs = None
        one_hot_len = variant["one_hot_len"] or len(CW10)
        if variant["sequence_name"] is not None:
            tasks = SEQUENCES[variant["sequence_name"]]
        else:
            tasks = variant["tasks"]

        print(one_hot_len, tasks)

        test_envs = [
            get_single_env(
                env,
                add_one_hot=True,
                one_hot_idx=i,
                one_hot_len=one_hot_len,
                append_timestep=variant["append_timestep"],
                success_as_reward=variant["sparse_rewards"],
                done_on_success=variant["finish_on_done"],
                reward_early_finish=variant["reward_early_finish"],
            )
            for i, env in enumerate(tasks)
        ]

        train_env = get_mt_env(
            tasks,
            steps_per_task=variant["steps"] * 10,
            success_as_reward=variant["sparse_rewards"],
            append_timestep=variant["append_timestep"],
            done_on_success=variant["finish_on_done"],
            reward_early_finish=variant["reward_early_finish"],
            one_hot_len=one_hot_len,
        )
        num_heads = one_hot_len

    elif variant["scenario"] == "single":
        task = CW10[variant["task_idx"]]

        one_hot_len = variant["one_hot_len"] or len(CW10)

        train_env = get_single_env(
            task,
            add_one_hot=True,
            one_hot_idx=variant["task_idx"],
            one_hot_len=one_hot_len,
            success_as_reward=variant["sparse_rewards"],
            append_timestep=variant["append_timestep"],
            done_on_success=variant["finish_on_done"],
            reward_early_finish=variant["reward_early_finish"],
        )
        test_envs = [
            get_single_env(
                task,
                add_one_hot=True,
                one_hot_idx=variant["task_idx"],
                one_hot_len=one_hot_len,
                success_as_reward=variant["sparse_rewards"],
                append_timestep=variant["append_timestep"],
                done_on_success=variant["finish_on_done"],
                reward_early_finish=variant["reward_early_finish"],
            )
        ]
        num_heads = len(CW10)

    hidden_sizes = list(variant["hidden_dim"] for _ in range(variant["num_layers"]))
    ac = core.MLPActorCritic(
        train_env.observation_space,
        train_env.action_space,
        hidden_sizes=hidden_sizes,
        activation=custom_leaky_relu,
        hide_task_id=True,
        num_heads=num_heads,
        use_layer_norm=variant["use_layer_norm"],
        one_hot_len=one_hot_len,
        alpha_init=variant["alpha_init"],
        target_output_std=variant["target_output_std"],
    )

    # Load pretrained weights, if available.
    weights_path = variant["init_weights_path"]

    if weights_path and os.path.isdir(weights_path):
        # If the weights path is a dir, then those are Tensorflow weights
        ac, ac_targ = load_tf_weights(ac, weights_path, use_layer_norm=variant["use_layer_norm"])
        group_name = f"sac_{weights_path.split('/')[-1]}"
    elif weights_path:
        # Otherwise these are PyTorch weights
        checkpoint = torch.load(weights_path)

        critic_keys = [key.split(".")[2] for key in checkpoint["ac"].keys() if "q1" in key and ".weight" in key]
        last_head = int(critic_keys[-1])
        num_heads_ckpt = checkpoint["ac"][f"q1.q.{last_head}.weight"].shape[0]
        if num_heads != num_heads_ckpt:
            if num_heads < num_heads_ckpt:
                print("Removing heads!")
                checkpoint = remove_heads(checkpoint, num_heads, num_heads_ckpt)
            else:
                print("Adding heads!")
                checkpoint = append_new_heads(checkpoint, num_heads, num_heads_ckpt)

        ac.load_state_dict(checkpoint["ac"])
        ac_targ = deepcopy(ac)
        ac_targ.load_state_dict(checkpoint["ac_targ"])
        group_name = "sac_pytorch_weights"
    else:
        ac_targ = deepcopy(ac)
        group_name = "sac"

    logger_kwargs = setup_logger_kwargs(variant["exp_name"], variant["seed"])
    if variant["log_to_wandb"]:
        config = variant.copy()
        config["my_PWD"] = os.getenv("PWD", "")
        config["my_WANDBPWD"] = os.getenv("WANDBPWD", "")
        wandb.init(
            name=variant["exp_name"] + f"_{variant['seed']}",
            group=group_name,
            entity="<PLACEHOLDER>",
            project="offlinerl_pretraining",
            config=config,
            tags=variant["exp_tags"],
        )

    torch.set_num_threads(torch.get_num_threads())

    sac(
        train_env,
        test_envs,
        ac=ac,
        ac_targ=ac_targ,
        lr=variant["learning_rate"],
        gamma=variant["gamma"],
        seed=variant["seed"],
        steps=variant["steps"],
        replay_size=variant["replay_size"],
        logger_kwargs=logger_kwargs,
        start_steps=variant["start_steps"],
        log_every=variant["log_every"],
        bootstrap_on_time_limit=variant["bootstrap_on_time_limit"],
        num_tasks=num_heads,
        one_hot_len=one_hot_len,
        done_on_transition=variant["done_on_transition"],
        train_alpha=variant["alpha_init"] == "auto",
        memory=variant["memory"],
        memory_envs=memory_envs,
        actor_memory_weight=variant["actor_memory_weight"],
        num_test_eps_stochastic=variant["num_test_eps_stochastic"],
        num_test_eps_deterministic=variant["num_test_eps_deterministic"],
        num_render_eps_stochastic=variant["num_render_eps_stochastic"],
        num_render_eps_deterministic=variant["num_render_eps_deterministic"],
        critic_memory_weight=variant["critic_memory_weight"],
        update_after=variant["update_after"],
        reset_weights_after_memory=variant["reset_weights_after_memory"],
        save_buffer=variant["save_buffer"],
        apply_kl_loss=variant["apply_kl_loss"],
        reg_method=variant["reg_method"],
    )


if __name__ == "__main__":
    args = sac_parse_args()
    run_sac(variant=vars(args))
