# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=unused-import
import pdb
import copy
import math
import logging
import dataclasses
from collections import OrderedDict
import typing as tp
import omegaconf

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore

from url_benchmark import utils
# from url_benchmark import replay_buffer as rb
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark.dmc import TimeStep
from url_benchmark import goals as _goals
from .ddpg import MetaDict
from .fb_modules import IdentityMap
from .ddpg import Encoder
from .fb_modules import BackwardMap, mlp
import time


logger = logging.getLogger(__name__)


from .fb_ddpg import FBDDPGAgentConfig


@dataclasses.dataclass
class DiscretePSMFBAgentConfig(FBDDPGAgentConfig):
    # @package agent
    _target_: str = "url_benchmark.agent.discrete_psm_fb.DiscretePSMFBAgent"
    name: str = "discrete_psm_fb"
    preprocess: bool = False
    expl_eps: float = 0.2
    boltzmann = True
    temp = 100
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    lr_w: float = 3e-4
    lr_coef: float = 1
    fb_target_tau: float = 0.01  # 0.001-0.01
    update_every_steps: int = 1
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING  # ???  # to be specified later
    num_inference_steps: int = 20000
    hidden_dim: int = 1024   # 128, 2048
    backward_hidden_dim: int = 526   # 512
    feature_dim: int = 512   # 128, 1024
    z_dim: int = 16  # 100
    d_dim: int = 50  # 100
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)" #
    stddev_clip: float = 0.3  # 1
    update_z_every_step: int = 300
    update_z_proba: float = 1.0
    nstep: int = 1
    batch_size: int = 16 # 512
    num_neg_samples = 512
    init_fb: bool = True
    update_encoder: bool = omegaconf.II("update_encoder")  # ${update_encoder}
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    ortho_coef: float = 0.0  # 0.01-10
    cons_coef: float = 0.01  # 0.01-10
    log_std_bounds: tp.Tuple[float, float] = (-5, 2)  # param for DiagGaussianActor
    debug: bool = False
    future_ratio: float = 0.0
    mix_ratio: float = 0.5  # 0-1
    rand_weight: bool = False  # True, False
    preprocess: bool = False
    norm_z: bool = True
    q_loss: bool = False
    q_loss_coef: float = 0.01
    additional_metric: bool = False
    add_trunk: bool = False
    use_dgd: bool = True
    softmax: bool = True
    div_eps: float = 1000.0
    div_coef: float = 0.01
    inf_coeff: float = 5.0
    representation_learning_steps: int = 40000



cs = ConfigStore.instance()
cs.store(group="agent", name="discrete_psm_fb", node=DiscretePSMFBAgentConfig)



class ForwardMap(nn.Module):
    """ forward representation class"""

    def __init__(self, obs_dim, z_dim, action_dim, feature_dim, hidden_dim,
                 preprocess=False, add_trunk=True) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.z_dim = z_dim
        self.action_dim = action_dim
        self.preprocess = preprocess

        if self.preprocess:
            self.obs_net = mlp(self.obs_dim, hidden_dim, "ntanh", feature_dim, "irelu")
            self.obs_z_net = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh", feature_dim, "irelu")
            if not add_trunk:
                self.trunk: nn.Module = nn.Identity()
                feature_dim = 2 * feature_dim
            else:
                self.trunk = mlp(2 * feature_dim, hidden_dim, "irelu")
                feature_dim = hidden_dim
        else:
            self.trunk = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh",
                             hidden_dim, "irelu",
                             hidden_dim, "irelu")
            feature_dim = hidden_dim

        seq = [feature_dim, hidden_dim, "irelu", self.z_dim * self.action_dim]
        self.F1 = mlp(*seq)
        self.F2 = mlp(*seq)

        self.apply(utils.weight_init)

    def forward(self, obs, z):
        assert z.shape[-1] == self.z_dim

        if self.preprocess:
            obs = self.obs_net(obs)
            obs_z = self.obs_z_net(torch.cat([obs, z], dim=-1))
            h = torch.cat([obs, obs_z], dim=-1)
        else:
            h = torch.cat([obs, z], dim=-1)
        if hasattr(self, "trunk"):
            h = self.trunk(h)
        F1 = self.F1(h)
        F2 = self.F2(h)
        return F1.reshape(-1, self.z_dim, self.action_dim), F2.reshape(-1, self.z_dim, self.action_dim)
    
