import pdb
import copy
import math
import logging
import dataclasses
from collections import OrderedDict
import typing as tp

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore
import omegaconf

from url_benchmark import utils
# from url_benchmark import replay_buffer as rb
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark.dmc import TimeStep
from url_benchmark import goals as _goals
from .ddpg import MetaDict
from .fb_modules import IdentityMap
from .ddpg import Encoder
from .fb_modules import mlp

logger = logging.getLogger(__name__)

@dataclasses.dataclass
class PSMConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.psm.PSMAgent"
    name: str = "psm"
    # reward_free: ${reward_free}
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    lr_coef: float = 1
    fb_target_tau: float = 0.01  # 0.001-0.01
    update_every_steps: int = 2
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING  # ???  # to be specified later
    num_inference_steps: int = 5120
    hidden_dim: int = 1024   # 128, 2048
    backward_hidden_dim: int = 526   # 512
    feature_dim: int = 512   # 128, 1024
    z_dim: int = 50  # 100
    d_dim: int = 50  # 100
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)" #
    stddev_clip: float = 0.3  # 1
    update_z_every_step: int = 300
    update_z_proba: float = 1.0
    nstep: int = 1
    batch_size: int = 32  # 512
    init_fb: bool = True
    update_encoder: bool = omegaconf.II("update_encoder")  # ${update_encoder}
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    ortho_coef: float = 1.0  # 0.01-10
    log_std_bounds: tp.Tuple[float, float] = (-5, 2)  # param for DiagGaussianActor
    temp: float = 1  # temperature for DiagGaussianActor
    boltzmann: bool = False  # set to true for DiagGaussianActor
    debug: bool = False
    future_ratio: float = 0.0
    mix_ratio: float = 0.5  # 0-1
    rand_weight: bool = False  # True, False
    preprocess: bool = True
    norm_z: bool = True
    q_loss: bool = False
    q_loss_coef: float = 0.01
    additional_metric: bool = False
    add_trunk: bool = False
    use_dgd: bool = True
    softmax: bool = True
    div_eps: float = 0.1
    div_coef: float = 1.0
    num_actor_inference_steps: int = 10000

cs = ConfigStore.instance()
cs.store(group="agent", name="psm", node=PSMConfig)

class SamplingActor(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, feature_dim, hidden_dim) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.z_dim = z_dim
        self.action_dim = action_dim

        # if self.preprocess:
        #     self.obs_net = mlp(self.obs_dim, hidden_dim, "ntanh", feature_dim, "irelu")
        #     self.obs_z_net = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh", feature_dim, "irelu")
        #     if not add_trunk:
        #         self.trunk: nn.Module = nn.Identity()
        #         feature_dim = 2 * feature_dim
        #     else:
        #         self.trunk = mlp(2 * feature_dim, hidden_dim, "irelu")
        #         feature_dim = hidden_dim
        # else:
        self.mlp = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh",
                             hidden_dim, "irelu")
        feature_dim = hidden_dim

        self.policy = mlp(feature_dim, hidden_dim, "irelu", self.action_dim)
        self.apply(utils.weight_init)
        # initialize the last layer by zero
        # self.policy[-1].weight.data.fill_(0.0)

    def forward(self, obs, z, std):
        assert z.shape[-1] == self.z_dim

        # if self.preprocess:
        #     obs_z = self.obs_z_net(torch.cat([obs, z], dim=-1))
        #     obs = self.obs_net(obs)
        #     h = torch.cat([obs, obs_z], dim=-1)
        # else:
        #     h = torch.cat([obs, z], dim=-1)
        # if hasattr(self, "trunk"):
        #     h = self.trunk(h)
        h = self.mlp(torch.cat([obs, z], dim=-1))
        mu = self.policy(h)
        std = torch.ones_like(mu) * std

        dist = utils.TruncatedNormal(mu, std) 
        return dist

    
    
class DDPGActor(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim, log_std_bounds,
                 preprocess=False) -> None:
        super().__init__()
        self.log_std_bounds = log_std_bounds
        feature_dim = obs_dim

        self.policy = mlp(feature_dim, hidden_dim, "irelu", hidden_dim, "irelu", 2 * action_dim)
        self.apply(utils.weight_init)

    def forward(self, obs):
        h = obs
        mu, log_std = self.policy(h).chunk(2, dim=-1)
        # constrain log_std inside [log_std_min, log_std_max]
        log_std = torch.tanh(log_std)
        log_std_min, log_std_max = self.log_std_bounds
        log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)
        std = log_std.exp()
        dist = utils.SquashedNormal(mu, std)
        return dist
    
class SamplingSeedActor(nn.Module):
    def __init__(self, action_dim, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.action_dim = action_dim

    def forward(self, obs, z):
        actions = []
        for i in range(obs.shape[0]):
            seed = utils.binary_to_long(z[i].cpu().numpy())
            seed = seed + hash(obs[i].cpu())
            torch.random.manual_seed(seed)
            action = torch.rand(size=(self.action_dim,)).unsqueeze(0)
            actions.append(action)
        return torch.cat(actions)
    
class PSM(nn.Module):
    def __init__(self, obs_dim, goal_dim, d_dim, action_dim, feature_dim, hidden_dim) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.goal_dim = goal_dim
        self.d_dim = d_dim
        self.action_dim = action_dim
        self.mlp_phi = mlp(self.obs_dim + self.action_dim + self.goal_dim, hidden_dim, "irelu",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu")
        
        self.mlp_b = mlp(self.obs_dim + self.action_dim + self.goal_dim, hidden_dim, "irelu",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu")
        feature_dim = hidden_dim

        seq_phi = [feature_dim, hidden_dim, "irelu", self.d_dim] 
        seq_b = [feature_dim, hidden_dim, "irelu", 1]

        self.phi_fc = mlp(*seq_phi)
        self.b_fc = mlp(*seq_b)

        self.apply(utils.weight_init)

    def forward(self, obs, action, goal):
        phi = self.phi_fc(self.mlp_phi(torch.cat([obs, action, goal], dim=-1)))
        b = self.b_fc(self.mlp_b(torch.cat([obs, action, goal], dim=-1)))
        # phi = phi.reshape(-1, self.action_dim, self.d_dim)
        # b = b.reshape(-1, self.action_dim)
        return phi, b

class PSMAgent:

    # pylint: disable=unused-argument
    def __init__(self,
                 **kwargs: tp.Any
                 ):
        cfg = PSMConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        # self.solved_meta: tp.Any = None

        # models
        # if cfg.obs_type == 'pixels':
        #     self.aug: nn.Module = utils.RandomShiftsAug(pad=4)
        #     self.encoder: nn.Module = Encoder(cfg.obs_shape).to(cfg.device)
        #     self.obs_dim = self.encoder.repr_dim
        # else:
            # self.aug = nn.Identity()
            # self.encoder = nn.Identity()
        self.obs_dim = cfg.obs_shape[0]
        if cfg.feature_dim < self.obs_dim:
            logger.warning(f"feature_dim {cfg.feature_dim} should not be smaller that obs_dim {self.obs_dim}")
        goal_dim = self.obs_dim
        if cfg.goal_space is not None:
            goal_dim = _goals.get_goal_space_dim(cfg.goal_space)
        if cfg.d_dim < goal_dim:
            logger.warning(f"d_dim {cfg.d_dim} should not be smaller that goal_dim {goal_dim}")

        self.goal_dim = goal_dim

        self.psm = PSM(self.obs_dim, goal_dim, cfg.d_dim, self.action_dim, cfg.feature_dim, cfg.hidden_dim).to(cfg.device)

        self.psm_target = PSM(self.obs_dim, goal_dim, cfg.d_dim, self.action_dim, cfg.feature_dim, cfg.hidden_dim).to(cfg.device)

        self.w = mlp(cfg.z_dim, cfg.hidden_dim, "irelu", 
                     cfg.hidden_dim, "irelu",
                     cfg.hidden_dim, "irelu", cfg.d_dim).to(cfg.device)
        
        self.w.apply(utils.weight_init)
        
        self.w_target = mlp(cfg.z_dim, cfg.hidden_dim, "irelu",
                            cfg.hidden_dim, "irelu",
                            cfg.hidden_dim, "irelu", cfg.d_dim).to(cfg.device)
        
        self.w_inf = torch.zeros((cfg.d_dim), requires_grad=True, device=cfg.device)

        # self.sampling_actor = SamplingActor(self.obs_dim, cfg.z_dim, self.action_dim, cfg.feature_dim, cfg.hidden_dim).to(cfg.device)
        self.sampling_actor = SamplingSeedActor(self.action_dim, cfg.z_dim).to(cfg.device)
        self.ddpg_actor = DDPGActor(self.obs_dim, self.action_dim, cfg.hidden_dim, cfg.log_std_bounds).to(cfg.device)
        
        # load the weights into the target networks
        self.psm_target.load_state_dict(self.psm.state_dict())
        self.w_target.load_state_dict(self.w_target.state_dict())

        # optimizer
        self.opt = torch.optim.Adam([{'params': self.psm.parameters()},  # type: ignore
                                        {'params': self.w.parameters(), 'lr': cfg.lr_coef * cfg.lr}],
                                       lr=cfg.lr)
        
        self.inf_opt = torch.optim.Adam([{'params': self.w_inf}], lr=cfg.lr)

        self.actor_opt = torch.optim.Adam([{'params': self.sampling_actor.parameters()}], lr=cfg.lr)
        self.ddpg_actor_opt = torch.optim.Adam([{'params': self.ddpg_actor.parameters()}], lr=cfg.lr)

        self.train()
        self.psm.train()
        self.w.train()

    def train(self, training: bool = True) -> None:
        self.training = training
        for net in [self.psm, self.w]:
            net.train(training)

    def init_from(self, other) -> None:
        # copy parameters over
        names = []
        if self.cfg.init_fb:
            names += ["psm", "w", "psm_target", "w_target"]
        for name in names:
            utils.hard_update_params(getattr(other, name), getattr(self, name))
        for key, val in self.__dict__.items():
            if isinstance(val, torch.optim.Optimizer):
                val.load_state_dict(copy.deepcopy(getattr(other, key).state_dict()))

    # def sample_z(self, size, device: str = "cpu"):
    #     gaussian_rdv = torch.randn((size, self.cfg.z_dim), dtype=torch.float32, device=device)
    #     gaussian_rdv = F.normalize(gaussian_rdv, dim=1)
    #     if self.cfg.norm_z:
    #         z = math.sqrt(self.cfg.z_dim) * gaussian_rdv
    #     else:
    #         uniform_rdv = torch.rand((size, self.cfg.z_dim), dtype=torch.float32, device=device)
    #         z = np.sqrt(self.cfg.z_dim) * uniform_rdv * gaussian_rdv
    #     return z

    def sample_z(self, size, device: str = "cpu"):
        z = torch.randint(0, 4096, (size,))
        bin_z = []
        for i in range(size):
            bin_z.append(utils.long_to_binary(z[i].item(), self.cfg.z_dim))
        z = torch.cat(bin_z).reshape(size, self.cfg.z_dim).float().to(device)
        return z
    
    def update_psm(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        discount: torch.Tensor,
        next_obs: torch.Tensor,
        next_goal: torch.Tensor,
        z: torch.Tensor,
        step: int
    ) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        idx = torch.arange(obs.shape[0]).to(obs.device)
        mesh = torch.stack(torch.meshgrid(idx, idx, indexing='xy')).T.reshape(-1, 2)
        m_obs = obs[mesh[:, 0]]
        m_next_obs = next_obs[mesh[:, 0]]
        m_action = action[mesh[:, 0]]
        m_next_goal = next_goal[mesh[:, 1]]
        print(obs.shape, m_obs.shape)
        # compute target successor measure
        with torch.no_grad():
            # pi = self.sampling_actor(m_next_obs, z)
            # act = pi.sample()
            next_actions = self.sampling_actor(m_next_obs, z).to(obs.device)
            target_phi, target_b = self.psm_target(m_next_obs, next_actions, m_next_goal)
            
            target_w = self.w_target(z)
            print(target_w.shape)
            # if self.cfg.softmax:
            #     pi_probs = F.softmax(pi.logits, dim=-1)
            #     target_phi = torch.einsum("sa, sad -> sd", pi_probs, target_phi)
            #     target_b = torch.einsum("sa, sa -> s", pi_probs, target_b)
            # else:
            # pi_action = pi.sample()
            # target_phi = target_phi[torch.arange(target_phi.shape[0]), pi_action]
            # target_b = target_b[torch.arange(target_b.shape[0]), pi_action]
            
            target_M = torch.einsum("sd, sd -> s", target_phi, target_w).unsqueeze(1) + target_b

        # compute PSM loss
        phi, b = self.psm(m_obs, m_action, m_next_goal)
        # phi = phi[torch.arange(phi.shape[0]), m_action]
        # b = b[torch.arange(b.shape[0]), m_action]
        print(phi.shape, b.shape)
        M = torch.einsum("sd, sd -> s", phi, self.w(z)).unsqueeze(1) + b
        print(M.shape, target_M.shape)
        M = M.reshape(obs.shape[0], obs.shape[0])
        target_M = target_M.reshape(obs.shape[0], obs.shape[0])
        I = torch.eye(*M.size(), device=M.device)
        off_diag = ~I.bool()
        psm_offdiag: tp.Any = 0.5 * (M - discount * target_M)[off_diag].pow(2).mean()
        psm_diag: tp.Any = -((1 - discount) * M.diag()).mean()
        psm_loss = psm_offdiag + psm_diag

        # Diversity LOSS

        # z1, z2 = self._sample_z(2, device=obs.device)
        # z1 = z1.unsqueeze(0).repeat(obs.shape[0], 1)
        # z2 = z2.unsqueeze(0).repeat(obs.shape[0], 1)
        # pi1 = self.sampling_actor(obs, z1)
        # pi2 = self.sampling_actor(obs, z2)

        # pi_logprobs = pi1.log_prob(action)
        # p2_logprobs = pi2.log_prob(action)

        # div_loss = torch.max(-(torch.abs(pi_logprobs - p2_logprobs)).mean(), -self.cfg.div_eps)
        # psm_loss += self.cfg.div_coef * div_loss
        

        # ORTHONORMALITY LOSS FOR Phi
        # Cov = torch.matmul(phi.T, phi)
        # I = torch.eye(*Cov.size(), device=Cov.device)
        # off_diag = ~I.bool()
        # orth_loss_diag = - 2 * Cov.diag().mean()
        # orth_loss_offdiag = Cov[off_diag].pow(2).mean()
        # orth_loss = orth_loss_offdiag + orth_loss_diag
        # psm_loss += self.cfg.ortho_coef * orth_loss

        # Cov = torch.cov(B.T)  # Vicreg loss
        # var_loss = F.relu(1 - Cov.diag().clamp(1e-4, 1).sqrt()).mean()  # eps avoids inf. sqrt gradient at 0
        # cov_loss = 2 * torch.triu(Cov, diagonal=1).pow(2).mean() # 2x upper triangular part
        # orth_loss =  var_loss + cov_loss
        # fb_loss += self.cfg.ortho_coef * orth_loss

        if self.cfg.use_tb or self.cfg.use_wandb or self.cfg.use_hiplog:
            metrics['psm_loss'] = psm_loss.item()
            metrics['psm_diag'] = psm_diag.item()
            metrics['psm_offdiag'] = psm_offdiag.item()
            # metrics['div_loss'] = div_loss.item()
            # metrics['orth_loss'] = orth_loss.item()
            # metrics['orth_loss_diag'] = orth_loss_diag.item()
            # metrics['orth_loss_offdiag'] = orth_loss_offdiag.item()
            # eye_diff = torch.matmul(B.T, B) / B.shape[0] - torch.eye(B.shape[1], device=B.device)
            # metrics['orth_linf'] = torch.max(torch.abs(eye_diff)).item()
            # metrics['orth_l2'] = eye_diff.norm().item() / math.sqrt(B.shape[1])
            if isinstance(self.opt, torch.optim.Adam):
                metrics["opt_lr"] = self.opt.param_groups[0]["lr"]

        # optimize PSM
        self.opt.zero_grad(set_to_none=True)
        self.actor_opt.zero_grad(set_to_none=True)
        psm_loss.backward()
        self.opt.step()
        self.actor_opt.step()
        return metrics
    
    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        if step % self.cfg.update_every_steps != 0:
            return metrics

        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)

        # pdb.set_trace()
        obs = batch.obs
        action = batch.action.type(torch.int64)
        discount = batch.discount
        next_obs = next_goal = batch.next_obs
        if self.cfg.goal_space is not None:
            assert batch.next_goal is not None
            next_goal = batch.next_goal

        # if len(batch.meta) == 1 and batch.meta[0].shape[-1] == self.cfg.z_dim:
        #     z = batch.meta[0]
        #     invalid = torch.linalg.norm(z, dim=1) < 1e-15
        #     if sum(invalid):
        #         z[invalid, :] = self.sample_z(sum(invalid)).to(self.cfg.device)
        # else:
        z = self.sample_z(self.cfg.batch_size, device=self.cfg.device)
        z = z.repeat(self.cfg.batch_size, 1)
        if not z.shape[-1] == self.cfg.z_dim:
            raise RuntimeError("There's something wrong with the logic here")
        # obs = self.aug_and_encode(batch.obs)
        # next_obs = self.aug_and_encode(batch.next_obs)
        # if not self.cfg.update_encoder:
        #     obs = obs.detach()
        #     next_obs = next_obs.detach()

        # backward_input = batch.obs
        # future_goal = batch.future_obs
        # if self.cfg.goal_space is not None:
        #     assert batch.goal is not None
        #     backward_input = batch.goal
        #     future_goal = batch.future_goal

        # perm = torch.randperm(self.cfg.batch_size)
        # backward_input = backward_input[perm]

        # if self.cfg.mix_ratio > 0:
        #     mix_idxs: tp.Any = np.where(np.random.uniform(size=self.cfg.batch_size) < self.cfg.mix_ratio)[0]
        #     if not self.cfg.rand_weight:
        #         with torch.no_grad():
        #             mix_z = self.backward_net(backward_input[mix_idxs]).detach()
        #     else:
        #         # generate random weight
        #         weight = torch.rand(size=(mix_idxs.shape[0], self.cfg.batch_size)).to(self.cfg.device)
        #         weight = F.normalize(weight, dim=1)
        #         uniform_rdv = torch.rand(mix_idxs.shape[0], 1).to(self.cfg.device)
        #         weight = uniform_rdv * weight
        #         with torch.no_grad():
        #             mix_z = torch.matmul(weight, self.backward_net(backward_input).detach())
        #     if self.cfg.norm_z:
        #         mix_z = math.sqrt(self.cfg.z_dim) * F.normalize(mix_z, dim=1)
        #     z[mix_idxs] = mix_z

        # # hindsight replay
        # if self.cfg.future_ratio > 0:
        #     assert future_goal is not None
        #     future_idxs = np.where(np.random.uniform(size=self.cfg.batch_size) < self.cfg.future_ratio)
        #     z[future_idxs] = self.backward_net(future_goal[future_idxs]).detach()

        metrics.update(self.update_psm(obs=obs, action=action, discount=discount,
                                      next_obs=next_obs, next_goal=next_goal, z=z, step=step))

        # update critic target
        utils.soft_update_params(self.psm, self.psm_target,
                                 self.cfg.fb_target_tau)
        utils.soft_update_params(self.w, self.w_target,
                                 self.cfg.fb_target_tau)

        return metrics
    
    def init_inference(self):
        '''
        Initialize the w_inf parameter
        Initialize the lagrange variables, optimizers etc
        '''

        # initialize the w_inf parameter
        self.w_inf = torch.ones((self.cfg.d_dim), requires_grad=True, device=self.cfg.device)
        self.inf_opt = torch.optim.Adam([{'params': self.w_inf}], lr=self.cfg.lr)

        self.ddpg_actor.apply(utils.weight_init)

        self.ddpg_actor_opt = torch.optim.Adam([{'params': self.ddpg_actor.parameters()}], lr=self.cfg.lr)

        if self.cfg.use_dgd:
            self.lmult = mlp(self.obs_dim + self.action_dim + self.goal_dim, self.cfg.hidden_dim, "irelu", 
                            self.cfg.hidden_dim, "irelu", 1, "soft").to(self.cfg.device)
            
            self.lmult.apply(utils.weight_init)
            
            self.lmult_opt = torch.optim.Adam([{'params': self.lmult.parameters()}], lr=self.cfg.lr)

    def infer_w_goal(self, replay_loader: ReplayBuffer, inf_logger, goal):
        metrics: tp.Dict[str, float] = {}
        self.init_inference()
        goal = torch.tensor(goal).unsqueeze(0).to(self.cfg.device)
        for step in range(self.cfg.num_inference_steps):
            batch = replay_loader.sample(self.cfg.batch_size)
            obs = batch.obs
            actions = batch.action
            next_goals = batch.next_goal
            obs = torch.tensor(obs).to(self.cfg.device)
            next_goals = torch.tensor(next_goals).to(self.cfg.device)
            actions = torch.tensor(actions).to(self.cfg.device)
            # goal = torch.repeat(goal, 0, obs.shape[0])
            goal_rep = goal.repeat(obs.shape[0], 1)
            metrics.update(self._infer_step_gc(obs, actions, next_goals, goal_rep, step))
            inf_logger.log_metrics(metrics)
            inf_logger.dump()

        return metrics
    
    def infer_w_from_obs_and_rewards(self, replay_loader: ReplayBuffer, robs, rews):
        NotImplementedError


    def _infer_step_gc(self, obs, actions, next_goals, goal, step):
        perm = torch.randperm(next_goals.shape[0])
        perm_next_goals = next_goals[perm]

        metrics = {}

        with torch.no_grad():
            phi_g, b_g = self.psm(obs, actions, goal)
            phi_perm, b_perm = self.psm(obs, actions, perm_next_goals)
        
        obj = -torch.einsum("sd, d -> s", phi_g, self.w_inf).mean()
        
        if self.cfg.use_dgd:
            with torch.no_grad():
                l_mult = self.lmult(torch.cat([obs, actions, perm_next_goals], dim=-1))
            constraints = - ((torch.einsum("sd, d -> s", phi_perm, self.w_inf) + b_perm) * l_mult).mean()
        else:
            constraints = - (torch.min(torch.einsum("sd, d -> s", phi_perm, self.w_inf) + b_perm, torch.tensor(0.0)) * self.cfg.inf_coeff).mean()
        
        self.inf_opt.zero_grad(set_to_none=True)
        loss = obj + constraints
        loss.backward()
        self.inf_opt.step()

        metrics['obj'] = obj.item()
        metrics['constraints'] = constraints.item()
        metrics['lamb'] = l_mult.mean().item() if self.cfg.use_dgd else self.cfg.inf_coeff

        if self.cfg.use_dgd:
            constraints = ((torch.einsum("sd, d -> s", phi_perm, self.w_inf) + b_perm) * self.lmult(torch.cat([obs, actions, perm_next_goals], dim=-1))).mean()

            self.lmult_opt.zero_grad(set_to_none=True)
            loss = constraints
            loss.backward()
            self.lmult_opt.step()
        
        return metrics
    
    def update_actor(self, obs, goal):
        metrics: tp.Dict[str, float] = {}
        obs = torch.tensor(obs).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).to(self.cfg.device)
        goal = goal.repeat(obs.shape[0], 1)
        dist = self.ddpg_actor(obs)
        action = dist.rsample()

        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        Q = self.q_function(obs, action, goal)
        actor_loss = (self.cfg.temp * log_prob - Q).mean()

        # optimize actor
        self.ddpg_actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.ddpg_actor_opt.step()

        if self.cfg.use_tb or self.cfg.use_wandb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['q'] = Q.mean().item()
            metrics['actor_logprob'] = log_prob.mean().item()
            # metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()

        return metrics

    def distill_actor_ddpg(self, replay_loader: ReplayBuffer, actor_logger, goal):
        metrics: tp.Dict[str, float] = {}
        for step in range(self.cfg.num_actor_inference_steps):
            batch = replay_loader.sample(self.cfg.batch_size)
            obs = batch.obs
            metrics.update(self.update_actor(obs, goal))
            actor_logger.log_metrics(metrics)
            actor_logger.dump()

        return metrics
        
    
    def q_function(self, obs, actions, goal):
        # print(obs.shape, actions.shape, goal.shape)
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs).to(self.cfg.device)
            goal = torch.tensor(goal).to(self.cfg.device)
            actions = torch.tensor(actions).to(self.cfg.device)
        with torch.no_grad():
            phi, b = self.psm(obs, actions, goal)
            q = torch.einsum("sd, d -> s", phi, self.w_inf) + b

        return q
    
    def act(self, obs):
        obs = torch.tensor(obs).to(self.cfg.device)
        dist = self.ddpg_actor(obs)
        action = dist.sample()

        return action.cpu().numpy()
    


    