# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=unused-import
import pdb
import copy
import math
import logging
import dataclasses
from collections import OrderedDict
import typing as tp

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore

from url_benchmark import utils
# from url_benchmark import replay_buffer as rb
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark.dmc import TimeStep
from url_benchmark import goals as _goals
from .ddpg import MetaDict
from .fb_modules import IdentityMap
from .ddpg import Encoder
from .fb_modules import BackwardMap, mlp


logger = logging.getLogger(__name__)


from .fb_ddpg import FBDDPGAgentConfig


@dataclasses.dataclass
class DiscreteACROFBAgentConfig(FBDDPGAgentConfig):
    # @package agent
    _target_: str = "url_benchmark.agent.acro_sf.DiscreteACROFBAgent"
    name: str = "discrete_acro_fb"
    preprocess: bool = False
    expl_eps: float = 0.2
    boltzmann = False
    temp = 100

    # ACRO-specific parameters
    acro_learning_rate: float = 1e-4
    acro_weight_decay: float = 1e-5
    acro_forward_weight: float = 1.0
    acro_l2_penalty: float = 0.0
    acro_use_l2_norm: bool = False
    acro_l1_penalty: float = 0.0
    acro_dynamic_l1_penalty: bool = False
    acro_train_stop_epochs: int = 1000000
    acro_representation_train_steps: int = 1000
    acro_k_steps: int = 1
    acro_embed_dim: int = 50
    num_inference_steps: int = 25000
    representation_steps: int = 20000



cs = ConfigStore.instance()
cs.store(group="agent", name="discrete_acro_fb", node=DiscreteACROFBAgentConfig)



class ForwardMap(nn.Module):
    """ forward representation class"""

    def __init__(self, obs_dim, z_dim, action_dim, feature_dim, hidden_dim,
                 preprocess=False, add_trunk=True) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.z_dim = z_dim
        self.action_dim = action_dim
        self.preprocess = preprocess

        if self.preprocess:
            self.obs_net = mlp(self.obs_dim, hidden_dim, "ntanh", feature_dim, "irelu")
            self.obs_z_net = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh", feature_dim, "irelu")
            if not add_trunk:
                self.trunk: nn.Module = nn.Identity()
                feature_dim = 2 * feature_dim
            else:
                self.trunk = mlp(2 * feature_dim, hidden_dim, "irelu")
                feature_dim = hidden_dim
        else:
            self.trunk = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh",
                             hidden_dim, "irelu",
                             hidden_dim, "irelu")
            feature_dim = hidden_dim

        seq = [feature_dim, hidden_dim, "irelu", self.z_dim * self.action_dim]
        self.F1 = mlp(*seq)
        self.F2 = mlp(*seq)

        self.apply(utils.weight_init)

    def forward(self, obs, z):
        assert z.shape[-1] == self.z_dim

        if self.preprocess:
            obs = self.obs_action_net(obs)
            obs_z = self.obs_z_net(torch.cat([obs, z], dim=-1))
            h = torch.cat([obs, obs_z], dim=-1)
        else:
            h = torch.cat([obs, z], dim=-1)
        if hasattr(self, "trunk"):
            h = self.trunk(h)
        F1 = self.F1(h)
        F2 = self.F2(h)
        return F1.reshape(-1, self.z_dim, self.action_dim), F2.reshape(-1, self.z_dim, self.action_dim)
    
class AcroEncoder(nn.Module):
    """ACRO encoder network that learns representations from observations."""
    
    def __init__(self, obs_dim: int, embed_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.encoder = mlp(obs_dim, hidden_dim, "ntanh", hidden_dim, "irelu", embed_dim, "ntanh")
        self.output_dim = embed_dim
        self.apply(utils.weight_init)
    
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.encoder(obs)


class AcroForwardDynamics(nn.Module):
    """ACRO forward dynamics model that predicts next state representation."""
    
    def __init__(self, embed_dim: int, action_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.action_dim = action_dim    
        self.forward_model = mlp(embed_dim + action_dim, hidden_dim, "ntanh", hidden_dim, "irelu", embed_dim, "ntanh")
        self.apply(utils.weight_init)
    
    def forward(self, encoded_obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        # For continuous actions, we use the raw action values
        # For discrete actions, we would need to one-hot encode them
        action = F.one_hot(action.long(), num_classes=self.action_dim).squeeze(1).float()
        # import pdb; pdb.set_trace()
        combined = torch.cat([encoded_obs, action], dim=-1)
        return self.forward_model(combined)


class AcroInverseDynamics(nn.Module):
    """ACRO inverse dynamics model that predicts action from state representations."""
    
    def __init__(self, embed_dim: int, action_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.inverse_model = mlp(2 * embed_dim, hidden_dim, "ntanh", hidden_dim, "irelu", action_dim)
        self.apply(utils.weight_init)
    
    def forward(self, encoded_obs: torch.Tensor, encoded_next_obs: torch.Tensor) -> torch.Tensor:
        combined = torch.cat([encoded_obs, encoded_next_obs], dim=-1)
        return self.inverse_model(combined)

class DiscreteACROFBAgent:

    # pylint: disable=unused-argument
    def __init__(self,
                 **kwargs: tp.Any
                 ):
        cfg = DiscreteACROFBAgentConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        self.solved_meta: tp.Any = None

        # models
        if cfg.obs_type == 'pixels':
            self.aug: nn.Module = utils.RandomShiftsAug(pad=4)
            self.encoder: nn.Module = Encoder(cfg.obs_shape).to(cfg.device)
            self.obs_dim = self.encoder.repr_dim
        else:
            self.aug = nn.Identity()
            self.encoder = nn.Identity()
            self.obs_dim = cfg.obs_shape[0]
        if cfg.feature_dim < self.obs_dim:
            logger.warning(f"feature_dim {cfg.feature_dim} should not be smaller that obs_dim {self.obs_dim}")
        goal_dim = self.obs_dim
        if cfg.goal_space is not None:
            goal_dim = _goals.get_goal_space_dim(cfg.goal_space)
        if cfg.z_dim < goal_dim:
            logger.warning(f"z_dim {cfg.z_dim} should not be smaller that goal_dim {goal_dim}")

        self.forward_net = ForwardMap(self.obs_dim, cfg.z_dim, self.action_dim,
                                      cfg.feature_dim, cfg.hidden_dim,
                                      preprocess=cfg.preprocess, add_trunk=self.cfg.add_trunk).to(cfg.device)
        # if cfg.debug:
        #     self.backward_net: nn.Module = IdentityMap().to(cfg.device)
        #     self.backward_target_net: nn.Module = IdentityMap().to(cfg.device)
        # else:
        #     self.backward_net = BackwardMap(goal_dim, cfg.z_dim, cfg.backward_hidden_dim, norm_z=cfg.norm_z).to(
        #         cfg.device)
        #     self.backward_target_net = BackwardMap(goal_dim,
                                                #    cfg.z_dim, cfg.backward_hidden_dim, norm_z=cfg.norm_z).to(cfg.device)
        # build up the target network
        self.forward_target_net = ForwardMap(self.obs_dim, cfg.z_dim, self.action_dim,
                                             cfg.feature_dim, cfg.hidden_dim,
                                             preprocess=cfg.preprocess, add_trunk=self.cfg.add_trunk).to(cfg.device)
        # load the weights into the target networks
        self.forward_target_net.load_state_dict(self.forward_net.state_dict())
        # self.backward_target_net.load_state_dict(self.backward_net.state_dict())

        # ACRO encoder
        self.acro_encoder = AcroEncoder(self.obs_dim, cfg.acro_embed_dim, cfg.hidden_dim).to(cfg.device)
        self.acro_forward = AcroForwardDynamics(cfg.acro_embed_dim, self.action_dim, cfg.hidden_dim).to(cfg.device)
        self.acro_inverse = AcroInverseDynamics(cfg.acro_embed_dim, self.action_dim, cfg.hidden_dim).to(cfg.device)

        self.acro_encoder_target = AcroEncoder(self.obs_dim, cfg.acro_embed_dim, cfg.hidden_dim).to(cfg.device)
        self.acro_encoder_target.load_state_dict(self.acro_encoder.state_dict())

        # ACRO optimizer
        self.acro_optimizer = torch.optim.Adam(
            list(self.acro_encoder.parameters()) +
            list(self.acro_forward.parameters()) +
            list(self.acro_inverse.parameters()),
            lr=cfg.acro_learning_rate,
            weight_decay=cfg.acro_weight_decay,
        )

        # optimizers
        self.encoder_opt: tp.Optional[torch.optim.Adam] = None
        if cfg.obs_type == 'pixels':
            self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=cfg.lr)

        self.fb_opt = torch.optim.Adam([{'params': self.forward_net.parameters()},],  # type: ignore
                                       lr=cfg.lr)

        self.train()
        self.forward_target_net.train()
        self.acro_encoder_target.train()
        # self.backward_target_net.train()

    def train(self, training: bool = True) -> None:
        self.training = training
        for net in [self.encoder, self.forward_net, self.acro_encoder, self.acro_forward, self.acro_inverse]:
            net.train(training)

    def init_from(self, other) -> None:
        # copy parameters over
        names = ["encoder"]
        if self.cfg.init_fb:
            names += ["forward_net", "backward_net", "backward_target_net", "forward_target_net"]
        for name in names:
            utils.hard_update_params(getattr(other, name), getattr(self, name))
        for key, val in self.__dict__.items():
            if isinstance(val, torch.optim.Optimizer):
                val.load_state_dict(copy.deepcopy(getattr(other, key).state_dict()))

    def get_goal_meta(self, goal_array: np.ndarray) -> MetaDict:
        desired_goal = torch.tensor(goal_array).unsqueeze(0).to(self.cfg.device)
        with torch.no_grad():
            z = self.backward_net(desired_goal)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * F.normalize(z, dim=1)
        z = z.squeeze(0).cpu().numpy()
        meta = OrderedDict()
        meta['z'] = z
        return meta

    def get_neg_goal_meta(self, goal_array: np.ndarray, neg_goal_array: np.ndarray) -> MetaDict:
        print('in get_neg_goal_meta')
        desired_goal = torch.tensor(goal_array).unsqueeze(0).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal_array).unsqueeze(0).to(self.cfg.device)
        with torch.no_grad():
            z = self.backward_net(desired_goal) - self.backward_net(neg_goal)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * F.normalize(z, dim=1)
        z = z.squeeze(0).cpu().numpy()
        meta = OrderedDict()
        meta['z'] = z
        return meta

    def infer_meta(self, replay_loader: ReplayBuffer) -> MetaDict:
        obs_list, reward_list = [], []
        batch_size = 0
        while batch_size < self.cfg.num_inference_steps:
            batch = replay_loader.sample(self.cfg.batch_size)
            batch = batch.to(self.cfg.device)
            obs_list.append(batch.next_goal if self.cfg.goal_space is not None else batch.next_obs)
            reward_list.append(batch.reward)
            batch_size += batch.next_obs.size(0)
        obs, reward = torch.cat(obs_list, 0), torch.cat(reward_list, 0)  # type: ignore
        obs, reward = obs[:self.cfg.num_inference_steps], reward[:self.cfg.num_inference_steps]
        return self.infer_meta_from_obs_and_rewards(obs, reward)

    def infer_meta_from_obs_and_rewards(self, obs: torch.Tensor, reward: torch.Tensor) -> MetaDict:
        print('max reward: ', reward.max().cpu().item())
        print('99 percentile: ', torch.quantile(reward, 0.99).cpu().item())
        print('median reward: ', reward.median().cpu().item())
        print('min reward: ', reward.min().cpu().item())
        print('mean reward: ', reward.mean().cpu().item())
        print('num reward: ', reward.shape[0])

        # filter out small reward
        # pdb.set_trace()
        # idx = torch.where(reward >= torch.quantile(reward, 0.99))[0]
        # obs = obs[idx]
        # reward = reward[idx]
        with torch.no_grad():
            B = self.backward_net(obs)
        z = torch.matmul(reward.T, B) / reward.shape[0]
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * F.normalize(z, dim=1)
        meta = OrderedDict()
        meta['z'] = z.squeeze().cpu().numpy()
        # self.solved_meta = meta
        return meta

    def sample_z(self, size, device: str = "cpu"):
        gaussian_rdv = torch.randn((size, self.cfg.z_dim), dtype=torch.float32, device=device)
        gaussian_rdv = F.normalize(gaussian_rdv, dim=1)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * gaussian_rdv
        else:
            uniform_rdv = torch.rand((size, self.cfg.z_dim), dtype=torch.float32, device=device)
            z = np.sqrt(self.cfg.z_dim) * uniform_rdv * gaussian_rdv
        return z

    def init_meta(self) -> MetaDict:
        if self.solved_meta is not None:
            print('solved_meta')
            return self.solved_meta
        else:
            z = self.sample_z(1)
            z = z.squeeze().numpy()
            meta = OrderedDict()
            meta['z'] = z
        return meta

    # pylint: disable=unused-argument
    def update_meta(
            self,
            meta: MetaDict,
            global_step: int,
            time_step: TimeStep,
            finetune: bool = False,
            replay_loader: tp.Optional[ReplayBuffer] = None
    ) -> MetaDict:
        if global_step % self.cfg.update_z_every_step == 0 and np.random.rand() < self.cfg.update_z_proba:
            return self.init_meta()
        return meta

    def act(self, obs, meta, step, eval_mode) -> tp.Any:
        obs = torch.as_tensor(obs, device=self.cfg.device, dtype=torch.float32).unsqueeze(0)  # type: ignore
        h = self.encoder(obs)
        z = torch.as_tensor(meta['z'], device=self.cfg.device).unsqueeze(0)  # type: ignore
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        action = Q.max(1)[1]

        if not eval_mode:
            if step < self.cfg.num_expl_steps:
                action = torch.randint_like(action, self.action_dim)
            else:
                action = torch.randint_like(action, self.action_dim) \
                    if np.random.rand() < self.cfg.expl_eps else action
        return action.item()
    
    def update_acro(self, obs: torch.Tensor, action: torch.Tensor, next_obs: torch.Tensor) -> tp.Dict[str, float]:
        """Update ACRO representation learning components."""
        # Encode observations
        o_encoded = self.acro_encoder(obs)
        on_encoded = self.acro_encoder(next_obs)

        # Forward dynamics loss
        if self.cfg.acro_forward_weight > 0:
            forward_model_loss = F.mse_loss(
                self.acro_forward(o_encoded, action),
                on_encoded,
            )
        else:
            forward_model_loss = torch.tensor(0.0, device=obs.device)

        # L1 regularization loss
        if self.cfg.acro_l1_penalty > 0 and not self.cfg.acro_use_l2_norm:
            l1_loss = (
                torch.linalg.vector_norm(o_encoded, ord=1, dim=1).mean()
                + torch.linalg.vector_norm(on_encoded, ord=1, dim=1).mean()
            ) / 2
        else:
            l1_loss = torch.zeros(1, device=obs.device)

        # Inverse dynamics loss
        inverse_model_pred = self.acro_inverse(o_encoded, on_encoded)
        inverse_model_loss = F.mse_loss(inverse_model_pred, action)  # Use MSE for continuous actions

        # Calculate accuracy for inverse model (for monitoring)
        if self.action_dim == 1:  # Continuous action
            accuracy = torch.mean((torch.abs(inverse_model_pred - action) < 0.1).float())
        else:  # Discrete action
            accuracy = torch.mean((torch.argmax(inverse_model_pred, dim=1) == action).float())

        # Dynamic L1 penalty
        if self.cfg.acro_dynamic_l1_penalty:
            gain = 5
            cur_l1_penalty = self.cfg.acro_l1_penalty * np.exp(-gain * (accuracy.detach().item() - 1) ** 2)
        else:
            cur_l1_penalty = self.cfg.acro_l1_penalty

        l1_loss = cur_l1_penalty * l1_loss
        forward_model_loss = self.cfg.acro_forward_weight * forward_model_loss

        # L2 regularization loss
        p = F.softmax(o_encoded, dim=1)
        l2_per = torch.linalg.norm(p, ord=2, dim=1)
        l2_loss = l2_per.mean() * self.cfg.acro_l2_penalty

        # Total ACRO loss
        if self.cfg.acro_use_l2_norm:
            total_acro_loss = forward_model_loss + l2_loss + inverse_model_loss
        else:
            total_acro_loss = forward_model_loss + l1_loss + inverse_model_loss

        # Update ACRO components
        self.acro_optimizer.zero_grad(set_to_none=True)
        total_acro_loss.backward()
        self.acro_optimizer.step()

        return {
            "acro_inverse_loss": inverse_model_loss.detach().item(),
            "acro_forward_loss": forward_model_loss.detach().item(),
            "acro_l1_loss": l1_loss.detach().item(),
            "acro_l2_loss": l2_loss.detach().item(),
            "acro_total_loss": total_acro_loss.detach().item(),
            "acro_accuracy": accuracy.detach().item(),
            "acro_cur_l1_penalty": cur_l1_penalty,
            "acro_mean_element_magnitude": torch.abs(o_encoded).float().mean().detach().item(),
            "acro_mean_representation_magnitude": torch.linalg.vector_norm(o_encoded, ord=1, dim=1).mean().detach().item(),
        }


    def update_fb(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        discount: torch.Tensor,
        next_obs: torch.Tensor,
        next_goal: torch.Tensor,
        z: torch.Tensor,
        step: int
    ) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}
        # compute target successor measure
        with torch.no_grad():
            # compute greedy action
            target_F1, target_F2 = self.forward_target_net(next_obs, z)
            next_Q1, next_Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [target_F1, target_F2]]
            next_Q = torch.min(next_Q1, next_Q2)

            if self.cfg.boltzmann:
                pi = F.softmax(next_Q / self.cfg.temp, dim=-1)
                target_F1, target_F2 = [torch.einsum("sa, sda -> sd", pi, Fi) for Fi in [target_F1, target_F2]] # batch x z_dim
                next_Q = torch.einsum("sa, sa -> s", pi, next_Q)
            else:
                next_action = next_Q.max(1)[1]
                next_idx = next_action[:, None].repeat(1, self.cfg.z_dim)[:, :, None]
                target_F1, target_F2 = [Fi.gather(-1, next_idx).squeeze() for Fi in [target_F1, target_F2]]  # batch x z_dim
                next_Q = next_Q.max(1)[0]
            target_B = self.acro_encoder(next_goal)  # batch x z_dim
            target_M1, target_M2 = [torch.einsum('sd, td -> st', Fi, target_B) \
                                    for Fi in [target_F1, target_F2]] # batch x batch
            target_M = torch.min(target_M1, target_M2)

        # compute FB loss
        idxs = action.repeat(1, self.cfg.z_dim)[:, :, None]
        F1, F2 = [Fi.gather(-1, idxs).squeeze() for Fi in self.forward_net(obs, z)]
        B = self.acro_encoder(next_goal)
        M1 = torch.einsum('sd, td -> st', F1, B)  # batch x batch
        M2 = torch.einsum('sd, td -> st', F2, B)  # batch x batch
        I = torch.eye(*M1.size(), device=M1.device)
        off_diag = ~I.bool()
        fb_offdiag: tp.Any = 0.5 * sum((M - discount * target_M)[off_diag].pow(2).mean() for M in [M1, M2])
        fb_diag: tp.Any = -sum(M.diag().mean() for M in [M1, M2])
        fb_loss = fb_offdiag + fb_diag

        # Q LOSS

        if self.cfg.q_loss:
            with torch.no_grad():
                # next_Q1, nextQ2 = [torch.einsum('sd, sd -> s', target_Fi, z) for target_Fi in [target_F1, target_F2]]
                # next_Q = torch.min(next_Q1, nextQ2)
                cov = torch.matmul(B.T, B) / B.shape[0]
                inv_cov = torch.linalg.pinv(cov)
                implicit_reward = (torch.matmul(B, inv_cov) * z).sum(dim=1)  # batch_size
                target_Q = implicit_reward.detach() + discount.squeeze(1) * next_Q  # batch_size

            Q1, Q2 = [torch.einsum('sd, sd -> s', Fi, z) for Fi in [F1, F2]]
            q_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
            fb_loss += self.cfg.q_loss_coef * q_loss

        # ORTHONORMALITY LOSS FOR BACKWARD EMBEDDING

        Cov = torch.matmul(B, B.T)
        orth_loss_diag = - 2 * Cov.diag().mean()
        orth_loss_offdiag = Cov[off_diag].pow(2).mean()
        orth_loss = orth_loss_offdiag + orth_loss_diag
        fb_loss += self.cfg.ortho_coef * orth_loss

        # Cov = torch.cov(B.T)  # Vicreg loss
        # var_loss = F.relu(1 - Cov.diag().clamp(1e-4, 1).sqrt()).mean()  # eps avoids inf. sqrt gradient at 0
        # cov_loss = 2 * torch.triu(Cov, diagonal=1).pow(2).mean() # 2x upper triangular part
        # orth_loss =  var_loss + cov_loss
        # fb_loss += self.cfg.ortho_coef * orth_loss

        if self.cfg.use_tb or self.cfg.use_wandb or self.cfg.use_hiplog:
            metrics['target_M'] = target_M.mean().item()
            metrics['M1'] = M1.mean().item()
            metrics['F1'] = F1.mean().item()
            metrics['B'] = B.mean().item()
            metrics['B_norm'] = torch.norm(B, dim=-1).mean().item()
            metrics['z_norm'] = torch.norm(z, dim=-1).mean().item()
            metrics['fb_loss'] = fb_loss.item()
            metrics['fb_diag'] = fb_diag.item()
            metrics['fb_offdiag'] = fb_offdiag.item()
            if self.cfg.q_loss:
                metrics['q_loss'] = q_loss.item()
            metrics['orth_loss'] = orth_loss.item()
            metrics['orth_loss_diag'] = orth_loss_diag.item()
            metrics['orth_loss_offdiag'] = orth_loss_offdiag.item()
            if self.cfg.q_loss:
                metrics['q_loss'] = q_loss.item()
            eye_diff = torch.matmul(B.T, B) / B.shape[0] - torch.eye(B.shape[1], device=B.device)
            metrics['orth_linf'] = torch.max(torch.abs(eye_diff)).item()
            metrics['orth_l2'] = eye_diff.norm().item() / math.sqrt(B.shape[1])
            if isinstance(self.fb_opt, torch.optim.Adam):
                metrics["fb_opt_lr"] = self.fb_opt.param_groups[0]["lr"]

        # optimize FB
        if self.encoder_opt is not None:
            self.encoder_opt.zero_grad(set_to_none=True)
        self.fb_opt.zero_grad(set_to_none=True)
        fb_loss.backward()
        self.fb_opt.step()
        if self.encoder_opt is not None:
            self.encoder_opt.step()
        return metrics


    def aug_and_encode(self, obs: torch.Tensor) -> torch.Tensor:
        obs = self.aug(obs)
        return self.encoder(obs)

    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        if step< self.cfg.acro_representation_train_steps:
            metrics: tp.Dict[str, float] = {}

            # Only update ACRO representation learning
            batch = replay_loader.sample(self.cfg.batch_size)
            batch = batch.to(self.cfg.device)
            obs = batch.obs
            action = batch.action
            next_obs = batch.next_obs
            acro_metrics = self.update_acro(obs, action, next_obs)
            metrics.update(acro_metrics)

            # Update ACRO target encoder
            utils.soft_update_params(self.acro_encoder, self.acro_encoder_target, self.cfg.fb_target_tau)

            return metrics
        else:

            # print('in discrete fb update')
            metrics: tp.Dict[str, float] = {}

            if step % self.cfg.update_every_steps != 0:
                return metrics

            batch = replay_loader.sample(self.cfg.batch_size)
            batch = batch.to(self.cfg.device)

            # pdb.set_trace()
            obs = batch.obs
            action = batch.action.type(torch.int64)
            discount = batch.discount
            next_obs = next_goal = batch.next_obs
            if self.cfg.goal_space is not None:
                assert batch.next_goal is not None
                next_goal = batch.next_goal

            # if len(batch.meta) == 1 and batch.meta[0].shape[-1] == self.cfg.z_dim:
            #     z = batch.meta[0]
            #     invalid = torch.linalg.norm(z, dim=1) < 1e-15
            #     if sum(invalid):
            #         z[invalid, :] = self.sample_z(sum(invalid)).to(self.cfg.device)
            # else:
            z = self.sample_z(self.cfg.batch_size, device=self.cfg.device)
            if not z.shape[-1] == self.cfg.z_dim:
                raise RuntimeError("There's something wrong with the logic here")
            # obs = self.aug_and_encode(batch.obs)
            # next_obs = self.aug_and_encode(batch.next_obs)
            # if not self.cfg.update_encoder:
            #     obs = obs.detach()
            #     next_obs = next_obs.detach()

            # backward_input = batch.obs
            # future_goal = batch.future_obs
            # if self.cfg.goal_space is not None:
            #     assert batch.goal is not None
            #     backward_input = batch.goal
            #     future_goal = batch.future_goal

            # perm = torch.randperm(self.cfg.batch_size)
            # backward_input = backward_input[perm]

            # if self.cfg.mix_ratio > 0:
            #     mix_idxs: tp.Any = np.where(np.random.uniform(size=self.cfg.batch_size) < self.cfg.mix_ratio)[0]
            #     if not self.cfg.rand_weight:
            #         with torch.no_grad():
            #             mix_z = self.backward_net(backward_input[mix_idxs]).detach()
            #     else:
            #         # generate random weight
            #         weight = torch.rand(size=(mix_idxs.shape[0], self.cfg.batch_size)).to(self.cfg.device)
            #         weight = F.normalize(weight, dim=1)
            #         uniform_rdv = torch.rand(mix_idxs.shape[0], 1).to(self.cfg.device)
            #         weight = uniform_rdv * weight
            #         with torch.no_grad():
            #             mix_z = torch.matmul(weight, self.backward_net(backward_input).detach())
            #     if self.cfg.norm_z:
            #         mix_z = math.sqrt(self.cfg.z_dim) * F.normalize(mix_z, dim=1)
            #     z[mix_idxs] = mix_z

            # hindsight replay
            # if self.cfg.future_ratio > 0:
            #     assert future_goal is not None
            #     future_idxs = np.where(np.random.uniform(size=self.cfg.batch_size) < self.cfg.future_ratio)
            #     z[future_idxs] = self.backward_net(future_goal[future_idxs]).detach()

            metrics.update(self.update_fb(obs=obs, action=action, discount=discount,
                                        next_obs=next_obs, next_goal=next_goal, z=z, step=step))

            # update critic target
            utils.soft_update_params(self.forward_net, self.forward_target_net,
                                    self.cfg.fb_target_tau)

            return metrics

    def q_function(self, obs, goal):
        h = self.encoder(obs)
        z = self.backward_net(goal)
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        return Q
    
    def q_function_pos_neg(self, obs, goal, neg_goal):
        print('in q_function_pos_neg')
        h = self.encoder(obs)
        z = self.backward_net(goal) - self.backward_net(neg_goal)
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        return Q
        
    def plot_q_function(self, work_dir, step, env, goal, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function(obs_list, goal).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg
        
    def plot_q_function_pos_neg(self, work_dir, step, env, goal, neg_goal, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_pos_neg_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function_pos_neg(obs_list, goal, neg_goal).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg
        

    def inference(self, replay_loader: ReplayBuffer, inf_logger, pos_goal_set, neg_goal_set, reward_fn):
        # infer z from pos_goal_set and neg_goal_set
        pos_backward = 0.0
        neg_backward = 0.0
        # print(pos_goal_set, neg_goal_set)
        for goal in pos_goal_set:
            goal_tensor = torch.tensor(goal).unsqueeze(0).to(self.cfg.device)
            with torch.no_grad():
                b = self.acro_encoder(goal_tensor)
            pos_backward += b.squeeze(0)
        for neg_goal in neg_goal_set:
            neg_goal_tensor = torch.tensor(neg_goal).unsqueeze(0).to(self.cfg.device)
            with torch.no_grad():
                b = self.acro_encoder(neg_goal_tensor)
            neg_backward += b.squeeze(0) 

        z = (pos_backward - neg_backward).unsqueeze(0)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * F.normalize(z, dim=1)
        return z
    
    def q_function_inference(self, obs, pos_goal, neg_goal, z):
        z = z.repeat(obs.shape[0], 1)
        h = self.encoder(obs)
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        return Q
    
    def plot_q_function_inference(self, work_dir, step, env, z, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_function_inference')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        # goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function_inference(obs_list, z).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"inference_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg






