import pdb
import copy
import math
import logging
import dataclasses
from collections import OrderedDict
import typing as tp

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from hydra.core.config_store import ConfigStore
import omegaconf

torch.set_num_threads(4)

from url_benchmark import utils
# from url_benchmark import replay_buffer as rb
from url_benchmark.in_memory_replay_buffer import ReplayBuffer
from url_benchmark.dmc import TimeStep
from url_benchmark import goals as _goals
from .ddpg import MetaDict
from .fb_modules import IdentityMap
from .ddpg import Encoder
from .fb_modules import mlp
from url_benchmark import utils
import time
import xxhash
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class DiscretePSMConfig:
    # @package agent
    _target_: str = "url_benchmark.agent.discrete_psm.DiscretePSMAgent"
    name: str = "discrete_psm"
    # reward_free: ${reward_free}
    obs_type: str = omegaconf.MISSING  # to be specified later
    obs_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    action_shape: tp.Tuple[int, ...] = omegaconf.MISSING  # to be specified later
    device: str = omegaconf.II("device")  # ${device}
    lr: float = 1e-4
    lr_w: float = 3e-4
    lr_coef: float = 1
    fb_target_tau: float = 0.01  # 0.001-0.01
    update_every_steps: int = 1
    use_tb: bool = omegaconf.II("use_tb")  # ${use_tb}
    use_wandb: bool = omegaconf.II("use_wandb")  # ${use_wandb}
    use_hiplog: bool = omegaconf.II("use_hiplog")  # ${use_wandb}
    num_expl_steps: int = omegaconf.MISSING  # ???  # to be specified later
    num_inference_steps: int = 20000
    hidden_dim: int = 1024   # 128, 2048
    backward_hidden_dim: int = 526   # 512
    feature_dim: int = 512   # 128, 1024
    z_dim: int = 16  # 100
    d_dim: int = 50  # 100
    stddev_schedule: str = "0.2"  # "linear(1,0.2,200000)" #
    stddev_clip: float = 0.3  # 1
    update_z_every_step: int = 300
    update_z_proba: float = 1.0
    nstep: int = 1
    batch_size: int = 16 # 512
    num_neg_samples = 512
    init_fb: bool = True
    update_encoder: bool = omegaconf.II("update_encoder")  # ${update_encoder}
    goal_space: tp.Optional[str] = omegaconf.II("goal_space")
    ortho_coef: float = 0.0  # 0.01-10
    cons_coef: float = 0.01  # 0.01-10
    log_std_bounds: tp.Tuple[float, float] = (-5, 2)  # param for DiagGaussianActor
    temp: float = 1  # temperature for DiagGaussianActor
    boltzmann: bool = False  # set to true for DiagGaussianActor
    debug: bool = False
    future_ratio: float = 0.0
    mix_ratio: float = 0.5  # 0-1
    rand_weight: bool = False  # True, False
    preprocess: bool = True
    norm_z: bool = True
    q_loss: bool = False
    q_loss_coef: float = 0.01
    additional_metric: bool = False
    add_trunk: bool = False
    use_dgd: bool = True
    softmax: bool = True
    div_eps: float = 1000.0
    div_coef: float = 0.01
    inf_coeff: float = 5.0



cs = ConfigStore.instance()
cs.store(group="agent", name="discrete_psm", node=DiscretePSMConfig)

class SamplingActor(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, feature_dim, hidden_dim) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.z_dim = z_dim
        self.action_dim = action_dim

        # if self.preprocess:
        #     self.obs_net = mlp(self.obs_dim, hidden_dim, "ntanh", feature_dim, "irelu")
        #     self.obs_z_net = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh", feature_dim, "irelu")
        #     if not add_trunk:
        #         self.trunk: nn.Module = nn.Identity()
        #         feature_dim = 2 * feature_dim
        #     else:
        #         self.trunk = mlp(2 * feature_dim, hidden_dim, "irelu")
        #         feature_dim = hidden_dim
        # else:
        self.mlp = mlp(self.obs_dim + self.z_dim, hidden_dim, "ntanh",
                             hidden_dim, "irelu")
        feature_dim = hidden_dim

        self.policy = mlp(feature_dim, hidden_dim, "irelu", self.action_dim)
        self.apply(utils.weight_init)
        # initialize the last layer by zero
        # self.policy[-1].weight.data.fill_(0.0)

    def forward(self, obs, z):
        assert z.shape[-1] == self.z_dim

        # if self.preprocess:
        #     obs_z = self.obs_z_net(torch.cat([obs, z], dim=-1))
        #     obs = self.obs_net(obs)
        #     h = torch.cat([obs, obs_z], dim=-1)
        # else:
        #     h = torch.cat([obs, z], dim=-1)
        # if hasattr(self, "trunk"):
        #     h = self.trunk(h)
        h = self.mlp(torch.cat([obs, z], dim=-1))
        logits = self.policy(h)

        dist = torch.distributions.Categorical(logits=logits) 
        return dist
    
class SamplingSeedActor(nn.Module):
    def __init__(self, action_dim, z_dim, batch_size):
        super().__init__()
        self.z_dim = z_dim
        self.action_dim = action_dim
        self.powers = torch.tensor([2**i for i in range(self.z_dim)][::-1]).to('cuda').repeat(batch_size,1)
        self.max_seed = 2**z_dim+20000
        self.seed_to_action = []
        
        for i in range(self.max_seed):
            torch.random.manual_seed(i)
            action = torch.randint(0, self.action_dim, (1,)).unsqueeze(0).numpy()
            self.seed_to_action.append(action)
        self.seed_to_action = np.array(self.seed_to_action)
        self.seed_to_action = torch.tensor(self.seed_to_action).to('cuda')
    
    def forward(self, obs_hash, z):
        # import ipdb;ipdb.set_trace()
        actions = []
        z_seed_time = time.time()
        seed_long = (z*self.powers).sum(1)
        # print("Time to compute z seed: ", time.time()-z_seed_time)
        final_seed_computation_time = time.time()
        final_seed = seed_long+obs_hash.reshape(-1)
        # print("Time to compute final seed: ", time.time()-final_seed_computation_time)
        # import ipdb;ipdb.set_trace()
        actions_computation_time = time.time()
        actions = self.seed_to_action[final_seed.long()]
        # print("Time to compute actions: ", time.time()-actions_computation_time)
        return torch.tensor(actions.reshape(-1,1)).to('cuda')

    # def _sample(self, inps):
    #     # print('Inps shape', inps.shape)
    #     d = inps.shape[0]
    #     obs_dim = d - self.z_dim
    #     obs = inps[:obs_dim]
    #     z = inps[obs_dim:]
    #     # print('Obs shape', obs.shape)
    #     # print('Z shape', z.shape)
    #     seed = utils.binary_to_long(z)
    #     seed = seed + hash(torch.tensor(obs))
    #     torch.random.manual_seed(seed)
    #     # print('Seed', seed)
    #     return torch.rand(size=(self.action_dim,)).unsqueeze(0)

    # def forward(self, obs, z):
    #     actions = []
    #     # for i in range(obs.shape[0]):
    #     #     seed = utils.binary_to_long(z[i].cpu().numpy())
    #     #     seed = seed + hash(obs[i].cpu())
    #     #     torch.random.manual_seed(seed)
    #     #     action = torch.rand(size=(self.action_dim,)).unsqueeze(0)
    #     #     actions.append(action)
    #     # print('Obs shape', obs.shape)
    #     # print('Z shape', z.shape)
    #     obs = obs.cpu().numpy()
    #     z = z.cpu().numpy()
    #     inps = np.concatenate((obs, z), axis=1)
    #     with ThreadPool() as p:
    #         # actions = [p.apply_async(self._sample, args=(obs, z))]
    #         actions = p.map(self._sample, inps)
    #     # print('Actions shape', actions.shape)
    #     return torch.cat(actions)

# class PSM(nn.Module):
#     def __init__(self, obs_dim, goal_dim, d_dim, action_dim, feature_dim, hidden_dim) -> None:
#         super().__init__()
#         self.obs_dim = obs_dim
#         self.goal_dim = goal_dim
#         self.d_dim = d_dim
#         self.action_dim = action_dim
#         self.mlp_phi = mlp(self.obs_dim + self.goal_dim, hidden_dim, "ntanh",
#                             hidden_dim, "irelu",
#                             hidden_dim, "irelu")
        
#         self.mlp_b = mlp(self.obs_dim + self.goal_dim, hidden_dim, "ntanh",
#                             hidden_dim, "irelu",
#                             hidden_dim, "irelu")
#         feature_dim = hidden_dim

#         seq_phi = [feature_dim, hidden_dim, "irelu", self.d_dim * self.action_dim]
#         seq_b = [feature_dim, hidden_dim, "irelu", self.action_dim]

#         self.phi_fc = mlp(*seq_phi)
#         self.b_fc = mlp(*seq_b)

#         self.apply(utils.weight_init)

#     def forward(self, obs, goal):
#         phi = self.phi_fc(self.mlp_phi(torch.cat([obs, goal], dim=-1)))
#         b = self.b_fc(self.mlp_b(torch.cat([obs, goal], dim=-1)))
#         phi = phi.reshape(-1, self.action_dim, self.d_dim)
#         b = b.reshape(-1, self.action_dim)
#         return phi, b
    
class PSM(nn.Module):
    def __init__(self, obs_dim, goal_dim, d_dim, action_dim, feature_dim, hidden_dim) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.goal_dim = goal_dim
        self.d_dim = d_dim
        self.action_dim = action_dim
        self.mlp_phi = mlp(self.obs_dim + self.goal_dim, hidden_dim, "ntanh",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu", self.d_dim * self.action_dim)
        
        self.mlp_b = mlp(self.obs_dim + self.goal_dim, hidden_dim, "ntanh",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu",
                            hidden_dim, "irelu", self.action_dim)
        feature_dim = hidden_dim

        # seq_phi = [feature_dim, hidden_dim, "irelu", self.d_dim * self.action_dim]
        # seq_b = [feature_dim, hidden_dim, "irelu", self.action_dim]

        # self.phi_fc = mlp(*seq_phi)
        # self.b_fc = mlp(*seq_b)

        self.apply(utils.weight_init)

    def forward(self, obs, goal):
        # phi = self.phi_fc(self.mlp_phi(torch.cat([obs, goal], dim=-1)))
        # b = self.b_fc(self.mlp_b(torch.cat([obs, goal], dim=-1)))
        phi = self.mlp_phi(torch.cat([obs, goal], dim=-1))
        b = self.mlp_b(torch.cat([obs, goal], dim=-1))
        phi = phi.reshape(-1, self.action_dim, self.d_dim)
        b = b.reshape(-1, self.action_dim)
        return phi, b

class DiscretePSMAgent:

    # pylint: disable=unused-argument
    def __init__(self,
                 **kwargs: tp.Any
                 ):
        cfg = DiscretePSMConfig(**kwargs)
        self.cfg = cfg
        assert len(cfg.action_shape) == 1
        self.action_dim = cfg.action_shape[0]
        # self.solved_meta: tp.Any = None
        self.obs_dim = cfg.obs_shape[0]
        if cfg.feature_dim < self.obs_dim:
            logger.warning(f"feature_dim {cfg.feature_dim} should not be smaller that obs_dim {self.obs_dim}")
        goal_dim = self.obs_dim
        if cfg.goal_space is not None:
            goal_dim = _goals.get_goal_space_dim(cfg.goal_space)
        if cfg.d_dim < goal_dim:
            logger.warning(f"d_dim {cfg.d_dim} should not be smaller that goal_dim {goal_dim}")

        self.goal_dim = goal_dim

        self.psm = PSM(self.obs_dim, goal_dim, cfg.d_dim, self.action_dim, cfg.feature_dim, cfg.hidden_dim).to(cfg.device)

        self.psm_target = PSM(self.obs_dim, goal_dim, cfg.d_dim, self.action_dim, cfg.feature_dim, cfg.hidden_dim).to(cfg.device)

        self.w = mlp(cfg.z_dim, cfg.hidden_dim, "irelu", 
                     cfg.hidden_dim, "irelu",
                     cfg.hidden_dim, "irelu", cfg.d_dim,"L2").to(cfg.device)
        
        self.w.apply(utils.weight_init)
        
        self.w_target = mlp(cfg.z_dim, cfg.hidden_dim, "irelu",
                            cfg.hidden_dim, "irelu",
                            cfg.hidden_dim, "irelu", cfg.d_dim,"L2").to(cfg.device)
        
        self.w_inf = torch.zeros((cfg.d_dim), requires_grad=True, device=cfg.device)

        # self.sampling_actor = SamplingActor(self.obs_dim, cfg.z_dim, self.action_dim, cfg.feature_dim, cfg.hidden_dim).to(cfg.device)
        self.sampling_actor = SamplingSeedActor(self.action_dim, cfg.z_dim, cfg.batch_size * cfg.batch_size).to(cfg.device)
        
        # load the weights into the target networks
        self.psm_target.load_state_dict(self.psm.state_dict())
        self.w_target.load_state_dict(self.w.state_dict())

        # optimizer
        self.opt = torch.optim.Adam([{'params': self.psm.parameters()},  # type: ignore
                                        {'params': self.w.parameters(), 'lr': cfg.lr_coef * cfg.lr}],
                                       lr=cfg.lr)
        
        self.inf_opt = torch.optim.Adam([self.w_inf], lr=cfg.lr_w)

        # self.actor_opt = torch.optim.Adam([{'params': self.sampling_actor.parameters()}], lr=cfg.lr)

        self.train()
        self.psm.train()
        self.w.train()

    def train(self, training: bool = True) -> None:
        self.training = training
        for net in [self.psm, self.w]:
            net.train(training)

    def init_from(self, other) -> None:
        # copy parameters over
        names = []
        if self.cfg.init_fb:
            names += ["psm", "w", "psm_target", "w_target"]
        for name in names:
            utils.hard_update_params(getattr(other, name), getattr(self, name))
        for key, val in self.__dict__.items():
            if isinstance(val, torch.optim.Optimizer):
                val.load_state_dict(copy.deepcopy(getattr(other, key).state_dict()))

    def int_to_binary_array(self, int_vector, num_bits=None):
        if num_bits is None:
            num_bits = int_vector.max().bit_length()
        
        binary_array = ((int_vector[:, None] & (1 << np.arange(num_bits))) > 0).astype(int)
        return binary_array
    
    def sample_z(self, size, device: str = "cpu"):
        z_np = np.random.randint(0, 2**self.cfg.z_dim, (size,))
        binary_array = self.int_to_binary_array(z_np, self.cfg.z_dim)
        return torch.FloatTensor(binary_array).to(device)
    
    def update_psm(
        self,
        obs: torch.Tensor,
        action: torch.Tensor,
        discount: torch.Tensor,
        next_obs: torch.Tensor,
        next_obs_hash: torch.Tensor,
        next_goal: torch.Tensor,
        z: torch.Tensor,
    ) -> tp.Dict[str, float]:
        metrics: tp.Dict[str, float] = {}

        #==========================================================
        
        mesh_making_time= time.time()

        idx = torch.arange(obs.shape[0]).to(obs.device)
        mesh = torch.stack(torch.meshgrid(idx, idx, indexing='xy')).T.reshape(-1, 2)
        
        
        # print(mesh[:, 0].shape)
        # print(obs[mesh[:, 0]].shape)
        m_obs = obs[mesh[:, 0]]
        m_next_obs = next_obs[mesh[:, 0]]
        m_next_obs_hash = next_obs_hash[mesh[:, 0]]
        m_action = action[mesh[:, 0]]
        m_next_goal = next_goal[mesh[:, 1]]
        # print("Time to make mesh: ", time.time()-mesh_making_time)
        # compute target successor measure
        # perm = torch.randperm(obs.shape[0])
        # perm_goals = next_goal[perm]
        with torch.no_grad():
            # compute greedy action
            target_phi, target_b = self.psm_target(m_next_obs, m_next_goal)
            # pi = self.sampling_actor(m_next_obs, z)
            target_w = self.w_target(z)
            
            # if self.cfg.softmax:
            #     pi_probs = F.softmax(pi.logits, dim=-1)
            #     target_phi = torch.einsum("sa, sad -> sd", pi_probs, target_phi)
            #     target_b = torch.einsum("sa, sa -> s", pi_probs, target_b)
            # else:
            #     pi_action = pi.sample()
            #     target_phi = target_phi[torch.arange(target_phi.shape[0]), pi_action]
            #     target_b = target_b[torch.arange(target_b.shape[0]), pi_action]
            # actions = self.sampling_actor(m_next_obs, z)
            # target_phi = target_phi[torch.arange(target_phi.shape[0]), ]
            # print('before sampling actions')
            sampling_time = time.time()
            next_actions = self.sampling_actor(m_next_obs_hash, z)
            # print("Time to sample actions: ", time.time()-sampling_time)
            # print('after sampling actions')
            # import ipdb; ipdb.set_trace()
            target_phi = target_phi[torch.arange(target_phi.shape[0]), next_actions.squeeze(1)]
            target_b = target_b[torch.arange(target_b.shape[0]), next_actions.squeeze(1)]

            
            # print(target_phi.shape, target_w.shape, target_b.shape)
            target_M = torch.einsum("sd, sd -> s", target_phi, target_w) + target_b

        start_t = time.time()
        # compute PSM loss
        phi, b = self.psm(m_obs, m_next_goal)
        # print(m_obs.device)
        # print(phi.shape, m_action.shape)
        phi = phi[torch.arange(phi.shape[0]), m_action.squeeze(1)]
        # print(phi.shape)
        b = b[torch.arange(b.shape[0]), m_action.squeeze(1)]
        # print(phi.shape, self.w(z).shape, b.shape, target_M.shape)
        M = torch.einsum("sd, sd -> s", phi, self.w(z)) + b
        M = M.reshape(obs.shape[0], obs.shape[0])
        target_M = target_M.reshape(obs.shape[0], obs.shape[0])
        I = torch.eye(*M.size(), device=M.device)
        off_diag = ~I.bool()
        psm_offdiag: tp.Any = 0.5 * (M - discount * target_M)[off_diag].pow(2).mean()
        psm_diag: tp.Any = -(1-discount).mean()*(M.diag()).mean()
        psm_loss = psm_offdiag + psm_diag
        
        #==========================================================
        
        # with torch.no_grad():
        #     # compute greedy action
        #     target_phi, target_b = self.psm_target(next_obs, next_goal)
        #     target_w = self.w_target(z)
        #     next_actions = self.sampling_actor(next_obs_hash, z)
        #     target_phi = target_phi[torch.arange(target_phi.shape[0]), next_actions.squeeze(1)]
        #     target_b = target_b[torch.arange(target_b.shape[0]), next_actions.squeeze(1)]
        #     target_M = torch.einsum("sd, sd -> s", target_phi, target_w) + target_b

        
        # start_t = time.time()
        # # compute PSM loss
        # phi, b = self.psm(obs, next_goal)
        # phi = phi[torch.arange(phi.shape[0]), action.squeeze(1)]
        # # print(phi.shape)
        # b = b[torch.arange(b.shape[0]), action.squeeze(1)]
        # # print(phi.shape, self.w(z).shape, b.shape, target_M.shape)
        # M = torch.einsum("sd, sd -> s", phi, self.w(z)) + b
        # # M = M.reshape(obs.shape[0], obs.shape[0])
        # # target_M = target_M.reshape(obs.shape[0], obs.shape[0])
        # # I = torch.eye(*M.size(), device=M.device)
        # indices = torch.arange(self.cfg.batch_size * self.cfg.num_neg_samples).to(M.device)
        # off_diag = torch.where(indices % self.cfg.num_neg_samples != 0)
        # diag = torch.where(indices % self.cfg.num_neg_samples == 0)
        # # diag_indices = np.arange(self.cfg.batch_size)*self.cfg.num_neg_samples
        # # off_diag = np.arange(self.cfg.batch_size*self.cfg.num_neg_samples) - diag_indices
        # # off_diag = ~I.bool()
        # psm_offdiag: tp.Any = 0.5 * (M - discount * target_M)[off_diag].pow(2).mean()
        # psm_diag: tp.Any = -(1-discount).mean()*(M[diag]).mean()
        # psm_loss = psm_offdiag + psm_diag

        
        #==========================================================

        # with torch.no_grad():
        #     # compute greedy action
        #     target_phi, target_b = self.psm_target(next_obs, next_goal)
        #     target_w = self.w_target(z)
        #     next_actions = self.sampling_actor(next_obs_hash, z)
        #     target_phi = target_phi[torch.arange(target_phi.shape[0]), next_actions.squeeze(1)]
        #     target_b = target_b[torch.arange(target_b.shape[0]), next_actions.squeeze(1)]
        #     target_M = torch.einsum("sd, sd -> s", target_phi, target_w) + target_b

        
        # start_t = time.time()
        # # compute PSM loss
        # phi, b = self.psm(obs, next_goal)
        # phi = phi[torch.arange(phi.shape[0]), action.squeeze(1)]
        # b = b[torch.arange(b.shape[0]), action.squeeze(1)]
        # M = torch.einsum("sd, sd -> s", phi, self.w(z)) + b
        # # target_M = target_M.reshape(obs.shape[0], obs.shape[0])
        # # print(rewards.shape)
        # # print(rewards.sum())
        # # ids = torch.where(rewards == 1)
        # # print(obs[ids[0]], next_obs[ids[0]], next_goal[ids[0]], action[ids[0]], phi[ids[0]], b[ids[0]], M[ids[0]], target_M[ids[0]])
        # discount = discount.squeeze(1)
        # # print('rewards', rewards.shape)
        # # print('discount', discount.shape)
        # # print('prod', ((1 - discount) * rewards).shape)
        # # print('target', target_M.shape)
        # # print('target_prod', (discount * target_M).shape)
        # psm_loss = (M - (1 - discount) * rewards -  discount * target_M).pow(2).mean()


        #==========================================================

        # mean_phi = torch.mean(phi, dim=0)
        # mean_b = torch.mean(b, dim=0)

        # phi_l2_norm = torch.linalg.norm(phi, dim=1).mean()


        # rphi, rb = self.psm(robs, rnext_obs)
        # rphi = rphi[torch.arange(rphi.shape[0]), raction.squeeze(1)]
        # rb = rb[torch.arange(rb.shape[0]), raction.squeeze(1)]

        # rM = torch.einsum("sd, sd -> s", rphi, self.w(rz)) + rb

        # regularization_loss = torch.square(rM).mean()

        # psm_loss += self.cfg.cons_coef * regularization_loss

        full_loss_computation_time = time.time()-start_t
        # print("Time to compute full loss: ", full_loss_computation_time)


        # Cov = torch.matmul(phi.T, phi)
        # I = torch.eye(*Cov.size(), device=Cov.device)
        # off_diag = ~I.bool()
        # orth_loss_diag = - 2 * Cov.diag().mean()
        # orth_loss_offdiag = Cov[off_diag].pow(2).mean()
        # orth_loss = orth_loss_offdiag + orth_loss_diag
        # psm_loss += self.cfg.ortho_coef * orth_loss


        if self.cfg.use_tb or self.cfg.use_wandb or self.cfg.use_hiplog:
            metrics['psm_loss'] = psm_loss.item()
            metrics['psm_diag'] = psm_diag.item()
            metrics['psm_offdiag'] = psm_offdiag.item()
            # metrics['mean_phi'] = mean_phi.mean().item()
            # metrics['mean_b'] = mean_b.mean().item()
            # metrics['phi_l2_norm'] = phi_l2_norm.item()
            # metrics['regularization_loss'] = regularization_loss.item()
            # metrics['div_loss'] = div_loss.item()
            # metrics['orth_loss'] = orth_loss.item()
            # metrics['orth_loss_diag'] = orth_loss_diag.item()
            # metrics['orth_loss_offdiag'] = orth_loss_offdiag.item()
 
            if isinstance(self.opt, torch.optim.Adam):
                metrics["opt_lr"] = self.opt.param_groups[0]["lr"]

        # optimize PSM
        start_t = time.time()
        self.opt.zero_grad(set_to_none=True)
        # self.actor_opt.zero_grad(set_to_none=True)
        psm_loss.backward()
        self.opt.step()
        # print("Time to optimize PSM: ", time.time()-start_t)
        # self.actor_opt.step()

        # print('Step: ', step, ' | PSM Loss: ', psm_loss.item(), ' | Div Loss: ', div_loss.item(), ' | Diversity: ', diversity.item(), ' | Orth Loss: ', orth_loss.item())
        return metrics
    
    def update(self, replay_loader: ReplayBuffer, step: int) -> tp.Dict[str, float]:
        # print("Update started")
        metrics: tp.Dict[str, float] = {}
        # print(step)
        if step % self.cfg.update_every_steps != 0:
            return metrics
        
        start_t = time.time()
        batch = replay_loader.sample(self.cfg.batch_size)
        batch = batch.to(self.cfg.device)
        # batch2 = replay_loader.sample(self.cfg.batch_size ** 2)
        # batch2 = batch2.to(self.cfg.device)
        replay_buffer_sampling_time = time.time() - start_t
        start_t = time.time()
        # print("Batch sampled")
        # pdb.set_trace()
        obs = batch.obs
        action = batch.action.type(torch.int64)
        discount = batch.discount
        next_obs = batch.next_obs
        next_goal = batch.next_goal
        next_obs_hash = batch.next_obs_hash
        rewards = batch.reward
        
        # robs = batch2.obs
        # raction = batch2.action.type(torch.int64)
        # rnext_obs = batch2.next_obs
        # perm = torch.randperm(robs.shape[0])
        # rnext_obs = rnext_obs[perm]
        batch_time = time.time() - start_t
        start_t = time.time()
        z = self.sample_z(self.cfg.batch_size, device=self.cfg.device)
        # print(z.shape)
        # print("Z sampling took: ",time.time()-start_sample)
        # print("Z sampled")
        z = torch.repeat_interleave(z, self.cfg.batch_size, 0)
        # z = z.repeat_interleave(self.cfg.num_neg_samples, 1)
        # print(z.shape)
        if not z.shape[-1] == self.cfg.z_dim:
            raise RuntimeError("There's something wrong with the logic here")

        # rz = self.sample_z(self.cfg.batch_size ** 2, device=self.cfg.device)

        z_sampling_time = time.time() - start_t
        start_t = time.time()
        robs = None
        raction = None
        rnext_obs = None
        rz = None
        metrics.update(self.update_psm(obs=obs, action=action, discount=discount,
                                      next_obs=next_obs,next_obs_hash=next_obs_hash, next_goal=next_goal, 
                                      z=z))
        # print("Time to update: {}".format(time.time()-update_start))
        # update critic target
        psm_update_time = time.time() - start_t
        start_t = time.time()

        # print(f"Replay buffer sampling time: {replay_buffer_sampling_time}, Batch time: {batch_time}, Z sampling time: {z_sampling_time}, PSM update time: {psm_update_time}")
        utils.soft_update_params(self.psm, self.psm_target,
                                 self.cfg.fb_target_tau)
        utils.soft_update_params(self.w, self.w_target,
                                 self.cfg.fb_target_tau)

        return metrics
    
    def init_inference(self):
        '''
        Initialize the w_inf parameter
        Initialize the lagrange variables, optimizers etc
        '''

        # initialize the w_inf parameter
        self.w_inf = torch.randn((self.cfg.d_dim), requires_grad=True, device=self.cfg.device)
        self.inf_opt = torch.optim.Adam([{'params': self.w_inf}], lr=self.cfg.lr_w)

        if self.cfg.use_dgd:
            # self.lmult = mlp(self.obs_dim + self.goal_dim, self.cfg.hidden_dim, "irelu", 
            #                 self.cfg.hidden_dim, "irelu", self.action_dim, "soft").to(self.cfg.device)
            self.lmult = mlp(self.obs_dim + self.goal_dim, 256, "irelu", 
                            256, "irelu", self.action_dim, "soft").to(self.cfg.device)
            
            self.lmult.apply(utils.weight_init)
            
            self.lmult_opt = torch.optim.Adam([{'params': self.lmult.parameters()}], lr=self.cfg.lr_w)

    def infer_w(self, replay_loader: ReplayBuffer, inf_logger, goal):
        metrics: tp.Dict[str, float] = {}
        self.init_inference()
        goal = torch.tensor(goal).unsqueeze(0).to(self.cfg.device)
        # print(goal.shape)
        for step in range(self.cfg.num_inference_steps):
            batch = replay_loader.sample(self.cfg.batch_size)
            obs = batch.obs
            # print(goal.shape)
            goal_rep = goal.repeat(obs.shape[0], 1)
            # print(goal.shape)
            obs = torch.tensor(obs).to(self.cfg.device)
            # print(obs.shape, goal.shape)
            metrics.update(self._infer_step(obs, goal_rep, step))
            inf_logger.log_metrics(metrics)
            inf_logger.dump()
        self.w_inf = (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)
        return metrics

    def infer_w_pos_neg(self, replay_loader: ReplayBuffer, inf_logger, goal, neg_goal):
        metrics: tp.Dict[str, float] = {}
        self.init_inference()
        goal = torch.tensor(goal).unsqueeze(0).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal).unsqueeze(0).to(self.cfg.device)
        # print(goal.shape)
        for step in range(self.cfg.num_inference_steps):
            batch = replay_loader.sample(self.cfg.batch_size)
            obs = batch.obs
            # print(goal.shape)
            goal_rep = goal.repeat(obs.shape[0], 1)
            neg_goal_rep = neg_goal.repeat(obs.shape[0], 1)
            # print(goal.shape)
            obs = torch.tensor(obs).to(self.cfg.device)
            # print(obs.shape, goal.shape)
            metrics.update(self._infer_step_pos_neg(obs, goal_rep, neg_goal_rep, step))
            inf_logger.log_metrics(metrics)
            inf_logger.dump()
        self.w_inf = (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)
        return metrics


    def _infer_step(self, obs, goal, step):
        perm = torch.randperm(obs.shape[0])
        perm_obs = obs[perm]

        metrics = {}

        with torch.no_grad():
            phi_g, b_g = self.psm(obs, goal)
            phi_perm, b_perm = self.psm(obs, perm_obs)
        # import ipdb;ipdb.set_trace()
        # obj = -(phi_g.T * self.w_inf).mean()
        obj = -torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)).mean()
        # obj = -torch.min(torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)), torch.tensor(0.0)).mean()
        # obj = -torch.min(torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_g, torch.tensor(-1.0)).mean()
        if self.cfg.use_dgd:
            with torch.no_grad():
                l_mult = self.lmult(torch.cat([obs, perm_obs], dim=-1))
            constraints = - ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * l_mult).mean()
        else:
            constraints = - (torch.min(torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm, torch.tensor(0.0)) * self.cfg.inf_coeff).mean()
        
        self.inf_opt.zero_grad(set_to_none=True)
        loss = obj + constraints
        loss.backward()
        self.inf_opt.step()

        metrics['obj'] = obj.item()
        metrics['constraints'] = constraints.item()
        metrics['lamb'] = l_mult.mean().item() if self.cfg.use_dgd else self.cfg.inf_coeff
        # metrics['step'] = step

        if self.cfg.use_dgd:        
            constraints = ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * self.lmult(torch.cat([obs, perm_obs], dim=-1))).mean()

            self.lmult_opt.zero_grad(set_to_none=True)
            loss = constraints
            loss.backward()
            self.lmult_opt.step()
        
        # print('Step: ', step, ' | Objective: ', metrics["obj"], ' | Constraints: ', metrics["constraints"])
        return metrics

    def _infer_step_pos_neg(self, obs, goal, neg_goal, step):
        perm = torch.randperm(obs.shape[0])
        perm_obs = obs[perm]

        metrics = {}

        # print(obs.shape, goal.shape, neg_goal.shape, perm_obs.shape)

        with torch.no_grad():
            phi_g, b_g = self.psm(obs, goal)
            phi_ng, b_ng = self.psm(obs, neg_goal)
            phi_perm, b_perm = self.psm(obs, perm_obs)
        # import ipdb;ipdb.set_trace()
        # obj = -(phi_g.T * self.w_inf).mean()
        obj = -(torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) - torch.einsum("sad, d -> sa", phi_ng, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1))).mean()
        
        if self.cfg.use_dgd:
            with torch.no_grad():
                l_mult = self.lmult(torch.cat([obs, perm_obs], dim=-1))
            constraints = - ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * l_mult).mean()
        else:
            constraints = - (torch.min(torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm, torch.tensor(0.0)) * self.cfg.inf_coeff).mean()
        
        self.inf_opt.zero_grad(set_to_none=True)
        loss = obj + constraints
        loss.backward()
        self.inf_opt.step()

        metrics['obj'] = obj.item()
        metrics['constraints'] = constraints.item()
        metrics['lamb'] = l_mult.mean().item() if self.cfg.use_dgd else self.cfg.inf_coeff
        # metrics['step'] = step

        if self.cfg.use_dgd:        
            constraints = ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * self.lmult(torch.cat([obs, perm_obs], dim=-1))).mean()

            self.lmult_opt.zero_grad(set_to_none=True)
            loss = constraints
            loss.backward()
            self.lmult_opt.step()
        
        # print('Step: ', step, ' | Objective: ', metrics["obj"], ' | Constraints: ', metrics["constraints"])
        return metrics
    
    def q_function(self, obs, goal):
        obs = torch.tensor(obs).to(self.cfg.device)
        goal = torch.tensor(goal).to(self.cfg.device)
        with torch.no_grad():
            phi, b = self.psm(obs, goal)
            q = torch.einsum("sad, d -> sa", phi, self.w_inf) + b

        return q

    def q_function_pos_neg(self, obs, goal, neg_goal):
        obs = torch.tensor(obs).to(self.cfg.device)
        goal = torch.tensor(goal).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal).to(self.cfg.device)
        with torch.no_grad():
            phi, b = self.psm(obs, goal)
            phi_ng, b_ng = self.psm(obs, neg_goal)
            q = torch.einsum("sad, d -> sa", phi, self.w_inf) + b - torch.einsum("sad, d -> sa", phi_ng, self.w_inf) - b_ng

        return q
    
    
    def act(self, obs, goal):
        q_function = self.q_function(obs, goal)

        # pi = torch.distributions.Categorical(logits=q_function)
        # action = pi.sample()
        action = torch.argmax(q_function, dim=1)

        return action

    def act_pos_neg(self, obs, goal, neg_goal):
        q_function = self.q_function_pos_neg(obs, goal, neg_goal)

        # pi = torch.distributions.Categorical(logits=q_function)
        # action = pi.sample()
        action = torch.argmax(q_function, dim=1)

        return action
    
    def inference(self, replay_loader: ReplayBuffer, inf_logger, goal_set, neg_goal_set, reward_fn):
        metrics: tp.Dict[str, float] = {}
        self.init_inference()
        goal_set = torch.tensor(goal_set).to(self.cfg.device)
        neg_goal_set = torch.tensor(neg_goal_set).to(self.cfg.device)

        for step in range(self.cfg.num_inference_steps):
            batch = replay_loader.sample(self.cfg.batch_size)
            obs = batch.obs
            batch_size = obs.shape[0]
            num_goals = goal_set.shape[0]
            goal_tensor = torch.tensor(goal_set).to(self.cfg.device)
            goal_tensor = goal_tensor.repeat_interleave(batch_size, dim=0)

            if len(neg_goal_set) != 0:
                neg_goal_tensor = torch.tensor(neg_goal_set).to(self.cfg.device)
                neg_goal_tensor = neg_goal_tensor.repeat_interleave(batch_size, dim=0)
            else:
                neg_goal_tensor = torch.tensor([]).to(self.cfg.device)

            obs = torch.tensor(obs).to(self.cfg.device)
            # Repeat obs for each goal in the set
            pos_obs_rep = obs.repeat(len(goal_set), 1)
            neg_obs_rep = obs.repeat(len(neg_goal_set), 1) if len(neg_goal_set) != 0 else torch.tensor([]).to(self.cfg.device)
            # goal_rep = goal_set.repeat_interleave(obs.shape[0], dim=0)
            # neg_goal_rep = neg_goal_set.repeat_interleave(obs.shape[0], dim=0)

            metrics.update(self._inference(pos_obs_rep, goal_tensor, len(goal_set), neg_obs_rep, neg_goal_tensor, len(neg_goal_set), step))
            inf_logger.log_metrics(metrics)
            inf_logger.dump()

        self.w_inf = (F.normalize(self.w_inf.reshape(1, -1)) * math.sqrt(self.w_inf.shape[0])).reshape(-1)
        return metrics

    def psm_set(self, obs, obs_prime, num_goals):
        '''
        Compute the successor measure for a set of observations
        '''
        # phi = 0.0
        # b = 0.0
        # batch_size = obs.shape[0]
        # obs = obs.repeat(len(obs_prime_set), 1)
        # obs_prime_tensor = torch.tensor(obs_prime_set).to(self.cfg.device)
        # obs_prime_tensor = obs_prime_tensor.repeat_interleave(batch_size, dim=0)

        # import pdb; pdb.set_trace()
        # for obs_prime in obs_prime_set:
        #     obs_prime = torch.tensor(obs_prime).unsqueeze(0).to(self.cfg.device)
        #     obs_prime = obs_prime.repeat(obs.shape[0], 1)
        #     with torch.no_grad():
        #         phi_g, b_g = self.psm(obs, obs_prime)
        #     phi += phi_g
        #     b += b_g
        #     # import pdb; pdb.set_trace()
        with torch.no_grad():
            phi_g, b_g = self.psm(obs, obs_prime)
        phi_g = phi_g.reshape(num_goals, obs.shape[0] // num_goals, 5, -1)
        b_g = b_g.reshape(num_goals, obs.shape[0] // num_goals, -1)
        phi = phi_g.mean(dim=0)
        b = b_g.mean(dim=0)
        return phi, b
    
    def _inference(self, obs, goal_tensor, num_pos_goals, nobs, neg_goal_tensor, num_neg_goals, step):
        # perm = torch.randperm(obs.shape[0])
        # perm_obs = obs[perm]

        metrics = {}

        with torch.no_grad():
            phi_g, b_g = self.psm_set(obs, goal_tensor, num_pos_goals)
            # import pdb; pdb.set_trace()
            if num_neg_goals != 0:
                phi_ng, b_ng = self.psm_set(nobs, neg_goal_tensor, num_neg_goals)
            # phi_perm, b_perm = self.psm(obs, perm_obs)
        # import ipdb;ipdb.set_trace()
        # obj = -(phi_g.T * self.w_inf).mean()
        # import pdb; pdb.set_trace()
        if num_neg_goals == 0:
            obj = -torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)).mean()
        else:
            # import pdb; pdb.set_trace()
            obj = -(torch.einsum("sad, d -> sa", phi_g, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) - torch.einsum("sad, d -> sa", phi_ng, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1))).mean()

        self.inf_opt.zero_grad(set_to_none=True)
        loss = obj
        loss.backward()
        self.inf_opt.step()

        metrics['obj'] = obj.item()

        return metrics
        
        if self.cfg.use_dgd:
            with torch.no_grad():
                l_mult = self.lmult(torch.cat([obs, perm_obs], dim=-1))
            constraints = - ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * l_mult).mean()
        else:
            constraints = - (torch.min(torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm, torch.tensor(0.0)) * self.cfg.inf_coeff).mean()
        
        self.inf_opt.zero_grad(set_to_none=True)
        loss = obj + constraints
        loss.backward()
        self.inf_opt.step()

        metrics['obj'] = obj.item()
        metrics['constraints'] = constraints.item()
        metrics['lamb'] = l_mult.mean().item() if self.cfg.use_dgd else self.cfg.inf_coeff
        # metrics['step'] = step

        if self.cfg.use_dgd:        
            constraints = ((torch.einsum("sad, d -> sa", phi_perm, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b_perm) * self.lmult(torch.cat([obs, perm_obs], dim=-1))).mean()

            self.lmult_opt.zero_grad(set_to_none=True)
            loss = constraints
            loss.backward()
            self.lmult_opt.step()
        
        # print('Step: ', step, ' | Objective: ', metrics["obj"], ' | Constraints: ', metrics["constraints"])
        return metrics
    
    def q_function_inference(self, obs, goal_set, neg_goal_set, z):
        obs = torch.tensor(obs).to(self.cfg.device)
        goal_set = torch.tensor(goal_set).to(self.cfg.device)
        num_pos_goals = goal_set.shape[0]
        goal_set = goal_set.repeat_interleave(obs.shape[0], dim=0)
        pos_obs = obs.repeat(num_pos_goals, 1)
        num_neg_goals = len(neg_goal_set)
        if num_neg_goals != 0:
            neg_goal_set = torch.tensor(neg_goal_set).to(self.cfg.device)
            neg_goal_set = neg_goal_set.repeat_interleave(obs.shape[0], dim=0)
            neg_obs = obs.repeat(num_neg_goals, 1)
            # obs = obs.repeat(num_neg_goals, 1)
        phi, b = self.psm_set(pos_obs, goal_set, num_pos_goals)
        if num_neg_goals != 0:
            phi_ng, b_ng = self.psm_set(neg_obs, neg_goal_set, num_neg_goals)
        # q = torch.einsum("sad, d -> sa", phi, self.w_inf) + b - torch.einsum("sad, d -> sa", phi_ng, self.w_inf) - b_ng
        # F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])
        if num_neg_goals == 0:
            # q = torch.einsum("sad, d -> sa", phi, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b
            q = torch.einsum("sad, d -> sa", phi, self.w_inf) + b
        else:
            q = torch.einsum("sad, d -> sa", phi, self.w_inf) + b - torch.einsum("sad, d -> sa", phi_ng, self.w_inf) - b_ng
            # q = torch.einsum("sad, d -> sa", phi, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) + b - torch.einsum("sad, d -> sa", phi_ng, (F.normalize(self.w_inf.reshape(1,-1))*math.sqrt(self.w_inf.shape[0])).reshape(-1)) - b_ng
        return q
    
    def act_set(self, obs, goal_set, neg_goal_set):
        q_function = self.q_function_set(obs, goal_set, neg_goal_set)

        # pi = torch.distributions.Categorical(logits=q_function)
        # action = pi.sample()
        action = torch.argmax(q_function, dim=1)

        return action

    def visualize_sampling_policy(self, work_dir, step, env, num_z, comp_log_prob=False):
        z = self.sample_z(num_z)
        state_list = env.get_state_list()
        obs_list = [env.get_obs_from_state(state) for state in state_list] # implement this function
        for i in range(num_z):
            obs = torch.cat(obs_list, dim=0).to(self.cfg.device)
            z = z[i].unsqueeze(0).repeat(obs.shape[0], 1)
            act_list = self.sampling_actor(obs, z)
            diversity = 0.0
            if comp_log_prob:
                if i == 0:
                    log_pi_first = self.sampling_actor.log_prob(act_list, obs_list, z)
                else:
                    log_pi = self.sampling_actor.log_prob(act_list, obs_list, z)
                    diversity = torch.abs(log_pi - log_pi_first).mean()
                
            env.plot_policy_from_list(work_dir, obs_list, act_list, diversity, f"training_step_{step}_sampling_policy_{i}") #implement this function

    def plot_q_function(self, work_dir, step, env, goal, bf_action = None):
        state_list = env.get_state_list()
        print('in plot_q_function')
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function(obs_list, goal)
        # print(q_list)
        v_list = torch.max(q_list, dim=1)[0]
        a_list = torch.argmax(q_list, dim=1)
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function
        num_pos = 0
        num_neg = 0
        if bf_action is not None:
            for i in range(len(state_list)):
                print('State: ', state_list[i], ' | Optimal Action: ', bf_action[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

                if a_list[i].item() in bf_action[(state_list[i][1], state_list[i][0])]:
                    num_pos += 1
                else:
                    num_neg += 1
            print('Positive: ', num_pos, ' | Negative: ', num_neg)
            return num_pos, num_neg
        
    def plot_q_function_pos_neg(self, work_dir, step, env, goal, neg_goal):
        state_list = env.get_state_list()
        print('in plot_q_function')
        actions = ['up', 'right', 'down', 'left', 'stay']
        # print(state_list)
        obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
        # print(obs_list)
        # print(len(state_list))
        obs_list = torch.cat(obs_list, dim=0).to(self.cfg.device)
        goal = torch.tensor(goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        neg_goal = torch.tensor(neg_goal).unsqueeze(0).repeat(obs_list.shape[0], 1).to(self.cfg.device)
        # print(obs_list.shape, goal.shape)
        q_list = self.q_function_pos_neg(obs_list, goal, neg_goal)
        v_list = torch.max(q_list, dim=1)[0]
        a_list = torch.argmax(q_list, dim=1)

        # Print state - action and q function
        # for i in range(len(state_list)):
        #     for j in range(self.action_dim):
        #         print('State: ', state_list[i], ' | Action: ', actions[j], ' | Q-Value: ', q_list[i*self.action_dim+j].item())
        # print(v_list, a_list)
        env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_v_function") # write this function





    