class SamplingSeedActor(nn.Module):
    def __init__(self, action_dim, z_dim, batch_size):
        super().__init__()
        self.z_dim = z_dim
        self.action_dim = action_dim
        self.powers = torch.tensor([2**i for i in range(self.z_dim)][::-1]).to('cuda').repeat(batch_size,1)
        self.max_seed = 2**z_dim+20000
        self.seed_to_action = []
        
        for i in range(self.max_seed):
            torch.random.manual_seed(i)
            action = torch.randint(0, self.action_dim, (1,)).unsqueeze(0).numpy()
            self.seed_to_action.append(action)
        self.seed_to_action = np.array(self.seed_to_action)
        self.seed_to_action = torch.tensor(self.seed_to_action).to('cuda')
    
    def forward(self, obs_hash, z):
        # import ipdb;ipdb.set_trace()
        actions = []
        z_seed_time = time.time()
        seed_long = (z*self.powers).sum(1)
        # print("Time to compute z seed: ", time.time()-z_seed_time)
        final_seed_computation_time = time.time()
        final_seed = seed_long+obs_hash.reshape(-1)
        # print("Time to compute final seed: ", time.time()-final_seed_computation_time)
        # import ipdb;ipdb.set_trace()
        actions_computation_time = time.time()
        actions = self.seed_to_action[final_seed.long()]
        # print("Time to compute actions: ", time.time()-actions_computation_time)
        return torch.tensor(actions.reshape(-1,1)).to('cuda')

