"""
highly based on
https://github.com/alimama-tech/AuctionNet/blob/main/strategy_train_env/bidding_train_env/baseline/dt/dt.py#L7
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math
import numpy as np
from bidding_train_env.baseline.dt_dist.resource_manager import ResourceManager
from bidding_train_env.baseline.dt_dist.dt_embedding import EmbeddingTransformer
def getScore_nips(reward, cpa, cpa_constraint):
    beta = 2
    penalty = 1
    if cpa > cpa_constraint:
        coef = cpa_constraint / (cpa + 1e-10)
        penalty = pow(coef, beta)
    return penalty * reward

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config['n_embd'] % config['n_head'] == 0
        self.key = nn.Linear(config['n_embd'], config['n_embd'])
        self.query = nn.Linear(config['n_embd'], config['n_embd'])
        self.value = nn.Linear(config['n_embd'], config['n_embd'])

        self.attn_drop = nn.Dropout(config['attn_pdrop'])
        self.resid_drop = nn.Dropout(config['resid_pdrop'])

        # 1*1*n_ctx*n_ctx
        self.register_buffer("bias",
                             torch.tril(torch.ones(config['n_ctx'], config['n_ctx'])).view(1, 1, config['n_ctx'],
                                                                                           config['n_ctx']))
        self.register_buffer("masked_bias", torch.tensor(-1e4))

        self.proj = nn.Linear(config['n_embd'], config['n_embd'])
        self.n_head = config['n_head']

    def forward(self, x, mask): 
        B, T, C = x.size() # T=seq*num_item, C=emb_dim

        # batch*n_head*T*C // self.n_head
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        mask = mask.view(B, -1)
        # batch*1*1*(seq*3)
        mask = mask[:, None, None, :]
        # 1->0, 0->-10000
        mask = (1.0 - mask) * -10000.0
        # batch*n_head*T*T
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = torch.where(self.bias[:, :, :T, :T].bool(), att, self.masked_bias.to(att.dtype))
        att = att + mask
        att = F.softmax(att, dim=-1)
        self._attn_map = att.clone()
        att = self.attn_drop(att)
        # batch*n_head*T*C // self.n_head
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class RegressionNet(nn.Module):
    def __init__(self, params=None):
        super(RegressionNet, self).__init__()
        if params is not None:
            self.params = torch.FloatTensor(params)

    def forward(self, period, time_ind, cate, cpa_ind, x_segment, x):
        val = self.params[period, time_ind, cate, cpa_ind, x_segment]
        return x * val[:, 0] + val[:, 1]

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config['n_embd'])
        self.ln2 = nn.LayerNorm(config['n_embd'])
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config['n_embd'], config['n_inner']),
            nn.GELU(),
            nn.Linear(config['n_inner'], config['n_embd']),
            nn.Dropout(config['resid_pdrop']),
        )

    def forward(self, inputs_embeds, attention_mask): # batch*(seq*3)*dim, batch*(seq*3)
        x = inputs_embeds + self.attn(self.ln1(inputs_embeds), attention_mask)
        x = x + self.mlp(self.ln2(x))
        return x

class GMMHead(nn.Module):
    def __init__(self, hidden_size, n_mix):
        super().__init__()
        self.n_mix = n_mix
        self.proj = nn.Linear(hidden_size, n_mix * 3)  # [pi_hat, mu, sigma_hat]

    def forward(self, h):                
        x = self.proj(h)
        if x.dim() == 2:              
            B,H = x.shape
            x = x.view(B, self.n_mix, 3)
        else:
            B,T,H = x.shape
            x = x.view(B, T, self.n_mix, 3)

        pi_hat, mu, sig_hat = x[...,0], x[...,1], x[...,2]
        pi = F.softmax(pi_hat, dim=-1)    
        sigma = F.softplus(sig_hat) + 1e-4
        return pi, mu, sigma             
    

def gmm_nll(pi, mu, sigma, a, mask=None):
    """
    pi, mu, sigma: [B,T,M]; a: [B,T] real-valued actions
    mask: [B,T] 0/1
    Returns scalar loss
    """  
    # a = a.unsqueeze(-1)                  # a[B,T,1]
    var = sigma * sigma
    log_comp = -0.5*((a - mu)**2 / (var + 1e-12)) - torch.log(sigma + 1e-12) - 0.5*math.log(2*math.pi)
    log_mix = torch.log(pi + 1e-12) + log_comp   # [B,T,M]
    log_prob = torch.logsumexp(log_mix, dim=-1)  # [B,T]
    nll = -log_prob
    if mask is not None:
        nll = (nll * mask).sum() / (mask.sum() + 1e-12)
    else:
        nll = nll.mean()
    return nll
    

class DecisionTransformer(nn.Module):

    def __init__(self, state_dim, act_dim, state_mean, state_std, hidden_size=512, action_tanh=False, M=5, K=10,
                 max_ep_len=48, scale=2000,
                 target_return=1, target_ctg = 1., device="cpu",
                 baseline_method = 'vanilla_dt',
                 reweight_w = 0.2,
                 critic_ensemble = None,
                 model_ref = None,
                 learning_rate=1e-5
                 ):
        super(DecisionTransformer, self).__init__()
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        self.length_times = 3
        self.reweight_w = reweight_w
        self.baseline_method = baseline_method
        self.hidden_size = 512
        self.state_mean = state_mean
        self.state_std = state_std
        self.max_length = K
        self.gmm_num = M
        self.max_ep_len = max_ep_len

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.scale = scale
        self.target_return = target_return
        self.target_ctg = target_ctg

        self.warmup_steps = 10000
        self.weight_decay = 0.0001
        self.learning_rate = learning_rate
        self.time_dim = 8

        self.block_config = {
            "n_ctx": 1024,
            "n_embd": self.hidden_size ,  # 512
            "n_layer": 6,
            "n_head": 8,
            "n_inner": 512,
            "activation_function": "relu",
            "n_position": 1024,
            "resid_pdrop": 0.1,
            "attn_pdrop": 0.1
        }
        block_config = self.block_config
        
        self.hyperparameters = {
            "n_ctx": self.block_config['n_ctx'],
            "n_embd": self.block_config['n_embd'],
            "n_layer": self.block_config['n_layer'],
            "n_head": self.block_config['n_head'],
            "n_inner": self.block_config['n_inner'],
            "activation_function": self.block_config['activation_function'],
            "n_position": self.block_config['n_position'],
            "resid_pdrop": self.block_config['resid_pdrop'],
            "attn_pdrop": self.block_config['attn_pdrop'],
            "length_times": self.length_times,
            "hidden_size": self.hidden_size,
            "state_mean": self.state_mean,
            "state_std": self.state_std,
            # "state_max": self.state_max,
            # "state_min": self.state_min,
            "max_length": self.max_length,
            "K": K,
            "M": M,
            "state_dim": state_dim,
            "act_dim": act_dim,
            "scale": scale,
            "target_return": target_return,
            "warmup_steps": self.warmup_steps,
            "weight_decay": self.weight_decay,
            "learning_rate": self.learning_rate,
            "time_dim":self.time_dim

        }

        # n_layer of Block
        self.transformer = nn.ModuleList([Block(block_config) for _ in range(block_config['n_layer'])])

        self.embed_timestep = nn.Embedding(self.max_ep_len, self.time_dim)
        self.embed_return = torch.nn.Linear(1, self.hidden_size)
        self.embed_reward = torch.nn.Linear(1, self.hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, self.hidden_size)
        self.embed_ctg = torch.nn.Linear(1, self.hidden_size)

        self.trans_return = torch.nn.Linear(self.time_dim+self.hidden_size, self.hidden_size)
        self.trans_reward = torch.nn.Linear(self.time_dim+self.hidden_size, self.hidden_size)
        self.trans_state = torch.nn.Linear(self.time_dim+self.hidden_size, self.hidden_size)
        self.trans_action = torch.nn.Linear(self.time_dim+self.hidden_size, self.hidden_size)
        self.trans_cost = torch.nn.Linear(self.time_dim+self.hidden_size, self.hidden_size)
        self.trans_ctg = torch.nn.Linear(self.time_dim+self.hidden_size, self.hidden_size)

        self.embed_ln = nn.LayerNorm(self.hidden_size)
        self.predict_state = torch.nn.Linear(self.hidden_size, self.state_dim)
        self.predict_action = GMMHead(hidden_size = self.hidden_size, n_mix = self.gmm_num)
        self.predict_return = torch.nn.Linear(self.hidden_size, 1)

        self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer,
                                                           lambda steps: min((steps + 1) / self.warmup_steps, 1))

        self.init_eval()

        self.embed_model=None
        self.faiss_index = None
        self.encodings = None
        self.next_actions = None
        self.retrieve_budget=None
        self.retrieve_cpa = None
        self.faiss_loaded = False

        #add resource manager
        self.resource_manager = ResourceManager()

    def forward(self, states, actions, rewards, returns_to_go, ctg, score_to_go, timesteps, attention_mask=None):
        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)

        rtg_embeddings = self.embed_return(returns_to_go)
        rewards_embeddings = self.embed_reward(rewards)
        time_embeddings = self.embed_timestep(timesteps)

        # To achieve a stable and good dt baseline, we use concat instead of common add, to let the model be aware of time
        state_embeddings = torch.cat((state_embeddings, time_embeddings), dim=-1)
        action_embeddings = torch.cat((action_embeddings, time_embeddings), dim=-1)
        rtg_embeddings = torch.cat((rtg_embeddings, time_embeddings), dim=-1)
        rewards_embeddings = torch.cat((rewards_embeddings, time_embeddings), dim=-1)

        state_embeddings = self.trans_state(state_embeddings)
        action_embeddings = self.trans_action(action_embeddings)
        rtg_embeddings = self.trans_return(rtg_embeddings)
        rewards_embeddings = self.trans_reward(rewards_embeddings)

        # batch*self.length_times*seq*dim->batch*(seq*self.length_times)*dim
        stacked_inputs = torch.stack(
            (rtg_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, self.length_times * seq_length, self.hidden_size)
        stacked_inputs = self.embed_ln(stacked_inputs)

        # batch*(seq_len * self.length_times)*embedd_size
        stacked_attention_mask = torch.stack(
            ([attention_mask for _ in range(self.length_times)]), dim=1
        ).permute(0, 2, 1).reshape(batch_size, self.length_times * seq_length).to(stacked_inputs.dtype)

        x = stacked_inputs
        for block in self.transformer:
            x = block(x, stacked_attention_mask)

        # batch*3*seq*dim
        x = x.reshape(batch_size, seq_length, self.length_times, self.hidden_size).permute(0, 2, 1, 3)

        # predict the action based on the state embedding part
        pi, mu, sigma = self.predict_action(x[:,-2])
        action_preds = (pi * mu).sum(dim=-1)        


        return x, action_preds,pi, mu, sigma, None, None

    def get_action(self, states, actions, rewards, returns_to_go, ctg, score_to_go, timesteps, **kwargs):
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        rewards = rewards.reshape(1, -1, 1)
        ctg = ctg.reshape(1, -1, 1)
        score_to_go = score_to_go.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.max_length is not None:
            states = states[:, -self.max_length:]
            actions = actions[:, -self.max_length:]
            returns_to_go = returns_to_go[:, -self.max_length:]
            rewards = rewards[:, -self.max_length:]
            timesteps = timesteps[:, -self.max_length:]
            ctg = ctg[:, -self.max_length:]
            score_to_go = score_to_go[:, -self.max_length:]

            attention_mask = torch.cat([torch.zeros(self.max_length - states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length - states.shape[1], self.state_dim),
                             device=states.device), states],
                dim=1).to(dtype=torch.float32)
            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
                             device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((returns_to_go.shape[0], self.max_length - returns_to_go.shape[1], 1),
                             device=returns_to_go.device), returns_to_go],
                dim=1).to(dtype=torch.float32)
            rewards = torch.cat(
                [torch.zeros((rewards.shape[0], self.max_length - rewards.shape[1], 1), device=rewards.device),
                 rewards],
                dim=1).to(dtype=torch.float32)
            ctg = torch.cat(
                [torch.zeros((ctg.shape[0], self.max_length - ctg.shape[1], 1),
                             device=ctg.device), ctg],
                dim=1).to(dtype=torch.float32)
            score_to_go = torch.cat(
                [torch.zeros((score_to_go.shape[0], self.max_length - score_to_go.shape[1], 1),
                             device=score_to_go.device), score_to_go],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length - timesteps.shape[1]), device=timesteps.device),
                 timesteps],
                dim=1).to(dtype=torch.long)
        else:
            attention_mask = None

        x, action_preds, pi, mu, sigma, _, _ = self.forward(
            states=states, actions=actions, rewards=rewards, returns_to_go=returns_to_go, ctg=ctg, score_to_go=score_to_go, timesteps=timesteps, attention_mask=attention_mask)
        pi_t  = pi[0, -1]            # [M]
        mu_t  = mu[0, -1]            # [M]
        k = torch.argmax(pi_t)       # choose the component with largest weight
        action_pred = mu_t[k]
        return x, action_preds[0,-1], pi[0, -1],mu[0, -1],sigma[0, -1]   # x is the embedding

    def step(self, states, actions, rewards, dones, rtg, timesteps, attention_mask, ctg, score_to_go, costs):
        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        costs = costs.to(self.device)   # cost is cost(s_t, a_t), which is every single step's true cost
        dones = dones.to(self.device)
        rtg = rtg.to(self.device)
        timesteps = timesteps.to(self.device)
        attention_mask = attention_mask.to(self.device)
        ctg = ctg.to(self.device)
        score_to_go = score_to_go.to(self.device)

        rewards_target, action_target, rtg_target, costs_target = torch.clone(rewards), torch.clone(actions), torch.clone(rtg), torch.clone(costs)

        _, action_preds, pi, mu, sigma, _, _ = self.forward(
            states=states, actions=actions, rewards=rewards, returns_to_go=rtg[:, :-1], ctg=ctg[:,:-1], score_to_go=score_to_go[:, :-1], timesteps=timesteps, attention_mask=attention_mask,
        )


        action_loss = gmm_nll(pi,mu,sigma,action_target,attention_mask)
        loss = action_loss

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), .25)
        self.optimizer.step()

        return action_loss.detach().cpu().item()

    def take_actions(self, state, actual_excuted_action, target_return=None, target_ctg=None, pre_reward=None, pre_cost=None, cpa_constrain=None,retrieved_K =5):
        self.load_embed_model()
        self.eval()
        if self.eval_states is None:
            self.eval_states = torch.from_numpy(state).reshape(1, self.state_dim).to(self.device)
            ep_return = target_return.to(self.device) if target_return is not None else self.target_return
            self.eval_target_return = torch.tensor(ep_return, dtype=torch.float32).reshape(1, 1).to(self.device)
            self.eval_target_score_to_go = torch.tensor(ep_return, dtype=torch.float32).reshape(1,1).to(self.device)

            ep_ctg = target_ctg.to(self.device) if target_ctg is not None else self.target_ctg
            self.eval_target_ctg = torch.tensor(ep_ctg, dtype=torch.float32).reshape(1, 1).to(self.device)
        else:
            assert pre_reward is not None
            assert pre_cost is not None
            cur_state = torch.from_numpy(state).reshape(1, self.state_dim).to(self.device)
            self.eval_states = torch.cat([self.eval_states, cur_state], dim=0).to(self.device)
            
            self.eval_rewards[-1] = pre_reward
            self.eval_costs[-1] = pre_cost

            # Implementing different methods' condition to go
            pred_return = self.eval_target_return[0, -1] - (pre_reward / self.scale)
            self.eval_target_return = torch.cat([self.eval_target_return, pred_return.reshape(1, 1)], dim=1)

            # pred_ctg = self.eval_target_ctg[0, -1] - (pre_cost/ self.scale)
            pred_ctg = torch.ones_like(self.eval_target_ctg[0, -1]) # ctg is always set as 1 in the inference stage
            self.eval_target_ctg = torch.cat([self.eval_target_ctg, pred_ctg.reshape(1, 1)], dim=1)

            self.eval_timesteps = torch.cat(
                [self.eval_timesteps, torch.ones((1, 1), dtype=torch.long).to(self.device) * self.eval_timesteps[:, -1] + 1], dim=1)

        # If actual_executed_action has a value, the action actually executed should replace the placeholder action from the previous moment.
        if actual_excuted_action is None:
            self.eval_actions = torch.cat([self.eval_actions, torch.zeros(1, self.act_dim).to(self.device)], dim=0)
        else:
            self.eval_actions[-1] = torch.from_numpy(actual_excuted_action).reshape(1, self.act_dim).to(self.device)
            self.eval_actions = torch.cat([self.eval_actions, torch.zeros(1, self.act_dim).to(self.device)], dim=0)
        
        self.eval_rewards = torch.cat([self.eval_rewards, torch.zeros(1).to(self.device)])
        self.eval_costs = torch.cat([self.eval_costs, torch.zeros(1).to(self.device)])
        state_mean_tensor = torch.tensor(self.state_mean, device=self.eval_states.device, dtype=torch.float32)
        state_std_tensor = torch.tensor(self.state_std, device=self.eval_states.device, dtype=torch.float32)

        # states, actions, rewards, returns_to_go, ctg, score_to_go, timesteps
        x, action, pi, mu, sigma = self.get_action(
            (self.eval_states.to(dtype=torch.float32) - torch.tensor(self.state_mean).to(self.device)) / torch.tensor(self.state_std).to(self.device),
            self.eval_actions.to(dtype=torch.float32),
            self.eval_rewards.to(dtype=torch.float32),
            self.eval_target_return.to(dtype=torch.float32),
            self.eval_target_ctg.to(dtype=torch.float32),
            self.eval_target_score_to_go.to(dtype=torch.float32),
            self.eval_timesteps.to(dtype=torch.long),
        )
        self.eval_actions[-1] = action
        action = action.detach().cpu().numpy()
        # state embedding encoding is used for retrieval
        eval_state_code=self.get_code((self.eval_states.to(dtype=torch.float32) - state_mean_tensor) /state_std_tensor,
                    self.eval_actions.to(dtype=torch.float32),
                    self.eval_rewards.to(dtype=torch.float32),
                    self.eval_target_return.to(dtype=torch.float32),
                    self.eval_timesteps.to(dtype=torch.long)
                    )
         # retrieve actions corresponding to similar states
        retrieved_actions, retrieve_rtg,similarities = self.retrieve_similar_actions(eval_state_code,k=retrieved_K)


        return action,retrieved_actions,pi, mu, sigma
    
    def load_embed_model(self):
        """Lazily loads the encoder model upon the first invocation."""
        self.embed_model = self.resource_manager.get_embed_model(EmbeddingTransformer)


    def get_code(self, states,actions,rewards,returns_to_go,timesteps,**kwargs):
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        rewards = rewards.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.embed_model.max_length is not None:
            states = states[:, -self.embed_model.max_length:]
            actions = actions[:, -self.embed_model.max_length:]
            returns_to_go = returns_to_go[:, -self.embed_model.max_length:]
            rewards = rewards[:, -self.embed_model.max_length:]
            timesteps = timesteps[:, -self.embed_model.max_length:]

            attention_mask = torch.cat([ torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
        else:
            attention_mask = None
        
        _, _, _, _,state_encodings = self.embed_model.forward(
            states, actions, rewards, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)

        return state_encodings

    def retrieve_similar_actions(self, eval_state_code, k=5):
        """Retrieves states similar to the current state embedding and returns their associated actions."""
    
        # ensure faiss index is loaded
        self.load_faiss_index()
    
        if isinstance(eval_state_code, torch.Tensor):
            eval_state_code = eval_state_code.detach().cpu().numpy()

        last_state_code = eval_state_code[0, -1]  #choose the latest state encoding
        # reshape for faiss
        query_vector = last_state_code.reshape(1, -1).astype('float32')
        #normalize the query vector
        norm = np.linalg.norm(query_vector, axis=1, keepdims=True)
        if norm[0, 0] > 1e-10:  # avoid division by zero
            query_vector = query_vector / norm
        # search in faiss index
        distances, indices = self.faiss_index.search(query_vector, k*3)
        #get the retrieved return-to-go
        retrieve_rtg = self.retrieve_rtg[indices[0]]
        # get the retrieved actions
        similar_actions = self.next_actions[indices[0]]
        retrieve_rtg = np.asarray(retrieve_rtg, dtype=np.float32).reshape(-1)

        similar_actions = np.asarray(similar_actions, dtype=np.float32)
        distances = np.asarray(distances, dtype=np.float32).reshape(-1)
        top_n = min(k, retrieve_rtg.shape[0])
        #choose top k based on retrieve_rtg
        idx_part = np.argpartition(retrieve_rtg, -top_n)[-top_n:]
        idx_sorted = idx_part[np.argsort(retrieve_rtg[idx_part])[::-1]]  
        # get top k actions, rtg, distances
        top_actions   = similar_actions[idx_sorted]
        top_rtg       = retrieve_rtg[idx_sorted]
        top_distances = distances[idx_sorted]

        return top_actions, top_rtg, top_distances
    
    def load_faiss_index(self):
        """loading faiss index and corresponding resources"""
        if not self.faiss_loaded:
            self.faiss_index, self.next_actions, self.retrieve_rtg ,self.sim_trajectories,self.trajectories_indices,self.trajectories_position = self.resource_manager.get_faiss_resources()
            self.faiss_loaded = True

    def init_eval(self):
        self.eval_states = None
        self.eval_actions = torch.zeros((0, self.act_dim), dtype=torch.float32).to(self.device)
        self.eval_rewards = torch.zeros(0, dtype=torch.float32).to(self.device)
        self.eval_costs = torch.zeros(0, dtype=torch.float32).to(self.device)

        self.eval_target_return = None
        self.eval_target_ctg = None

        self.eval_timesteps = torch.tensor(0, dtype=torch.long).reshape(1, 1).to(self.device)

        self.eval_episode_return, self.eval_episode_length = 0, 0

    def save_net(self, save_path,step=None):
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if step is not None:
            file_path = os.path.join(save_path, f'dt_{step}.pt')
        else:
            file_path = os.path.join(save_path, 'dt.pt')
        torch.save(self.state_dict(), file_path)

    def save_jit(self, save_path):
        if not os.path.isdir(save_path):
            os.makedirs(save_path)
        jit_model = torch.jit.script(self.cpu())
        torch.jit.save(jit_model, f'{save_path}/dt_model.pth')

    def load_net(self, load_path="saved_model/DT/dt.pt", device='cpu'):
        file_path = load_path
        self.load_state_dict(torch.load(file_path, map_location=device), strict=False)