class DiscretePSMFBAgent:

    # pylint: disable=unused-argument
    def __init__(self,
                 **kwargs: tp.Any
                 ):
        cfg = DiscretePSMFBAgentConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        self.solved_meta: tp.Any = None

        # models
        if cfg.obs_type == 'pixels':
            self.aug: nn.Module = utils.RandomShiftsAug(pad=4)
            self.encoder: nn.Module = Encoder(cfg.obs_shape).to(cfg.device)
            self.obs_dim = self.encoder.repr_dim
        else:
            self.aug = nn.Identity()
            self.encoder = nn.Identity()
            self.obs_dim = cfg.obs_shape[0]
        if cfg.feature_dim < self.obs_dim:
            logger.warning(f"feature_dim {cfg.feature_dim} should not be smaller that obs_dim {self.obs_dim}")
        goal_dim = self.obs_dim
        if cfg.goal_space is not None:
            goal_dim = _goals.get_goal_space_dim(cfg.goal_space)
        if cfg.z_dim < goal_dim:
            logger.warning(f"z_dim {cfg.z_dim} should not be smaller that goal_dim {goal_dim}")

        self.forward_net = ForwardMap(self.obs_dim, cfg.d_dim, self.action_dim,
                                      cfg.feature_dim, cfg.hidden_dim,
                                      preprocess=cfg.preprocess, add_trunk=self.cfg.add_trunk).to(cfg.device)
        
        self.forward_psm_net = ForwardMap(self.obs_dim, cfg.d_dim, self.action_dim,
                                      cfg.feature_dim, cfg.hidden_dim,
                                      preprocess=cfg.preprocess, add_trunk=self.cfg.add_trunk).to(cfg.device)
        if cfg.debug:
            self.backward_net: nn.Module = IdentityMap().to(cfg.device)
            self.backward_target_net: nn.Module = IdentityMap().to(cfg.device)
        else:
            self.backward_net = BackwardMap(goal_dim, cfg.d_dim, cfg.backward_hidden_dim, norm_z=cfg.norm_z).to(
                cfg.device)
            self.backward_target_net = BackwardMap(goal_dim,
                                                   cfg.d_dim, cfg.backward_hidden_dim, norm_z=cfg.norm_z).to(cfg.device)
        # build up the target network
        self.forward_target_net = ForwardMap(self.obs_dim, cfg.d_dim, self.action_dim,
                                             cfg.feature_dim, cfg.hidden_dim,
                                             preprocess=cfg.preprocess, add_trunk=self.cfg.add_trunk).to(cfg.device)

        self.forward_psm_target_net = ForwardMap(self.obs_dim, cfg.d_dim, self.action_dim,
                                             cfg.feature_dim, cfg.hidden_dim,
                                             preprocess=cfg.preprocess, add_trunk=self.cfg.add_trunk).to(cfg.device)
        
        self.w = mlp(cfg.z_dim, cfg.hidden_dim, "irelu", 
                     cfg.hidden_dim, "irelu",
                     cfg.hidden_dim, "irelu", cfg.d_dim,"L2").to(cfg.device)
        
        self.w.apply(utils.weight_init)
        
        self.w_target = mlp(cfg.z_dim, cfg.hidden_dim, "irelu",
                            cfg.hidden_dim, "irelu",
                            cfg.hidden_dim, "irelu", cfg.d_dim,"L2").to(cfg.device)
        
        self.sampling_actor = SamplingSeedActor(self.action_dim, cfg.z_dim, cfg.batch_size).to(cfg.device)

        self.w_target.load_state_dict(self.w.state_dict())
        self.forward_psm_target_net.load_state_dict(self.forward_psm_net.state_dict())
        # load the weights into the target networks
        self.forward_target_net.load_state_dict(self.forward_net.state_dict())
        self.backward_target_net.load_state_dict(self.backward_net.state_dict())
        # optimizers
        self.encoder_opt: tp.Optional[torch.optim.Adam] = None
        if cfg.obs_type == 'pixels':
            self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=cfg.lr)

        self.fb_opt = torch.optim.Adam([{'params': self.forward_net.parameters()}],
                                       lr=cfg.lr)

        self.psm_opt = torch.optim.Adam([{'params': self.forward_psm_net.parameters()},
                                         {'params': self.backward_net.parameters()},
                                         {'params': self.w.parameters()}],
                                        lr=cfg.lr)

        self.train()
        self.forward_target_net.train()
        self.backward_target_net.train()
        self.forward_psm_target_net.train()
        self.w_target.train()

    def train(self, training: bool = True) -> None:
        self.training = training
        for net in [self.encoder, self.forward_net, self.backward_net, self.w, self.forward_psm_net]:
            net.train(training)

    def init_from(self, other) -> None:
        # copy parameters over
        names = ["encoder"]
        if self.cfg.init_fb:
            names += ["forward_net", "backward_net", "backward_target_net", "forward_target_net"]
        for name in names:
            utils.hard_update_params(getattr(other, name), getattr(self, name))
        for key, val in self.__dict__.items():
            if isinstance(val, torch.optim.Optimizer):
                val.load_state_dict(copy.deepcopy(getattr(other, key).state_dict()))

    def get_goal_meta(self, goal_array: np.ndarray) -> MetaDict:
        desired_goal = torch.tensor(goal_array).unsqueeze(0).to(self.cfg.device)
        with torch.no_grad():
            z = self.backward_net(desired_goal)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * F.normalize(z, dim=1)
        z = z.squeeze(0).cpu().numpy()
        meta = OrderedDict()
        meta['z'] = z
        return meta

    def get_neg_goal_meta(self, goal_array: np.ndarray, neg_goal_array: np.ndarray) -> MetaDict:
        print('in get_neg_goal_meta')
        desired_goal = torch.tensor(goal_array).unsqueeze(0).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal_array).unsqueeze(0).to(self.cfg.device)
        with torch.no_grad():
            z = self.backward_net(desired_goal) - self.backward_net(neg_goal)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * F.normalize(z, dim=1)
        z = z.squeeze(0).cpu().numpy()
        meta = OrderedDict()
        meta['z'] = z
        return meta

    def infer_meta(self, replay_loader: ReplayBuffer) -> MetaDict:
        obs_list, reward_list = [], []
        batch_size = 0
        while batch_size < self.cfg.num_inference_steps:
            batch = replay_loader.sample(self.cfg.batch_size)
            batch = batch.to(self.cfg.device)
            obs_list.append(batch.next_goal if self.cfg.goal_space is not None else batch.next_obs)
            reward_list.append(batch.reward)
            batch_size += batch.next_obs.size(0)
        obs, reward = torch.cat(obs_list, 0), torch.cat(reward_list, 0)  # type: ignore
        obs, reward = obs[:self.cfg.num_inference_steps], reward[:self.cfg.num_inference_steps]
        return self.infer_meta_from_obs_and_rewards(obs, reward)

    def infer_meta_from_obs_and_rewards(self, obs: torch.Tensor, reward: torch.Tensor) -> MetaDict:
        print('max reward: ', reward.max().cpu().item())
        print('99 percentile: ', torch.quantile(reward, 0.99).cpu().item())
        print('median reward: ', reward.median().cpu().item())
        print('min reward: ', reward.min().cpu().item())
        print('mean reward: ', reward.mean().cpu().item())
        print('num reward: ', reward.shape[0])

        # filter out small reward
        # pdb.set_trace()
        # idx = torch.where(reward >= torch.quantile(reward, 0.99))[0]
        # obs = obs[idx]
        # reward = reward[idx]
        with torch.no_grad():
            B = self.backward_net(obs)
        z = torch.matmul(reward.T, B) / reward.shape[0]
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.z_dim) * F.normalize(z, dim=1)
        meta = OrderedDict()
        meta['z'] = z.squeeze().cpu().numpy()
        # self.solved_meta = meta
        return meta

    def sample_z(self, size, device: str = "cpu"):
        gaussian_rdv = torch.randn((size, self.cfg.d_dim), dtype=torch.float32, device=device)
        gaussian_rdv = F.normalize(gaussian_rdv, dim=1)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.d_dim) * gaussian_rdv
        else:
            uniform_rdv = torch.rand((size, self.cfg.d_dim), dtype=torch.float32, device=device)
            z = np.sqrt(self.cfg.d_dim) * uniform_rdv * gaussian_rdv
        return z
    
    def int_to_binary_array(self, int_vector, num_bits=None):
        if num_bits is None:
            num_bits = int_vector.max().bit_length()
        
        binary_array = ((int_vector[:, None] & (1 << np.arange(num_bits))) > 0).astype(int)
        return binary_array
    
    def sample_z_psm(self, size, device: str = "cpu"):
        z_np = np.random.randint(0, 2**self.cfg.z_dim, (size,))
        binary_array = self.int_to_binary_array(z_np, self.cfg.z_dim)
        return torch.FloatTensor(binary_array).to(device)

    def init_meta(self) -> MetaDict:
        if self.solved_meta is not None:
            print('solved_meta')
            return self.solved_meta
        else:
            z = self.sample_z(1)
            z = z.squeeze().numpy()
            meta = OrderedDict()
            meta['z'] = z
        return meta

    # pylint: disable=unused-argument
    def update_meta(
            self,
            meta: MetaDict,
            global_step: int,
            time_step: TimeStep,
            finetune: bool = False,
            replay_loader: tp.Optional[ReplayBuffer] = None
    ) -> MetaDict:
        if global_step % self.cfg.update_z_every_step == 0 and np.random.rand() < self.cfg.update_z_proba:
            return self.init_meta()
        return meta

    def act(self, obs, meta, step, eval_mode) -> tp.Any:
        obs = torch.as_tensor(obs, device=self.cfg.device, dtype=torch.float32).unsqueeze(0)  # type: ignore
        h = self.encoder(obs)
        z = torch.as_tensor(meta['z'], device=self.cfg.device).unsqueeze(0)  # type: ignore
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        action = Q.max(1)[1]

        if not eval_mode:
            if step < self.cfg.num_expl_steps:
                action = torch.randint_like(action, self.action_dim)
            else:
                action = torch.randint_like(action, self.action_dim) \
                    if np.random.rand() < self.cfg.expl_eps else action
        return action.item()


    def update_psm(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        discount: torch.Tensor,
        next_obs: torch.Tensor,
        next_obs_hash: torch.Tensor,
        next_goal: torch.Tensor,
        z: torch.Tensor,
        step: int
    ) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}
        # compute target successor measure
        with torch.no_grad():
            # compute greedy action
            target_w = self.w_target(z)
            target_F1, target_F2 = self.forward_psm_target_net(next_obs, target_w)
            next_actions = self.sampling_actor(next_obs_hash, z)

            # next_Q1, next_Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [target_F1, target_F2]]
            # next_Q = torch.min(next_Q1, next_Q2)

            # if self.cfg.boltzmann:
            #     pi = F.softmax(next_Q / self.cfg.temp, dim=-1)
            #     target_F1, target_F2 = [torch.einsum("sa, sda -> sd", pi, Fi) for Fi in [target_F1, target_F2]] # batch x z_dim
            #     next_Q = torch.einsum("sa, sa -> s", pi, next_Q)
            # else:
            #     next_action = next_Q.max(1)[1]
            #     next_idx = next_action[:, None].repeat(1, self.cfg.z_dim)[:, :, None]
            #     target_F1, target_F2 = [Fi.gather(-1, next_idx).squeeze() for Fi in [target_F1, target_F2]]  # batch x z_dim
            #     next_Q = next_Q.max(1)[0]
            next_idx = next_actions.repeat(1, self.cfg.d_dim)[:, :, None]
            target_F1, target_F2 = [Fi.gather(-1, next_idx).squeeze() for Fi in [target_F1, target_F2]]  # batch x z
            target_B = self.backward_target_net(next_goal)  # batch x z_dim
            # import pdb; pdb.set_trace()
            target_M1, target_M2 = [torch.einsum('sd, td -> st', Fi, target_B) \
                                    for Fi in [target_F1, target_F2]] # batch x batch
            target_M = torch.min(target_M1, target_M2)

        # compute FB loss
        idxs = action.repeat(1, self.cfg.d_dim)[:, :, None]
        w = self.w(z)
        F1, F2 = [Fi.gather(-1, idxs).squeeze() for Fi in self.forward_psm_net(obs, w)]
        B = self.backward_net(next_goal)
        M1 = torch.einsum('sd, td -> st', F1, B)  # batch x batch
        M2 = torch.einsum('sd, td -> st', F2, B)  # batch x batch
        I = torch.eye(*M1.size(), device=M1.device)
        off_diag = ~I.bool()
        fb_offdiag: tp.Any = 0.5 * sum((M - discount * target_M)[off_diag].pow(2).mean() for M in [M1, M2])
        fb_diag: tp.Any = -sum(M.diag().mean() for M in [M1, M2])
        fb_loss = fb_offdiag + fb_diag

        # Q LOSS

        # if self.cfg.q_loss:
        #     with torch.no_grad():
        #         # next_Q1, nextQ2 = [torch.einsum('sd, sd -> s', target_Fi, z) for target_Fi in [target_F1, target_F2]]
        #         # next_Q = torch.min(next_Q1, nextQ2)
        #         cov = torch.matmul(B.T, B) / B.shape[0]
        #         inv_cov = torch.linalg.pinv(cov)
        #         implicit_reward = (torch.matmul(B, inv_cov) * z).sum(dim=1)  # batch_size
        #         target_Q = implicit_reward.detach() + discount.squeeze(1) * next_Q  # batch_size

        #     Q1, Q2 = [torch.einsum('sd, sd -> s', Fi, z) for Fi in [F1, F2]]
        #     q_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
        #     fb_loss += self.cfg.q_loss_coef * q_loss

        # ORTHONORMALITY LOSS FOR BACKWARD EMBEDDING

        Cov = torch.matmul(B, B.T)
        orth_loss_diag = - 2 * Cov.diag().mean()
        orth_loss_offdiag = Cov[off_diag].pow(2).mean()
        orth_loss = orth_loss_offdiag + orth_loss_diag
        fb_loss += self.cfg.ortho_coef * orth_loss

        # Cov = torch.cov(B.T)  # Vicreg loss
        # var_loss = F.relu(1 - Cov.diag().clamp(1e-4, 1).sqrt()).mean()  # eps avoids inf. sqrt gradient at 0
        # cov_loss = 2 * torch.triu(Cov, diagonal=1).pow(2).mean() # 2x upper triangular part
        # orth_loss =  var_loss + cov_loss
        # fb_loss += self.cfg.ortho_coef * orth_loss

        if self.cfg.use_tb or self.cfg.use_wandb or self.cfg.use_hiplog:
            metrics['target_M'] = target_M.mean().item()
            metrics['M1'] = M1.mean().item()
            metrics['F1'] = F1.mean().item()
            metrics['B'] = B.mean().item()
            metrics['B_norm'] = torch.norm(B, dim=-1).mean().item()
            metrics['z_norm'] = torch.norm(z, dim=-1).mean().item()
            metrics['fb_loss'] = fb_loss.item()
            metrics['fb_diag'] = fb_diag.item()
            metrics['fb_offdiag'] = fb_offdiag.item()
            if self.cfg.q_loss:
                metrics['q_loss'] = q_loss.item()
            metrics['orth_loss'] = orth_loss.item()
            metrics['orth_loss_diag'] = orth_loss_diag.item()
            metrics['orth_loss_offdiag'] = orth_loss_offdiag.item()
            if self.cfg.q_loss:
                metrics['q_loss'] = q_loss.item()
            eye_diff = torch.matmul(B.T, B) / B.shape[0] - torch.eye(B.shape[1], device=B.device)
            metrics['orth_linf'] = torch.max(torch.abs(eye_diff)).item()
            metrics['orth_l2'] = eye_diff.norm().item() / math.sqrt(B.shape[1])
            if isinstance(self.fb_opt, torch.optim.Adam):
                metrics["fb_opt_lr"] = self.fb_opt.param_groups[0]["lr"]

        # optimize FB
        if self.encoder_opt is not None:
            self.encoder_opt.zero_grad(set_to_none=True)
        self.psm_opt.zero_grad(set_to_none=True)
        fb_loss.backward()
        self.psm_opt.step()
        if self.encoder_opt is not None:
            self.encoder_opt.step()
        return metrics

    def update_fb(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        discount: torch.Tensor,
        next_obs: torch.Tensor,
        next_goal: torch.Tensor,
        z: torch.Tensor,
        step: int
    ) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}
        # compute target successor measure
        with torch.no_grad():
            # compute greedy action
            target_F1, target_F2 = self.forward_target_net(next_obs, z)
            next_Q1, next_Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [target_F1, target_F2]]
            next_Q = torch.min(next_Q1, next_Q2)

            if self.cfg.boltzmann:
                pi = F.softmax(next_Q / self.cfg.temp, dim=-1)
                target_F1, target_F2 = [torch.einsum("sa, sda -> sd", pi, Fi) for Fi in [target_F1, target_F2]] # batch x z_dim
                next_Q = torch.einsum("sa, sa -> s", pi, next_Q)
            else:
                next_action = next_Q.max(1)[1]
                next_idx = next_action[:, None].repeat(1, self.cfg.d_dim)[:, :, None]
                target_F1, target_F2 = [Fi.gather(-1, next_idx).squeeze() for Fi in [target_F1, target_F2]]  # batch x d_dim
                next_Q = next_Q.max(1)[0]
            target_B = self.backward_net(next_goal)  # batch x d_dim
            target_M1, target_M2 = [torch.einsum('sd, td -> st', Fi, target_B) \
                                    for Fi in [target_F1, target_F2]] # batch x batch
            target_M = torch.min(target_M1, target_M2)

        # import pdb; pdb.set_trace()
        # compute FB loss
        idxs = action.repeat(1, self.cfg.d_dim)[:, :, None]
        F1, F2 = [Fi.gather(-1, idxs).squeeze() for Fi in self.forward_net(obs, z)]
        B = self.backward_net(next_goal)
        M1 = torch.einsum('sd, td -> st', F1, B)  # batch x batch
        M2 = torch.einsum('sd, td -> st', F2, B)  # batch x batch
        I = torch.eye(*M1.size(), device=M1.device)
        off_diag = ~I.bool()
        fb_offdiag: tp.Any = 0.5 * sum((M - discount * target_M)[off_diag].pow(2).mean() for M in [M1, M2])
        fb_diag: tp.Any = -sum(M.diag().mean() for M in [M1, M2])
        fb_loss = fb_offdiag + fb_diag

        # Q LOSS

        if self.cfg.q_loss:
            with torch.no_grad():
                # next_Q1, nextQ2 = [torch.einsum('sd, sd -> s', target_Fi, z) for target_Fi in [target_F1, target_F2]]
                # next_Q = torch.min(next_Q1, nextQ2)
                cov = torch.matmul(B.T, B) / B.shape[0]
                inv_cov = torch.linalg.pinv(cov)
                implicit_reward = (torch.matmul(B, inv_cov) * z).sum(dim=1)  # batch_size
                target_Q = implicit_reward.detach() + discount.squeeze(1) * next_Q  # batch_size

            Q1, Q2 = [torch.einsum('sd, sd -> s', Fi, z) for Fi in [F1, F2]]
            q_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
            fb_loss += self.cfg.q_loss_coef * q_loss

        # ORTHONORMALITY LOSS FOR BACKWARD EMBEDDING

        # Cov = torch.matmul(B, B.T)
        # orth_loss_diag = - 2 * Cov.diag().mean()
        # orth_loss_offdiag = Cov[off_diag].pow(2).mean()
        # orth_loss = orth_loss_offdiag + orth_loss_diag
        # fb_loss += self.cfg.ortho_coef * orth_loss

        # Cov = torch.cov(B.T)  # Vicreg loss
        # var_loss = F.relu(1 - Cov.diag().clamp(1e-4, 1).sqrt()).mean()  # eps avoids inf. sqrt gradient at 0
        # cov_loss = 2 * torch.triu(Cov, diagonal=1).pow(2).mean() # 2x upper triangular part
        # orth_loss =  var_loss + cov_loss
        # fb_loss += self.cfg.ortho_coef * orth_loss

        if self.cfg.use_tb or self.cfg.use_wandb or self.cfg.use_hiplog:
            metrics['target_M'] = target_M.mean().item()
            metrics['M1'] = M1.mean().item()
            metrics['F1'] = F1.mean().item()
            metrics['B'] = B.mean().item()
            metrics['B_norm'] = torch.norm(B, dim=-1).mean().item()
            metrics['z_norm'] = torch.norm(z, dim=-1).mean().item()
            metrics['fb_loss'] = fb_loss.item()
            metrics['fb_diag'] = fb_diag.item()
            metrics['fb_offdiag'] = fb_offdiag.item()
            if self.cfg.q_loss:
                metrics['q_loss'] = q_loss.item()
            # metrics['orth_loss'] = orth_loss.item()
            # metrics['orth_loss_diag'] = orth_loss_diag.item()
            # metrics['orth_loss_offdiag'] = orth_loss_offdiag.item()
            if self.cfg.q_loss:
                metrics['q_loss'] = q_loss.item()
            eye_diff = torch.matmul(B.T, B) / B.shape[0] - torch.eye(B.shape[1], device=B.device)
            metrics['orth_linf'] = torch.max(torch.abs(eye_diff)).item()
            metrics['orth_l2'] = eye_diff.norm().item() / math.sqrt(B.shape[1])
            if isinstance(self.fb_opt, torch.optim.Adam):
                metrics["fb_opt_lr"] = self.fb_opt.param_groups[0]["lr"]

        # optimize FB
        if self.encoder_opt is not None:
            self.encoder_opt.zero_grad(set_to_none=True)
        self.fb_opt.zero_grad(set_to_none=True)
        fb_loss.backward()
        self.fb_opt.step()
        if self.encoder_opt is not None:
            self.encoder_opt.step()
        return metrics


    def aug_and_encode(self, obs: torch.Tensor) -> torch.Tensor:
        obs = self.aug(obs)
        return self.encoder(obs)

    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        # if step < self.cfg.representation_learning_steps:
            metrics: tp.Dict[str, float] = {}
            # print(step)
            if step % self.cfg.update_every_steps != 0:
                return metrics
            
            start_t = time.time()
            batch = replay_loader.sample(self.cfg.batch_size)
            batch = batch.to(self.cfg.device)
            # batch2 = replay_loader.sample(self.cfg.batch_size ** 2)
            # batch2 = batch2.to(self.cfg.device)
            replay_buffer_sampling_time = time.time() - start_t
            start_t = time.time()
            # print("Batch sampled")
            # pdb.set_trace()
            obs = batch.obs
            action = batch.action.type(torch.int64)
            discount = batch.discount
            next_obs = batch.next_obs
            next_goal = batch.next_obs
            if self.cfg.goal_space is not None:
                assert batch.next_goal is not None
                next_goal = batch.next_goal
            next_obs_hash = batch.next_obs_hash
            rewards = batch.reward
            
            # robs = batch2.obs
            # raction = batch2.action.type(torch.int64)
            # rnext_obs = batch2.next_obs
            # perm = torch.randperm(robs.shape[0])
            # rnext_obs = rnext_obs[perm]
            batch_time = time.time() - start_t
            start_t = time.time()
            z = self.sample_z_psm(self.cfg.batch_size, device=self.cfg.device)
            # print(z.shape)
            # print("Z sampling took: ",time.time()-start_sample)
            # print("Z sampled")
            # z = torch.repeat_interleave(z, self.cfg.batch_size, 0)
            # z = z.repeat_interleave(self.cfg.num_neg_samples, 1)
            # print(z.shape)
            if not z.shape[-1] == self.cfg.z_dim:
                raise RuntimeError("There's something wrong with the logic here")

            # rz = self.sample_z(self.cfg.batch_size ** 2, device=self.cfg.device)

            # z_sampling_time = time.time() - start_t
            # start_t = time.time()
            # robs = None
            # raction = None
            # rnext_obs = None
            # rz = None
            metrics.update(self.update_psm(obs=obs, action=action, discount=discount,
                                        next_obs=next_obs,next_obs_hash=next_obs_hash, next_goal=next_goal, 
                                        z=z, step=step))
            # print("Time to update: {}".format(time.time()-update_start))
            # update critic target
            psm_update_time = time.time() - start_t
            start_t = time.time()

            # print(f"Replay buffer sampling time: {replay_buffer_sampling_time}, Batch time: {batch_time}, Z sampling time: {z_sampling_time}, PSM update time: {psm_update_time}")
            utils.soft_update_params(self.forward_psm_net, self.forward_psm_target_net,
                                    self.cfg.fb_target_tau)
            utils.soft_update_params(self.backward_net, self.backward_target_net,
                                    self.cfg.fb_target_tau)
            utils.soft_update_params(self.w, self.w_target,
                                    self.cfg.fb_target_tau)

            # return metrics
        
            # else:
            # print('in discrete fb update')
            metrics: tp.Dict[str, float] = {}

            # if step % self.cfg.update_every_steps != 0:
            #     return metrics

            batch = replay_loader.sample(1024)
            batch = batch.to(self.cfg.device)

            # pdb.set_trace()
            obs = batch.obs
            action = batch.action.type(torch.int64)
            discount = batch.discount
            next_obs = next_goal = batch.next_obs
            if self.cfg.goal_space is not None:
                assert batch.next_goal is not None
                next_goal = batch.next_goal

            # if len(batch.meta) == 1 and batch.meta[0].shape[-1] == self.cfg.z_dim:
            #     z = batch.meta[0]
            #     invalid = torch.linalg.norm(z, dim=1) < 1e-15
            #     if sum(invalid):
            #         z[invalid, :] = self.sample_z(sum(invalid)).to(self.cfg.device)
            # else:
            z = self.sample_z(1024, device=self.cfg.device)
            if not z.shape[-1] == self.cfg.d_dim:
                raise RuntimeError("There's something wrong with the logic here")
            # obs = self.aug_and_encode(batch.obs)
            # next_obs = self.aug_and_encode(batch.next_obs)
            # if not self.cfg.update_encoder:
            #     obs = obs.detach()
            #     next_obs = next_obs.detach()

            # backward_input = batch.obs
            # future_goal = batch.future_obs
            # if self.cfg.goal_space is not None:
            #     assert batch.goal is not None
            #     backward_input = batch.goal
            #     future_goal = batch.future_goal

            # perm = torch.randperm(self.cfg.batch_size)
            # backward_input = backward_input[perm]

            # if self.cfg.mix_ratio > 0:
            #     mix_idxs: tp.Any = np.where(np.random.uniform(size=self.cfg.batch_size) < self.cfg.mix_ratio)[0]
            #     if not self.cfg.rand_weight:
            #         with torch.no_grad():
            #             mix_z = self.backward_net(backward_input[mix_idxs]).detach()
            #     else:
            #         # generate random weight
            #         weight = torch.rand(size=(mix_idxs.shape[0], self.cfg.batch_size)).to(self.cfg.device)
            #         weight = F.normalize(weight, dim=1)
            #         uniform_rdv = torch.rand(mix_idxs.shape[0], 1).to(self.cfg.device)
            #         weight = uniform_rdv * weight
            #         with torch.no_grad():
            #             mix_z = torch.matmul(weight, self.backward_net(backward_input).detach())
            #     if self.cfg.norm_z:
            #         mix_z = math.sqrt(self.cfg.z_dim) * F.normalize(mix_z, dim=1)
            #     z[mix_idxs] = mix_z

            # hindsight replay
            # if self.cfg.future_ratio > 0:
            #     assert future_goal is not None
            #     future_idxs = np.where(np.random.uniform(size=self.cfg.batch_size) < self.cfg.future_ratio)
            #     z[future_idxs] = self.backward_net(future_goal[future_idxs]).detach()

            metrics.update(self.update_fb(obs=obs, action=action, discount=discount,
                                        next_obs=next_obs, next_goal=next_goal, z=z, step=step))

            # update critic target
            utils.soft_update_params(self.forward_net, self.forward_target_net,
                                    self.cfg.fb_target_tau)
            # utils.soft_update_params(self.backward_net, self.backward_target_net,
            #                         self.cfg.fb_target_tau)

            # return metrics
            return {}

    def q_function(self, obs, goal):
        h = self.encoder(obs)
        z = self.backward_net(goal)
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        return Q
    
    def q_function_pos_neg(self, obs, goal, neg_goal):
        print('in q_function_pos_neg')
        h = self.encoder(obs)
        z = self.backward_net(goal) - self.backward_net(neg_goal)
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        return Q
        
    def plot_q_function(self, work_dir, step, env, goal, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function(obs_list, goal).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg
        
    def plot_q_function_pos_neg(self, work_dir, step, env, goal, neg_goal, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_pos_neg_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function_pos_neg(obs_list, goal, neg_goal).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg
        

    def inference(self, replay_loader: ReplayBuffer, inf_logger, pos_goal_set, neg_goal_set, reward_fn):
        # infer z from pos_goal_set and neg_goal_set
        pos_backward = 0.0
        neg_backward = 0.0
        # print(pos_goal_set, neg_goal_set)
        for goal in pos_goal_set:
            goal_tensor = torch.tensor(goal).unsqueeze(0).to(self.cfg.device)
            with torch.no_grad():
                b = self.backward_net(goal_tensor)
            pos_backward += b.squeeze(0)
        for neg_goal in neg_goal_set:
            neg_goal_tensor = torch.tensor(neg_goal).unsqueeze(0).to(self.cfg.device)
            with torch.no_grad():
                b = self.backward_net(neg_goal_tensor)
            neg_backward += b.squeeze(0) 

        z = (pos_backward - neg_backward).unsqueeze(0)
        if self.cfg.norm_z:
            z = math.sqrt(self.cfg.d_dim) * F.normalize(z, dim=1)
        return z
    
    def q_function_inference(self, obs, pos_goal, neg_goal, z):
        z = z.repeat(obs.shape[0], 1)
        h = self.encoder(obs)
        F1, F2 = self.forward_net(h, z)
        Q1, Q2 = [torch.einsum('sda, sd -> sa', Fi, z) for Fi in [F1, F2]]
        Q = torch.min(Q1, Q2)
        return Q
    
    def plot_q_function_inference(self, work_dir, step, env, z, bf_action=None):
        state_list = env.get_state_list()
        print('in plot_q_function_inference')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        # goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function_inference(obs_list, z).detach()
        v_list = torch.max(q_list, dim=1)[0]
        # v_list = v_list
        a_list = torch.argmax(q_list, dim=1).cpu()
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"inference_step_{step}_v_function") # write this function

        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg






